In [71]:
import torch
import torch.nn as nn
import torchvision
# import utils
from typing import List
from torchvision import transforms, models, datasets
import torch
from torch.utils.data import DataLoader, Subset, random_split, ConcatDataset
import numpy as np
import random
from os.path import exists
import os
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import seaborn as sns
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import accuracy_score, confusion_matrix

In [72]:
use_gpu = True
device = torch.device("cuda:0" if torch.cuda.is_available() and use_gpu else "cpu")
# device = torch.device("cpu")
device

device(type='cuda', index=0)

In [73]:
import torch

# Check if CUDA (GPU support) is available
print("CUDA available: ", torch.cuda.is_available())

# Check if code is running on GPU
print("Code running on GPU: ", torch.cuda.is_initialized())


CUDA available:  True
Code running on GPU:  True


## Data Loader

In [74]:
def get_oxford_splits(
    batch_size: int,
    data_loader_seed: int = 111,
    pin_memory: bool = True,
    num_workers: int = 2,
    ):
    K = 5
    num_support = 80
    num_query = 20

    def seed_worker(worker_id):
        # worker_seed = torch.initial_seed() % 2 ** 32
        np.random.seed(data_loader_seed)
        random.seed(data_loader_seed)
    g = torch.Generator()
    g.manual_seed(data_loader_seed)

    support_classes = list(np.arange(num_support))
    query_classes = list(np.arange(num_query) + num_support)

    img_dim = 64

    train_transforms = transforms.Compose([
        transforms.Resize((img_dim, img_dim)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
    test_transforms = transforms.Compose([
        transforms.Resize((img_dim, img_dim)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
    validation_transforms = transforms.Compose([
        transforms.Resize((img_dim, img_dim)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

    data_path = f'/content/data'

    train_ds_full = datasets.Flowers102(root=data_path, split="train", download=True, transform=train_transforms)
    val_ds_full = datasets.Flowers102(root=data_path, split="val", download=True, transform=validation_transforms)
    test_ds_full = datasets.Flowers102(root=data_path, split="test", download=True, transform=test_transforms)

    train_indxs_support = torch.where(torch.isin(torch.tensor(train_ds_full._labels), torch.asarray(support_classes)))[0]
    val_indxs_support = torch.where(torch.isin(torch.tensor(val_ds_full._labels), torch.asarray(support_classes)))[0]
    test_indxs_support = torch.where(torch.isin(torch.tensor(test_ds_full._labels), torch.asarray(support_classes)))[0]

    train_ds_subset_support = torch.utils.data.Subset(train_ds_full, train_indxs_support)
    val_ds_subset_support = torch.utils.data.Subset(val_ds_full, val_indxs_support)
    test_ds_subset_support = torch.utils.data.Subset(test_ds_full, test_indxs_support)

    merged_dataset = ConcatDataset([train_ds_subset_support, val_ds_subset_support, test_ds_subset_support])
    ### A, B
    train_ds_support, test_ds_support = torch.utils.data.random_split(merged_dataset, [0.75, 0.25], generator=torch.Generator().manual_seed(42))
    ###

    train_indxs_query = torch.where(torch.isin(torch.tensor(train_ds_full._labels), torch.asarray(query_classes)))[0]
    N = 10
    starting_indices = np.arange(0, len(train_indxs_query), N)
    train_indxs_query = np.hstack([train_indxs_query[i:i+K] for i in starting_indices if i + K <= len(train_indxs_query)])
    ### C
    train_ds_query = torch.utils.data.Subset(train_ds_full, train_indxs_query)
    ###

    val_indxs_query = torch.where(torch.isin(torch.tensor(val_ds_full._labels), torch.asarray(query_classes)))[0]
    test_indxs_query = torch.where(torch.isin(torch.tensor(test_ds_full._labels), torch.asarray(query_classes)))[0]
    val_ds_subset_query = torch.utils.data.Subset(val_ds_full, val_indxs_query)
    test_ds_subset_query = torch.utils.data.Subset(test_ds_full, test_indxs_query)

    test_ds_query_full = ConcatDataset([val_ds_subset_query, test_ds_subset_query])
    ### D
    _, test_ds_query = torch.utils.data.random_split(test_ds_query_full, [0.7, 0.3], generator=torch.Generator().manual_seed(42))
    ###

    ### E
    test_all = ConcatDataset([test_ds_query, test_ds_support])


    A_train_dl = DataLoader(
        train_ds_support,
        batch_size = batch_size,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )
    A_test_dl = DataLoader(
        test_ds_support,
        batch_size = batch_size,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )

    B_train_dl = DataLoader(
        train_ds_query,
        batch_size = batch_size,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )
    B_test_dl = DataLoader(
        test_ds_query,
        batch_size = batch_size,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )
    test_all = DataLoader(
        test_all,
        batch_size = batch_size,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )

    return A_train_dl, A_test_dl, B_train_dl, B_test_dl, test_all

## Split data

In [75]:
A_train_dl, A_test_dl, B_train_dl, B_test_dl, test_all = get_oxford_splits(
    batch_size=128,
    data_loader_seed=111,
    pin_memory=False,
    num_workers=1
)

## plot output

In [76]:
def make_dir(dir_name: str):
    """
    creates directory "dir_name" if it doesn't exists
    """
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

def custom_plot_training_stats(
        acc_hist,
        loss_hist,
        phase_list,
        title: str,
        dir: str,
        name: str = 'acc_loss'):
    fig, (ax1, ax2) = plt.subplots(nrows = 1, ncols = 2, figsize=[14, 6], dpi=100)

    for phase in phase_list:
        lowest_loss_x = np.argmin(np.array(loss_hist[phase]))
        lowest_loss_y = loss_hist[phase][lowest_loss_x]

        ax1.annotate("{:.4f}".format(lowest_loss_y), [lowest_loss_x, lowest_loss_y])
        ax1.plot(loss_hist[phase], '-x', label=f'{phase} loss', markevery = [lowest_loss_x])

        ax1.set_xlabel(xlabel='epochs')
        ax1.set_ylabel(ylabel='loss')

        ax1.grid(color = 'green', linestyle = '--', linewidth = 0.5, alpha=0.75)
        ax1.legend()
        ax1.label_outer()

    # acc:
    for phase in phase_list:
        highest_acc_x = np.argmax(np.array(acc_hist[phase]))
        highest_acc_y = acc_hist[phase][highest_acc_x]

        ax2.annotate("{:.4f}".format(highest_acc_y), [highest_acc_x, highest_acc_y])
        ax2.plot(acc_hist[phase], '-x', label=f'{phase} acc', markevery = [highest_acc_x])

        ax2.set_xlabel(xlabel='epochs')
        ax2.set_ylabel(ylabel='acc')

        ax2.grid(color = 'green', linestyle = '--', linewidth = 0.5, alpha=0.75)
        ax2.legend()
        #ax2.label_outer()

    fig.suptitle(f'{title}')

    make_dir(dir)
    plt.savefig(f'{dir}/{name}.jpg')
    plt.clf()

def plot_conf(labels, preds, title, dir_, name):
    """
    labels: an [N, ] array containing true labels for N samples
    preds: an [N, ] array containing predications for N samples

    saves confusion matrix plot of the given prediction and true labels in 'dir_/name.jpg'
    """

    conf = confusion_matrix(labels, preds)

    plt.clf()
    cm = conf.astype('float') / conf.sum(axis=1)[:, np.newaxis]
    cmap = sns.light_palette("navy", as_cmap=True)
    plt.figure(figsize=(20, 20))
    sns.heatmap(cm, annot=False, cmap=cmap, fmt=".2f", cbar=False)
    plt.title(f'{title}')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    make_dir(dir_)
    plt.savefig(f'{dir_}/{name}')


## Data Analysis

In [77]:
len(A_train_dl.dataset), len(A_test_dl.dataset), len(B_train_dl.dataset), len(B_test_dl.dataset),

(4617, 1538, 100, 518)

In [78]:
len(A_train_dl), len(B_train_dl), len(B_train_dl), len(B_test_dl)

(37, 1, 1, 5)

In [79]:
A_train_dl.dataset[0][0].shape

torch.Size([3, 64, 64])

In [80]:
for data_indx, (input, target) in enumerate(A_train_dl.dataset):
    if data_indx < 3:
        print('data index: ', data_indx)
        print('input: ', input.shape, ', type of input: ', type(input))
        print('target: ', target, ', type of target: ', type(target))
    else:
        break

data index:  0
input:  torch.Size([3, 64, 64]) , type of input:  <class 'torch.Tensor'>
target:  46 , type of target:  <class 'int'>
data index:  1
input:  torch.Size([3, 64, 64]) , type of input:  <class 'torch.Tensor'>
target:  54 , type of target:  <class 'int'>
data index:  2
input:  torch.Size([3, 64, 64]) , type of input:  <class 'torch.Tensor'>
target:  17 , type of target:  <class 'int'>


## Making network

In [81]:
class CNN(nn.Module):
    def __init__(self, num_classes=80):
        super(CNN, self).__init__()

        self.flatten = nn.Flatten()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(96)
        self.relu3 = nn.ReLU()

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(96, 128, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.ReLU()

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.ReLU()

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.flatten = nn.Flatten()

        self.fc = nn.Linear(256*4*4, num_classes)

    def forward(self, inputs, debug=False):
        # conv 1
        conv1 = self.conv1(inputs)
        # conv1 = self.conv1(inputs.unsqueeze(0))
        bn1 = self.bn1(conv1)
        relu1 = self.relu1(bn1)

        # conv 2
        conv2_1 = self.conv2(relu1)
        bn2_1 = self.bn2(conv2_1)
        relu2_1 = self.relu2(bn2_1)

        conv2_2 = self.conv2(relu2_1)
        bn2_2 = self.bn2(conv2_2)
        relu2_2 = self.relu2(bn2_2)

        conv2_3 = self.conv2(relu2_2)
        bn2_3 = self.bn2(conv2_3)
        relu2_3 = self.relu2(bn2_3)

        conv2_4 = self.conv2(relu2_3)
        bn2_4 = self.bn2(conv2_4)
        relu2_4 = self.relu2(bn2_4)

        # pool 1
        pool1 = self.pool1(relu2_4)

        # conv 3
        conv3_1 = self.conv3(pool1)
        bn3_1 = self.bn3(conv3_1)
        relu3_1 = self.relu3(bn3_1)

        conv3_2 = self.conv3_2(relu3_1)
        bn3_2 = self.bn3(conv3_2)
        relu3_2 = self.relu3(bn3_2)

        conv3_3 = self.conv3_2(relu3_2)
        bn3_3 = self.bn3(conv3_3)
        relu3_3 = self.relu3(bn3_3)

        conv3_4 = self.conv3_2(relu3_3)
        bn3_4 = self.bn3(conv3_4)
        relu3_4 = self.relu3(bn3_4)

        # pool 2
        pool2 = self.pool2(relu3_4)

        # conv 4
        conv4_1 = self.conv4(pool2)
        bn4_1 = self.bn4(conv4_1)
        relu4_1 = self.relu4(bn4_1)

        conv4_2 = self.conv4_2(relu4_1)
        bn4_2 = self.bn4(conv4_2)
        relu4_2 = self.relu4(bn4_2)

        conv4_3 = self.conv4_2(relu4_2)
        bn4_3 = self.bn4(conv4_3)
        relu4_3 = self.relu4(bn4_3)

        conv4_4 = self.conv4_2(relu4_3)
        bn4_4 = self.bn4(conv4_4)
        relu4_4 = self.relu4(bn4_4)

        # pool 3
        pool3 = self.pool3(relu4_4)

        # conv 5
        conv5_1 = self.conv5(pool3)
        bn5_1 = self.bn5(conv5_1)
        relu5_1 = self.relu5(bn5_1)

        conv5_2 = self.conv5_2(relu5_1)
        bn5_2 = self.bn5(conv5_2)
        relu5_2 = self.relu5(bn5_2)

        conv5_3 = self.conv5_2(relu5_2)
        bn5_3 = self.bn5(conv5_3)
        relu5_3 = self.relu5(bn5_3)

        conv5_4 = self.conv5_2(relu5_3)
        bn5_4 = self.bn5(conv5_4)
        relu5_4 = self.relu5(bn5_4)

        # pool 4
        pool4 = self.pool4(relu5_4)

        # fc
        flatten = self.flatten(pool4)
        fc = self.fc(flatten)

        return(fc)

In [82]:
# DataLoader.datasets

## Train model

In [83]:
def train_one_epoch(model: nn.Module, optim: torch.optim.Optimizer,
                    dataloader: DataLoader, loss_fn):

    num_samples = len(dataloader.dataset)
    num_batches = len(dataloader)
    running_corrects = 0
    running_loss = 0.0

    model.train()

    for batch_indx, (inputs, targets) in enumerate(dataloader): # Get a batch of Data
        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = model(inputs) # Forward Pass
        loss = loss_fn(outputs, targets) # Compute Loss

        loss.backward() # Compute Gradients
        optim.step() # Update parameters
        optim.zero_grad() # zero the parameter's gradients

        _, preds = torch.max(outputs, dim=1)
        running_corrects += torch.sum(preds == targets).cpu()
        running_loss += loss.item()

        print(f"-- >> preds:\n{preds}")
        print(f"-- >> targets:\n{targets}")

        # print(f"< TRAIN >--< {batch_indx} >-------------------------------")
        # print(f"out:\n{outputs}")
        # print(f"tar:\n{target_tensor}")
        # print(f"running_corrects:\n{running_corrects}")
        # print(f"running_loss:\n{running_loss}")

    epoch_acc = (running_corrects / num_samples) * 100
    print(running_corrects, num_samples)
    epoch_loss = (running_loss / num_batches)

    return epoch_acc, epoch_loss

## Test model

In [84]:
def test_model(model: nn.Module,
               dataloader: DataLoader, loss_fn):

    num_samples = len(dataloader.dataset)
    num_batches = len(dataloader)
    running_corrects = 0
    running_loss = 0.0

    # we call `model.eval()` to set dropout and batch normalization layers to evaluation mode before running inference.
    model.eval()

    with torch.no_grad():
        for batch_indx, (inputs, targets) in enumerate(dataloader): # Get a batch of Data
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs) # Forward Pass
            loss = loss_fn(outputs, targets) # Compute Loss

            _, preds = torch.max(outputs, 1) #
            running_corrects += torch.sum(preds == targets).cpu()
            running_loss += loss.item()

            # print(f"< TEST >--< {batch_indx} >-------------------------------")
            # print(f"out:\n{outputs}")
            # print(f"tar:\n{target_tensor}")
            # print(f"running_corrects:\n{running_corrects}")
            # print(f"running_loss:\n{running_loss}")

    test_acc = (running_corrects / num_samples) * 100
    test_loss = (running_loss / num_batches)

    return test_acc, test_loss

## Evaluate model

In [85]:
def evaluate():
    num_epochs = 150
    learning_rate = 0.005

    full_dataloaders = {
        'train': A_train_dl,
        'test': A_test_dl
    }

    model = CNN()
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    cross_entropy = nn.CrossEntropyLoss()

    acc_history = {'train': [], 'test': []}
    loss_history = {'train': [], 'test': []}

    for epoch in range(num_epochs):
        train_acc, train_loss = train_one_epoch(model=model, optim=optimizer, dataloader=full_dataloaders['train'], loss_fn=cross_entropy)
        test_acc, test_loss = test_model(model=model, dataloader=full_dataloaders['test'], loss_fn=cross_entropy)

        acc_history['train'].append(train_acc)
        acc_history['test'].append(test_acc)
        loss_history['train'].append(train_loss)
        loss_history['test'].append(test_loss)

        print(f"---------< epoch: {epoch} >---------")
        print(f"acc_history:\n{acc_history}")
        print(f"loss_history:\n{loss_history}")

    custom_plot_training_stats(acc_history, loss_history, ['train', 'test'], title='demp', dir='demo_plots')

    return (acc_history, loss_history)

In [86]:
acc_history, loss_history = evaluate()

-- >> preds:
tensor([16, 16, 16, 16, 16, 79, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
        16, 63, 16, 16, 16, 16, 16, 16,  1, 16, 16, 16, 16, 10,  0, 16, 16, 16,
        16, 59, 16, 16, 14, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 10, 16,
         8, 16,  8, 16, 49, 16, 16, 16, 16, 14, 59,  0, 16, 16, 16, 16, 16, 16,
        16,  8, 16, 59, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 10, 73,  1, 16,
        16, 16, 59, 16,  8, 16, 16, 16, 16, 16, 16, 16, 24, 16, 16, 16, 16, 49,
        16, 16, 16, 16, 16, 16, 16, 16, 16, 59, 16, 16, 16,  8, 16, 16, 16, 16,
        16, 16], device='cuda:0')
-- >> targets:
tensor([10, 11, 27, 45, 68, 13, 11, 41, 76, 37, 33, 45, 72, 77, 79, 51, 16, 64,
        51, 49, 15, 38, 38, 36, 76, 29, 57, 10, 39, 76, 76, 79, 47, 56, 27, 77,
        51, 75, 64, 50, 50, 50, 22, 67, 74, 50,  1, 39,  9, 73, 23, 77, 10, 76,
        42, 42, 52, 45, 42, 71, 52, 55, 79, 58, 28, 36, 18, 76, 30, 13, 42, 36,
        11, 16, 13, 14, 17, 62, 29, 67, 59, 40,  6, 72, 36

<Figure size 1400x600 with 0 Axes>

In [87]:
loss_history

{'train': [8.576656676627493, 5.664602125013197],
 'test': [4.316541891831618, 4.556376567253699]}

In [88]:
acc_history

{'train': [tensor(3.5088), tensor(4.2019)],
 'test': [tensor(3.9662), tensor(4.3563)]}