In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchsummary import summary
from imgaug import augmenters as iaa
import numpy as np
import numpy as np
from imgaug import augmenters as iaa

In [None]:
!/opt/bin/nvidia-smi

In [None]:
# Number of Residual Layers
# res_layer_num = 0
# Number of Residual blocks in Residual Layer i
res_block_nums = [2, 2, 4, 1]
# Number of channels in Residual Layer i
channel_nums = [48, 96, 192, 384]
# Conv. kernel size in Residual Layer i
# conv_kernel_sizes = []
# Skip connection kernel size in Residual Layer i
# skip_kernel_sizes = []
# Average pool kernel size
avg_pool_kernel_size = 4

In [None]:
class BasicBlock(nn.Module):

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out



class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, channel_nums[0], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, channel_nums[1], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, channel_nums[2], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, channel_nums[3], num_blocks[3], stride=2)
        self.linear = nn.Linear(channel_nums[3], num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, avg_pool_kernel_size)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def project1_model_reference():
    return ResNet(BasicBlock, [3, 4, 5, 3])

def project1_model():
    return ResNet(BasicBlock, res_block_nums)

In [None]:
resnet = project1_model().cuda()

resnet.load_state_dict(torch.load('model_state.pt'))

In [None]:
for param_tensor in resnet.state_dict():
    print(param_tensor, "\t", resnet.state_dict()[param_tensor].size())

Just Test:

In [None]:

#*** PERFORM NORMAILIZATION BEFORE TESTING

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

BATCH_SIZE = 256

test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testDataLoader = torch.utils.data.DataLoader(test, batch_size = BATCH_SIZE, shuffle = False)
#resnet =  resnet().to(device)
resnet.eval()




corrects = 0
for batch_idx, (inputs,labels) in enumerate(testDataLoader, 1):
  with torch.set_grad_enabled(False):
    inputs = inputs.cuda()
    labels = labels.cuda()
    outputs = resnet(inputs)
    _, preds = torch.max(outputs,1)    
  corrects += torch.sum(preds == labels.data)
print(corrects.float() / len(testDataLoader.dataset))