In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import collections 

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

aug_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=4,
                                          shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


## Линейный классификатор

In [3]:
class LinearNet(torch.nn.Module):
    def __init__(self):
        super(LinearNet, self).__init__()
        self.inp = torch.nn.Linear(32 * 32 * 3, 128)
        self.out = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = self.inp(x)
        return self.out(x)

### Линейный классификатор выдает не очень хорошее качество, поэтому я решил использовать CNN

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv2_bn = nn.BatchNorm2d(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc1_bn = nn.BatchNorm1d(120)
        self.out = nn.Linear(120, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2_bn(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1_bn(self.fc1(x)))
        return self.out(x)

In [5]:
def train(model, train_loader, optimizer, criterion, loss_vector, accuracy_vector):
    model.train()
    
    train_loss, train_correct = 0.0, 0.0
    running_loss, correct = 0.0, 0.0

    for i, (inputs, labels) in enumerate(train_loader, 0):
        optimizer.zero_grad()
        outputs = model(inputs)

        pred = outputs.data.max(1)[1] # get the index of the max log-probability
        correct += pred.eq(labels.data).sum().item()

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2500 == 2499:
            print('[%d, %d] (%d%%)\t loss: %.3f,\t accuracy: %.3f' %
                  (i + 1, len(train_loader), 100. * (i+1) / len(train_loader), 
                   running_loss / 2500, 100. * correct / (2500*labels.size(0)) ))
            train_correct += correct
            train_loss += running_loss
            correct = 0.0
            running_loss = 0.0

    train_loss /= len(train_loader)
    loss_vector.append(train_loss)

    accuracy = 100. * train_correct / len(train_loader.dataset)
    accuracy_vector.append(accuracy)

In [6]:
def validate(model, validation_loader, criterion, loss_vector, accuracy_vector):
    model.eval()
    
    val_loss, val_correct = 0.0, 0.0

    for i, (inputs, labels) in enumerate(validation_loader, 0):
        outputs = model(inputs)
        val_loss += criterion(outputs, labels).item()
        pred = outputs.data.max(1)[1] # get the index of the max log-probability
        val_correct += pred.eq(labels.data).sum().item()

    val_loss /= len(validation_loader)
    loss_vector.append(val_loss)

    accuracy = 100. * val_correct / len(validation_loader.dataset)
    accuracy_vector.append(accuracy)
    print('Test average loss: %.3f\t accuracy: (%.3f%%)\n' % (val_loss, accuracy))

# Pre-training

In [44]:
momentum_update = lambda m,fk,fq: { key: m*fq[key] + (1-m)*fk[key] for key in fq.keys()}

def pretrain(epoches, f_q, train_loader, optimizer, criterion, t, m, K, model):
    f_k = model()
    f_q.train()
    f_k.train()
    
    # initializing queue 
    queue = []
    for i, (x, _) in enumerate(train_loader, 0):
        if (i*x.size(0) >= K):
            break
        queue.append(f_k(x).detach().numpy())
    queue = torch.Tensor(queue)
    
    C = queue.size(2) # num of classes 
  
    for epoch in range(epoches):
        f_q.train()
        print("Pre-train Epoch %d" % (epoch + 1))
        running_loss, correct = 0.0, 0.0
        
        f_k.load_state_dict(f_q.state_dict()) # initialize
        for i, (x, _) in enumerate(train_loader, 0):
            N = x.size(0)
            # augmentation
            # TODO: лучше запилить аугментацию в DataLoader
            x_q, x_k = torch.Tensor(x.size()),torch.Tensor(x.size())
            for j in np.arange(4):
                x_q[j] = aug_transform(x[j])
                x_k[j] = aug_transform(x[j])
            
            q = f_q(x_q) # queries: NxC
            k = f_k(x_k) # keys: NxC
            k = k.detach() # no gradient to keys
            
            # positive logits: Nx1
            l_pos = torch.matmul(q.view(N,1,C), k.view(N,C,1)).view(N, 1)
            # negative logits: NxK
            l_neg = torch.mm(q.view(N,C), queue.view(K,C).transpose(0,1))
            # logits: Nx(1+K)
            logits = torch.cat((l_pos, l_neg), dim=1)
            
            labels = torch.LongTensor(N).zero_()
            loss = criterion(logits/t, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # momentum update: key network
            f_k.load_state_dict(momentum_update(m, f_k.state_dict(), f_q.state_dict()))
            # update dictionary
            queue[:-1] = queue[1:] # dequeue the earliest minibatch
            queue[-1] = k # enqueue the current minibatch

            # print statistics
            running_loss += loss.item()
            
            if i % 2500 == 2499:
                print('[%d, %d] (%d%%)\t loss: %.3f,\t' %
                      (i + 1, len(train_loader), 100. * (i+1) / len(train_loader), 
                       running_loss / 2500 ))
                correct = 0.0
                running_loss = 0.0
        validate(f_q, test_loader, nn.CrossEntropyLoss(), [], [])

    print('Finished Pre-training')

In [45]:
pretrained_model = Net()

pretrained_optimizer = optim.SGD(pretrained_model.parameters(), lr=0.0001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

pretrain(10, pretrained_model, train_loader, pretrained_optimizer, criterion, 0.07, 0.999, 40, Net)
torch.save(pretrained_model.state_dict(), "./pretrained_model.pth")

Pre-train Epoch 1
[2500, 12500] (20%)	 loss: 1.930,	
[5000, 12500] (40%)	 loss: 1.371,	
[7500, 12500] (60%)	 loss: 1.269,	
[10000, 12500] (80%)	 loss: 1.155,	
[12500, 12500] (100%)	 loss: 1.109,	
Test average loss: 2.650	 accuracy: (9.670%)

Pre-train Epoch 2
[2500, 12500] (20%)	 loss: 1.056,	
[5000, 12500] (40%)	 loss: 1.055,	
[7500, 12500] (60%)	 loss: 0.960,	
[10000, 12500] (80%)	 loss: 0.942,	
[12500, 12500] (100%)	 loss: 0.908,	
Test average loss: 2.725	 accuracy: (9.940%)

Pre-train Epoch 3
[2500, 12500] (20%)	 loss: 0.899,	
[5000, 12500] (40%)	 loss: 0.864,	
[7500, 12500] (60%)	 loss: 0.872,	
[10000, 12500] (80%)	 loss: 0.812,	
[12500, 12500] (100%)	 loss: 0.822,	
Test average loss: 2.736	 accuracy: (4.920%)

Pre-train Epoch 4
[2500, 12500] (20%)	 loss: 0.762,	
[5000, 12500] (40%)	 loss: 0.758,	
[7500, 12500] (60%)	 loss: 0.773,	
[10000, 12500] (80%)	 loss: 0.756,	
[12500, 12500] (100%)	 loss: 0.758,	
Test average loss: 2.601	 accuracy: (9.950%)

Pre-train Epoch 5
[2500, 12500] 

## Standard model

In [9]:
model = Net()
train_loss_vector, test_loss_vector, train_accuracy_vector, test_accuracy_vector = [],[],[],[]

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    print("Epoch %d" % (epoch + 1))
    train(model, train_loader, optimizer, criterion, train_loss_vector, train_accuracy_vector)
    validate(model, test_loader, criterion, test_loss_vector, test_accuracy_vector)
print('Finished Training')

Epoch 1
[2500, 12500] (20%)	 loss: 1.880,	 accuracy: 31.260
[5000, 12500] (40%)	 loss: 1.715,	 accuracy: 38.300
[7500, 12500] (60%)	 loss: 1.650,	 accuracy: 41.440
[10000, 12500] (80%)	 loss: 1.618,	 accuracy: 42.190
[12500, 12500] (100%)	 loss: 1.577,	 accuracy: 44.510
Test average loss: 1.339	 accuracy: (52.370%)

Epoch 2
[2500, 12500] (20%)	 loss: 1.534,	 accuracy: 45.940
[5000, 12500] (40%)	 loss: 1.525,	 accuracy: 46.740
[7500, 12500] (60%)	 loss: 1.501,	 accuracy: 47.440
[10000, 12500] (80%)	 loss: 1.493,	 accuracy: 47.500
[12500, 12500] (100%)	 loss: 1.486,	 accuracy: 48.320
Test average loss: 1.215	 accuracy: (56.950%)

Epoch 3
[2500, 12500] (20%)	 loss: 1.455,	 accuracy: 49.510
[5000, 12500] (40%)	 loss: 1.446,	 accuracy: 49.630
[7500, 12500] (60%)	 loss: 1.429,	 accuracy: 49.650
[10000, 12500] (80%)	 loss: 1.438,	 accuracy: 49.580
[12500, 12500] (100%)	 loss: 1.414,	 accuracy: 50.830
Test average loss: 1.154	 accuracy: (59.570%)

Epoch 4
[2500, 12500] (20%)	 loss: 1.384,	 acc

## Pretrained model

In [46]:
pret_model = Net()
pret_model.load_state_dict(torch.load("./pretrained_model.pth"))
pret_loss_vector, pret_test_loss_vector, pret_accuracy_vector, pret_test_accuracy_vector = [],[],[],[]

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pret_model.parameters(), lr=0.001)

for epoch in range(10):
    print("Epoch %d" % (epoch + 1))
    train(pret_model, train_loader, optimizer, criterion, pret_loss_vector, pret_accuracy_vector)
    validate(pret_model, test_loader, criterion, pret_test_loss_vector, pret_test_accuracy_vector)
print('Finished Training')

Epoch 1
[2500, 12500] (20%)	 loss: 1.904,	 accuracy: 31.420
[5000, 12500] (40%)	 loss: 1.736,	 accuracy: 37.870
[7500, 12500] (60%)	 loss: 1.702,	 accuracy: 39.370
[10000, 12500] (80%)	 loss: 1.668,	 accuracy: 40.870
[12500, 12500] (100%)	 loss: 1.628,	 accuracy: 41.460
Test average loss: 1.403	 accuracy: (50.150%)

Epoch 2
[2500, 12500] (20%)	 loss: 1.580,	 accuracy: 43.670
[5000, 12500] (40%)	 loss: 1.583,	 accuracy: 44.230
[7500, 12500] (60%)	 loss: 1.571,	 accuracy: 45.180
[10000, 12500] (80%)	 loss: 1.560,	 accuracy: 44.880
[12500, 12500] (100%)	 loss: 1.532,	 accuracy: 46.300
Test average loss: 1.278	 accuracy: (54.560%)

Epoch 3
[2500, 12500] (20%)	 loss: 1.521,	 accuracy: 46.870
[5000, 12500] (40%)	 loss: 1.500,	 accuracy: 48.180
[7500, 12500] (60%)	 loss: 1.506,	 accuracy: 47.840
[10000, 12500] (80%)	 loss: 1.489,	 accuracy: 48.270
[12500, 12500] (100%)	 loss: 1.477,	 accuracy: 48.320
Test average loss: 1.232	 accuracy: (56.090%)

Epoch 4
[2500, 12500] (20%)	 loss: 1.416,	 acc

In [29]:
import plotly.graph_objects as go

In [30]:
def draw_graph(strm, ptrm, stm, ptm, met):
    x = np.array([k for k in np.arange(11)])

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=strm, mode='lines', name="Standard training model", line_color='rgb(110,200,250)', line_width=3))
    fig.add_trace(go.Scatter(x=x, y=ptrm, mode='lines', name="Pretrained training model ", line_color='rgb(250,110,200)', line_width=3))
    fig.add_trace(go.Scatter(x=x, y=stm, mode='lines', name="Standard test model", line_color='rgb(80,100,180)', line_width=3))
    fig.add_trace(go.Scatter(x=x, y=ptm, mode='lines', name="Pretrained test model", line_color='rgb(210,70,100)', line_width=3))
    fig.update_layout(
        xaxis=go.layout.XAxis(title="Epoch"),
        yaxis=go.layout.YAxis(title=met)
    )
    fig.show()

## Loss

In [47]:
draw_graph(train_loss_vector, pret_loss_vector, test_loss_vector, pret_test_loss_vector, "Loss")

## Accuracy

In [48]:
draw_graph(train_accuracy_vector, pret_accuracy_vector, test_accuracy_vector, pret_test_accuracy_vector, "Accuracy")

In [33]:
torch.save(model.state_dict(), "./st_model.pth")

In [34]:
torch.save(pret_model.state_dict(), "./pret_model.pth")