In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm.notebook import tqdm, trange

import copy
import random
import time
import os

In [None]:
# constants
ROOT = ".data"
VALID_RATIO = 0.9
BATCH_SIZE = 64
DROPOUT = 0.05
OUTPUT_CLASSES = 10
LEARNING_RATE = 0.01
EPOCHS = 200

In [None]:
class Data():

    def __init__(self, valid_ratio, batch_size):
        
        """Downloads, splits and normalizes the data. Provides dataloaders for training"""

        train_data = datasets.CIFAR10(root="../data",
                            train=True,
                            download=True)

        train_data.data = torch.tensor(train_data.data)

        channels = train_data.data.split(1, dim=-1)
        channel_tensors = [channel.squeeze(-1) for channel in channels]

        means = [z.float().mean() / 255 for z in channel_tensors]
        stds = [z.float().std() / 255 for z in channel_tensors]

        train_transforms = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(mean=means, std=stds)
                                      ])

        test_transforms = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=means, std=stds)
                                            ])

        train_data = datasets.CIFAR10(root="../data",
                                    train=True,
                                    download=True,
                                    transform=train_transforms)

        test_data = datasets.CIFAR10(root="../data",
                                train=False,
                                download=True,
                                transform=test_transforms)
        
        n_train_examples = int(len(train_data) * valid_ratio)
        n_valid_examples = len(train_data) - n_train_examples

        train_data, valid_data = data.random_split(train_data,
                                                [n_train_examples, n_valid_examples])

        valid_data = copy.deepcopy(valid_data)
        valid_data.dataset.transform = test_transforms

        self.train_loader = data.DataLoader(train_data,
                                        shuffle=True,
                                        batch_size=batch_size, num_workers=2)

        self.valid_loader = data.DataLoader(valid_data,
                                        batch_size=batch_size, num_workers=2)
        
        self.test_loader = data.DataLoader(test_data, batch_size=batch_size, num_workers=2)


    #     # downloads the data
    #     self.train_data = datasets.CIFAR10(root=ROOT, train=True, download=True)
    #     self.test_data = datasets.CIFAR10(root=ROOT, train=False,download=True)

    #     # splitting the data
    #     n_train_examples = int(len(self.train_data) * valid_ratio)
    #     n_valid_examples = len(self.train_data) - n_train_examples
    #     self.train_data, self.valid_data = data.random_split(self.train_data, [n_train_examples, n_valid_examples])
    #     self.valid_data = copy.deepcopy(self.valid_data)

    #     # normalizing the data

    #     t_mean, t_std = self.calculate_mean_std(self.train_data)
    #     train_transforms = transforms.Compose([
    #                         transforms.ToTensor(),
    #                         transforms.Normalize(mean=t_mean, std=t_std)])
    #     v_mean, v_std = self.calculate_mean_std(self.valid_data)
    #     valid_transforms = transforms.Compose([
    #                         transforms.ToTensor(),
    #                         transforms.Normalize(mean=v_mean, std=v_std)])
    #     te_mean, te_std = self.calculate_mean_std(self.test_data)
    #     test_transforms = transforms.Compose([
    #                         transforms.ToTensor(),
    #                         transforms.Normalize(mean=te_mean, std=te_std)])
        
    #     self.train_data.transform = train_transforms
    #     self.valid_data.transform = valid_transforms
    #     self.test_data.transform = test_transforms

    #     self.train_loader = data.DataLoader(self.train_data,
    #                              shuffle=True,
    #                              batch_size=batch_size, num_workers=2)
    #     self.valid_loader = data.DataLoader(self.valid_data,
    #                              batch_size=batch_size, num_workers=2)
        
    
    # def calculate_mean_std(self, dataset):
    #     """Calculates the mean and std of a dataset"""
    #     dataset.data = torch.tensor(dataset.data)
    #     channels = dataset.data.split(1, dim=-1)
    #     channel_tensors = [channel.squeeze(-1) for channel in channels]
    #     means = [z.float().mean() / 255 for z in channel_tensors]
    #     stds = [z.float().std() / 255 for z in channel_tensors]
    #     return means, stds

In [None]:
class Residual(nn.Module):
    """The Residual block of ResNet models."""
    
    def __init__(self, num_channels, use_1x1conv=False, strides=1, dp=0.1):
        
        super().__init__()
        self.conv1 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.LazyBatchNorm2d()
        self.bn2 = nn.LazyBatchNorm2d()
        self.dp = nn.Dropout(dp)

    def forward(self, X):
        Y = self.dp(F.relu(self.bn1(self.conv1(X))))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return self.dp(F.relu(Y))

In [None]:
class ResNet(nn.Module):

    def __init__(self, arch, lr=0.1, num_classes=10, dp=0.1):
        super(ResNet, self).__init__()

        self.dp = dp
        self.lr = lr

        self.net = nn.Sequential(self.b1())
        for i, b in enumerate(arch):
            self.net.add_module(f'b{i+2}', self.block(*b, first_block=(i==0)))
        self.net.add_module('last', nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
            nn.LazyLinear(num_classes)))

    def b1(self):
        return nn.Sequential(
            nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
            nn.LazyBatchNorm2d(), nn.ReLU(),
            nn.Dropout(self.dp),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    
    def block(self, num_residuals, num_channels, first_block=False):
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(num_channels, use_1x1conv=True, strides=2, dp=self.dp))
            else:
                blk.append(Residual(num_channels, dp=self.dp))
        return nn.Sequential(*blk)
    

    def forward(self, X):
        return self.net(X)
    
    
class ResNet18(ResNet):
    def __init__(self, lr=0.1, num_classes=10, dp=0.1):
        self.dp = dp
        super().__init__(((2, 64), (2, 128), (2, 256), (2, 512)), lr, num_classes, dp)

In [None]:
class Trainer():

    def __init__(self, model, data, optimizer, criterion, device):
        self.model = model
        self.data = data
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device

    def calculate_accuracy(self, y_pred, y):
        top_pred = y_pred.argmax(1, keepdim=True)
        correct = top_pred.eq(y.view_as(top_pred)).sum()
        acc = correct.float() / y.shape[0]
        return acc
    
    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    def train(self):

        iterator = self.data.train_loader

        epoch_loss = 0
        epoch_acc = 0

        self.model.train()

        for (x, y) in tqdm(iterator, desc="Training", leave=False):

            x = x.to(self.device)
            y = y.to(self.device)

            self.optimizer.zero_grad()

            y_pred = self.model(x)

            loss = self.criterion(y_pred, y)

            acc = self.calculate_accuracy(y_pred, y)

            loss.backward()

            self.optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += acc.item()

        return epoch_loss / len(iterator), epoch_acc / len(iterator)

    def evaluate(self):

        iterator = self.data.valid_loader

        epoch_loss = 0
        epoch_acc = 0

        self.model.eval()

        with torch.no_grad():

            for (x, y) in tqdm(self.iterator, desc="Evaluating", leave=False):

                x = x.to(self.device)
                y = y.to(self.device)

                y_pred = self.model(x)

                loss = self.criterion(y_pred, y)

                acc = self.calculate_accuracy(y_pred, y)

                epoch_loss += loss.item()
                epoch_acc += acc.item()

        return epoch_loss / len(iterator), epoch_acc / len(iterator)
    
    def epoch_time(start_time, end_time):
        elapsed_time = end_time - start_time
        elapsed_mins = int(elapsed_time / 60)
        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
        return elapsed_mins, elapsed_secs

In [None]:
def init_cnn(module: nn.Module):
    if type(module) == nn.Linear or type(module) == nn.Conv2d:
        nn.init.xavier_uniform_(module.weight)

data = Data(VALID_RATIO, BATCH_SIZE)
model = ResNet18(lr=LEARNING_RATE, num_classes=OUTPUT_CLASSES, dp=DROPOUT)

optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)
model = model.to(device)
criterion = criterion.to(device)

model(next(iter(data.train_loader))[0].to(device))
model.net.apply(init_cnn)

In [None]:
trainer = Trainer(model, data, optimizer, criterion, device)

best_valid_loss = float('inf')

for epoch in trange(EPOCHS, desc="Epochs"):

    start_time = time.monotonic()

    train_loss, train_acc = trainer.train()
    valid_loss, valid_acc = trainer.evaluate()

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'IMAGENETMODEL.pt')

    end_time = time.monotonic()

    epoch_mins, epoch_secs = trainer.epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')