In [1]:
import os
import numpy as np
import cv2
import pickle
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchsummary import summary
import torch.nn.utils as utils

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
!ls
!tar -xvzf dataset.tar.gz

dataset.tar.gz	my_model_weights_1.pt  sample_data
train_images.pkl
train_labels.pkl
val_images.pkl
val_labels.pkl


In [4]:
# load train
train_images = pickle.load(open('train_images.pkl', 'rb'))
train_labels = pickle.load(open('train_labels.pkl', 'rb'))
# load val
val_images = pickle.load(open('val_images.pkl', 'rb'))
val_labels = pickle.load(open('val_labels.pkl', 'rb'))

In [5]:
train_images = torch.tensor(train_images, dtype=torch.float32)
val_images = torch.tensor(val_images, dtype=torch.float32)

train_images = train_images.permute(0, 3, 1, 2)
val_images = val_images.permute(0, 3, 1, 2)

In [6]:
train_dataset = TensorDataset(train_images,
                              torch.tensor(train_labels.squeeze(), dtype=torch.long))
val_dataset = TensorDataset(val_images,
                            torch.tensor(val_labels.squeeze(), dtype=torch.long))

In [7]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [8]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.model = nn.Sequential(
            # First block: Conv -> ReLU -> Conv -> ReLU -> MaxPool -> Dropout
            nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=True),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=0, bias=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),

            # Second block: Conv -> ReLU -> Conv -> ReLU -> MaxPool -> Dropout
            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=0, bias=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),

            # Flatten layer
            nn.Flatten(),

            # Fully connected block: Dense -> ReLU -> Dropout -> Dense -> Softmax
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 5),
        )

    def forward(self, x):
        return self.model(x)

# Running Iterative Pruning

Given some "sparsity" value, we will set the weights of the model for the values to 0 train for n steps and then update the process iteratively.

In [64]:
model = ConvNet()

model.load_state_dict(torch.load('my_model_weights_1.pt'))
model.to(device)

criterion = torch.nn.CrossEntropyLoss()

In [65]:
sparsity = 0.95
N = 500 # number of steps of training
lr = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)

In [66]:
# getting lists of weights in the model
weights = []
with torch.no_grad():  # disable gradient tracking for efficiency
  for name, param in model.named_parameters():
    if "weight" in name:  # only apply to weights, skip biases
      weights.append(name)

In [67]:
def train_n_steps(model, train_loader, optimizer, criterion, device,
                  n_steps = N, s = sparsity):


    # setting sparsity
    with torch.no_grad():  # disable gradient tracking for efficiency
      # weight threshold to prune
      threshold = torch.cat([torch.flatten(torch.absolute(model.get_parameter(i))) for i in weights]).quantile(s)
      #print(threshold)
      for name, param in model.named_parameters():
          if "weight" in name:  # only apply to weights, skip biases
              param[torch.absolute(param) < threshold] = 0

    model.train()  # Set model to training mode
    running_loss = 0.0
    correct = 0
    total = 0

    steps = 0



    # Progress bar for the training loop
    train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False, total=n_steps)

    for inputs, labels in train_loader_tqdm:

        steps += 1

        if steps > n_steps:
            break

        # Zero the parameter gradients
        optimizer.zero_grad()  # Zero the parameter gradients
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Track loss and accuracy
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Update tqdm description with current loss and accuracy
        train_loader_tqdm.set_postfix(loss=running_loss / n_steps, accuracy=100 * correct / total)

    train_accuracy = 100 * correct / total
    train_loss = running_loss / n_steps
    return train_loss, train_accuracy

In [68]:
def validate(model, val_loader, criterion, device):
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0

    # Progress bar for the validation loop
    val_loader_tqdm = tqdm(val_loader, desc="Validation", leave=False)

    with torch.no_grad():  # Disable gradient calculations for validation
        for inputs, labels in val_loader_tqdm:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Track loss and accuracy
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Update tqdm description with current validation loss and accuracy
            val_loader_tqdm.set_postfix(loss=val_loss / total, accuracy=100 * correct / total)

    val_accuracy = 100 * correct / total
    val_loss = val_loss / len(val_loader)
    return val_loss, val_accuracy

In [69]:
# Main training loop

train_loss_hist = []
train_acc_hist = []
val_loss_hist = []
val_acc_hist = []

num_iter = 50
for iter in range(num_iter):
    print(f"Pruning Iteration {iter+1}/{num_iter}")

    # Training
    train_loss, train_accuracy = train_n_steps(model, train_loader, optimizer, criterion, device)

    # Validation
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)

    # Keep History of Output
    train_loss_hist.append(train_loss)
    train_acc_hist.append(train_accuracy)
    val_loss_hist.append(val_loss)
    val_acc_hist.append(val_accuracy)

    # Print epoch results
    print(f'Iter [{iter+1}/{num_iter}], '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')

Pruning Iteration 1/50




Iter [1/50], Train Loss: 1.1580, Train Acc: 53.02%, Val Loss: 1.0268, Val Acc: 58.30%
Pruning Iteration 2/50




Iter [2/50], Train Loss: 1.1149, Train Acc: 55.19%, Val Loss: 1.0062, Val Acc: 59.17%
Pruning Iteration 3/50




Iter [3/50], Train Loss: 1.0815, Train Acc: 57.19%, Val Loss: 1.0553, Val Acc: 56.51%
Pruning Iteration 4/50




Iter [4/50], Train Loss: 1.0590, Train Acc: 57.96%, Val Loss: 1.0481, Val Acc: 56.63%
Pruning Iteration 5/50




Iter [5/50], Train Loss: 1.0388, Train Acc: 59.51%, Val Loss: 0.9098, Val Acc: 64.24%
Pruning Iteration 6/50




Iter [6/50], Train Loss: 1.0156, Train Acc: 60.67%, Val Loss: 0.9038, Val Acc: 64.67%
Pruning Iteration 7/50




Iter [7/50], Train Loss: 1.0089, Train Acc: 60.51%, Val Loss: 0.8809, Val Acc: 64.79%
Pruning Iteration 8/50




Iter [8/50], Train Loss: 1.0064, Train Acc: 60.48%, Val Loss: 0.9850, Val Acc: 61.23%
Pruning Iteration 9/50




Iter [9/50], Train Loss: 0.9796, Train Acc: 61.74%, Val Loss: 0.8838, Val Acc: 65.70%
Pruning Iteration 10/50




Iter [10/50], Train Loss: 0.9820, Train Acc: 61.30%, Val Loss: 0.9013, Val Acc: 65.70%
Pruning Iteration 11/50




Iter [11/50], Train Loss: 0.9646, Train Acc: 62.12%, Val Loss: 0.9394, Val Acc: 62.61%
Pruning Iteration 12/50




Iter [12/50], Train Loss: 0.9612, Train Acc: 62.23%, Val Loss: 0.8829, Val Acc: 65.39%
Pruning Iteration 13/50




Iter [13/50], Train Loss: 0.9537, Train Acc: 62.84%, Val Loss: 0.9465, Val Acc: 61.86%
Pruning Iteration 14/50




Iter [14/50], Train Loss: 0.9588, Train Acc: 62.78%, Val Loss: 0.8998, Val Acc: 65.86%
Pruning Iteration 15/50




Iter [15/50], Train Loss: 0.9391, Train Acc: 63.79%, Val Loss: 0.9266, Val Acc: 63.05%
Pruning Iteration 16/50




Iter [16/50], Train Loss: 0.9431, Train Acc: 63.34%, Val Loss: 0.9317, Val Acc: 62.93%
Pruning Iteration 17/50




Iter [17/50], Train Loss: 0.9388, Train Acc: 63.65%, Val Loss: 0.9041, Val Acc: 64.36%
Pruning Iteration 18/50




Iter [18/50], Train Loss: 0.9271, Train Acc: 64.10%, Val Loss: 0.8539, Val Acc: 66.42%
Pruning Iteration 19/50




Iter [19/50], Train Loss: 0.9253, Train Acc: 64.03%, Val Loss: 0.8909, Val Acc: 65.54%
Pruning Iteration 20/50




Iter [20/50], Train Loss: 0.9331, Train Acc: 63.84%, Val Loss: 0.8709, Val Acc: 65.90%
Pruning Iteration 21/50




Iter [21/50], Train Loss: 0.9041, Train Acc: 65.01%, Val Loss: 0.8812, Val Acc: 65.98%
Pruning Iteration 22/50




Iter [22/50], Train Loss: 0.9055, Train Acc: 65.06%, Val Loss: 0.8477, Val Acc: 66.14%
Pruning Iteration 23/50




Iter [23/50], Train Loss: 0.9132, Train Acc: 64.53%, Val Loss: 0.8594, Val Acc: 66.57%
Pruning Iteration 24/50




Iter [24/50], Train Loss: 0.9044, Train Acc: 65.03%, Val Loss: 0.8596, Val Acc: 65.58%
Pruning Iteration 25/50




Iter [25/50], Train Loss: 0.9037, Train Acc: 65.26%, Val Loss: 0.8869, Val Acc: 65.19%
Pruning Iteration 26/50




Iter [26/50], Train Loss: 0.9064, Train Acc: 64.95%, Val Loss: 0.8514, Val Acc: 65.98%
Pruning Iteration 27/50




Iter [27/50], Train Loss: 0.8967, Train Acc: 65.47%, Val Loss: 0.8565, Val Acc: 66.77%
Pruning Iteration 28/50




Iter [28/50], Train Loss: 0.8924, Train Acc: 65.77%, Val Loss: 0.8843, Val Acc: 65.54%
Pruning Iteration 29/50




Iter [29/50], Train Loss: 0.8943, Train Acc: 65.84%, Val Loss: 0.8653, Val Acc: 67.09%
Pruning Iteration 30/50




Iter [30/50], Train Loss: 0.8879, Train Acc: 65.49%, Val Loss: 0.8413, Val Acc: 67.05%
Pruning Iteration 31/50




Iter [31/50], Train Loss: 0.9014, Train Acc: 65.36%, Val Loss: 0.8584, Val Acc: 66.61%
Pruning Iteration 32/50




Iter [32/50], Train Loss: 0.8891, Train Acc: 65.74%, Val Loss: 0.8039, Val Acc: 68.59%
Pruning Iteration 33/50




Iter [33/50], Train Loss: 0.8805, Train Acc: 66.19%, Val Loss: 0.8688, Val Acc: 66.46%
Pruning Iteration 34/50




Iter [34/50], Train Loss: 0.8857, Train Acc: 66.30%, Val Loss: 0.8301, Val Acc: 68.16%
Pruning Iteration 35/50




Iter [35/50], Train Loss: 0.8780, Train Acc: 66.05%, Val Loss: 0.8443, Val Acc: 67.29%
Pruning Iteration 36/50




Iter [36/50], Train Loss: 0.8885, Train Acc: 65.57%, Val Loss: 0.8456, Val Acc: 67.13%
Pruning Iteration 37/50




Iter [37/50], Train Loss: 0.8754, Train Acc: 66.45%, Val Loss: 0.8609, Val Acc: 67.37%
Pruning Iteration 38/50




Iter [38/50], Train Loss: 0.8766, Train Acc: 65.97%, Val Loss: 0.8293, Val Acc: 67.29%
Pruning Iteration 39/50




Iter [39/50], Train Loss: 0.8673, Train Acc: 66.79%, Val Loss: 0.8293, Val Acc: 67.92%
Pruning Iteration 40/50




Iter [40/50], Train Loss: 0.8703, Train Acc: 66.71%, Val Loss: 0.8378, Val Acc: 67.21%
Pruning Iteration 41/50




Iter [41/50], Train Loss: 0.8821, Train Acc: 66.01%, Val Loss: 0.8186, Val Acc: 68.24%
Pruning Iteration 42/50




Iter [42/50], Train Loss: 0.8706, Train Acc: 66.59%, Val Loss: 0.8131, Val Acc: 68.24%
Pruning Iteration 43/50




Iter [43/50], Train Loss: 0.8561, Train Acc: 66.78%, Val Loss: 0.8478, Val Acc: 67.33%
Pruning Iteration 44/50




Iter [44/50], Train Loss: 0.8710, Train Acc: 66.64%, Val Loss: 0.8540, Val Acc: 66.89%
Pruning Iteration 45/50




Iter [45/50], Train Loss: 0.8612, Train Acc: 66.71%, Val Loss: 0.8262, Val Acc: 69.03%
Pruning Iteration 46/50




Iter [46/50], Train Loss: 0.8706, Train Acc: 66.29%, Val Loss: 0.7892, Val Acc: 68.99%
Pruning Iteration 47/50




Iter [47/50], Train Loss: 0.8620, Train Acc: 67.26%, Val Loss: 0.8125, Val Acc: 68.83%
Pruning Iteration 48/50




Iter [48/50], Train Loss: 0.8602, Train Acc: 67.19%, Val Loss: 0.8245, Val Acc: 67.49%
Pruning Iteration 49/50




Iter [49/50], Train Loss: 0.8582, Train Acc: 66.95%, Val Loss: 0.8178, Val Acc: 67.49%
Pruning Iteration 50/50


                                                                                        

Iter [50/50], Train Loss: 0.8549, Train Acc: 67.12%, Val Loss: 0.8429, Val Acc: 66.53%




In [71]:
# setting sparsity
#with torch.no_grad():  # disable gradient tracking for efficiency
  # weight threshold to prune
#  threshold = torch.cat([torch.flatten(torch.absolute(model.get_parameter(i))) for i in weights]).quantile(sparsity)
#  for name, param in model.named_parameters():
#      if "weight" in name:  # only apply to weights, skip biases
#          param[torch.absolute(param) < threshold] = 0
validate(model, val_loader, criterion, device)



(0.8428579821616788, 66.53465346534654)

In [72]:
model_name = f'pruning_l1_iterative_sparisity{sparsity}_N{N}_lr{lr}_iter{iter}.pt'
torch.save(model.state_dict(), model_name, _use_new_zipfile_serialization=False)