# **Lottery Ticket Hypothesis**

Within this notebook, I would like to demonstrate how I can catch the winning lottery in order to find subnetwork inside a large network.

## **1. Preparation**

Load packages needed and a dataset

In [2]:
# Load packages
import torch
import torch.nn as nn
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [16]:
# Load a dataset (MNIST dataset)
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

testing_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# Save training and testing data
torch.save(training_data, "./data/mnist_training_data.pt")
torch.save(testing_data, "./data/mnist_testing_data.pt")

100%|██████████| 26.4M/26.4M [01:21<00:00, 324kB/s] 
100%|██████████| 29.5k/29.5k [00:00<00:00, 169kB/s]
100%|██████████| 4.42M/4.42M [00:12<00:00, 350kB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 4.27MB/s]


In [3]:
# Load the training and testing dataset
training_data = torch.load("./data/mnist_training_data.pt", weights_only=False)
testing_data = torch.load("./data/mnist_testing_data.pt", weights_only=False)

In [4]:
# Wrap training and testing data into data loader
training_dataloader = DataLoader(training_data, batch_size=128, shuffle=True)
testing_dataloader = DataLoader(testing_data, batch_size=128, shuffle=True)

## **MLP Model**

In [5]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(784, 256)
        self.linear_2 = nn.Linear(256, 128)
        self.linear_3 = nn.Linear(128, 32)
        self.linear_4 = nn.Linear(32, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        y = self.linear_1(x)
        y = self.relu(y)
        y = self.linear_2(y)
        y = self.relu(y)
        y = self.linear_3(y)
        y = self.relu(y)
        output = self.linear_4(y)

        return output

In [6]:
model = MLP()
model

MLP(
  (linear_1): Linear(in_features=784, out_features=256, bias=True)
  (linear_2): Linear(in_features=256, out_features=128, bias=True)
  (linear_3): Linear(in_features=128, out_features=32, bias=True)
  (linear_4): Linear(in_features=32, out_features=10, bias=True)
  (relu): ReLU()
)

In [7]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(lr=0.00001, params=model.parameters())
loss_fn = nn.CrossEntropyLoss()

In [8]:
def train(model,
          training_data_loader,
          testing_data_loader,
          optimizer,
          loss_fn,
          epoch):
    for ep in range(epoch):
        # Training
        model.train()
        # Batch processing
        for x_batch, y_batch in training_data_loader:
            optimizer.zero_grad()
            x_batch = x_batch.view(x_batch.size(0), -1)
            prediction = model.forward(x_batch)
            loss = loss_fn(prediction, y_batch)
            loss.backward()
            optimizer.step()

        # Evaluation
        model.eval()
        with torch.no_grad():
            for x_test_batch, y_test_batch in testing_data_loader:
                x_test_batch = x_test_batch.view(x_test_batch.size(0), -1)
                eval_output = model.forward(x_test_batch)
            
            # Calculate metrics
            loss_test = loss_fn(eval_output, y_test_batch)
            eval_output_labeled = torch.argmax(eval_output, dim=1)
            accuracy = (eval_output_labeled == y_test_batch).sum().item() / len(eval_output_labeled)

        print(f"Epoch {ep+1} ==> Loss = {loss_test} | Accuracy = {round(accuracy, 2)}")


In [9]:
train(
    model,
    training_dataloader,
    testing_dataloader,
    optimizer,
    loss_fn,
    20
)

Epoch 1 ==> Loss = 1.9174922704696655 | Accuracy = 0.44
Epoch 2 ==> Loss = 1.6862765550613403 | Accuracy = 0.38
Epoch 3 ==> Loss = 1.3466004133224487 | Accuracy = 0.5
Epoch 4 ==> Loss = 0.9653546810150146 | Accuracy = 0.62
Epoch 5 ==> Loss = 0.9416233897209167 | Accuracy = 0.81


KeyboardInterrupt: 