In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

In [2]:
from tqdm import tqdm
def train(model, train_loader, test_loader, epochs, lr, writer):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    global_step = 0
    test_step = 0
    for epoch in range(epochs):
        print(f"epoch: {epoch+1}/{epochs}")
        model.train()
        train_num = train_right = test_num = test_right = 0
        for X, y in tqdm(train_loader):
            X, y = X.to(device), y.to(device)
            y_hat = model(X)
            l = criterion(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()

            writer.add_scalar("loss/train", l.item(), global_step)
            global_step += 1
            train_num += len(y)
            train_right += (y==y_hat.argmax(dim=1)).sum().item()
        
        with torch.no_grad():
            model.eval()
            for X, y in test_loader:
                X, y = X.to(device), y.to(device)
                y_hat = model(X)
                l = criterion(y_hat, y)
                test_right += (y==y_hat.argmax(dim=1)).sum().item()
                test_num += len(y)

                writer.add_scalar("loss/test", l.item(), test_step)
                test_step += 1
                
        print(f"accuracy/train: {train_right/train_num}")
        print(f"aaccuracy/test: {test_right/test_num}")
        writer.add_scalars("acc", {'train': train_right/train_num, 'test': test_right/test_num}, epoch)

In [3]:
class Residual(nn.Module):
    def __init__(self, in_dim, out_dim, use_1d=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
        if use_1d:
            self.conv3 = nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        
        self.bn1 = nn.BatchNorm2d(out_dim)
        self.bn2 = nn.BatchNorm2d(out_dim)
        
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

def resnet_block(in_dim, out_dim, blocks, first_block=False):
    blk = []
    for i in range(blocks):
        if i == 0 and not first_block:
            blk.append(Residual(in_dim, out_dim,
                                use_1d=True, strides=2))
        else:
            blk.append(Residual(out_dim, out_dim))
    return blk

class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),# (1, 224, 224) -> (64, 112, 112)
                                nn.BatchNorm2d(64), nn.ReLU(),
                                nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# (64, 56, 56)
                                )
        self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True)) #(64, 56, 56)
        self.b3 = nn.Sequential(*resnet_block(64, 128, 2)) #(128, 28, 28)
        self.b4 = nn.Sequential(*resnet_block(128, 256, 2)) #(256, 14, 14)
        self.b5 = nn.Sequential(*resnet_block(256, 512, 2)) #(512, 7, 7)

        self.net = nn.Sequential(self.b1, self.b2, self.b3,
                                 self.b4, self.b5,
                                 nn.AdaptiveAvgPool2d((1, 1)),
                                 nn.Flatten(),
                                 nn.Linear(512, 10))

    def forward(self, X):
        return self.net(X)

In [4]:
# trans = transforms.Compose([transforms.Resize(224), transforms.ToTensor()])
trans = transforms.Compose([transforms.Resize(96), transforms.ToTensor()])
train_data = torchvision.datasets.FashionMNIST("./dataset", train=True, transform=trans, download=False)
test_data = torchvision.datasets.FashionMNIST("./dataset", train=False, transform=trans, download=False)

batch_size = 512
train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [5]:
model = ResNet()
writer = SummaryWriter(log_dir="./resnet_log")
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


train(model, train_loader, test_loader, epochs=10, lr=0.001, writer=writer)
writer.close()

epoch: 1/10


100%|██████████| 118/118 [00:15<00:00,  7.69it/s]


accuracy/train: 0.8337166666666667
aaccuracy/test: 0.8773
epoch: 2/10


100%|██████████| 118/118 [00:15<00:00,  7.86it/s]


accuracy/train: 0.9075166666666666
aaccuracy/test: 0.881
epoch: 3/10


100%|██████████| 118/118 [00:15<00:00,  7.77it/s]


accuracy/train: 0.9195666666666666
aaccuracy/test: 0.9033
epoch: 4/10


100%|██████████| 118/118 [00:15<00:00,  7.66it/s]


accuracy/train: 0.93355
aaccuracy/test: 0.917
epoch: 5/10


100%|██████████| 118/118 [00:15<00:00,  7.82it/s]


accuracy/train: 0.9404333333333333
aaccuracy/test: 0.916
epoch: 6/10


100%|██████████| 118/118 [00:15<00:00,  7.81it/s]


accuracy/train: 0.9473166666666667
aaccuracy/test: 0.9222
epoch: 7/10


100%|██████████| 118/118 [00:14<00:00,  8.09it/s]


accuracy/train: 0.9537333333333333
aaccuracy/test: 0.927
epoch: 8/10


100%|██████████| 118/118 [00:15<00:00,  7.74it/s]


accuracy/train: 0.9584666666666667
aaccuracy/test: 0.9248
epoch: 9/10


100%|██████████| 118/118 [00:14<00:00,  7.87it/s]


accuracy/train: 0.9687333333333333
aaccuracy/test: 0.9241
epoch: 10/10


100%|██████████| 118/118 [00:15<00:00,  7.78it/s]


accuracy/train: 0.97285
aaccuracy/test: 0.919


In [6]:
torch.save(model.state_dict(), "./resnet.ckpt")