In [None]:
pip install torch torchvision matplotlib

CIFAR-10 CNN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



#defining tangma
class Tangma(nn.Module):
    def __init__(self):
        super(Tangma, self).__init__()
        self.alpha = nn.Parameter(torch.tensor(0.0))
        self.gamma = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        return x * torch.tanh(x + self.alpha) + self.gamma * x



#defining swish
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)



class CIFAR_CNN(nn.module):
    #constructor
    def __init__(self, activation):
        super(CNNModel, self).__init__()

        self.activation = activation

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) #rbg 32x32 with 3x3 kernel 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)


        
        self.dropout = nn.Dropout(0.5)

        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)




    def forward(self, x):
        x = self.pool(self.activation(self.conv1(x)))
        x = self.pool(self.activation(self.conv2(x)))
        x = self.pool(self.activation(self.conv3(x)))
        x = torch.flatten(x, 1)
        x = self.dropout(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x



# preprocessing
transform = transforms.Compose([
    transforms.ToTensor(), #convert from [0, 255] to [0.0, 1.0]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #using normalized = (x−μ)/σ  = ([0.0, 1.0] - 0.5)/0.5 = [-1.0, 1.0]

])


#import dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

#train both datasets with batch size = 128
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=128, shuffle=False)

# training 
def train_model(model, name, epochs=10):
    model.to(device)

    #Adam optimizer + crossentropyloss
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()



    train_losses, val_losses, val_accuracies, epoch_times = [], [], [], []


    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        start_time = time.time()

        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad() #reset gradients 
            outputs = model(inputs)
            loss = criterion(outputs, labels) #prediction vs. target --> get loss

            #backpropagation and send it to the optimizer 
            loss.backward()
            optimizer.step()

            #loss 
            running_loss += loss.item()

        avg_train_loss = running_loss / len(trainloader)
        train_losses.append(avg_train_loss)

    #evaluating the validation set
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item() #get loss
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item() #accuracy (all that are labeled right)
                total += labels.size(0) #all labels


        #calculating metrics
        avg_val_loss = val_loss / len(testloader) #avg per epoch
        val_acc = 100 * correct / total #validation set accuracy


        val_losses.append(avg_val_loss)
        val_accuracies.append(val_acc)
        epoch_times.append(time.time() - start_time)

        print(f"[{name}] Epoch {epoch+1}/10 - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%, Time: {epoch_times[-1]:.2f}s")

    return train_losses, val_losses, val_accuracies, epoch_times

# activations
activations = {
    "Tangma": Tangma(),
    "Swish": Swish(),
    "GELU": nn.GELU(),
    "ReLU": nn.ReLU()
}



# run
results = {}
for name, act in activations.items():
    print(f"\nTraining with {name} activation")
    model = CNNModel(act)
    results[name] = train_model(model, name)



# plot metrics
for i, metric in enumerate(["Train Loss", "Val Loss", "Val Accuracy", "Time"]):
    plt.figure(figsize=(8, 5))
    for name in results:
        plt.plot(results[name][i], label=name)
    plt.title(f"{metric} per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel(metric)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"cifar10_{metric.lower().replace(' ', '_')}.png")
    plt.show()