<a href="https://colab.research.google.com/github/pulkitdixit/cnn_ResNet_CIFAR100/blob/master/CIFAR100_Pretrained_Pytorch_GPU_GoogleColab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
import torch
import torch.nn as nn
from torch.utils import model_zoo
from torch.hub import load_state_dict_from_url
import torchvision
from torchvision import models
import torchvision.transforms as transforms

In [0]:
root = 'gdrive/My Drive/Google Colab/'
batch_size = 100
learn_rate = 0.001
scheduler_step_size = 5
scheduler_gamma = 0.5
num_epochs = 50

In [0]:
transform_train = transforms.Compose([transforms.RandomRotation(10),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor()
                                     ])

In [0]:
train_dataset = torchvision.datasets.CIFAR100(root = root, train=True, transform=transform_train, download=False)
test_dataset = torchvision.datasets.CIFAR100(root = root, train=False, transform=transform_train, download=False)

In [0]:
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True, num_workers = 8)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = False, num_workers = 8)

In [0]:
def upsample(x):
  up = nn.Upsample(scale_factor=7, mode='bilinear')
  return(up(x))

In [0]:
def resnet18(pretrained=True, progress=True):
  model = torchvision.models.resnet.ResNet(torchvision.models.resnet.BasicBlock, [2, 4, 4, 2])
  if pretrained:
        state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet18-5c106cde.pth',
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
#   if pretrained:
#     model.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth', model_dir='./'))
  return(model)

In [0]:
#layers = [2, 4, 4, 2]
#model = torchvision.models.resnet18(pretrained = True)
model = resnet18(pretrained = True)

#Enable GPU:
use_cuda = True
if use_cuda and torch.cuda.is_available():
    model.cuda()
    
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), 
                                lr = learn_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size = scheduler_step_size, 
                                            gamma = scheduler_gamma)

In [0]:
train_acc_list = []
test_acc_list = []
for epochs in range(num_epochs):
    scheduler.step()
    correct = 0
    total = 0
    print('Current epoch: \t\t', epochs+1, '/', num_epochs)
    #print('--------------------------------------------------')
    for images, labels in train_loader:
        #images = images.reshape(-1, 16*16)
        images = images
        labels = labels
        if use_cuda and torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()
            images = upsample(images).cuda()
        
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        total = total + labels.size(0)
        correct = correct + (predicted == labels).sum().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_acc = correct/total
    print('Training accuracy: \t', train_acc)
    #print('--------------------------------------------------')
    train_acc_list.append(train_acc)
    
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images
            labels = labels
            if use_cuda and torch.cuda.is_available():
              images = images.cuda()
              labels = labels.cuda()
              images = upsample(images).cuda() 
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total = total + labels.size(0)
            correct = correct + (predicted == labels).sum().item()
    test_acc = correct/total
    print('Test Accuracy: \t\t', test_acc)
    print('**************************************************')
    test_acc_list.append(test_acc)
    model.train()