In [1]:
from IPython.core.interactiveshell import InteractiveShell 
InteractiveShell.ast_node_interactivity = "all"

In [4]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter

### dataset 

In [5]:
import torchvision

train_data = torchvision.datasets.CIFAR10(root='../data',train=True,transform=torchvision.transforms.ToTensor(),
                                            download=True)
test_data = torchvision.datasets.CIFAR10(root='../data',train=False,transform=torchvision.transforms.ToTensor(),
                                            download=True)

# ----> len dataset
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练集的长度：{}".format(train_data_size))
print("测试集的长度：{}".format(test_data_size))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data\cifar-10-python.tar.gz


170499072it [00:56, 3043713.77it/s]                               


Extracting ../data\cifar-10-python.tar.gz to ../data
Files already downloaded and verified
训练集的长度：50000
测试集的长度：10000


### dataloader

In [7]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

### 搭建神经网络

In [9]:
import torch.nn as nn

class FanNet(nn.Module):
    
    def __init__(self):
        super(FanNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5, padding=2)
        self.maxpool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 32, 5, padding=2)
        self.maxpool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)
        self.maxpool3 = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(1024, 64)
        self.linear2 = nn.Linear(64, 10)

    def forward(self,x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

if __name__ == '__main__':
    fanfan = FanNet()
    input = torch.ones((64,3,32,32))
    output = fanfan(input)
    print(output.shape)

torch.Size([64, 10])


### Train and Test

In [14]:
# Loss
loss1 = nn.CrossEntropyLoss()

# optimer
learning_rate = 0.01
optimer = torch.optim.Adam(fanfan.parameters(),lr=learning_rate)

# train_model parameters setting
epoch = 10
total_train_step = 0
total_test_step = 0

# tensorboard start
writer = SummaryWriter('../loss_train')

for i in range(epoch):
    print("----------第{}轮训练开始----------".format(i))
    
    ## train start
    for img, label in train_dataloader:
        output = fanfan(img)
        loss = loss1(output,label)

        ### 开始优化
        optimer.zero_grad()
        loss.backward()
        optimer.step()

        total_train_step += 1
        if total_train_step % 50 == 0:
            print("训练次数：{}, loss:{}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)




----------第0轮训练开始----------
训练次数：1, loss:4.677282810211182
训练次数：2, loss:102.27971649169922
训练次数：3, loss:32.91666793823242
训练次数：4, loss:101.23944854736328
训练次数：5, loss:85.64004516601562
训练次数：6, loss:17.29160499572754
训练次数：7, loss:34.875160217285156
训练次数：8, loss:31.705135345458984
训练次数：9, loss:48.421817779541016
训练次数：10, loss:30.37312889099121
训练次数：11, loss:52.78559494018555
训练次数：12, loss:35.105491638183594
训练次数：13, loss:25.28563117980957
训练次数：14, loss:32.901119232177734
训练次数：15, loss:193.61170959472656
训练次数：16, loss:252.36770629882812
训练次数：17, loss:104.58283233642578
训练次数：18, loss:38.1594123840332
训练次数：19, loss:20.039886474609375
训练次数：20, loss:17.333147048950195
训练次数：21, loss:32.87016677856445
训练次数：22, loss:89.71215057373047
训练次数：23, loss:32.59584426879883
训练次数：24, loss:52.150474548339844
训练次数：25, loss:14.243913650512695
训练次数：26, loss:63.50657653808594
训练次数：27, loss:18.99108123779297
训练次数：28, loss:14.867518424987793
训练次数：29, loss:7.814864158630371
训练次数：30, loss:10.393378257751465
训练次数：3

KeyboardInterrupt: 

### test


In [None]:
# test 更准确的说是验证，计算的是整个数据集的loss
total_test_loss = 0
with torch.no_grad():
    for img, label in test_dataloader:
        output = fanfan(img)
        loss = loss1(output, label)
        optimer.zero_grad()
        loss.backward()
        optimer.step()
        total_test_step += 1
        total_test_loss += loss

print("整体测试集上的Loss:{}".format(total_test_loss))
writer.add_scalar("test_loss", total_test_loss, total_test_step)
total_test_step += 1

torch.save(fanfan, "fanfan_{}.pt".format(i))
print("model save already")


### Metric

In [None]:
# Accuracy, AUC, Recall, F1, Pricision






