In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
# settings
seed = 1
batch_size = 800
test_batch_size = 300
lr = 0.01
momentum = 0.9
weight_decay = 5e-4
epochs = 200
log_interval = 16
save_model = False
criterion = nn.CrossEntropyLoss()
torch.backends.cudnn.benchmark=True

In [3]:
# ResNet inner blocks
class planeBlock(nn.Module):
    def __init__(self, inplane, outplane, stride=1, padding=1, downsample=False):
        super(planeBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplane, outplane, kernel_size=3, stride=stride, padding=padding, bias=True)
        self.bn1 = nn.BatchNorm2d(inplane)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout2d(0.5)
        self.conv2 =  nn.Conv2d(outplane, outplane,  kernel_size=3, stride=1, padding=padding, bias=True)
        self.bn2 = nn.BatchNorm2d(outplane)

        # for change identity size
        if stride != 1 or inplane != outplane :
            self.downsample = nn.Sequential(
            nn.Conv2d(inplane, outplane, kernel_size=1, stride=stride, bias=True),
            nn.BatchNorm2d(outplane),
            )
        else :
            self.downsample = downsample

        print(inplane, outplane,)
        
    def forward(self, x):
        identity = x        # inputs shorcut

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)

        if self.downsample is not False :
            identity = self.downsample(x)
            #print(self.stride)

        out += identity        
#        out = self.relu(out)

        return out

In [4]:
class resNet(nn.Module):
    def __init__(self):
        super(resNet, self).__init__()
#        self.inplanes = 128

        # top convolution 
        self.base = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1,),
                nn.Dropout2d(0.3),
        )
        
        
        # inner blocks settings
        layers =  [
                planeBlock(64, 64),
                planeBlock(64, 64),
                planeBlock(64, 64),
                planeBlock(64, 128, stride=2,),
                planeBlock(128, 128, stride=1,),
                planeBlock(128, 128, stride=1,),
                planeBlock(128, 128, stride=1,),
                planeBlock(128, 128, stride=1,),
                planeBlock(128, 256, stride=2,),
                planeBlock(256, 256, stride=1,),
                planeBlock(256, 256, stride=1,),
                planeBlock(256, 256, stride=1,),
                planeBlock(256, 256, stride=1,),
                planeBlock(256, 256, stride=1,),
                planeBlock(256, 512, stride=2,),
                planeBlock(512, 512, stride=1,),
                planeBlock(512, 512, stride=1,),
        ]
        
        self.planes = nn.Sequential(*layers)
        
        
        # pooling layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        
        # output side
        self.classifier = nn.Sequential(
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, num_classes),
        )
        

    def forward(self, x):
        x = self.base(x)
        x = self.planes(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


In [5]:
def train(model, device, train_loader, optimizer, log_interval, epoch, epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # proccessing display for training
        if batch_idx % log_interval == 0:
            print('Train Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, epochs, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item())) 


In [6]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).sum().item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    # proccessing display for test   
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [7]:
# use device to Cuda or CPU
use_cuda =  torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}

torch.manual_seed(seed)


# get datasets
train_dataset=datasets.CIFAR100('./data_c100', train=True, download=True,
                                transform=transforms.Compose([
                                        transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                       
                                        ])
                               )

test_dataset=datasets.CIFAR100('./data_c100', train=False, download=True,
                                transform=transforms.Compose([
                                        transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])
                              )


# classes count
labels = [i[1] for i in train_dataset]
num_classes = max(labels) + 1

# show datasets shape
image = train_dataset[0]
print(image[0].size())
print('Labels:', num_classes)

Files already downloaded and verified
Files already downloaded and verified
torch.Size([3, 32, 32])
Labels: 100


In [8]:
# datasets to DataLoader
train_loader = torch.utils.data.DataLoader( dataset=train_dataset, 
                                            batch_size=batch_size, shuffle=True, 
                                            **kwargs,
                                          )

test_loader = torch.utils.data.DataLoader( dataset=test_dataset, 
                                            batch_size=test_batch_size, shuffle=False,
                                            **kwargs
                                         )

# check tensor shape
for images, labels in train_loader:
    print(images.size())
    print(images[0].size())
    print(labels.size())
    
    break

torch.Size([800, 3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([800])


In [9]:
# build model
model = resNet().to(device)
optimizer = optim.SGD(model.parameters(),
                       lr=lr,
                       momentum=momentum,
                       weight_decay=weight_decay,
                      )

# model run
def run():
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, log_interval, epoch, epochs)
        test(model, device, test_loader)

    if (save_model):
        torch.save(model.state_dict(),"resnet.pt")

64 64
64 64
64 64
64 128
128 128
128 128
128 128
128 128
128 256
256 256
256 256
256 256
256 256
256 256
256 512
512 512
512 512


In [10]:
%%time
print('Use device is', str(device).upper())

run()

Use device is CUDA

Test set: Average loss: 0.0150, Accuracy: 463/10000 (5%)


Test set: Average loss: 0.0143, Accuracy: 778/10000 (8%)


Test set: Average loss: 0.0136, Accuracy: 1078/10000 (11%)


Test set: Average loss: 0.0130, Accuracy: 1293/10000 (13%)


Test set: Average loss: 0.0126, Accuracy: 1452/10000 (15%)


Test set: Average loss: 0.0122, Accuracy: 1668/10000 (17%)


Test set: Average loss: 0.0119, Accuracy: 1793/10000 (18%)


Test set: Average loss: 0.0114, Accuracy: 2054/10000 (21%)


Test set: Average loss: 0.0113, Accuracy: 2083/10000 (21%)


Test set: Average loss: 0.0107, Accuracy: 2367/10000 (24%)


Test set: Average loss: 0.0106, Accuracy: 2483/10000 (25%)


Test set: Average loss: 0.0103, Accuracy: 2612/10000 (26%)


Test set: Average loss: 0.0100, Accuracy: 2777/10000 (28%)


Test set: Average loss: 0.0096, Accuracy: 2956/10000 (30%)


Test set: Average loss: 0.0094, Accuracy: 3050/10000 (30%)


Test set: Average loss: 0.0091, Accuracy: 3223/10000 (32%)


Test set


Test set: Average loss: 0.0074, Accuracy: 4283/10000 (43%)


Test set: Average loss: 0.0075, Accuracy: 4280/10000 (43%)


Test set: Average loss: 0.0072, Accuracy: 4398/10000 (44%)


Test set: Average loss: 0.0073, Accuracy: 4357/10000 (44%)


Test set: Average loss: 0.0072, Accuracy: 4485/10000 (45%)


Test set: Average loss: 0.0072, Accuracy: 4515/10000 (45%)


Test set: Average loss: 0.0070, Accuracy: 4642/10000 (46%)


Test set: Average loss: 0.0071, Accuracy: 4615/10000 (46%)


Test set: Average loss: 0.0072, Accuracy: 4535/10000 (45%)


Test set: Average loss: 0.0071, Accuracy: 4605/10000 (46%)


Test set: Average loss: 0.0071, Accuracy: 4634/10000 (46%)


Test set: Average loss: 0.0071, Accuracy: 4666/10000 (47%)


Test set: Average loss: 0.0071, Accuracy: 4617/10000 (46%)


Test set: Average loss: 0.0070, Accuracy: 4670/10000 (47%)


Test set: Average loss: 0.0072, Accuracy: 4601/10000 (46%)


Test set: Average loss: 0.0072, Accuracy: 4656/10000 (47%)


Test set: Average loss:


Test set: Average loss: 0.0073, Accuracy: 4838/10000 (48%)


Test set: Average loss: 0.0075, Accuracy: 4829/10000 (48%)


Test set: Average loss: 0.0075, Accuracy: 4780/10000 (48%)


Test set: Average loss: 0.0077, Accuracy: 4819/10000 (48%)


Test set: Average loss: 0.0075, Accuracy: 4838/10000 (48%)


Test set: Average loss: 0.0076, Accuracy: 4831/10000 (48%)


Test set: Average loss: 0.0076, Accuracy: 4828/10000 (48%)


Test set: Average loss: 0.0079, Accuracy: 4763/10000 (48%)


Test set: Average loss: 0.0077, Accuracy: 4803/10000 (48%)


Test set: Average loss: 0.0078, Accuracy: 4794/10000 (48%)


Test set: Average loss: 0.0078, Accuracy: 4821/10000 (48%)


Test set: Average loss: 0.0081, Accuracy: 4733/10000 (47%)


Test set: Average loss: 0.0080, Accuracy: 4768/10000 (48%)


Test set: Average loss: 0.0082, Accuracy: 4756/10000 (48%)


Test set: Average loss: 0.0083, Accuracy: 4651/10000 (47%)


Test set: Average loss: 0.0080, Accuracy: 4835/10000 (48%)


Test set: Average loss:


Test set: Average loss: 0.0086, Accuracy: 4866/10000 (49%)


Test set: Average loss: 0.0087, Accuracy: 4847/10000 (48%)


Test set: Average loss: 0.0088, Accuracy: 4843/10000 (48%)


Test set: Average loss: 0.0087, Accuracy: 4827/10000 (48%)


Test set: Average loss: 0.0088, Accuracy: 4857/10000 (49%)


Test set: Average loss: 0.0088, Accuracy: 4850/10000 (48%)


Test set: Average loss: 0.0089, Accuracy: 4818/10000 (48%)


Test set: Average loss: 0.0089, Accuracy: 4837/10000 (48%)


Test set: Average loss: 0.0089, Accuracy: 4822/10000 (48%)


Test set: Average loss: 0.0089, Accuracy: 4838/10000 (48%)


Test set: Average loss: 0.0089, Accuracy: 4829/10000 (48%)


Test set: Average loss: 0.0091, Accuracy: 4815/10000 (48%)


Test set: Average loss: 0.0090, Accuracy: 4857/10000 (49%)


Test set: Average loss: 0.0092, Accuracy: 4780/10000 (48%)


Test set: Average loss: 0.0090, Accuracy: 4857/10000 (49%)


Test set: Average loss: 0.0090, Accuracy: 4884/10000 (49%)


Test set: Average loss:


Test set: Average loss: 0.0094, Accuracy: 4850/10000 (48%)


Test set: Average loss: 0.0094, Accuracy: 4812/10000 (48%)


Test set: Average loss: 0.0093, Accuracy: 4912/10000 (49%)


Test set: Average loss: 0.0097, Accuracy: 4777/10000 (48%)


Test set: Average loss: 0.0094, Accuracy: 4840/10000 (48%)


Test set: Average loss: 0.0093, Accuracy: 4909/10000 (49%)


Test set: Average loss: 0.0093, Accuracy: 4878/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4901/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4956/10000 (50%)


Test set: Average loss: 0.0096, Accuracy: 4797/10000 (48%)


Test set: Average loss: 0.0095, Accuracy: 4871/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4894/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4888/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4931/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4905/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4886/10000 (49%)


Test set: Average loss:


Test set: Average loss: 0.0093, Accuracy: 4950/10000 (50%)


Test set: Average loss: 0.0095, Accuracy: 4892/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4901/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4886/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4882/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4877/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4916/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4893/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4928/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4862/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4889/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4914/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4928/10000 (49%)


Test set: Average loss: 0.0097, Accuracy: 4839/10000 (48%)


Test set: Average loss: 0.0095, Accuracy: 4901/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4952/10000 (50%)


Test set: Average loss:


Test set: Average loss: 0.0095, Accuracy: 4921/10000 (49%)


Test set: Average loss: 0.0094, Accuracy: 4954/10000 (50%)


Test set: Average loss: 0.0095, Accuracy: 4870/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4927/10000 (49%)


Test set: Average loss: 0.0097, Accuracy: 4862/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4929/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4907/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4889/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4928/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4913/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4897/10000 (49%)


Test set: Average loss: 0.0097, Accuracy: 4822/10000 (48%)


Test set: Average loss: 0.0096, Accuracy: 4900/10000 (49%)


Test set: Average loss: 0.0096, Accuracy: 4887/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4920/10000 (49%)


Test set: Average loss: 0.0095, Accuracy: 4919/10000 (49%)


Test set: Average loss: