# Pruning / LTH (CIFAR-100)

In [53]:
import copy
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [2]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
path = "./"

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/pruning"

In [5]:
sys.path.append(path)

In [6]:
from constants import *
from utils import set_seed, train_data, val_data, \
                    train_loader, val_loader, fine_labels, invTrans
from models import get_model_and_optimizer

set_seed()

Files already downloaded and verified
Files already downloaded and verified


In [7]:
MODEL_NAME = (
    f"CNN_CIFAR_100_PRUNE_PCT_{PRUNE_PCT}"
)
print("Model Name:", MODEL_NAME)

Model Name: CNN_CIFAR_100_PRUNE_PCT_80


In [8]:
LOAD_MODEL_NAME = (
    f"CNN_CIFAR_100_PRUNE_PCT_0"
)

# Data

In [9]:
# # expected: (BATCH_SIZE, 3, 32, 32), picture of mountain

# batch = next(iter(train_loader))
# print(batch[0].shape)
# test_idx = 42
# plt.imshow(batch[0][test_idx].permute(1,2,0))
# plt.title(f'{fine_labels[batch[1][test_idx]]}')

# Pruning utils

In [10]:
def init_params(m):
    """
    Initializes params for model `m` given Conv2d, BatchNorm1d,  BatchNorm2d, Linear layers.
    """
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, mean=1, std=0.02)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight.data)
        nn.init.normal_(m.bias.data)

In [11]:
def reset_params(m, mask, init_state):
    """
    Resets surviving model parameters to initial values
    """
    step = 0
    for name, param in model.named_parameters(): 
        if "weight" in name: 
            param.data = torch.from_numpy(init_state[name].cpu().numpy() * mask[step]).to(param.device)
            step += 1
        if "bias" in name:
            param.data = init_state[name]

In [12]:
def generate_init_mask(m):
    """
    Generates initial mask matching the shape of model parameters.
    Returns:
        -Mask: List of length matching the number of weight layers in `m`, each of shape matching the corresponding weight tensor.
    """
    weight_params = [param.data.cpu().numpy() for (name, param) in model.named_parameters() if 'weight' in name]
    # TODO: maybe don't do this——could be unnecessarily memory intensive?
    # if so, revert back to prev method: sum of weight in name to get the length, then down below loop through actual model
    mask = [None] * len(weight_params)
    
    step = 0
    for param in weight_params:
        mask[step] = np.ones_like(param)
        step = step + 1
    return mask

In [13]:
def get_layer_weight_names(m):
    """
    Returns a list of weight layer names of model `m`.
    """
    layer_weight_names = []
    for name, _ in model.named_parameters():
        if 'weight' in name:
            layer_weight_names.append(name.split('.weight')[0])
    return layer_weight_names

In [14]:
def get_num_weight_params_by_layer(mask):
    """
    Returns number of surviving (nonzero) weight parameters in a pruned model via its binary mask, by layer.
    """
    return [np.count_nonzero(mask[i]) for i in range(len(mask))]

In [15]:
def prune_by_percent(m, mask, pct):
    """
    Prunes `pct`% of parameters of model `m`, and modifies the pruning mask in-place as well.
    Specifically, this is done layerwise (p% of weights for each layer).
    """
    assert isinstance(pct, (int, float)) and 0 <= pct and pct <= 100, "`pct` must be a numeric value between 0 and 100 (inclusive)."
    step = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            p_data_all = param.data.cpu().numpy()
            # flattened nonzero weights
            p_data = p_data_all[np.nonzero(p_data_all)]
    
            cutoff_val = np.percentile(np.abs(p_data), pct) # percentile calculated on surviving params

            new_mask = np.where(np.abs(p_data_all) < cutoff_val, 0, mask[step])
            
            param.data = torch.from_numpy(p_data_all * new_mask).to(param.device)
            mask[step] = new_mask
            step += 1

In [16]:
def eval(model, val_loader, criterion, device):
    val_losses = []
    correct = 0
    model.eval()
    
    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):
          
            img, label = img.to(device), label.to(device)
            out = model(img)
            
            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())
            
            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / (len(val_loader) * BATCH_SIZE)
    
    return val_loss, val_acc

In [17]:
# ## archived: initial 100 epoch training
# def initial_train(model, train_loader, val_loader, optimizer, criterion, device):
#     model.train()
#     train_losses, val_losses = [], []
#     val_accuracies = []
#     for epoch in range(EPOCHS):
        
#         print(f"Epoch {epoch+1}/{EPOCHS}")

#         # compute val acc every epoch
#         val_loss, val_acc = eval(model, val_loader, criterion, device)
#         val_losses.append(val_loss)
#         val_accuracies.append(val_acc)
#         print(f"Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.3f}")
#         model.train()
        
#         for step, (img, label) in enumerate(train_loader):

#             img, label = img.to(device), label.to(device)
#             optimizer.zero_grad()
#             out = model(img)
#             loss = criterion(out, label)
#             train_losses.append(loss.item()) # every step
#             loss.backward()
    
#             # Monitoring overall gradient norm
#             grads = [
#                     param.grad.detach().flatten()
#                     for param in model.parameters()
#                     if param.grad is not None
#                 ]
#             norm = torch.cat(grads).norm()
            
#             optimizer.step()
            
#             if step % PRINT_ITERS == 0:
#                 print(f"Step: {step}/{len(train_loader)} | Running Average Loss: {np.mean(train_losses):.3f} | Grad Norm: {norm:.2f}")
            
#         torch.save(
#             {
#                 "model_state_dict": model.state_dict(),
#                 "optimizer_state_dict": optimizer.state_dict(),
#             },
#             f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{epoch+1}_SEED_{SEED}.pt",
#         )

#         with open(
#             f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_train_losses.json", "w"
#         ) as f:
#             json.dump(train_losses, f)

#         with open(
#             f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_losses.json", "w"
#         ) as f2:
#             json.dump(val_losses, f2)

#         with open(
#             f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_accuracies.json", "w"
#         ) as f3:
#             json.dump(val_accuracies, f3)

#     return train_losses, val_losses, val_accuracies

In [18]:
# set_seed()
# model = Net().to(device)
# model.apply(init_params)

# optimizer = optim.AdamW(model.parameters(), lr=LR)
# criterion = nn.CrossEntropyLoss()

# # CPU: ~10 min/epoch, T4: ~45 sec
# train_losses, val_losses, val_accuracies = initial_train(model, train_loader, val_loader, optimizer, criterion, device)

# Pruning (to-test)

In [19]:
def prune_train(model, train_loader, optimizer, criterion, model_pruned_pct, device):
    
    print(f"Training model with {model_pruned_pct:.2f}% pruned")
    model.train()
    train_losses, val_losses = [], []
    val_accuracies = []
    for epoch in range(EPOCHS):
        
        print(f"Epoch {epoch+1}/{EPOCHS}")

        # compute val acc every epoch
        val_loss, val_acc = eval(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        print(f"Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.3f}")
        model.train()

        for i, (img, label) in enumerate(train_loader):
            img, label = img.to(device), label.to(device)
            optimizer.zero_grad()
            out = model(img)
            loss = criterion(out, label)
            train_losses.append(loss.item()) # every step
            loss.backward()
    
            # Monitoring overall gradient norm
            grads = [
                    param.grad.detach().flatten()
                    for param in model.parameters()
                    if param.grad is not None
                ]
            norm = torch.cat(grads).norm()
            
            # Disallow pruned weights from receiving gradient updates
            for name, p in model.named_parameters():
                if 'weight' in name:
                    p_data, p_grad = p.data.cpu().numpy(), p.grad.data.cpu().numpy()
                    grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                    p.grad.data = torch.from_numpy(grad_tensor).to(device)

            optimizer.step()
            
            if step % PRINT_ITERS == 0:
                print(f"Step: {step}/{len(train_loader)} | Running Average Loss: {np.mean(train_losses):.3f} | Grad Norm: {norm:.2f}")
            
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            f"{path}/checkpoints/{MODEL_NAME}_CURR_PRUNE_PCT_{model_pruned_pct:.0f}_SEED_{SEED}.pt",
        )
    
        with open(
            f"{path}/train_logs/{MODEL_NAME}_CURR_PRUNE_PCT_{model_pruned_pct:.0f}_SEED_{SEED}_train_losses.json", "w"
        ) as f:
            json.dump(train_losses, f)
    
        with open(
            f"{path}/train_logs/{MODEL_NAME}_CURR_PRUNE_PCT_{model_pruned_pct:.0f}_SEED_{SEED}_val_losses.json", "w"
        ) as f2:
            json.dump(val_losses, f2)
    
        with open(
            f"{path}/train_logs/{MODEL_NAME}_CURR_PRUNE_PCT_{model_pruned_pct:.0f}_SEED_{SEED}_val_accuracies.json", "w"
        ) as f3:
            json.dump(val_accuracies, f3)

    return train_losses, val_losses, val_accuracies

# Driver code

In [80]:
LOAD_EPOCH = 100

model, _ = get_model_and_optimizer()
model.load_state_dict(torch.load(f"{path}/checkpoints/{LOAD_MODEL_NAME}_EPOCH_{LOAD_EPOCH}_SEED_{SEED}.pt",
                                  map_location=device)["model_state_dict"])
model.to(device)
print('Model loaded')

Model loaded


In [81]:
init_state = copy.deepcopy(model.state_dict())
mask = generate_init_mask(model)

layer_names = get_layer_weight_names(model)
init_num_weight_params_by_layer = get_num_weight_params_by_layer(mask)
init_num_weight_params = sum(init_num_weight_params_by_layer)

optimizer = optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

In [82]:
# orig_model_metrics = eval(model, val_loader, criterion, device)

df = pd.DataFrame(columns = ["prune_iter", "pct_params", "pct_pruned", "val_loss", "val_acc"])

In [66]:
###### UNCOMMENT later

## populate with original model
# df.loc[0] = {"prune_iter": 0, "pct_params": 100.00, "pct_pruned": 0.00,
#              "val_loss": orig_model_metrics[0], "val_acc": orig_model_metrics[1]}
# df

## TEMP constants and unconfirmed code for testing

In [83]:
EPOCHS = 1

PRUNE_PCT = 80
PRUNE_ITERS = 5
PRUNE_ITER_PCT = ( 1 - ( (1 - PRUNE_PCT/100)**(1/PRUNE_ITERS) ) ) * 100
EPS = 1e-7
PRINT_ITERS = 5  # frequency to print train loss

# –––––––––

In [84]:
for prune_iter in range(PRUNE_ITERS):
    
    print(f"Prune Iteration {prune_iter+1}/{PRUNE_ITERS}")
    
    # prune
    prune_by_percent(model, mask, PRUNE_ITER_PCT)
    
    # reset remaining parameters to init state
    reset_params(model, mask, init_state)

    # log statistics
    num_weight_params_by_layer = get_num_weight_params_by_layer(mask)
    model_param_pct = (sum(num_weight_params_by_layer)/init_num_weight_params)*100
    model_pruned_pct = 100 - model_param_pct
    print(f"Number of weight parameters: {sum(num_weight_params_by_layer)}/{init_num_weight_params} ≈ {model_param_pct:.2f}%")
    
    ## create list of proportion of surviving weights (as strings) (e.g. ['72/734', '49/626'])
    params_props = [f"{num_weight_params_by_layer[i]}" + "/" + f"{init_num_weight_params_by_layer[i]}" for i in range(len(mask))]
    print("Weight breakdown:")
    for i, name in enumerate(layer_names):
        print(f"  {name: <15}: {params_props[i]: <}")
    print()
        
    # TRAIN HERE, for j iterations
    # train_losses, val_losses, val_accuracies = prune_train(model, train_loader, optimizer, criterion, model_pruned_pct, device)

    # TODO: how to set j? simply train to convergence again?

    # evaluate and save statistics
    # model_metrics = eval(model, val_loader, criterion, device)
    # df.loc[prune_iter+1] = {"prune_iter": prune_iter+1, "pct_params": model_param_pct, "pct_pruned": model_pruned_pct,
    #                         "val_loss": model_metrics[0], "val_acc": model_metrics[1]}
    
    ## TEMP to avoid time consuming evaluation
    df.loc[prune_iter+1] = {"prune_iter": prune_iter+1, "pct_params": model_param_pct, "pct_pruned": model_pruned_pct,
                            "val_loss": 0.8172, "val_acc": 0.487325}

Prune Iteration 1/5
Number of weight parameters: 1323156/1825600 ≈ 72.48%
Weight breakdown:
  conv1          : 1252/1728
  conv2          : 53436/73728
  conv3          : 213746/294912
  conv4          : 854985/1179648
  batchnorm2d_1  : 93/128
  batchnorm2d_2  : 371/512
  fc1            : 189996/262144
  fc2            : 9277/12800

Prune Iteration 2/5
Number of weight parameters: 958996/1825600 ≈ 52.53%
Weight breakdown:
  conv1          : 907/1728
  conv2          : 38729/73728
  conv3          : 154919/294912
  conv4          : 619676/1179648
  batchnorm2d_1  : 67/128
  batchnorm2d_2  : 269/512
  fc1            : 137705/262144
  fc2            : 6724/12800

Prune Iteration 3/5
Number of weight parameters: 695059/1825600 ≈ 38.07%
Weight breakdown:
  conv1          : 657/1728
  conv2          : 28070/73728
  conv3          : 112282/294912
  conv4          : 449128/1179648
  batchnorm2d_1  : 48/128
  batchnorm2d_2  : 195/512
  fc1            : 99806/262144
  fc2            : 4873/1280

In [85]:
df

Unnamed: 0,prune_iter,pct_params,pct_pruned,val_loss,val_acc
1,1,72.47787,27.52213,0.8172,0.487325
2,2,52.530456,47.469544,0.8172,0.487325
3,3,38.072908,61.927092,0.8172,0.487325
4,4,27.59438,72.40562,0.8172,0.487325
5,5,19.999836,80.000164,0.8172,0.487325


# Plotting

### Part I: Train vs. validation loss of differently pruned model size checkpoints

### Part II: Validation accuracy across the pruned model checkpoints