In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
class KSOM(nn.Module):
    def __init__(self, input_size, output_size, sigma=1.0, num_classes=10):
        super(KSOM, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.sigma = sigma
        self.weights = nn.Parameter(torch.randn(output_size, input_size) * 0.01)
        self.classifier = nn.Linear(output_size, num_classes)

    def forward(self, x):
        dist_to_winner = torch.cdist(x, self.weights)
        neighbors = torch.exp(-dist_to_winner / (2 * self.sigma ** 2))
        classification = self.classifier(neighbors)
        return classification


def classify(input_data, som):
    neighbors = som(input_data)
    _, indices = torch.max(neighbors, dim=1)
    return indices


def test_accuracy(test_loader, som):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.view(-1, 28 * 28)
            predicted_labels = classify(images, som)
            total += labels.size(0)
            correct += (predicted_labels == labels).sum().item()
    return correct / total


uniform_size = (224, 224)

transform = transforms.Compose([
    transforms.Resize(uniform_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


train_dataset = ImageFolder(root='/Users/sivaprasanth/Documents/DL/Ex4/stanford_cars/cars_train', transform=transform)
test_dataset = ImageFolder(root='/Users/sivaprasanth/Documents/DL/Ex4/stanford_cars/cars_test', transform=transform)
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
input_size = 28 * 28
output_size = 10
sigma = 1.0
input_size = 3 * 224 * 224
som = KSOM(input_size, output_size, sigma, num_classes=output_size)
epochs = 10
lr = 0.1
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(som.parameters(), lr=lr)
epoch_losses = []
for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.view(-1, input_size)
        optimizer.zero_grad()
        outputs = som(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    epoch_losses.append(epoch_loss)
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
plt.plot(range(1,epochs+1),epoch_losses)
plt.title("Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()