In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np

from torch.utils.data import DataLoader
from torchvision import datasets
from resnet import ResNet20, ResNet, Bottleneck
from datetime import datetime
from tqdm.notebook import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
hyperparameters = {
    'epochs': 300,
    'lr': 0.1,
    'lr_min': 1e-6,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'batch_size': 391,
    'sparsity_type': "feather",
    'dataset': 'cifar100',
    'model_type': 'rn50',
    'lr_decay': "cosine",
    'T_max': 280,
    'final_rate': 0.9
}

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Feather:
    def __init__(self, gth, theta):
        self.gth = gth
        self.theta = theta

    def forward(self, w):
        return Feather_aux.apply(w, self.gth, self.theta)

class Feather_aux(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, gth, theta):
        ctx.aux = torch.where(torch.abs(w) > gth, 1.0, theta)
        p = 3
        diff = torch.abs(w)**p - gth**p
        w_masked = torch.where(diff > 0, torch.sign(w)*(diff)**(1/p), 0.0)
        return w_masked

    @staticmethod
    def backward(ctx, g):
        g = ctx.aux * g
        return g, None, None

class SparseConv(nn.Module):
    def __init__(self, conv, feather):
        super(SparseConv, self).__init__()
        self.conv = conv
        self.feather = feather

    def forward(self, x):
        w = self.conv.weight
        b = self.conv.bias
        stride = self.conv.stride
        padding = self.conv.padding
        groups = self.conv.groups

        if self.feather.gth > 0:
            w = self.feather.forward(w)

        out = F.conv2d(x, w, bias=b, padding=padding, stride=stride, groups=groups)
        return out


class SparseFc(nn.Module):
    def __init__(self, fc, feather):
        super(SparseFc, self).__init__()
        self.fc = fc
        self.feather = feather

    def forward(self, x):
        w = self.fc.weight
        b = self.fc.bias

        if self.feather.gth > 0:
            w = self.feather.forward(w)

        out = F.linear(x, w, bias=b)
        return out

def iter_sparsify(m, feather, pthres=0):
    for name, child in m.named_children():
        iter_sparsify(child, feather, pthres)

        if isinstance(child, nn.Conv2d):
            nw = (child.in_channels * child.out_channels * child.kernel_size[0] * child.kernel_size[1]) / child.groups
            if nw >= pthres:
                slayer = SparseConv(child, feather)
                m.__setattr__(name, slayer)

        if isinstance(child, nn.Linear):
            nw = child.in_features * child.out_features
            if nw >= pthres:
                slayer = SparseFc(child, feather)
                m.__setattr__(name, slayer)


def iter_desparsify(m, feather):
    for name, child in m.named_children():
        iter_desparsify(child, feather)

        if isinstance(child, SparseConv):
            conv = child.conv
            w = conv.weight.data
            nw = feather.forward(w)
            conv.weight.data = nw
            m.__setattr__(name, conv)

        if isinstance(child, SparseFc):
            fc = child.fc
            w = fc.weight.data
            nw = feather.forward(w)
            fc.weight.data = nw
            m.__setattr__(name, fc)

def get_params(model):
    bn_ids =[]
    modules = list(model.named_modules())
    for n, layer in modules:
        if isinstance(layer, torch.nn.modules.batchnorm.BatchNorm2d):
            bn_ids.append(id(layer.weight))
            bn_ids.append(id(layer.bias))

    params, params_nowd = [], []
    for name, p in model.named_parameters():
        if id(p) in bn_ids or 'bias' in name:
            params_nowd += [p]
        else:
            params += [p]
    return params, params_nowd

def get_prunable_weights_cnt(model):
    prunable_weights_cnt = 0
    temp_dims = [0]
    for name, layer in model.named_modules():
        if isinstance(layer, SparseConv) or isinstance(layer, SparseFc):
            if isinstance(layer, SparseConv):
                w = layer.conv.weight
            elif isinstance(layer, SparseFc):
                w = layer.fc.weight
            temp_dims.append(w.numel())
            prunable_weights_cnt += w.numel()

    idx_list = [0]
    for i in range(len(temp_dims)):
        idx_list.append(temp_dims[i] + idx_list[i])

    return prunable_weights_cnt, idx_list

def calc_thresh(w, ratio):
    w_sorted, _ = torch.sort(w)
    m = (len(w_sorted)-1)*ratio
    idx_floor, idx_ceil = int(np.floor(m)), int(np.ceil(m))
    v1, v2 = w_sorted[idx_floor], w_sorted[idx_ceil]
    thresh = v1 + (v2-v1)*(m-idx_floor)
    return thresh.item()

def get_global_thresh(model, device, st_batch, prunable_weights_cnt, idx_list):
    i = 1
    w_total = torch.empty(prunable_weights_cnt).to(device)
    for name, layer in model.named_modules():
        if isinstance(layer, SparseConv) or isinstance(layer, SparseFc):
            if isinstance(layer, SparseConv):
                w = layer.conv.weight.flatten().detach()
            elif isinstance(layer, SparseFc):
                w = layer.fc.weight.flatten().detach()

            w_total[idx_list[i] : idx_list[i+1]] = w
            i +=1

    global_thresh = calc_thresh(torch.abs(w_total), st_batch)
    return global_thresh

def pruning_scheduler(final_rate, nbatches, ntotalsteps, t1):
    kf = final_rate
    t1 = t1*nbatches
    t2 = int(np.floor(ntotalsteps*0.5))

    t = np.arange(t1,t2)
    k = np.hstack(( np.zeros(t1), ( kf - kf*(1-(t-t1)/(t2-t1))**3), kf*np.ones(ntotalsteps-t2) ))
    return k

def get_theta(final_rate):
    if final_rate > 0.95:
        theta = 0.5
    else:
        theta = 1.0
    return theta


class Pruner:
    def __init__(self, model, device, final_rate, nbatches, epochs, pthres=0, t1=0):
        theta = get_theta(final_rate)
        self.ntotalsteps = nbatches * epochs
        self.step_idx = 0

        self.feather = Feather(gth=0.0, theta=theta)
        self.device = device
        self.t1 = t1

        self.model = model

        iter_sparsify(m=self.model, feather=self.feather, pthres=pthres)
        # print(self.model)

        prunable_weights_cnt, idx_list = get_prunable_weights_cnt(self.model)
        self.prunable_weights_cnt = prunable_weights_cnt
        self.idx_list = idx_list

        pscheduler = pruning_scheduler(final_rate, nbatches, self.ntotalsteps, self.t1)
        self.pscheduler = pscheduler

    def update_thresh(self, end_of_batch=False):
        idx = self.step_idx
        if end_of_batch:
            idx -=1
        st_batch = self.pscheduler[idx]

        new_gth = 0.0
        if st_batch > 0:
            new_gth = get_global_thresh(self.model, self.device, st_batch, self.prunable_weights_cnt, self.idx_list)

        self.feather.gth = new_gth
        if not end_of_batch:
            self.step_idx += 1

    def print_sparsity(self):
        local_zeros_cnt = 0
        for name, layer in resnet_model.named_modules():
            if isinstance(layer, SparseConv) or isinstance(layer, SparseFc):
                if isinstance(layer, SparseConv):
                    w = layer.conv.weight
                elif isinstance(layer, SparseFc):
                    w = layer.fc.weight

                th = pruner.feather.gth
                nw = F.hardshrink(w, th)
                tsparsity = (nw == 0).float().sum().item()

                tnum = nw.numel()
                print(f'{name}'.ljust(40), f'#w: {int(tnum)}'.ljust(11), f'| sparsity: {round(100.0 * tsparsity / tnum, 2)}%'.ljust(18))

        return 100 * float(local_zeros_cnt) / float(self.prunable_weights_cnt)

    def desparsify(self):
        iter_desparsify(self.model, self.feather)


In [None]:
data_type = hyperparameters['dataset']

print(f'Data type: {data_type}')


transform_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.201]),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.201]),
])


if data_type == "cifar10":
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)
elif data_type == 'cifar100':
    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_val)


train_loader = DataLoader(train_dataset, batch_size=hyperparameters['batch_size'], shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=hyperparameters['batch_size'], shuffle=False, num_workers=2)

In [None]:
model_type = hyperparameters['model_type']
classes = 100 if data_type == 'cifar100' else 10

print(f'Model: {model_type}')

if model_type == 'rn20':
    resnet_model = ResNet20(classes)
    resnet_model.to(device)
elif model_type == 'rn50':
    resnet_model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=classes)
    resnet_model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet_model.parameters(), lr=hyperparameters['lr'],
                      momentum=hyperparameters['momentum'], weight_decay=hyperparameters['weight_decay'])

In [None]:
current_learning_rate = 0.1

decay_type = hyperparameters['lr_decay']
if decay_type == 'linear':
    DECAY = 0.2
    DECAY_EPOCHS = [60, 120, 160]
elif decay_type == 'cosine':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, hyperparameters['T_max'], hyperparameters['lr_min'])

print(f'LR schedule: {decay_type}')

In [None]:
pruner = Pruner(resnet_model,
                device,
                final_rate=hyperparameters['final_rate'],
                nbatches=hyperparameters['batch_size'],
                epochs=hyperparameters['epochs'])

In [None]:
def train(model, train_loader, criterion, optimizer, epoch, log_file):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch_idx, (inputs, targets) in enumerate(pbar):
        pruner.update_thresh()

        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()
        total += targets.size(0)

        pbar.set_postfix(loss=running_loss/(batch_idx+1), accuracy=100.0 * correct / total)

    avg_loss = running_loss / len(train_loader)
    accuracy = 100.0 * correct / total

    sparsity = pruner.print_sparsity()
    log_file.write(f'Epoch [{epoch+1}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%, Sparsity: {sparsity:.2f}%\n')

    pruner.update_thresh(end_of_batch=True)

    return avg_loss, accuracy

In [None]:
def test(model, test_loader, criterion, log_file):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0

    pbar = tqdm(test_loader, desc="Testing")

    with torch.no_grad():
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()

            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

            pbar.set_postfix(loss=test_loss/(total + inputs.size(0)), accuracy=100.0 * correct / total)

    avg_test_loss = test_loss / len(test_loader)
    accuracy = 100.0 * correct / total
    log_file.write(f'Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}%\n')

    return avg_test_loss, accuracy


In [None]:
hyperparameter_file = os.path.join('./', 'hyperparameters.txt')
with open(hyperparameter_file, 'w') as f:
    for key, value in hyperparameters.items():
        f.write(f"{key}: {value}\n")

log_file_path = os.path.join('./', 'training_log.txt')

In [None]:
with open(log_file_path, 'w') as log_file:
    log_file.write(f"Training started at {datetime.now()}\n")

    best_accuracy = 0.0

    for epoch in range(hyperparameters['epochs']):
        train_loss, train_accuracy = train(pruner.model, train_loader, criterion, optimizer, epoch, log_file)
        test_loss, test_accuracy = test(pruner.model, test_loader, criterion, log_file)

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            model_checkpoint_path = os.path.join('./', f"best_model.pth")
            # pruner.desparsify()
            torch.save(pruner.model.state_dict(), model_checkpoint_path)
            print(f"Saved best model at epoch {epoch+1} with accuracy: {best_accuracy:.2f}%")

        if decay_type == 'linear':
            if epoch+1 in DECAY_EPOCHS:
                current_learning_rate = current_learning_rate * DECAY
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_learning_rate
                print("Current learning rate has decayed to %f" %current_learning_rate)
        elif decay_type == 'cosine':
            scheduler.step()
            curr_lr = scheduler.get_last_lr()[0]
            print(f"Current learning rate has decayed to {curr_lr:.6f}")


    log_file.write(f"Training completed at {datetime.now()}\n")
    log_file.write(f"Best model accuracy: {best_accuracy:.2f}%\n")

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=20, shuffle=False)
images, _ = next(iter(loader))  # [10, 3, 32, 32]

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import torch

def calculate_sparsity(tensor):
    if tensor is None:
        return 0.0, 0, 0
    num_zeros = torch.sum(tensor == 0).item()
    total_elements = tensor.numel()
    return num_zeros / total_elements, num_zeros, total_elements

state_dict = pruner.model.state_dict()

if "state_dict" in state_dict:
    state_dict = state_dict["state_dict"]

new_state_dict = {}
for k, v in state_dict.items():
    new_k = k.replace("module.", "") if k.startswith("module.") else k
    new_state_dict[new_k] = v
resnet_model.load_state_dict(new_state_dict, strict=False)

total_zeros = 0
total_params = 0

for name, param in pruner.model.named_parameters():
    if 'weight' in name and param.requires_grad:
        sparsity, zeros, total = calculate_sparsity(param.data)
        print(f"{name:40s}: sparsity = {sparsity:.4f}")
        total_zeros += zeros
        total_params += total

for name, layer in resnet_model.named_modules():
    if isinstance(layer, SparseConv) or isinstance(layer, SparseFc):
        if isinstance(layer, SparseConv):
            w = layer.conv.weight
        elif isinstance(layer, SparseFc):
            w = layer.fc.weight

        th = pruner.feather.gth
        nw = F.hardshrink(w, th)
        tsparsity = (nw == 0).float().sum().item()

        tnum = nw.numel()
        print(f'{name}'.ljust(40), f'#w: {int(tnum)}'.ljust(11), f'| sparsity: {round(100.0 * tsparsity / tnum, 2)}%'.ljust(18))

total_sparsity = total_zeros / total_params
print(f"\nTotal model sparsity: {total_sparsity:.4f} ({total_zeros:,} zero weights out of {total_params:,})")


In [None]:
import torch.nn.functional as F

def hard_prune_model_weights(model, threshold):
    for name, layer in model.named_modules():
        if isinstance(layer, SparseConv):
            weight = layer.conv.weight.data
            mask = weight.abs() < threshold
            weight[mask] = 0.0
            print(f"[Pruned] {name} (SparseConv): Set {mask.sum().item()} weights to zero")

        elif isinstance(layer, SparseFc):
            weight = layer.fc.weight.data
            mask = weight.abs() < threshold
            weight[mask] = 0.0
            print(f"[Pruned] {name} (SparseFc): Set {mask.sum().item()} weights to zero")

hard_prune_model_weights(resnet_model, pruner.feather.gth)
torch.save(resnet_model.state_dict(), "resnet_pruned.pth")

In [None]:
import torch
from collections import OrderedDict

f_m = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=100)

source_params = list(resnet_model.state_dict().items())
target_params = list(f_m.state_dict().items())

assert len(source_params) == len(target_params), "Mismatch in number of parameters"

new_state_dict = OrderedDict()

for (t_name, _), (s_name, s_param) in zip(target_params, source_params):
    new_state_dict[t_name] = s_param.clone()

f_m.load_state_dict(new_state_dict)

In [None]:
state_dict = f_m.state_dict()

if "state_dict" in state_dict:
    state_dict = state_dict["state_dict"]

new_state_dict = {}
for k, v in state_dict.items():
    new_k = k.replace("module.", "") if k.startswith("module.") else k
    new_state_dict[new_k] = v
f_m.load_state_dict(new_state_dict, strict=False)

total_zeros = 0
total_params = 0

for name, param in f_m.named_parameters():
    if 'weight' in name and param.requires_grad:
        sparsity, zeros, total = calculate_sparsity(param.data)
        print(f"{name:40s}: sparsity = {sparsity:.4f}")
        total_zeros += zeros
        total_params += total

In [None]:
torch.save(f_m.state_dict(), './final_pruned_model.pth')

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),  # CIFAR-100 mean & std
])

test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

def evaluate(model, dataloader, criterion):
    model.eval()
    device = next(model.parameters()).device
    correct = 0
    total = 0
    total_loss = 0.0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = 100.0 * correct / total
    avg_loss = total_loss / total
    return avg_loss, accuracy

criterion = nn.CrossEntropyLoss()

test_loss, test_accuracy = evaluate(resnet_model, test_loader, criterion)
print(f"Test Accuracy: {test_accuracy:.2f}% | Test Loss: {test_loss:.4f}")

print(resnet_model)


In [None]:
torch.save(resnet_model, "pruned_full_model.pth")


In [None]:
pruner.desparsify()
print(resnet_model)

test_loss, test_accuracy = evaluate(resnet_model, test_loader, criterion)
print(f"Test Accuracy: {test_accuracy:.2f}% | Test Loss: {test_loss:.4f}")


In [None]:
print(pruner.feather.gth)
torch.save(pruner.model.state_dict(), "mymodel.pth")

In [None]:
torch.save(resnet_model.state_dict(), "r1.pth")