In [1]:
import torch
import numpy as np
import pandas as pd
import os
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as data_utils
import torchvision
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import gc

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# strcmp strcpy strstr strchr strtok strcat memcpy memmove restrict

In [3]:
# 加载fashion-mnist数据集
train_transform = transforms.Compose([
#     transforms.RandomResizedCrop((28, 28)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.RandomRotation(90),
#     transforms.RandomGrayscale(0.1), # 0.1的概率转换为灰度图
#     transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])

train_data = datasets.FashionMNIST(root='fashion-mnist', train=True, 
                                   transform=train_transform,
                                   download=False)

test_transform = transforms.Compose([
#     transforms.Resize((28, 28)),
#     transforms.Normalize((0.49, 0.48, 0.44), (0.2, 0.22, 0.21)),
    transforms.ToTensor()
])

test_data = datasets.FashionMNIST(root='fashion-mnist', train=False, 
                                  transform=test_transform, download=False)

batch_size = 64
train_loader = data_utils.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

test_loader = data_utils.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

In [4]:
class CNN(torch.nn.Module):
    
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=5, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
        )
        self.fc = torch.nn.Linear(14 * 14 * 32, 10)
        
    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.shape[0], -1)
        out = self.fc(out)
        out = F.log_softmax(out, dim=1)
        return out

In [5]:
class ResBlock(torch.nn.Module):
    
    def __init__(self, in_channel, out_channel, stride=1):
        super(ResBlock, self).__init__()
        self.layer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1),
            torch.nn.BatchNorm2d(out_channel),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(out_channel)
        )
        
        self.shortcut = torch.nn.Sequential()
        if in_channel != out_channel or stride > 1:
            self.shortcut = torch.nn.Sequential(
                torch.nn.Conv2d(in_channel, out_channel, 
                                kernel_size=3, stride=stride, padding=1),
                torch.nn.BatchNorm2d(out_channel)
            )
            
    def forward(self, x):
        out1 = self.layer(x)
        out2 = self.shortcut(x)
        out = out1 + out2
        out = F.relu(out)
        return out

class ResNet(torch.nn.Module):
    
    def __init__(self):
        super(ResNet, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU()
        )
        self.layer1 = self.make_layer(ResBlock, 32, 64, 2, 2)
        self.layer2 = self.make_layer(ResBlock, 64, 128, 2, 2)
        self.layer3 = self.make_layer(ResBlock, 128, 256, 2, 2)
        self.layer4 = self.make_layer(ResBlock, 256, 512, 2, 2)
        self.mp = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(512, 10)
    
    @staticmethod
    def make_layer(block, in_channel, out_channel, stride, num_block):
        layers_list = []
        temp_channel = in_channel
        for i in range(num_block):
            in_stride = stride if i == 0 else 1
            layers_list.append(block(temp_channel, out_channel, in_stride))
            temp_channel = out_channel
        return torch.nn.Sequential(*layers_list)
    
    def forward(self, x):
        out = self.conv(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.mp(out)
        out = out.view(out.shape[0], -1)
        out = self.fc(out)
        out = F.log_softmax(out, dim=1)
        return out

In [6]:
class MobileNet(torch.nn.Module):
    
    def __init__(self):
        super(MobileNet, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU()
        )
        self.conv_dpw1 = self.conv_dw(32, 32, 1)
        self.conv_dpw2 = self.conv_dw(32, 64, 2)
        
        self.conv_dpw3 = self.conv_dw(64, 64, 1)
        self.conv_dpw4 = self.conv_dw(64, 128, 2)
        
        self.conv_dpw5 = self.conv_dw(128, 128, 1)
        self.conv_dpw6 = self.conv_dw(128, 256, 2)
        
        self.conv_dpw7 = self.conv_dw(256, 256, 1)
        self.conv_dpw8 = self.conv_dw(256, 512, 2)
        
        self.fc = torch.nn.Linear(512, 10)
    
    @staticmethod
    def conv_dw(in_channel, out_channel, stride):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channel, in_channel, kernel_size=3, 
                            stride=stride, padding=1, groups=in_channel),
            torch.nn.BatchNorm2d(in_channel),
            torch.nn.ReLU(),
            
            torch.nn.Conv2d(in_channel, out_channel, kernel_size=1),
            torch.nn.BatchNorm2d(out_channel),
            torch.nn.ReLU()
        )
    
    def forward(self, x):
        out = self.conv(x)
        out = self.conv_dpw1(out)
        out = self.conv_dpw2(out)
        out = self.conv_dpw3(out)
        out = self.conv_dpw4(out)
        out = self.conv_dpw5(out)
        out = self.conv_dpw6(out)
        out = self.conv_dpw7(out)
        out = self.conv_dpw8(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.shape[0], -1)
        out = self.fc(out)
        return out
        

In [7]:
def CBR(in_channel, out_channel, kernel_size, stride=1):
    return torch.nn.Sequential(
        torch.nn.Conv2d(in_channel, out_channel, 
                        kernel_size=kernel_size, stride=stride, 
                        padding=kernel_size // 2),
        torch.nn.BatchNorm2d(out_channel),
        torch.nn.ReLU()
    )

class BaseInception(torch.nn.Module):
    def __init__(self, in_channel, out_channel_list):
        super(BaseInception, self).__init__()
        self.branch1 = CBR(in_channel, out_channel_list[0], 1)
        self.branch2 = torch.nn.Sequential(
            CBR(in_channel, out_channel_list[1] // 2, 1),
            CBR(out_channel_list[1] // 2, out_channel_list[1], 3)
        )
        self.branch3 = torch.nn.Sequential(
            CBR(in_channel, out_channel_list[2] // 2, 1),
            CBR(out_channel_list[2] // 2, out_channel_list[2], 3),
            CBR(out_channel_list[2], out_channel_list[2], 3)
        )
        self.branch4 = torch.nn.Sequential(
            torch.nn.MaxPool2d(3, stride=1, padding=1),
            CBR(in_channel, out_channel_list[3], 1),
        )
        
    def forward(self, x):
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        out4 = self.branch4(x)
        out = torch.cat([out1, out2, out3, out4], dim=1)
        return out

class InceptionNet(torch.nn.Module):
    
    def __init__(self):
        super(InceptionNet, self).__init__()
        self.block1 = torch.nn.Sequential(
            CBR(1, 32, 3, 1), CBR(32, 64, 3, 2)
        )
        self.block2 = torch.nn.Sequential(
            BaseInception(64, [32, 64, 16, 16]),
            CBR(128, 256, 3, 2)
        )
        self.block3 = torch.nn.Sequential(
            BaseInception(256, [64, 128, 32, 32]),
            CBR(256, 512, 3, 2)
        )
        self.avg_pool = torch.nn.AvgPool2d(2)
        self.fc = torch.nn.Linear(2*2*512, 10)
    
    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        out = self.avg_pool(out)
        out = out.view(out.shape[0], -1)
        out = self.fc(out)
        return out
        

In [8]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# cnn = InceptionNet()
# cnn = cnn.to(device)
# # loss
# loss_func = torch.nn.CrossEntropyLoss()

# optimizer = torch.optim.Adam(cnn.parameters(), lr=0.01)

# # 每经过step_size个epoch，学习率会衰减为之前的gamma倍
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

# if not os.path.exists('log'):
#     os.mkdir('log')
# writer = SummaryWriter('log')

# step_train = 0
# step_test = 0
# for epoch in range(10):
#     for i, im_data in enumerate(train_loader):
#         images, labels = im_data
#         images = images.to(device)
#         labels = labels.to(device)
#         outputs = cnn(images)
#         loss = loss_func(outputs, labels)
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         print(f'epoch is {epoch+1}, ite is {i}/{len(train_data) // batch_size}, loss is {loss.item()}')
#         _, pred = outputs.max(1)
#         correct = torch.sum(pred == labels).item()
#         print('train correct', 100 * correct / batch_size)
#         writer.add_scalar('train loss', loss.item(), global_step=step_train)
#         writer.add_scalar('train correct', 100 * correct / batch_size, global_step=step_train)
# #         im = torchvision.utils.make_grid(images)
# #         writer.add_image('train im', im, global_step=step_train)
#         step_train += 1
# #         del images, labels
# #         gc.collect()
# #         torch.cuda.empty_cache()
        
#     # save models
#     if not os.path.exists('models'):
#         os.makedirs('models')
#     torch.save(cnn.state_dict(), fr'models\test_{epoch}.pth')
#     scheduler.step()
# #     print('lr is: ', optimizer.state_dict()['param_groups'][0]['lr'])
    
#     loss, correct = 0, 0
#     for i, im_data in enumerate(test_loader):
#         images, labels = im_data
#         images = images.to(device)
#         labels = labels.to(device)
#         outputs = cnn(images)
#         loss += loss_func(outputs, labels).item()
#         _, pred = outputs.max(1)
#         correct += torch.sum(pred == labels).item()
# #         print('train correct', 100 * correct / batch_size)
#         im = torchvision.utils.make_grid(images)
#         writer.add_image('test im', im, global_step=step_test)
#         step_test += 1
# #         del images, labels
# #         gc.collect()
# #         torch.cuda.empty_cache()
        
#     correct /= len(test_data)
#     loss /= len(test_data) // batch_size
#     writer.add_scalar('test loss', loss, global_step=epoch + 1)
#     writer.add_scalar('test correct', 100 * correct, global_step=epoch + 1)
    
# writer.close()