<a href="https://colab.research.google.com/github/tykwak-deepbio/colab-experiment/blob/master/image_classification_test_CIFAR_10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import datetime
import numpy as np
import math

In [0]:
class SE2d(nn.Module):
    def __init__(self, in_channel, intra_channel):
        super(SE2d, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, intra_channel, 1)
        self.conv2 = nn.Conv2d(intra_channel, in_channel, 1)

    def forward(self, x):
        a = F.avg_pool2d(x, x.size()[2:])
        a = F.relu(self.conv1(a))
        a = torch.sigmoid(self.conv2(a))
        x = x * a.expand_as(x)
        return x

class AllConvNet(nn.Module):
    def __init__(self):
        super(AllConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5, padding=2)
        self.bn1   = nn.BatchNorm2d(64)
        self.se1   = SE2d(64, 8)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1, stride=2)
        self.drop2 = nn.Dropout2d(p=0.5)
        self.conv3 = nn.Conv2d(64, 128, 5, padding=2)
        self.bn3   = nn.BatchNorm2d(128)
        self.se3   = SE2d(128, 16)
        self.conv4 = nn.Conv2d(128, 128, 3, padding=1, stride=2)
        self.drop4 = nn.Dropout2d(p=0.5)
        self.conv5 = nn.Conv2d(128, 256, 5, padding=2)
        self.bn5   = nn.BatchNorm2d(256)
        self.se5   = SE2d(256, 32)
        self.conv6 = nn.Conv2d(256, 10, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.se1(x)
        x = F.relu(self.conv2(x))
        x = self.drop2(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.se3(x)
        x = F.relu(self.conv4(x))
        x = self.drop4(x)
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.se5(x)
        x = F.relu(self.conv6(x))
        x = F.avg_pool2d(x, x.size()[2:]).squeeze(-1).squeeze(-1)
        return x

class MeanTeacher(nn.Module):
    def __init__(self, model, var1=0.0001, var2=0.0001, weight=0.9):
        assert var1 >= 0
        assert var2 >= 0
        assert weight > 0 and weight < 1
        super(MeanTeacher, self).__init__()
        self.student = model
        self.teacher = (model.__class__)()
        self.stdev1 = math.sqrt(var1)
        self.stdev2 = math.sqrt(var2)
        self.weight = weight
        # copy student to teacher
        dict_student = self.student.state_dict()
        dict_teacher = self.teacher.state_dict()
        for key in dict_teacher:
            dict_teacher[key] = dict_student[key] 
        self.teacher.load_state_dict(dict_teacher)

    def forward(self, x):
        if self.training:
            self.teacher.requires_grad = False
            size = x.size()
            x1 = x + torch.randn(size, device=x.device) * self.stdev1    # TODO: check model device
            x2 = x + torch.randn(size, device=x.device) * self.stdev2    # TODO: check model device
            y1 = self.student(x1)
            y2 = self.teacher(x2)
            cl = (y2 - y1).pow(2).sum()
            return y1, cl
        else:
            y = self.teacher(x)
            return y, 0

    def update(self):
        if self.training:
            dict_student = self.student.state_dict()
            dict_teacher = self.teacher.state_dict()
            for key in dict_teacher:
                dict_teacher[key] = dict_teacher[key] * self.weight + dict_student[key] * (1 - self.weight)
            self.teacher.load_state_dict(dict_teacher)

net = MeanTeacher(AllConvNet(), var1=0.0001, var2=0.0001, weight=0.99).cuda()


In [0]:
mean = (0.4913997551666284, 0.48215855929893703, 0.4465309133731618)
std = (0.24703225141799082, 0.24348516474564, 0.26158783926049628)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4, nesterov=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150, 200], gamma=0.1)


  0%|          | 0/170498071 [00:00<?, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


170500096it [00:01, 87701024.05it/s]                               


Files already downloaded and verified


In [0]:
timestamp = str(datetime.datetime.now()).split('.')
print('(%s) Training Started' % timestamp[0])

def wl(epoch):
    if epoch < 10:
        return 0.01
    elif epoch < 20:
        return 0.02
    elif epoch < 30:
        return 0.04
    elif epoch < 40:
        return 0.08
    elif epoch < 50:
        return 0.16
    elif epoch < 60:
        return 0.32
    elif epoch < 70:
        return 0.64
    else:
        return 1.0

for epoch in range(250):
    scheduler.step()

    running_loss = 0.0
    total = 0
    net.train()
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()
        outputs, loss2 = net(inputs)
        if i % 10 < 1:
            loss = criterion(outputs, labels) + wl(epoch) * loss2
        else:
            loss = wl(epoch) * loss2
        loss.backward()
        optimizer.step()
        net.update()
        running_loss += loss.item()
        total += labels.size(0)
    timestamp = str(datetime.datetime.now()).split('.')
    print('(%s) [%03d] {Tr} loss: %.6f' % (timestamp[0], epoch + 1, 4 * running_loss / total), end='')

    running_loss = 0.0
    correct = 0
    total = 0
    net.eval()
    for data in testloader:
        images, labels = data
        images = images.cuda()
        labels = labels.cuda()
        outputs, loss2 = net(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(' {Va} loss: %.6f acc: %.2f%%' % (128 * running_loss / total, 100.0 * correct / total))

timestamp = str(datetime.datetime.now()).split('.')
print('(%s) Training Finished' % timestamp[0])



(2019-06-14 02:53:37) Training Started
(2019-06-14 02:57:14) [001] {Tr} loss: 0.229345 {Va} loss: 2.255415 acc: 17.23%
(2019-06-14 03:00:56) [002] {Tr} loss: 0.226892 {Va} loss: 2.232613 acc: 19.15%
(2019-06-14 03:04:38) [003] {Tr} loss: 0.226528 {Va} loss: 2.233479 acc: 19.41%
(2019-06-14 03:08:20) [004] {Tr} loss: 0.226083 {Va} loss: 2.225054 acc: 20.01%
(2019-06-14 03:12:01) [005] {Tr} loss: 0.225124 {Va} loss: 2.208807 acc: 20.11%
(2019-06-14 03:15:43) [006] {Tr} loss: 0.225321 {Va} loss: 2.202458 acc: 22.45%
(2019-06-14 03:19:26) [007] {Tr} loss: 0.224510 {Va} loss: 2.190380 acc: 22.62%
(2019-06-14 03:23:08) [008] {Tr} loss: 0.223842 {Va} loss: 2.163603 acc: 23.59%
(2019-06-14 03:26:50) [009] {Tr} loss: 0.222468 {Va} loss: 2.146799 acc: 23.03%
(2019-06-14 03:30:32) [010] {Tr} loss: 0.222092 {Va} loss: 2.142327 acc: 23.80%
(2019-06-14 03:34:15) [011] {Tr} loss: 0.222109 {Va} loss: 2.153796 acc: 24.18%
(2019-06-14 03:37:58) [012] {Tr} loss: 0.222412 {Va} loss: 2.149233 acc: 23.33%
(