In [1]:
import os
import sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), 'resnet'))
sys.path.append(parent_dir)

from tqdm import tqdm
from ranger import Ranger

import torch
from torch import nn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn.utils.prune as prune
import torch.nn.init as init

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import Normalize
from torchmetrics import Accuracy
import torchvision.utils as vutils

import torch.optim as optim

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import pickle
import random
import copy
import gc

import warnings
warnings.filterwarnings('ignore')

from resnet_18 import *

In [2]:
%run utils.ipynb

In [3]:
sns.set_style('darkgrid')

In [4]:
print(torch.cuda.is_available())

True


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
batch_size = 128
train_path = '../datasets/imagenette2/train'
val_path = '../datasets/imagenette2/val'

In [7]:
train_dataloader = DataLoader(datasets.ImageFolder(train_path, transform = transforms.Compose([
                                                                    transforms.RandomResizedCrop(224),
                                                                    transforms.RandomHorizontalFlip(),
                                                                    transforms.ToTensor(),
                                                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                                        std=[0.229, 0.224, 0.225])
                                                            ])), batch_size = batch_size, shuffle=True, num_workers=5, pin_memory=True)

test_dataloader = DataLoader(datasets.ImageFolder(val_path,
                                                               transform=transforms.Compose([
                                                                   transforms.ToTensor(),
                                                                   transforms.Resize([224, 224]),
                                                                   transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                                        std=[0.229, 0.224, 0.225])
                                                               ])),batch_size=batch_size, shuffle=False, num_workers=5, pin_memory=True)

In [8]:
classes = ('tench', 'springer', 'casette_player', 'chain_saw','church', 'French_horn', 'garbage_truck', 'gas_pump', 'golf_ball', 'parachute')


In [9]:
model = resnet_18(pretrained=False, filter='None', filter_layer=0).to(device)
# Set the initial learning rate for the optimizer.
learning_rate = 1e-3
# Define the starting iteration for training
start_iteration = 0
# Set the total number of training iterations.
end_iteration = 30
# Specify how often to print training progress.
print_frequency = 1
# Specify how often to run validation.
valid_frequency = 1
# Set the percentage of weights to remove in each round of pruning.
prune_percent = 20
# Define the number of iterative pruning cycles to perform.
prune_iterations = 30


In [10]:
# copy initial state dict
initial_state_dict = copy.deepcopy(model.state_dict())
# check or create the directory
checkdir(f"{os.getcwd()}/models/")
# Save initial state dict
torch.save(model, f"{os.getcwd()}/models/initial_state_dict_lt.pth.tar")

In [11]:
mask = generate_mask(model)

In [12]:
# Loss function
criterion = nn.CrossEntropyLoss(reduction="mean").cuda()
# ranger optimizer
optimizer = Ranger(model.parameters(),lr = learning_rate, weight_decay=1e-4, eps = 1e-06)
# Cosine Scheduler for learning rate
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=end_iteration)

Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers


In [13]:
# tracks best accuracy
best_accuracy = 0
# Stores the best validation accuracy achieved after each pruning and retraining cycle.
best_accuracies = np.zeros(prune_iterations,float)
step = 0

In [None]:
# This loop performs the full cycle of pruning, rewinding weights, and retraining for `prune_iterations` rounds.
for prune_iter in range(start_iteration, prune_iterations):
    # The first iteration trains the full, unpruned network to establish a baseline.
    if not prune_iter == 0:
        # Prune the model by removing a 'prune_percent' of the weights with the smallest magnitudes.
        mask = prune_network(prune_percent,mask, model)

        # resets the remaining (unpruned) weights back to their original values from before training started.
        original_initialization(mask,model, initial_state_dict)

        # Re-initialize the optimizer.
        optimizer = Ranger(model.parameters(), lr=learning_rate, weight_decay=1e-4, eps = 1e-06)
        
    print(f"\nPruning Level ({prune_iter}/{prune_iterations}):")

    # Print the table of Nonzeros in each layer
    print_nonzeros(model)
    progress_bar = tqdm(range(end_iteration))

    # Inner Training Loop for one pruning iteration
    for iteration in progress_bar:
        # Evaluate the model on the test set at the frequency defined by `valid_frequency`.
        if iteration % valid_frequency == 0:
            accuracy = test(model, test_dataloader, criterion)

            # Check if the current accuracy is the best one in this pruning iteration.
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                checkdir(f"{os.getcwd()}/models/")
                # Save the best performing model for this pruning iteration.
                torch.save(model,f"{os.getcwd()}/models/{prune_iter}_model_lt.pth.tar")

        # Perform one epoch of training.
        loss = train(model, train_dataloader, optimizer, criterion,scheduler)
        
        # Update the progress bar description at the frequency defined by `print_frequency`.
        if iteration % print_frequency == 0:
            progress_bar.set_description(
                f'Train Epoch: {iteration}/{end_iteration} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}%')       

    # After retraining is complete, store the best accuracy achieved during this run.
    best_accuracies[prune_iter]=best_accuracy
    
    # Reset tracking variables for the next pruning iteration.
    best_accuracy = 0    