In [6]:
# ===============================================================
#              OPTION A - GENETIC PRUNING NOTEBOOK
#        Inspired by Reinikainen (2024) - Transformer Pruning
# ===============================================================

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import random, copy, pandas as pd

# ===============================================================
# 1. DATA LOADING (MNIST)
# ===============================================================

transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_data  = datasets.MNIST("./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=256, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=256, shuffle=False)

# ===============================================================
# 2. MODEL DEFINITION
# Small MLP used as a proxy to reproduce LLM pruning behaviour
# ===============================================================

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

model = MLP()

# ===============================================================
# 3. TRAINING FUNCTION (returns training time)
# ===============================================================

def train_model(m, epochs=1):
    m.train()
    opt = optim.Adam(m.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    start = time.time()
    
    for _ in range(epochs):
        for X, y in train_loader:
            opt.zero_grad()
            loss = loss_fn(m(X), y)
            loss.backward()
            opt.step()

    end = time.time()
    return end - start

# ---------------------------------------------------------------
# Train baseline model
# ---------------------------------------------------------------

baseline_train_time = train_model(model, 1)

# ===============================================================
# 4. BASELINE EVALUATION
# ===============================================================

def evaluate(mdl):
    mdl.eval()
    correct = 0
    with torch.no_grad():
        for X, y in test_loader:
            correct += (mdl(X).argmax(1) == y).sum().item()
    return correct / len(test_data)

baseline_acc = evaluate(model)
print("Baseline accuracy:", baseline_acc)
print("Baseline training time (s):", baseline_train_time)

# ===============================================================
# 5. MODEL SIZE FUNCTIONS
# ===============================================================

def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    nonzero = sum((p != 0).sum().item() for p in model.parameters())
    return total, nonzero

def model_size_mb(model):
    total_bytes = 0
    for p in model.parameters():
        total_bytes += p.nelement() * p.element_size()
    return total_bytes / (1024*1024)

# Baseline model size
baseline_total, baseline_nonzero = count_parameters(model)
baseline_size_mb = model_size_mb(model)

print("Baseline total params:", baseline_total)
print("Baseline non-zero params:", baseline_nonzero)
print("Baseline size (MB):", baseline_size_mb)

# ===============================================================
# 6. MASK ENCODING FOR PRUNING (Binary masks)
# ===============================================================

def apply_mask(mdl, mask):
    pruned = copy.deepcopy(mdl)
    idx = 0
    for p in pruned.parameters():
        num = p.numel()
        flat = p.data.view(-1).cpu().numpy()
        mask_slice = mask[idx:idx+num]
        flat[mask_slice == 0] = 0
        p.data = torch.from_numpy(flat.reshape(p.shape)).float()
        idx += num
    return pruned

def random_mask(model, sparsity):
    total = sum(p.numel() for p in model.parameters())
    keep = int(total * (1 - sparsity))
    bits = np.zeros(total, dtype=int)
    bits[:keep] = 1
    np.random.shuffle(bits)
    return bits

# ===============================================================
# 7. FITNESS FUNCTION (Accuracy - Sparsity Penalty)
# ===============================================================

def fitness(mask):
    pruned = apply_mask(model, mask)
    acc = evaluate(pruned)
    sparsity = 1 - np.mean(mask)
    return acc - 2.0 * abs(sparsity - SPARSITY_TARGET)   # weighted combination (Reinikainen-inspired)

# ===============================================================
# 8. GENETIC ALGORITHM
# ===============================================================

POP = 40  
GEN = 15
SPARSITY_TARGET = 0.60  # 60% pruning

# Initial population
population = [random_mask(model, SPARSITY_TARGET) for _ in range(POP)]

for gen in range(GEN):
    scores = [fitness(ind) for ind in population]
    best_idx = int(np.argmax(scores))
    print(f"Generation {gen+1}/{GEN} | Best fitness: {max(scores):.4f}")

    # Select top half
    top_half = [population[i] for i in np.argsort(scores)[-POP//2:]]

    # Generate offspring
    new_population = top_half.copy()

    while len(new_population) < POP:
        p1, p2 = random.sample(top_half, 2)
        point = random.randint(0, len(p1)-1)
        child = np.concatenate([p1[:point], p2[point:]])

        # Mutation
        if random.random() < 0.1:
            r = random.randint(0, len(child)-1)
            child[r] = 1 - child[r]

        new_population.append(child)

    population = new_population

# ===============================================================
# 9. BEST PRUNED MODEL EVALUATION
# ===============================================================

best_mask = population[np.argmax([fitness(ind) for ind in population])]
best_pruned_model = apply_mask(model, best_mask)
best_acc = evaluate(best_pruned_model)

print("Best pruned accuracy:", best_acc)

# ---- Fine-tuning pruned model ----
pruned_train_time = train_model(best_pruned_model, 1)

print("Pruned fine-tuning time (s):", pruned_train_time)

# ===============================================================
# 10. PRUNED MODEL SIZE
# ===============================================================

pruned_total, pruned_nonzero = count_parameters(best_pruned_model)
pruned_size_mb = model_size_mb(best_pruned_model)
sparsity_final = 1 - (pruned_nonzero / pruned_total)

print("Pruned total params:", pruned_total)
print("Pruned non-zero params:", pruned_nonzero)
print("Pruned size (MB):", pruned_size_mb)
print("Final sparsity:", sparsity_final)

# ===============================================================
# 11. SUMMARY TABLE
# ===============================================================

summary = pd.DataFrame({
    "Model": ["Baseline", "GA Pruned"],
    "Accuracy": [baseline_acc, best_acc],
    "Total Params": [baseline_total, pruned_total],
    "Non-zero Params": [baseline_nonzero, pruned_nonzero],
    "Size (MB)": [baseline_size_mb, pruned_size_mb],
    "Train Time (s)": [baseline_train_time, pruned_train_time],
    "Sparsity": [0.0, sparsity_final]
})

summary

Baseline accuracy: 0.9396
Baseline training time (s): 2.067254066467285
Baseline total params: 235146
Baseline non-zero params: 235146
Baseline size (MB): 0.8970108032226562
Generation 1/15 | Best fitness: 0.8501
Generation 2/15 | Best fitness: 0.8524
Generation 3/15 | Best fitness: 0.8524
Generation 4/15 | Best fitness: 0.8791
Generation 5/15 | Best fitness: 0.8791
Generation 6/15 | Best fitness: 0.8791
Generation 7/15 | Best fitness: 0.8856
Generation 8/15 | Best fitness: 0.8856
Generation 9/15 | Best fitness: 0.8880
Generation 10/15 | Best fitness: 0.8897
Generation 11/15 | Best fitness: 0.8920
Generation 12/15 | Best fitness: 0.8976
Generation 13/15 | Best fitness: 0.8995
Generation 14/15 | Best fitness: 0.8995
Generation 15/15 | Best fitness: 0.8995
Best pruned accuracy: 0.9011
Pruned fine-tuning time (s): 2.0492663383483887
Pruned total params: 235146
Pruned non-zero params: 220248
Pruned size (MB): 0.8970108032226562
Final sparsity: 0.06335638284299971


Unnamed: 0,Model,Accuracy,Total Params,Non-zero Params,Size (MB),Train Time (s),Sparsity
0,Baseline,0.9396,235146,235146,0.897011,2.067254,0.0
1,GA Pruned,0.9011,235146,220248,0.897011,2.049266,0.063356
