In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader
from torchvision import datasets
from resnet import ResNet20, ResNet, Bottleneck
from datetime import datetime
from tqdm.notebook import tqdm

In [None]:
hyperparameters = {
    'epochs': 100,
    'lr': 0.1,
    'lr_min': 1e-6,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'batch_size': 128,
    'sparsity_type': "sandwich",
    'dataset': 'cifar100',
    'model_type': 'rn50',
    'lr_decay': "cosine",
    'T_max': 100,
}

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
data_type = hyperparameters['dataset']
data_path = "./"
print(f'Data type: {data_type}')


transform_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.201]),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.201]),
])


if data_type == "cifar10":
    train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_val)
elif data_type == 'cifar100':
    train_dataset = datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR100(root=data_path, train=False, download=True, transform=transform_val)


train_loader = DataLoader(train_dataset, batch_size=hyperparameters['batch_size'], shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=hyperparameters['batch_size'], shuffle=False, num_workers=2)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Pruner:
    def __init__(self, model, N=10, M=100):
        self.model = model
        self.N = N
        self.M = M

    def apply_surprisal_sparsity(self):
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d) and "conv2" in name:
                weight = module.weight.data
                N_filters, C, H, W = weight.shape
                weight_flat = weight.view(N_filters, -1)

                for i in range(N_filters):
                    filter_w = weight_flat[i]
                    original_len = filter_w.numel()

                    pad_len = (self.M - original_len % self.M) % self.M
                    if pad_len > 0:
                        filter_w = F.pad(filter_w, (0, pad_len), mode='constant', value=0.0)

                    grouped = filter_w.view(-1, self.M)
                    abs_group = grouped.abs()
                    group_sum = abs_group.sum(dim=1, keepdim=True) + 1e-8
                    probs = abs_group / group_sum

                    entropy = -probs * torch.log(probs + 1e-10)
                    entropy_score = entropy

                    topk = self.N if self.prune_high_entropy else (self.M - self.N)

                    topk_indices = torch.topk(entropy_score, self.N, dim=1).indices

                    mask = torch.zeros_like(grouped)
                    mask.scatter_(1, topk_indices, 1.0 if self.prune_high_entropy else 0.0)
                    grouped *= mask

                    pruned_flat = grouped.view(-1)

                    if pad_len > 0:
                        pruned_flat = pruned_flat[:-pad_len]

                    weight_flat[i] = pruned_flat

                module.weight.data = weight_flat.view(N_filters, C, H, W)

    def apply_nm_sparsity(self):
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d) and ("conv1" in name):
                weight = module.weight.data
                orig_shape = weight.shape
                flattened = weight.view(-1)

                pad_len = (4 - flattened.numel() % 4) % 4
                if pad_len > 0:
                    flattened = F.pad(flattened, (0, pad_len))

                grouped = flattened.view(-1, 4)
                abs_vals = grouped.abs()

                topk_vals, topk_idx = torch.topk(abs_vals, k=2, dim=1)
                mask = torch.zeros_like(grouped)
                mask.scatter_(1, topk_idx, 1.0)

                sparse_grouped = grouped * mask
                sparse_flat = sparse_grouped.view(-1)

                if pad_len > 0:
                    sparse_flat = sparse_flat[:-pad_len]

                module.weight.data = sparse_flat.view(orig_shape)


    def print_sparsity(self):
        tot_params, tot_zeros = 0, 0
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                weight = module.weight.data
                n_params = weight.numel()
                n_zeros = torch.sum(weight == 0).item()
                tot_params += n_params
                tot_zeros += n_zeros
                print(f"{name}: Total Params = {n_params}. Zero Params = {n_zeros}. Sparsity = {n_zeros / n_params:.2%}")


In [None]:
def train(model, train_loader, criterion, optimizer, epoch, log_file):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch_idx, (inputs, targets) in enumerate(pbar):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()
        total += targets.size(0)

        pbar.set_postfix(loss=running_loss/(batch_idx+1), accuracy=100.0 * correct / total)

    avg_loss = running_loss / len(train_loader)
    accuracy = 100.0 * correct / total
    log_file.write(f'Epoch [{epoch+1}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
    sys.stdout.flush()

    return avg_loss, accuracy


In [None]:
def test(model, test_loader, criterion, log_file):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0

    pbar = tqdm(test_loader, desc="Testing")

    with torch.no_grad():
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()

            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

            pbar.set_postfix(loss=test_loss/(total + inputs.size(0)), accuracy=100.0 * correct / total)

    avg_test_loss = test_loss / len(test_loader)
    accuracy = 100.0 * correct / total
    log_file.write(f'Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
    sys.stdout.flush()

    return avg_test_loss, accuracy

In [None]:
model_type = hyperparameters['model_type']
classes = 100 if data_type == 'cifar100' else 10

if model_type == 'rn20':
    resnet_model = ResNet20(classes)
    resnet_model.to(device)
elif model_type == 'rn50':
    resnet_model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=classes)
    resnet_model.to(device)

resnet_model.load_state_dict(torch.load("./base.pth"))

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet_model.parameters(), lr=hyperparameters['lr'],
                      momentum=hyperparameters['momentum'], weight_decay=hyperparameters['weight_decay'])

In [None]:
current_learning_rate = 0.1

decay_type = hyperparameters['lr_decay']
if decay_type == 'linear':
    DECAY = 0.2
    DECAY_EPOCHS = [60, 120, 160]
elif decay_type == 'cosine':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, hyperparameters['T_max'], hyperparameters['lr_min'])

print(f'LR schedule: {decay_type}')

In [None]:
hyperparameter_file = os.path.join("./", 'hyperparameters.txt')
with open(hyperparameter_file, 'w') as f:
    for key, value in hyperparameters.items():
        f.write(f"{key}: {value}\n")

log_file_path = os.path.join("./", 'training_log.txt')

In [None]:
pruner = Pruner(resnet_model, N=10, M=100)

In [None]:
with open(log_file_path, 'w') as log_file:
    log_file.write(f"Training started at {datetime.now()}\n")

    best_accuracy = 0.0

    for epoch in range(hyperparameters['epochs']):
        train_loss, train_accuracy = train(resnet_model, train_loader, criterion, optimizer, epoch, log_file)

        pruner.apply_surprisal_sparsity()
        pruner.apply_nm_sparsity()

        test_loss, test_accuracy = test(resnet_model, test_loader, criterion, log_file)

        pruner.print_sparsity()

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            model_checkpoint_path = os.path.join("./", f"best_model.pth")
            torch.save(resnet_model.state_dict(), model_checkpoint_path)
            print(f"Saved best model at epoch {epoch+1} with accuracy: {best_accuracy:.2f}%")

        if decay_type == 'linear':
            if epoch+1 in DECAY_EPOCHS:
                current_learning_rate = current_learning_rate * DECAY
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_learning_rate
                print("Current learning rate has decayed to %f" %current_learning_rate)
        elif decay_type == 'cosine':
            scheduler.step()
            curr_lr = scheduler.get_last_lr()[0]
            print(f"Current learning rate has decayed to {curr_lr:.6f}")


    log_file.write(f"Training completed at {datetime.now()}\n")
    log_file.write(f"Best model accuracy: {best_accuracy:.2f}%\n")
