In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import os
import argparse

data_legend = {'loss': [], 'acc': []}


def parameter_ctr(model):
    total_parameters = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        param = parameter.numel()
        total_parameters+=param
    print("Parameters: ", total_parameters)
    return total_parameters


class BasicBlock(nn.Module):

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

# ResNET
class ResNet(nn.Module):
    def __init__(self, P, N, _C, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        stride = [1, 2, 2, 2]
        C = _C
        self.in_planes = C
        self.N = N
        self.P = P
        self.conv1 = nn.Conv2d(3, C, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(C)
        self.layers1 = nn.ModuleList()

        for i in range(0, N):
            if i > 0:
                C = 2 * C
            self.layers1.append(self._make_layer(block, C, num_blocks[i], stride[i]))
        self.linear = nn.Linear(C, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        for i in range(0, self.N):
            out = self.layers1[i](out)
        out = F.avg_pool2d(out, self.P)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def device_info():
  if torch.cuda.is_available():
    return 'cuda'
  else:
    return 'cpu'

device = device_info()


def project1_model(P, N, _C, block, num_blocks):
    return ResNet(P, N, _C, block, num_blocks)

args = {"lr":0.001, "resume":False}


if torch.cuda.is_available():
  print("\n====== RUNNING ON GPU ======\n")
else:
  print("\n====== RUNNING ON CPU ======\n")

top_acc = 0
start_epoch = 0

# Data
print('Augmenting Data...')
transform_train = transforms.Compose([
    transforms.RandomCrop(32),
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

transform_test = transforms.ToTensor()

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=128, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=100, shuffle=False, num_workers=2)


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('\nBuilding model...')

net = project1_model(3, 4, 32, BasicBlock, [4,4,4,3])
total_parameters = parameter_ctr(net)

if total_parameters < 5000000:
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True


# Loading Checkpoint
    if args['resume']:
        print('Loading from checkpoint...')
        assert os.path.isdir('checkpoint'), 'Checkpoint Empty'
        checkpoint = torch.load('./checkpoint/checkpoint.pth')
        net.load_state_dict(checkpoint['net'])
        top_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, verbose=False)


    # Training
    def train(epoch):
        print('\nEpoch: ', epoch)
        net.train()
        train_loss = 0
        correct = 0
        total = 0

        for batch_no, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if batch_no % 100 == 0:
              print('Training batch ',batch_no, ' of ', len(train_loader), ' | LOSS: ', train_loss / (batch_no + 1), ' | ACCURACY: ', 100 * correct / total, '%')

    def test(epoch):
        global top_acc
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_no, (inputs, targets) in enumerate(test_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                if batch_no % 20 == 0:
                  print('Testing batch ',batch_no, ' of ', len(test_loader), ' | LOSS: ', test_loss / (batch_no + 1), ' | ACCURACY: ', 100 * correct / total, '%')

        # Save checkpoint.
        curr_acc = (100 * correct / total)
        if curr_acc > top_acc:
            print('Saving Checkpoint...')
            state = {
                'net': net.state_dict(),
                'acc': curr_acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/checkpoint.pth')
            top_acc = curr_acc

    for epoch in range(start_epoch, start_epoch+50):
        train(epoch)
        test(epoch)
        scheduler.step()

print('Highest Accuracy: ', top_acc)



Augmenting Data...
Files already downloaded and verified
Files already downloaded and verified

Building model...
Parameters:  4754218
Highest Accuracy:  0
