In [1]:
! pip install -q  torch  torchvision

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [3]:
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.layer_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.layer_stack(x)
        return logits


In [4]:
# Transformations applied on each image
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Loading MNIST dataset from torchvision
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# # Data Loaders
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

from torch.utils.data import Subset

# Sort the dataset by labels
def sort_dataset_by_labels(dataset):
    indices = list(range(len(dataset)))
    indices.sort(key=lambda x: dataset.targets[x])
    return Subset(dataset, indices)

sorted_train_dataset = sort_dataset_by_labels(train_dataset)
sorted_test_dataset = sort_dataset_by_labels(test_dataset)

train_loader = DataLoader(sorted_train_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(sorted_test_dataset, batch_size=64, shuffle=False)

In [5]:
labels_list = []

# Iterate through the DataLoader
for images, labels in train_loader:
    labels_list.extend(labels.tolist())  # Convert labels to a Python list and add to the labels_list
    if len(labels_list) >= 100:  # Check if we have collected 100 labels
        break

# Print the first 100 labels
print("Labels for the first 100 items:", labels_list[:100])



Labels for the first 100 items: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [14]:
# number of parameters in the classic vs bio model
{'classic': 784*512 + 512*512 + 512*10, 'bio': (768*2000 + 2000*2000) / 10}

{'classic': 668672, 'bio': 553600.0}

In [15]:
{'classic': 784*512 + 512*512 + 512*10, 'bio': 0.5 * (768*2000 + 2000*2000) / 10}

{'classic': 668672, 'bio': 276800.0}

In [6]:
# Model, Loss, and Optimizer
model = SimpleMLP()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


SimpleMLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layer_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [10]:
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{len(dataloader.dataset):>5d}]")

# Training the model
epochs = 5
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_loader, model, loss_fn, optimizer)
print("Training done!")


Epoch 1
-------------------------------
loss: 5.933360  [    0/60000]
loss: 5.528784  [  640/60000]
loss: 4.458291  [ 1280/60000]
loss: 2.067710  [ 1920/60000]
loss: 0.340005  [ 2560/60000]
loss: 0.070007  [ 3200/60000]
loss: 0.192108  [ 3840/60000]
loss: 0.000535  [ 4480/60000]
loss: 0.003097  [ 5120/60000]
loss: 0.000879  [ 5760/60000]
loss: 2.957772  [ 6400/60000]
loss: 0.981879  [ 7040/60000]
loss: 0.066049  [ 7680/60000]
loss: 0.002220  [ 8320/60000]
loss: 0.000964  [ 8960/60000]
loss: 0.001016  [ 9600/60000]
loss: 0.000382  [10240/60000]
loss: 0.000440  [10880/60000]
loss: 0.000999  [11520/60000]
loss: 0.000275  [12160/60000]
loss: 7.257589  [12800/60000]
loss: 1.045243  [13440/60000]
loss: 0.005094  [14080/60000]
loss: 0.000268  [14720/60000]
loss: 0.000016  [15360/60000]
loss: 0.000016  [16000/60000]
loss: 0.000030  [16640/60000]
loss: 0.000024  [17280/60000]
loss: 0.000003  [17920/60000]
loss: 0.074737  [18560/60000]
loss: 3.331630  [19200/60000]
loss: 2.298141  [19840/60000]


In [8]:
def validate(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    loss /= num_batches
    correct /= size
    accuracy = 100 * correct
    print(f"Test Error: \n Accuracy: {(accuracy):>0.1f}%, Avg loss: {loss:>8f} \n")
    return accuracy


In [9]:
validate(test_loader, model, loss_fn)


Test Error: 
 Accuracy: 10.1%, Avg loss: 5.157917 



10.11