## Genetic Pruning with Combined Fitness (Inspired by Reinikainen 2024)


The Url : https://trepo.tuni.fi/bitstream/handle/10024/158561/ReinikainenSamuli.pdf?sequence=2

In [17]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import random, copy

In [18]:
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_data  = datasets.MNIST("./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=256, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=256, shuffle=False)

In [19]:
# Model Definition 
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

model = MLP()

In [20]:
# Train baseline 
def train_model(m, epochs=1):
    m.train()
    opt = optim.Adam(m.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(epochs):
        for X,y in train_loader:
            opt.zero_grad()
            loss_fn(m(X), y).backward()
            opt.step()

train_model(model, 1)
baseline_acc = sum((model(X).argmax(1)==y).sum().item() for X,y in test_loader)/len(test_data)
print("Baseline accuracy:", baseline_acc)

Baseline accuracy: 0.9402


In [21]:
# GPU mask encoding & unstructured pruning 
def apply_mask(mdl, mask):
    pruned = copy.deepcopy(mdl)
    idx=0
    for p in pruned.parameters():
        num = p.numel()
        flat = p.data.view(-1).cpu().numpy()
        mask_slice = mask[idx:idx+num]
        flat[mask_slice == 0] = 0
        p.data = torch.from_numpy(flat.reshape(p.shape)).float()
        idx += num
    return pruned

def random_mask(model, sparsity):
    total = sum(p.numel() for p in model.parameters())
    keep = int(total * (1-sparsity))
    bits = np.zeros(total, dtype=int)
    bits[:keep] = 1
    np.random.shuffle(bits)
    return bits

In [22]:
# Fitness evaluation with combined score
def evaluate(mdl):
    mdl.eval()
    correct=0
    with torch.no_grad():
        for X,y in test_loader:
            correct += (mdl(X).argmax(1)==y).sum().item()
    return correct/len(test_data)

def fitness(mask):
    pruned = apply_mask(model, mask)
    acc = evaluate(pruned)
    sparsity = 1 - np.mean(mask)
    return acc - 0.5*sparsity  # balance accuracy vs sparsity

In [23]:
# Genetic algorithm
POP=20
GEN=10
SPARSITY_TARGET=0.6

pop = [random_mask(model, SPARSITY_TARGET) for _ in range(POP)]

for gen in range(GEN):
    scores = [fitness(ind) for ind in pop]
    best_idx = int(np.argmax(scores))
    top_half = [pop[i] for i in np.argsort(scores)[-POP//2:]]
    newpop = top_half.copy()

    while len(newpop)<POP:
        p1,p2 = random.sample(top_half,2)
        point = random.randint(0,len(p1)-1)
        child = np.concatenate([p1[:point], p2[point:]])
        # mutate
        if random.random()<0.1:
            r = random.randint(0,len(child)-1)
            child[r] = 1-child[r]
        newpop.append(child)
    pop = newpop

best = pop[np.argmax([fitness(ind) for ind in pop])]
best_pruned = apply_mask(model, best)
best_acc = evaluate(best_pruned)
print("Best pruned accuracy:", best_acc)

Best pruned accuracy: 0.9058
