<a href="https://colab.research.google.com/github/shivani-202/CS-Deep-Learning-Assignment/blob/main/activation_func.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Examine the performance of various activation functions for MNIST, Fashion-MNIST, CIFAR-10, and CIFAR-100.
For training,  use the architectures ResNet, LeNet, MobileNet, AlexNet, with different learnable depths.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score
import torchvision.models as models
import matplotlib.pyplot as plt

# Activation function mapping
def activation_func(name, param=0.01):
    activation = {
        "sigmoid": nn.Sigmoid(),
        "bipolar_sigmoid": lambda x: 2 * torch.sigmoid(x) - 1,
        "tanh": nn.Tanh(),
        "relu": nn.ReLU(),
        "leaky_relu": nn.LeakyReLU(param),
        "param_relu": nn.PReLU(),
        "elu": nn.ELU(param),
        "softmax": nn.Softmax(dim=1),
        "gelu": nn.GELU(),
        "selu": nn.SELU(),
        "mish": nn.Mish(),
        "softplus": nn.Softplus(),
        "swish": nn.SiLU(),
        "e_swish": lambda x: 1.5 * x * torch.sigmoid(x),
        "telu": nn.CELU(param),
    }
    return activation[name]


def get_resnet_model(activation, num_classes=10, input_channels=3):
    model = models.resnet18(pretrained=True)
    model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.layer1[0].relu = activation_func(activation)
    model.layer2[0].relu = activation_func(activation)
    model.layer3[0].relu = activation_func(activation)
    model.layer4[0].relu = activation_func(activation)
    return model

def get_alexnet_model(activation, num_classes=10, input_channels=3):
    model = models.alexnet(pretrained=True)
    model.features[0] = nn.Conv2d(input_channels, 64, kernel_size=11, stride=4, padding=2)
    model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    model.classifier[4] = activation_func(activation)
    return model

def get_mobilenet_model(activation, num_classes=10, input_channels=3):
    model = models.mobilenet_v2(pretrained=True)
    model.features[0][0] = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1, bias=False)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    model.features[0][0].relu = activation_func(activation)
    return model

# Get dataloaders for datasets
def get_dataloader(dataset_name, batch_size=64):
    if dataset_name == "mnist":
        transform = transforms.Compose([transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "cifar10":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion-mnist":
        transform = transforms.Compose([transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "cifar100":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    return trainloader, testloader

# Training function
def train(model, trainloader, testloader, criterion, optimizer, writer, num_epochs=10, dataset_name="Dataset"):
    model.train()
    all_loss, all_acc = [], []

    for epoch in range(num_epochs):
        all_preds, all_labels = [], []
        epoch_loss = 0
        for i, (images, labels) in enumerate(trainloader):
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        epoch_loss /= len(trainloader)
        acc = accuracy_score(all_labels, all_preds)

        all_loss.append(epoch_loss)
        all_acc.append(acc)

        writer.add_scalar(f"Loss/{dataset_name}", epoch_loss, epoch)
        writer.add_scalar(f"Accuracy/{dataset_name}", acc, epoch)

        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f} | Accuracy: {acc:.4f}")

    return all_loss, all_acc


def plot_performance(activation_names, all_loss, all_acc):
    plt.figure(figsize=(12, 6))

    # Plot loss for all activation functions
    plt.subplot(1, 2, 1)
    for i, loss in enumerate(all_loss):
        plt.plot(loss, label=activation_names[i])
    plt.title('Loss for different Activation Functions')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracy for all activation functions
    plt.subplot(1, 2, 2)
    for i, acc in enumerate(all_acc):
        plt.plot(acc, label=activation_names[i])
    plt.title('Accuracy for different Activation Functions')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

# Main script
if __name__ == "__main__":
    writer = SummaryWriter("runs/activation_experiment")
    datasets = ["mnist", "fashion-mnist", "cifar10", "cifar100"]
    activation_functions = ["sigmoid", "relu", "tanh", "leaky_relu", "elu", "softmax", "gelu", "selu", "mish", "swish", "bipolar_sigmoid", "e_swish", "param_relu", "telu", "softplus"]

    models_to_train = [get_resnet_model, get_alexnet_model, get_mobilenet_model]

    all_loss = []
    all_acc = []
    activation_names = []

    for dataset in datasets:
        for activation in activation_functions:
            for model_fn in models_to_train:
                print(f"\nTraining with {activation.upper()} on {dataset.upper()} using {model_fn.__name__} model...")
                trainloader, testloader = get_dataloader(dataset)

                input_channels = 1 if dataset in ["mnist", "fashion-mnist"] else 3
                model = model_fn(activation, num_classes=10 if dataset != "cifar100" else 100, input_channels=input_channels)

                criterion = nn.CrossEntropyLoss()
                optimizer = optim.Adam(model.parameters(), lr=0.001)

                loss, acc = train(model, trainloader, testloader, criterion, optimizer, writer, num_epochs=10, dataset_name=dataset)
                all_loss.append(loss)
                all_acc.append(acc)
                activation_names.append(f"{activation}_{model_fn.__name__}")

    writer.close()
    plot_performance(activation_names, all_loss, all_acc)



Training with SIGMOID on MNIST using get_resnet_model model...
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 10.6MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 349kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.20MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.15MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 109MB/s]
