In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,random_split
import matplotlib.pyplot as plt


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


training_data = datasets.FashionMNIST(root='/home/manchik-pt7714/Documents/ML Tasks/data/temp/',train = True,transform=transform,download=False)

train_size = int(0.8*len(training_data))
val_size = len(training_data)-train_size

train_data,val_data=random_split(training_data,[train_size,val_size])

train_loader = DataLoader(train_data,batch_size=128,shuffle=True)
val_loader = DataLoader(val_data,batch_size=128,shuffle=False)

test_data = datasets.FashionMNIST(root='/home/manchik-pt7714/Documents/ML Tasks/data/temp/', train=False, transform=transform, download=True)

test_loader = DataLoader(test_data, batch_size=64, shuffle=False)


class SimpleNN(nn.Module):
    def __init__(self, norm_type="batch"):
        super(SimpleNN, self).__init__()
        self.norm_type = norm_type

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        
        if norm_type == "batch":
            self.norm1 = nn.BatchNorm2d(32)
            self.norm2 = nn.BatchNorm2d(64)
        elif norm_type == "layer":
            self.norm1 = nn.LayerNorm([32, 26, 26])
            self.norm2 = nn.LayerNorm([64, 24, 24])
        elif norm_type == "group":
            self.norm1 = nn.GroupNorm(4, 32)  
            self.norm2 = nn.GroupNorm(8, 64)  
        
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.norm1(self.conv1(x)))
        x = torch.relu(self.norm2(self.conv2(x)))
        x = torch.flatten(x, 1) 
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


def train_and_evaluate(norm_type):
    model = SimpleNN(norm_type)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    train_losses = []
    val_accuracies = []

    for epoch in range(10):
        model.train()
        epoch_loss = 0

        for images, labels in train_loader:
            images, labels = images, labels
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        train_losses.append(epoch_loss / len(train_loader))


        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images, labels
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracies.append(correct / total)

        print(f"Epoch {epoch+1}/{10} - Loss: {train_losses[-1]:.4f}, Validation Accuracy: {val_accuracies[-1]:.4f}")
    
    return train_losses, val_accuracies

normalization_methods = ["batch", "layer", "group"]
results = {}

for norm in normalization_methods:
    print(f"\nTraining with {norm} normalization...")
    train_losses, val_accuracies = train_and_evaluate(norm)
    results[norm] = (train_losses, val_accuracies)


plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
for norm, (train_losses, _) in results.items():
    plt.plot(range(1, 11), train_losses, label=f"{norm.capitalize()} Norm")
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
for norm, (_, val_accuracies) in results.items():
    plt.plot(range(1, 11), val_accuracies, label=f"{norm.capitalize()} Norm")
plt.title("Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.tight_layout()
plt.show()



Training with batch normalization...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x36864 and 2304x128)