In [None]:
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.ops import DropBlock2d
from torchvision.ops import stochastic_depth
from torchvision.datasets import FashionMNIST


# from torchvision.transforms import v2

#==============================================================================================================================================================
# RESNET ARCHITECTURE CLASS
#==============================================================================================================================================================


# Define a BasicBlock for the ResNet-like architecture
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, stride1=2):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(out_channels)
            ) if (stride != 1 or in_channels != out_channels) and (out_channels == 128 or out_channels == 256 or out_channels ==512) else nn.Identity()


    def forward(self, x):
        residual = self.downsample(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += residual
        x = self.relu(x)
        return x


# Define the Deformable Convolution Block
class DeformConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DeformConvBlock, self).__init__()
        self.offset_conv = nn.Conv2d(in_channels, 18, kernel_size=3, padding=1)  # 18 channels for x, y offsets, and masks
        self.deform_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        offset = self.offset_conv(x)

        # Adjust the mask channels to be half of the offset channels
        mask = offset[:, :offset.size(1) // 2, ...]

        x = torchvision.ops.deform_conv2d(input=x,
                                          offset=offset,
                                          weight=self.deform_conv.weight,
                                          bias=None,  # Set bias to None if you don't want bias
                                          padding=self.deform_conv.padding,
                                          mask=mask,  # Use the adjusted mask
                                          stride=self.deform_conv.stride,
                                          dilation=self.deform_conv.dilation)
        x = self.bn(x)
        x = self.relu(x)
        return x

# Define a BasicBlock for the ResNet-like architecture
class StoDepthBasicBlock(BasicBlock):
    def __init__(self, in_channels, out_channels, stride=1, stride1=2, p=0.5, mode="batch"):
        super().__init__(in_channels, out_channels, stride, stride1)
        self.p = p
        self.mode = mode

    def forward(self, x):
        residual = self.downsample(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = stochastic_depth(input=x,p=self.p,mode=self.mode)
        x += residual
        x = self.relu(x)
        return x

# Define a BasicBlock with DropBlock the ResNet-like architecture
class DropBlockBasicBlock(BasicBlock):
    def __init__(self, in_channels, out_channels, stride=1, stride1=2, p=0.1, block_size=1):
        super().__init__(in_channels, out_channels, stride, stride1)
        self.dropblock = DropBlock2d(p, block_size)

    def forward(self, x):
        residual = self.downsample(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.dropblock(x)
        x += residual
        x = self.relu(x)
        return x

# Define the ResNet18 architecture
class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18, self).__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.layer4 = nn.Identity()
        self.model.fc = nn.Linear(256, num_classes)

    def make_layer(self, in_channels, out_channels, blocks, stride1, stride=1):
        layers = []
        layers.append(BasicBlock(in_channels, out_channels, stride, stride1))
        for _ in range(1, blocks):
            layers.append(BasicBlock(out_channels, out_channels, stride, stride1=1))
        return nn.Sequential(*layers)



    def forward(self, x):
        return self.model(x)

# Define the CustomResNet architecture that uses DeformConvBlock and Resnet18
class CustomResNet(nn.Module):
    def __init__(self, num_classes, layers_to_deform):
        super(CustomResNet, self).__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        if 1 in layers_to_deform:
            self.model.layer1 = self.make_layer_with_deform(64, 64, blocks=2, stride=1, stride1=1)
        if 2 in layers_to_deform:
            self.model.layer2 = self.make_layer_with_deform(64, 128, blocks=2, stride=1, stride1=2)
        if 3 in layers_to_deform:
            self.model.layer3 = self.make_layer_with_deform(128, 256, blocks=2, stride=1, stride1=2)
        self.model.layer4 = nn.Identity()
        self.model.fc = nn.Linear(256, num_classes)

    def make_layer(self, in_channels, out_channels, blocks, stride1, stride=1):
        layers = []
        layers.append(BasicBlock(in_channels, out_channels, stride, stride1))
        for _ in range(1, blocks):
            layers.append(BasicBlock(out_channels, out_channels, stride, stride1=1))
        return nn.Sequential(*layers)

    def make_layer_with_deform(self, in_channels, out_channels, blocks, stride, stride1):
        layers = []
        layers.append(BasicBlock(in_channels, out_channels, stride, stride1))
        layers.append(DeformConvBlock(out_channels, out_channels))  # Add a deformable convolution block
        for _ in range(1, blocks - 1):
            layers.append(BasicBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Define the CustomResNet architecture that uses Stochastic Depth
class CustomResNetStoDepth(CustomResNet):
    def __init__(self, num_classes, p, mode):
        super(CustomResNetStoDepth, self).__init__(num_classes, [])
        self.p = p
        self.mode = mode
        self.model = models.resnet18(pretrained=True)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.layer1 = self.make_layer_stodepth(64, 64, blocks=2, stride1=1)
        self.model.layer2 = self.make_layer_stodepth(64, 128, blocks=2, stride1=2)
        self.model.layer3 = self.make_layer_stodepth(128, 256, blocks=2, stride1=2)
        self.model.layer4 = self.make_layer_stodepth(256, 512, blocks=2, stride1=2)
        self.model.fc = nn.Linear(512, num_classes)

    def make_layer_stodepth(self, in_channels, out_channels, blocks, stride1, stride=1):
        layers = []
        layers.append(StoDepthBasicBlock(in_channels, out_channels, stride, stride1, p=self.p, mode=self.mode))
        for _ in range(1, blocks):
            layers.append(StoDepthBasicBlock(out_channels, out_channels, stride, stride1=1, p=self.p, mode=self.mode))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Define CustomResNet architecture that uses DropBlock
class CustomResNetDropBlock(CustomResNet):
    def __init__(self, num_classes, p, block_size):
        super(CustomResNetDropBlock, self).__init__(num_classes, [])
        self.p = p
        self.block_size = block_size
        self.model = models.resnet18(pretrained=True)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.layer1 = self.make_layer_dropblock(64, 64, blocks=2, stride1=1)
        self.model.layer2 = self.make_layer_dropblock(64, 128, blocks=2, stride1=2)
        self.model.layer3 = self.make_layer_dropblock(128, 256, blocks=2, stride1=2)
        self.model.layer4 = self.make_layer_dropblock(256, 512, blocks=2, stride1=2)
        self.model.fc = nn.Linear(512, num_classes)

    def make_layer_dropblock(self, in_channels, out_channels, blocks, stride1, stride=1):
        layers = []
        layers.append(DropBlockBasicBlock(in_channels, out_channels, stride, stride1, p=self.p, block_size=self.block_size))
        for _ in range(1, blocks):
            layers.append(DropBlockBasicBlock(out_channels, out_channels, stride, stride1=1, p=self.p, block_size=self.block_size))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)








#==============================================================================================================================================================
# DATALOADING & TRAINING FUNCTION
#==============================================================================================================================================================


# Return train,validation and test dataloader
def fashionmnist_dataloader(batch_size, basic_aug = True):
    # PyTorch FashionMNIST
    fashion_mnist = FashionMNIST(download=True, train=True, root=".").train_data.float()

    # Normal normalization
    transform = transforms.Compose([
        transforms.Resize((224, 224),antialias=True),
        transforms.ToTensor(),
        transforms.Normalize((fashion_mnist.mean()/255,), (fashion_mnist.std()/255,)),

    ])

    # if Basic Augmentation true/false for different dataset
    if basic_aug == True:
        # For basic technique augmentation
        basic_aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=(224, 224), antialias=True),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((fashion_mnist.mean()/255,), (fashion_mnist.std()/255,)),

            ])
        train_dataset = FashionMNIST(root='./data', train=True, transform=basic_aug_transform, download=True)
    else:
        train_dataset = FashionMNIST(root='./data', train=True, transform=transform, download=True)


    test_dataset = FashionMNIST(root='./data', train=False, transform=transform, download=True)


    # Define  Dataloader
    train_dataset, val_dataset = train_test_split(train_dataset, test_size=0.2, random_state=42)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader



# MixUp technique for Data Augmentation
def mixup(x, y, alpha):

    lam = torch.distributions.beta.Beta(alpha, alpha).sample()
    index = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_y = lam * y + (1 - lam) * y[index]

    mixed_y = mixed_y.long()

    return mixed_x, mixed_y




# CutMix technique for Data Augmentation
def cutmix(images, labels, alpha):
    batch_size, channels, height, width = images.shape

    # Initialize arrays to store mixed images and mixed labels
    mixed_images = torch.empty_like(images)
    mixed_labels = torch.empty_like(labels)

    for i in range(batch_size):
        # Randomly choose another image from the batch
        j = torch.randint(0, batch_size, (1,)).item()
        image1, label1 = images[i], labels[i]
        image2, label2 = images[j], labels[j]

        # Generate random lambda value from beta distribution
        lam = torch.distributions.beta.Beta(alpha, alpha).sample()

        # Compute the cutmix image
        max_cut_width = int(width * 1)
        cut_width = int(max_cut_width * (1 - lam))
        x1 = torch.randint(0, width - cut_width, (1,)).item()
        y1 = torch.randint(0, height, (1,)).item()
        x2 = x1 + cut_width
        y2 = y1 + cut_width

        mixed_image = image1.clone()
        mixed_image[:, y1:y2, x1:x2] = image2[:, y1:y2, x1:x2]

        count_label1 = torch.sum(mixed_image == image1).item()
        count_label2 = torch.sum(mixed_image == image2).item()

        if count_label1 >= count_label2:
            mixed_label = label1
        else:
            mixed_label = label2

        mixed_label = mixed_label.long()
        mixed_images[i] = mixed_image
        mixed_labels[i] = mixed_label

    return mixed_images, mixed_labels


#----------------------------------------------------------------------------------------------------------------------------------


# Normal train function
def train(model, criterion, train_loader, optimizer, epoch, device):
    model.train()

    #for loss and accuracy tracking
    total_loss = 0.0
    total_correct = 0
    num_data = 0

    for data, target in train_loader:
        data=data.to(device)
        target=target.to(device)

        model.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = torch.max(output, 1)[1]
        correct = (pred == target).sum().item()
        total_correct += correct
        num_data += len(data)

    #calculate loss and accuracy for one epoch
    average_loss = total_loss / len(train_loader)
    accuracy = 100. * total_correct / num_data

    print(f'Train Epoch: {epoch}\tAverage Loss: {average_loss:.6f}\tAccuracy: {accuracy:.2f}%')

    return accuracy, average_loss


# MixUp train function
def train_mixup(model, criterion, train_loader, optimizer, epoch, alpha, device):
    model.train()

    #for loss and accuracy tracking
    total_loss = 0.0
    total_correct = 0
    num_data = 0

    for data, target in train_loader:
        data, target = mixup(data, target, alpha)
        data=data.to(device)
        target=target.to(device)
        model.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = torch.max(output, 1)[1]
        correct = (pred == target).sum().item()
        total_correct += correct
        num_data += len(data)

    #calculate loss and accuracy for one epoch
    average_loss = total_loss / len(train_loader)
    accuracy = 100. * total_correct / num_data

    print(f'Train Epoch: {epoch}\tAverage Loss: {average_loss:.6f}\tAccuracy: {accuracy:.2f}%')

    return accuracy, average_loss


# CutMix train function
def train_cutmix(model, criterion, train_loader, optimizer, epoch,alpha, device):
    model.train()

    #for loss and accuracy tracking
    total_loss = 0.0
    total_correct = 0
    num_data = 0

    for data, target in train_loader:
        data, target = cutmix(data, target, alpha)
        data=data.to(device)
        target=target.to(device)
        model.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = torch.max(output, 1)[1]
        correct = (pred == target).sum().item()
        total_correct += correct
        num_data += len(data)

    #calculate loss and accuracy for one epoch
    average_loss = total_loss / len(train_loader)
    accuracy = 100. * total_correct / num_data

    print(f'Train Epoch: {epoch}\tAverage Loss: {average_loss:.6f}\tAccuracy: {accuracy:.2f}%')

    return accuracy, average_loss


# Validation Function
def val(model, criterion, val_loader, device):
    model.eval()

    #for loss and accuracy tracking
    val_loss = 0
    correct = 0
    total_correct= 0
    num_data = 0

    with torch.no_grad():
        for data, target in val_loader:
            data= data.to(device)
            target= target.to(device)

            output = model(data)
            loss=criterion(output,target)
            val_loss += loss.item()
            pred = torch.max(output, 1)[1]
            correct = (pred == target).sum().item()
            total_correct += correct
            num_data += len(data)

    average_loss = val_loss / len(val_loader)
    test_acc = 100. * total_correct / num_data
    print(f'Val Loss: {average_loss:.4f}, Val Accuracy: {test_acc:.2f}%')

    return test_acc, average_loss

def criterion():
    """
    Create and return a CrossEntropyLoss criterion.
    """
    criterion = nn.CrossEntropyLoss()
    return criterion