In [None]:
# imports
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch
import torchvision
import torchmetrics
import time

In [None]:
# Utils

# Creating directory
def create_dir(addr):
    if not os.path.exists(addr):
        os.mkdir(addr)

# Delete folder and its content
def remove_folder_contents(folder):
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                remove_folder_contents(file_path)
                os.rmdir(file_path)
        except Exception as e:
            print(e)

# Addresses

# Raw data
raw_address = "../input/indian-birds/Birds_25"
raw_train = "../input/indian-birds/Birds_25/train"
raw_test = "../input/indian-birds/Birds_25/test"
raw_val = "../input/indian-birds/Birds_25/val"

# Temp Address
temp_address = "temp"

# Results
result_address = "results/"
base_model_result = "results/base_model"
bn_model_result = "results/bn_model"
in_model_result = "results/in_model"
bin_model_result = "results/bin_model"
ln_model_result = "results/ln_model"
gn_model_result = "results/gn_model"
nn_model_result = "results/nn_model"
bn_model_bsize8_result = "results/bn_model_bsize8"
gn_model_bsize8_result = "results/gn_model_bsize8"
bn_model_bsize128_result = "results/bn_model_bsize128"
gn_model_bsize128_result = "results/gn_model_bsize128"
overall_result_address = "results/overall"
bn_vs_base_result = "results/bn_base"
bn_vs_gn_bsize_result = "results/bn_gn_bsize"

# Init Structure
create_dir(temp_address)
create_dir(raw_address)
create_dir(raw_train)
create_dir(raw_test)
create_dir(raw_val)
create_dir(result_address)
create_dir(overall_result_address)
create_dir(bn_vs_base_result)
create_dir(bn_vs_gn_bsize_result)
# remove_folder_contents(result_address)

# Constants
random_seed = 68
n = 2
r = 25
batch_size = 32

# Setting Random Seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

In [None]:
# cuda
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Working with {device}")

In [None]:
# Data Loader

# Finding Category
category = sorted(os.listdir(raw_train))

# Transformation for preprocessing
transform_augment = torchvision.transforms.Compose([
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.RandomVerticalFlip(),
                    torchvision.transforms.RandomAutocontrast(),
                    torchvision.transforms.ToTensor()
                    ])

transform_normal = torchvision.transforms.Compose([
                   torchvision.transforms.ToTensor()
                   ])

class DataSet(torch.utils.data.Dataset):
    def __init__(self, address, augment = True):
        self.address = address
        self.cat_list = []
        for cat in category:
            cat_path = os.path.join(self.address, cat)
            self.cat_list.append(sorted(os.listdir(cat_path)))
        self.transform = transform_augment if augment else transform_normal

    def __len__(self):
        return sum([len(elem) for elem in self.cat_list])
    
    def __getitem__(self, idx):
        ind = idx
        for cat_ind in range(len(self.cat_list)):
            if ind < len(self.cat_list[cat_ind]):
                cat_path = os.path.join(self.address, category[cat_ind])
                img_path = os.path.join(cat_path, self.cat_list[cat_ind][ind])
                break
            else:
                ind -= len(self.cat_list[cat_ind])
        img = torchvision.datasets.folder.default_loader(img_path)
        x = self.transform(img).to(torch.float)
        return x, cat_ind

# Training Dataset and Data Loader
dataset_train = DataSet(raw_train, augment=False)
data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers = 4)

# Validation Dataset and Data Loader
dataset_val = DataSet(raw_val, augment=False)
data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers = 4)

# Test Dataset and Data Loader
dataset_test = DataSet(raw_test, augment=False)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers = 4)

In [None]:
# Model

class ResBlock(torch.nn.Module):
    def __init__(self, in_channel, out_channel, norm_func, kernel_size=3, stride=1):
        super(ResBlock, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding=1, dtype=torch.float)
        self.norm1 = norm_func(out_channel, dtype=torch.float)
        self.activation1 = torch.nn.ReLU()
        
        self.conv2 = torch.nn.Conv2d(out_channel, out_channel, kernel_size, stride=1, padding=1, dtype=torch.float)
        self.norm2 = norm_func(out_channel, dtype=torch.float)
        self.activation2 = torch.nn.ReLU()
        
        self.project = True if (in_channel != out_channel) or (stride != 1) else False
        if self.project:
            self.conv_project = torch.nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding=1, dtype=torch.float)

    def forward(self, x):
        res = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.activation1(x)
        
        x = self.conv2(x)
        x = self.norm2(x)
        x += self.conv_project(res) if self.project else res
        x = self.activation2(x)
        
        return x

class Resnet(torch.nn.Module):
    def __init__(self, n, r, norm_func = torch.nn.BatchNorm2d):
        super(Resnet, self).__init__()

        self.norm = norm_func
        
        #Input
        self.input_layer = []
        self.input_layer.append(torch.nn.Conv2d(3, 16, 3, 1, padding=1, dtype=torch.float))
        self.input_layer.append(self.norm(16, dtype=torch.float))
        self.input_layer.append(torch.nn.ReLU())
        self.input_layer = torch.nn.Sequential(*self.input_layer)
        
        # Layer1
        self.hidden_layer1 = []
        for i in range(n):
            self.hidden_layer1.append(ResBlock(16, 16, self.norm))
        self.hidden_layer1 = torch.nn.Sequential(*self.hidden_layer1)
        
        # Layer2
        self.hidden_layer2 = []
        self.hidden_layer2.append(ResBlock(16, 32, self.norm, stride = 2))
        for i in range(n-1):
            self.hidden_layer2.append(ResBlock(32, 32, self.norm))
        self.hidden_layer2 = torch.nn.Sequential(*self.hidden_layer2)
        
        # Layer3
        self.hidden_layer3 = []
        self.hidden_layer3.append(ResBlock(32, 64, self.norm, stride = 2))
        for i in range(n-1):
            self.hidden_layer3.append(ResBlock(64, 64, self.norm))
        self.hidden_layer3 = torch.nn.Sequential(*self.hidden_layer3)
            
        # Pool Layer
        self.pool = torch.nn.AdaptiveAvgPool2d(1)
        self.flatten = torch.nn.Flatten()
        
        # Output Layer
        self.output_layer = torch.nn.Linear(64, r, dtype=torch.float)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.hidden_layer1(x)
        x = self.hidden_layer2(x)
        x = self.hidden_layer3(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.output_layer(x)
        
        return x


In [None]:
# Training Model

def train(model, data_loader, save_addr, num_epoch = 50, learning_rate = 1e-3, overwrite = False):
    # Creating save folder
    create_dir(save_addr)
    model_addr = os.path.join(save_addr, 'model')
    create_dir(model_addr)
    loss_addr = os.path.join(save_addr, 'loss')
    create_dir(loss_addr)

    # Setting model to train mode
    model.train()

    # Parameters for training
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    start_time = time.time()

    # Metrics
    metric_f1_micro = torchmetrics.classification.MulticlassF1Score(num_classes = r, average = 'micro').to(device)
    metric_f1_macro = torchmetrics.classification.MulticlassF1Score(num_classes = r, average = 'macro').to(device)
    metric_accuracy = torchmetrics.classification.Accuracy(task = 'multiclass', num_classes = r).to(device)

    for epoch in range(num_epoch):
        batch_ct = 0
        epoch_loss = 0
        label_arr = torch.tensor([], device=device)
        label_pred_arr = torch.tensor([], device=device)

        # Loading previous model
        epoch_addr = os.path.join(model_addr, f'{0 if epoch < 10 else ""}{epoch}.pt')
        epoch_loss_addr = os.path.join(loss_addr, f'{0 if epoch < 10 else ""}{epoch}.pt')

        if not overwrite and os.path.exists(epoch_addr) and os.path.exists(epoch_loss_addr):
            model.load_state_dict(torch.load(epoch_addr))
            model.train()
            loss_arr = torch.load(epoch_loss_addr)
            epoch_loss = loss_arr[0].item()
            accuracy = loss_arr[1].item()
            f1_micro = loss_arr[2].item()
            f1_macro = loss_arr[3].item()

            print(f"Epoch: {epoch} Loaded\t\tLoss: {epoch_loss}\tAccuracy: {accuracy}\tf1_micro: {f1_micro}\tf1_macro: {f1_macro}")
        else:
            # Training next epoch
            for x, y in data_loader:
                # To Device
                x = x.to(device)
                y = y.to(device)

                # Predictions
                y_pred = model(x)

                # Maintaining label array
                label, label_pred = y, torch.argmax(y_pred, dim=1)
                label_arr = torch.cat((label_arr, label))
                label_pred_arr = torch.cat((label_pred_arr, label_pred))

                # Calculating Loss
                loss = loss_fn(y_pred, y)
                epoch_loss += loss.item()

                # Back Propogation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Log
                batch_ct += 1

            # Computing Metrics
            f1_micro = metric_f1_micro(label_pred_arr, label_arr).item()
            f1_macro = metric_f1_macro(label_pred_arr, label_arr).item()
            accuracy = metric_accuracy(label_pred_arr, label_arr).item()
            loss_arr = torch.tensor([epoch_loss/batch_ct, accuracy, f1_micro, f1_macro])

            # Saving model after each epoch
            torch.save(model.state_dict(), epoch_addr)
            torch.save(loss_arr, epoch_loss_addr)

            print(f"Epoch: {epoch}\tLoss: {epoch_loss/batch_ct}\tAccuracy: {accuracy}\tf1_micro: {f1_micro}\tf1_macro: {f1_macro}\tTime: {time.time() - start_time}")


In [None]:
# Validate Model

def validate(model, data_loader, save_addr, load_addr, overwrite = False):
    # Creating save folder
    create_dir(save_addr)
    model_addr_load = os.path.join(load_addr, 'model')

    # Parameters for training
    loss_fn = torch.nn.CrossEntropyLoss()
    start_time = time.time()

    # Metrics
    metric_f1_micro = torchmetrics.classification.MulticlassF1Score(num_classes = r, average = 'micro').to(device)
    metric_f1_macro = torchmetrics.classification.MulticlassF1Score(num_classes = r, average = 'macro').to(device)
    metric_accuracy = torchmetrics.classification.Accuracy(task = 'multiclass', num_classes = r).to(device)

    epoch = 0
    for param_addr in sorted(os.listdir(model_addr_load)):
        # Save Address
        epoch_loss_addr = os.path.join(save_addr, param_addr)

        # Checking if already present
        if not overwrite and os.path.exists(epoch_loss_addr):
            loss_arr = torch.load(epoch_loss_addr)
            epoch_loss = loss_arr[0].item()
            accuracy = loss_arr[1].item()
            f1_micro = loss_arr[2].item()
            f1_macro = loss_arr[3].item()

            print(f"Epoch: {epoch} Loaded\t\tLoss: {epoch_loss}\tAccuracy: {accuracy}\tf1_micro: {f1_micro}\tf1_macro: {f1_macro}")
            epoch += 1
            continue
        
        # Initializing Variable
        batch_ct = 0
        epoch_loss = 0
        label_arr = torch.tensor([], device=device)
        label_pred_arr = torch.tensor([], device=device)

        # Loading Model
        model.load_state_dict(torch.load(os.path.join(model_addr_load, param_addr)))

        # Evaluate Model and freezing it
        model.eval()
        with torch.no_grad():
            for x, y in data_loader:
                # To Device
                x = x.to(device)
                y = y.to(device)
                
                # Predictions
                y_pred = model(x)

                # Maintaining label array
                label, label_pred = y, torch.argmax(y_pred, dim=1)
                label_arr = torch.cat((label_arr, label))
                label_pred_arr = torch.cat((label_pred_arr, label_pred))

                # Calculating Loss
                loss = loss_fn(y_pred, y)
                epoch_loss += loss.item()

                # Log
                batch_ct += 1

        # Computing Metrics
        f1_micro = metric_f1_micro(label_pred_arr, label_arr).item()
        f1_macro = metric_f1_macro(label_pred_arr, label_arr).item()
        accuracy = metric_accuracy(label_pred_arr, label_arr).item()
        loss_arr = torch.tensor([epoch_loss/batch_ct, accuracy, f1_micro, f1_macro])

        # Saving model after each epoch
        torch.save(loss_arr, epoch_loss_addr)

        # Log
        print(f"Epoch: {epoch}\tLoss: {epoch_loss/batch_ct}\tAccuracy: {accuracy}\tf1_micro: {f1_micro}\tf1_macro: {f1_macro}\tTime: {time.time() - start_time}")
        epoch += 1


In [None]:
# Plot
def save(arr_x, arr_train, arr_val, address, title, y_label, y_lim = None):
    fig, ax = plt.subplots()
    ax.plot(arr_x, arr_train, label = 'Train')
    ax.plot(arr_x, arr_val, label = 'Validation')
    ax.set_xlabel("num_epochs")
    ax.set_ylabel(y_label)
    ax.set_title(title)
    if y_lim:
        ax.set_ylim(y_lim[0], y_lim[1])
    ax.legend()
    plt.savefig(address)

def plot(train_address, val_address, save_address, name):
    # Loading metrics
    train_metrics = []
    val_metrics = []
    for epoch_name in sorted(os.listdir(train_address)):
        train_metrics.append(np.array(torch.load(os.path.join(train_address, epoch_name))))
        val_metrics.append(np.array(torch.load(os.path.join(val_address, epoch_name))))
    train_metrics = np.stack(train_metrics)
    val_metrics = np.stack(val_metrics)

    arr_x = np.arange(1, train_metrics.shape[0]+1)
    save(arr_x, train_metrics[:, 0], val_metrics[:, 0], os.path.join(save_address, 'loss'), f"{name}: Cross Entropy Loss", "Loss")
    save(arr_x, 100*train_metrics[:, 1], 100*val_metrics[:, 1], os.path.join(save_address, 'acc'), f"{name}: Accuracy", "Accuracy")
    save(arr_x, train_metrics[:, 2], val_metrics[:, 2], os.path.join(save_address, 'micro'), f"{name}: F1 Micro Score", "f1_micro")
    save(arr_x, train_metrics[:, 3], val_metrics[:, 3], os.path.join(save_address, 'macro'), f"{name}: F1 Macro Score", "f1_macro")

    return np.argmin(val_metrics[:, 0]), train_metrics, val_metrics

# Report
def report(address, ind):
    metric_arr = []
    for folder_name in ['loss', 'val', 'test']:
        folder_address = os.path.join(address, folder_name)
        metric_addr = os.path.join(folder_address, sorted(os.listdir(folder_address))[ind])
        metric_arr.append(np.array(torch.load(metric_addr)))

    with open(os.path.join(address, 'result.txt'), 'w') as file:
        file.write("Training:\n")
        file.write(f"\tLoss:\t\t{metric_arr[0][0]}\n")
        file.write(f"\tAccuracy:\t{metric_arr[0][1]}\n")
        file.write(f"\tF1_Micro:\t{metric_arr[0][2]}\n")
        file.write(f"\tF1_Macro:\t{metric_arr[0][3]}\n")
        file.write("\nValidation:\n")
        file.write(f"\tLoss:\t\t{metric_arr[1][0]}\n")
        file.write(f"\tAccuracy:\t{metric_arr[1][1]}\n")
        file.write(f"\tF1_Micro:\t{metric_arr[1][2]}\n")
        file.write(f"\tF1_Macro:\t{metric_arr[1][3]}\n")
        file.write("\nTesting:\n")
        file.write(f"\tLoss:\t\t{metric_arr[2][0]}\n")
        file.write(f"\tAccuracy:\t{metric_arr[2][1]}\n")
        file.write(f"\tF1_Micro:\t{metric_arr[2][2]}\n")
        file.write(f"\tF1_Macro:\t{metric_arr[2][3]}\n")
    
    return metric_arr


In [None]:
# Model Function

def model_func(model, address, name):
    train(model, data_loader_train, address)
    validate(model, data_loader_val, os.path.join(address, 'val'), address)
    validate(model, data_loader_test, os.path.join(address, 'test'), address)
    best_model_ind, train_metric, val_metric = plot(os.path.join(address, 'loss'), os.path.join(address, 'val'), address, name)
    best_model_metric = report(address, best_model_ind)
    return train_metric, val_metric, best_model_metric

# Base Model

resnet_base_model = Resnet(n, r).to(device)
base_model_metric_train, base_model_metric_val, base_model_metric_best = model_func(resnet_base_model, base_model_result, 'Base Model')

In [None]:
# Batch Normalization

class BN_norm(torch.nn.Module):
    def __init__(self, channel_num, dtype = torch.float, momentum = 0.9, epsilon = 1e-8):
        super(BN_norm, self).__init__()

        # Class Constants
        self.dtype = dtype
        self.momentum = momentum
        self.epsilon = epsilon

        # Model Parameters
        self.mean = torch.nn.Parameter(torch.zeros(channel_num, dtype=dtype))
        self.std = torch.nn.Parameter(torch.ones(channel_num, dtype=dtype))

        # Running Parameters
        self.register_buffer('running_mean', torch.zeros((1, channel_num, 1, 1), dtype=dtype))
        self.register_buffer('running_var', torch.ones((1, channel_num, 1, 1), dtype=dtype))
    
    def forward(self, x):
        if self.training:
            # Calculating Batch Stats
            batch_mean = torch.mean(x, (0, 2, 3), keepdim=True)
            batch_var = torch.var(x, (0, 2, 3), unbiased=False, keepdim=True)

            # Normalizing Input
            x = (x - batch_mean)/torch.sqrt(batch_var + self.epsilon)

            # Updating running statistics
            self.running_mean = self.momentum*self.running_mean + (1-self.momentum)*batch_mean
            self.running_var = self.momentum*self.running_var + (1-self.momentum)*batch_var
        
        else:
            # Normalizing Input
            x = (x - self.running_mean)/torch.sqrt(self.running_var + self.epsilon)
        
        # Scale and shift
        x = x*self.std.view(1, -1, 1, 1) + self.mean.view(1, -1, 1, 1)

        return x
    
resnet_bn_model = Resnet(n, r, norm_func=BN_norm).to(device)
bn_model_metric_train, bn_model_metric_val, bn_model_metric_best = model_func(resnet_bn_model, bn_model_result, 'BN Model')

In [None]:
# Instance Normalization

class IN_norm(torch.nn.Module):
    def __init__(self, channel_num, dtype = torch.float,  epsilon = 1e-8):
        super(IN_norm, self).__init__()

        # Class Constants
        self.dtype = dtype
        self.epsilon = epsilon

    def forward(self, x):
        # Calculating Batch Stats
        batch_mean = torch.mean(x, (2, 3), keepdim=True)
        batch_var = torch.var(x, (2, 3), unbiased=False, keepdim=True)

        # Normalizing Input
        x = (x - batch_mean)/torch.sqrt(batch_var + self.epsilon)

        return x
    
resnet_in_model = Resnet(n, r, norm_func=IN_norm).to(device)
in_model_metric_train, in_model_metric_val, in_model_metric_best = model_func(resnet_in_model, in_model_result, 'IN Model')

In [None]:
# Batch Instance Normalization

class BIN_norm(torch.nn.Module):
    def __init__(self, channel_num, dtype = torch.float, momentum = 0.9, epsilon = 1e-8):
        super(BIN_norm, self).__init__()

        # Class Constants
        self.dtype = dtype
        self.momentum = momentum
        self.epsilon = epsilon

        # Model Parameters
        self.mean = torch.nn.Parameter(torch.zeros((1, channel_num, 1, 1), dtype=dtype))
        self.std = torch.nn.Parameter(torch.ones((1, channel_num, 1, 1), dtype=dtype))
        self.rho = torch.nn.Parameter(torch.rand((1, channel_num, 1, 1), dtype=dtype))

        # Running Parameters
        self.register_buffer('running_mean', torch.zeros((1, channel_num, 1, 1), dtype=dtype))
        self.register_buffer('running_var', torch.ones((1, channel_num, 1, 1), dtype=dtype))
    
    def forward(self, x):
        # Clipping rho
        with torch.no_grad():
            self.rho = torch.nn.Parameter(torch.clamp(self.rho, min=0, max=1))

        # Batch Normalization
        if self.training:
            # Calculating Batch Stats
            batch_mean = torch.mean(x, (0, 2, 3), keepdim=True)
            batch_var = torch.var(x, (0, 2, 3), unbiased=False, keepdim=True)

            # Normalizing Input
            x_bn = (x - batch_mean)/torch.sqrt(batch_var + self.epsilon)

            # Updating running statistics
            self.running_mean = self.momentum*self.running_mean + (1-self.momentum)*batch_mean
            self.running_var = self.momentum*self.running_var + (1-self.momentum)*batch_var
        
        else:
            # Normalizing Input
            x_bn = (x - self.running_mean)/torch.sqrt(self.running_var + self.epsilon)
        
        # Instance Normalization
        # Calculating Batch Stats
        batch_mean = torch.mean(x, (2, 3), keepdim=True)
        batch_var = torch.var(x, (2, 3), unbiased=False, keepdim=True)

        # Normalizing Input
        x_in = (x - batch_mean)/torch.sqrt(batch_var + self.epsilon)

        # Combining BN and IN
        x = self.rho*x_bn + (1-self.rho)*x_in

        # Scale and shift
        x = x*self.std + self.mean

        return x
    
resnet_bin_model = Resnet(n, r, norm_func=BIN_norm).to(device)
bin_model_metric_train, bin_model_metric_val, bin_model_metric_best = model_func(resnet_bin_model, bin_model_result, 'BIN Model')
    

In [None]:
# Layer Normalization

class LN_norm(torch.nn.Module):
    def __init__(self, channel_num, dtype = torch.float, epsilon = 1e-8):
        super(LN_norm, self).__init__()

        # Class Constants
        self.dtype = dtype
        self.epsilon = epsilon

        # Model Parameters
        self.mean = torch.nn.Parameter(torch.zeros((1, channel_num, 1, 1), dtype=self.dtype))
        self.std = torch.nn.Parameter(torch.ones((1, channel_num, 1, 1), dtype=self.dtype))

    def forward(self, x):
        # Calculating Batch Stats
        batch_mean = torch.mean(x, (1, 2, 3), keepdim=True)
        batch_var = torch.var(x, (1, 2, 3), unbiased=False, keepdim=True)

        # Normalizing Input
        x = (x - batch_mean)/torch.sqrt(batch_var + self.epsilon)

        # Scale and shift
        x = x*self.std + self.mean

        return x


resnet_ln_model = Resnet(n, r, norm_func=LN_norm).to(device)
ln_model_metric_train, ln_model_metric_val, ln_model_metric_best = model_func(resnet_ln_model, ln_model_result, 'LN Model')

In [None]:
# Group Normalization

class GN_norm(torch.nn.Module):
    def __init__(self, channel_num, group_num = 8, dtype = torch.float, epsilon = 1e-8):
        super(GN_norm, self).__init__()

        # Class Constants
        self.dtype = dtype
        self.epsilon = epsilon
        self.g = group_num

        # Model Parameters
        self.mean = torch.nn.Parameter(torch.zeros((1, channel_num, 1, 1), dtype=self.dtype))
        self.std = torch.nn.Parameter(torch.ones((1, channel_num, 1, 1), dtype=self.dtype))

    def forward(self, x):
        # Reshaping
        n, c, h, w = x.shape
        x = torch.reshape(x, (n, self.g, c//self.g, h, w))

        # Calculating Batch Stats
        grp_mean = torch.mean(x, (2, 3, 4), keepdim=True)
        grp_var = torch.var(x, (2, 3, 4), unbiased=False, keepdim=True)

        # Normalizing Input
        x = (x - grp_mean)/torch.sqrt(grp_var + self.epsilon)

        # Reshaping to original shape
        x = torch.reshape(x, (n, c, h, w))

        # Scale and shift
        x = x*self.std + self.mean

        return x

resnet_gn_model = Resnet(n, r, norm_func=GN_norm).to(device)
gn_model_metric_train, gn_model_metric_val, gn_model_metric_best = model_func(resnet_gn_model, gn_model_result, 'GN Model')

In [None]:
# No Normalization

class NN_norm(torch.nn.Module):
    def __init__(self, channel_num, dtype = torch.float):
        super(NN_norm, self).__init__()
        self.dtype = dtype
    
    def forward(self, x):
        return x.to(self.dtype)

resnet_nn_model = Resnet(n, r, norm_func=NN_norm).to(device)
nn_model_metric_train, nn_model_metric_val, nn_model_metric_best = model_func(resnet_nn_model, nn_model_result, 'NN Model')

In [None]:
# Variation with Batch Size

# Batch Size: 8
data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=8, shuffle=True, num_workers = 4)
data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=8, shuffle=False, num_workers = 4)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=8, shuffle=False, num_workers = 4)

resnet_bn_model_bsize8 = Resnet(n, r, norm_func=BN_norm).to(device)
bn_model_bsize8_metric_train, bn_model_bsize8_metric_val, bn_model_bsize8_metric_best = model_func(resnet_bn_model_bsize8, bn_model_bsize8_result, 'BN Model, Batch Size = 8')

resnet_gn_model_bsize8 = Resnet(n, r, norm_func=GN_norm).to(device)
gn_model_bsize8_metric_train, gn_model_bsize8_metric_val, gn_model_bsize8_metric_best = model_func(resnet_gn_model_bsize8, gn_model_bsize8_result, 'GN Model, Batch Size = 8')

# Batch Size: 128
data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True, num_workers = 4)
data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=128, shuffle=False, num_workers = 4)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=128, shuffle=False, num_workers = 4)

resnet_bn_model_bsize128 = Resnet(n, r, norm_func=BN_norm).to(device)
bn_model_bsize128_metric_train, bn_model_bsize128_metric_val, bn_model_bsize128_metric_best = model_func(resnet_bn_model_bsize128, bn_model_bsize128_result, 'BN Model, Batch Size = 128')

resnet_gn_model_bsize128 = Resnet(n, r, norm_func=GN_norm).to(device)
gn_model_bsize128_metric_train, gn_model_bsize128_metric_val, gn_model_bsize128_metric_best = model_func(resnet_gn_model_bsize128, gn_model_bsize128_result, 'GN Model, Batch Size = 128')


In [None]:
# Save Overall Plots
def save_overall(address, arr_x, metric, legend, title, y_label, y_lim = None):
    fig, ax = plt.subplots()
    for ind in range(metric.shape[0]):
        ax.plot(arr_x, metric[ind], label = legend[ind])
    ax.set_xlabel("num_epochs")
    ax.set_ylabel(y_label)
    ax.set_title(title)
    if y_lim:
        ax.set_ylim(y_lim[0], y_lim[1])
    ax.legend()
    plt.savefig(address)
    return

def save_comp(overall_result_address, arr_x, metric_train, metric_val, legend):
    save_overall(os.path.join(overall_result_address, 'train_acc'), arr_x, 100*metric_train[:, :, 1], legend, 'Train Accuracy', "Accuracy")
    save_overall(os.path.join(overall_result_address, 'val_acc'), arr_x, 100*metric_val[:, :, 1], legend, 'Validation Accuracy', "Accuracy")
    save_overall(os.path.join(overall_result_address, 'train_loss'), arr_x, metric_train[:, :, 0], legend, 'Train Loss', "Loss")
    save_overall(os.path.join(overall_result_address, 'val_loss'), arr_x, metric_val[:, :, 0], legend, 'Validation Loss', "Loss")
    save_overall(os.path.join(overall_result_address, 'train_f1_micro'), arr_x, metric_train[:, :, 2], legend, 'Train F1 Micro', "F1 Micro")
    save_overall(os.path.join(overall_result_address, 'val_f1_micro'), arr_x, metric_val[:, :, 2], legend, 'Validation F1 Micro', "F1 Micro")
    save_overall(os.path.join(overall_result_address, 'train_f1_macro'), arr_x, metric_train[:, :, 3], legend, 'Train F1 Macro', "F1 Macro")
    save_overall(os.path.join(overall_result_address, 'val_f1_macro'), arr_x, metric_val[:, :, 3], legend, 'Validation F1 Macro', "F1 Macro")

# Overall Metric
metric_train = np.stack([bn_model_metric_train, in_model_metric_train, bin_model_metric_train, ln_model_metric_train, gn_model_metric_train, nn_model_metric_train])
metric_val = np.stack([bn_model_metric_val, in_model_metric_val, bin_model_metric_val, ln_model_metric_val, gn_model_metric_val, nn_model_metric_val])
metric_best = np.stack([bn_model_metric_best, in_model_metric_best, bin_model_metric_best, ln_model_metric_best, gn_model_metric_best, nn_model_metric_best])
legend = ['BN Model', 'IN Model', 'BIN Model', 'LN Model', 'GN Model', 'NN Model']
arr_x = np.arange(1, metric_train.shape[1]+1)
save_comp(overall_result_address, arr_x, metric_train, metric_val, legend)

In [None]:
# Compare Base Model with BN model

metric_train = np.stack([base_model_metric_train, bn_model_metric_train])
metric_val = np.stack([base_model_metric_val, bn_model_metric_val])
metric_best = np.stack([base_model_metric_best, bn_model_metric_best])
legend = ['Pytorch Batchnorm', 'Batchnorm From Scratch']
arr_x = np.arange(1, metric_train.shape[1]+1)
save_comp(bn_vs_base_result, arr_x, metric_train, metric_val, legend)


In [None]:
# Effect of Batch Size

metric_train = np.stack([bn_model_bsize8_metric_train, bn_model_bsize128_metric_train, gn_model_bsize8_metric_train, gn_model_bsize128_metric_train])
metric_val = np.stack([bn_model_bsize8_metric_val, bn_model_bsize128_metric_val, gn_model_bsize8_metric_val, gn_model_bsize128_metric_val])
metric_best = np.stack([bn_model_bsize8_metric_best, bn_model_bsize128_metric_best, gn_model_bsize8_metric_best, gn_model_bsize128_metric_best])
legend = ['BN, Batch Size = 8', 'BN, Batch Size = 128', 'GN, Batch Size = 8', 'GN, Batch Size = 128']
arr_x = np.arange(1, metric_train.shape[1]+1)
save_comp(bn_vs_gn_bsize_result, arr_x, metric_train, metric_val, legend)