In [1]:
!pip install pandas



In [2]:
import torch

print("PyTorch version:", torch.__version__)


PyTorch version: 2.5.1+cu124


In [3]:
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import torchvision.models as models
import time

In [4]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
print(device)

cuda


In [6]:
# Data augmentation for training
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.RandomCrop(32, padding=4),  # Randomly crop the image
    transforms.Resize(224),  # Resize to 224x224 for ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# No augmentation for validation and test
val_test_transform = transforms.Compose([
    transforms.Resize(224),  # Resize to 224x224 for ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
train_val_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_test_transform)

# Split train_val_dataset into train and validation sets (80% train, 20% validation)
train_size = int(0.8 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size
train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])

# Apply val_test_transform to the validation set
val_dataset.dataset.transform = val_test_transform

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:08<00:00, 21.2MB/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [7]:
# Load pretrained ResNet-50 (Teacher Model)
teacher = models.resnet50(pretrained=False)

# Modify the final fully connected layer for 10 classes (CIFAR-10)
teacher.fc = nn.Linear(teacher.fc.in_features, 10)
# Move models to device
teacher = teacher.to(device)



In [8]:

model_path = '/kaggle/input/teacherc10/pytorch/default/1/Best_Teacher.pth'
# Load the model weights
teacher.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  teacher.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [9]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(pretrained=True)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 84.0MB/s]


In [10]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [11]:
# Logits normalization function
def normalize(logit):
    mean = logit.mean(dim=-1, keepdim=True)
    stdv = logit.std(dim=-1, keepdim=True)
    return (logit - mean) / (1e-7 + stdv)


In [12]:
# CA-KLD Loss for Classification
def cakld_loss(student_logits, teacher_logits, beta_prob):
    # Forward KL (student || teacher)
    student_log_prob = F.log_softmax(student_logits, dim=1)
    teacher_prob = F.softmax(teacher_logits, dim=1)
    forward_kl = F.kl_div(student_log_prob, teacher_prob, reduction='batchmean')

    # Reverse KL (teacher || student)
    teacher_log_prob = F.log_softmax(teacher_logits, dim=1)
    student_prob = F.softmax(student_logits, dim=1)
    reverse_kl = F.kl_div(teacher_log_prob, student_prob, reduction='batchmean')

    # Combined KL loss
    kl_loss = beta_prob * reverse_kl + (1 - beta_prob) * forward_kl
    return kl_loss


In [13]:
def evaluate(model, test_loader, device):
    model = model.to(device)  # Ensure model is on the correct device
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total


In [14]:
def calculate_sparsity(model):
    total_zeros = 0
    total_params = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            total_zeros += torch.sum(param == 0).item()
            total_params += param.numel()
    return total_zeros / total_params

In [15]:
import torch
import time
def measure_inference_time(model, test_loader, num_runs=5):
    device = torch.device('cpu')
    model.eval()
    model.to(device)

    # Warm-up (one batch to avoid startup cost)
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.to(device)
            _ = model(inputs)
            break

    total_time = 0
    total_images = 0

    with torch.no_grad():
        for _ in range(num_runs):
            for inputs, _ in test_loader:
                inputs = inputs.to(device)
                batch_size = inputs.size(0)
                start_time = time.time()
                _ = model(inputs)
                end_time = time.time()

                total_time += (end_time - start_time)
                total_images += batch_size

    avg_time_per_image = total_time / total_images
    return avg_time_per_image


In [16]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def calculate_model_size(model, filename="temp.pth"):
    torch.save(model.state_dict(), filename)
    size = os.path.getsize(filename) / (1024 * 1024)  # Size in MB
    os.remove(filename)
    return size

def compare_model_sizes(teacher, student, pruned_student):
    # Count parameters
    teacher_params = count_parameters(teacher)
    student_params = count_parameters(student)
    pruned_params = count_parameters(pruned_student)
    
    # Calculate disk size
    teacher_size = calculate_model_size(teacher, "teacher.pth")
    student_size = calculate_model_size(student, "student.pth")
    pruned_size = calculate_model_size(pruned_student, "pruned_student.pth")
    
    # Print comparison
    print("\n--- Model Size Comparison ---")
    print(f"Teacher Model: {teacher_params} parameters, {teacher_size:.2f} MB")
    print(f"Student Model (Before Pruning): {student_params} parameters, {student_size:.2f} MB")
    print(f"Student Model (After Pruning): {pruned_params} parameters, {pruned_size:.2f} MB")
    
    # Calculate compression ratio
    compression_ratio = student_size / pruned_size
    print(f"\nCompression Ratio: {compression_ratio:.2f}x")

In [17]:
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, patience=3):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    
    best_val_accuracy = 0.0
    best_model_state = None
    patience_counter = 0  # Counter for early stopping
    
    for epoch in range(epochs):
        print(epoch)
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Evaluate on the validation set
        val_accuracy = evaluate(model, val_loader, device)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {running_loss/len(train_loader):.4f} | Val Accuracy: {val_accuracy:.2f}%")
        
        # Early stopping logic
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_state = model.state_dict()
            patience_counter = 0  # Reset patience counter
            torch.save(model.state_dict(), 'best_teacher_model.pth')  # Save the best model
            print(f" New best model saved with validation accuracy: {best_val_accuracy:.2f}%")
        else:
            patience_counter += 1
            print(f" No improvement in validation accuracy ({patience_counter}/{patience})")
            
            # Stop training if no improvement for 'patience' epochs
            if patience_counter >= patience:
                print(f"\nEarly stopping triggered! No improvement for {patience} epochs.")
                break
    
    # Load the best model state
    model.load_state_dict(torch.load('best_teacher_model.pth'))
    print("\nLoading the best model for final evaluation.")
    
    # Evaluate on the test set
    test_accuracy = evaluate(model, test_loader, device)
    print(f"Test Accuracy with Best Model: {test_accuracy:.2f}%")
    
    return model



In [None]:

tch = train_model(
    teacher, train_loader, val_loader,
    epochs=50, lr=0.001, patience=5
)


0
Epoch 1/50 | Loss: 2.1086 | Val Accuracy: 27.72%
 New best model saved with validation accuracy: 27.72%
1
Epoch 2/50 | Loss: 1.8309 | Val Accuracy: 31.25%
 New best model saved with validation accuracy: 31.25%
2
Epoch 3/50 | Loss: 1.6877 | Val Accuracy: 37.42%
 New best model saved with validation accuracy: 37.42%
3
Epoch 4/50 | Loss: 1.5755 | Val Accuracy: 43.93%
 New best model saved with validation accuracy: 43.93%
4
Epoch 5/50 | Loss: 1.5001 | Val Accuracy: 45.49%
 New best model saved with validation accuracy: 45.49%
5
Epoch 6/50 | Loss: 1.4215 | Val Accuracy: 48.98%
 New best model saved with validation accuracy: 48.98%
6
Epoch 7/50 | Loss: 1.3491 | Val Accuracy: 50.04%
 New best model saved with validation accuracy: 50.04%
7
Epoch 8/50 | Loss: 1.2896 | Val Accuracy: 51.79%
 New best model saved with validation accuracy: 51.79%
8
Epoch 9/50 | Loss: 1.2194 | Val Accuracy: 53.57%
 New best model saved with validation accuracy: 53.57%
9
Epoch 10/50 | Loss: 1.1558 | Val Accuracy: 5

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def _get_gt_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask

def _get_other_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask

def cat_mask(tensor, mask1, mask2):
    t1 = (tensor * mask1).sum(dim=1, keepdim=True)
    t2 = (tensor * mask2).sum(dim=1, keepdim=True)
    return torch.cat([t1, t2], dim=1)

class DKDloss(nn.Module):
    def __init__(self):
        super(DKDloss, self).__init__()

    def forward(self, logits_student, logits_teacher, target, alpha, beta, temperature):
        # Get masks for ground-truth and other classes
        gt_mask = _get_gt_mask(logits_student, target)
        other_mask = _get_other_mask(logits_student, target)
        
        # Compute softened probabilities
        pred_student = F.softmax(logits_student / temperature, dim=1)
        pred_teacher = F.softmax(logits_teacher / temperature, dim=1)

        

        # Two-class transformation using GT and OTHER
        pred_student = cat_mask(pred_student, gt_mask, other_mask)
        pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)

        # True class KD loss
        log_pred_student = torch.log(pred_student)
        tckd_loss = F.kl_div(
            log_pred_student, pred_teacher, reduction='sum'
        ) * (temperature ** 2) / target.shape[0]

        # Non-ground-truth KD loss (mask GT with large value)
        pred_teacher_part2 = F.softmax(
            logits_teacher / temperature - 1000.0 * gt_mask, dim=1
        )
        log_pred_student_part2 = F.log_softmax(
            logits_student / temperature - 1000.0 * gt_mask, dim=1
        )
        nckd_loss = F.kl_div(
            log_pred_student_part2, pred_teacher_part2, reduction='sum'
        ) * (temperature ** 2) / target.shape[0]

        # Weighted sum
        loss = alpha * tckd_loss + beta * nckd_loss
        return loss


In [19]:
import torch
import torch.nn.functional as F

def compute_gradient_importance(
    teacher, student, data_loader, device, temperature=4.0, alpha=0.5, beta_prob=0.5, accumulation_epochs=3
):

    # Initialize DKD loss
    dkd_loss_fn = DKDloss()

    importance_scores = {}
    for name, param in student.named_parameters():
        if 'weight' in name and len(param.shape) == 4:  # Conv weights only
            importance_scores[name] = torch.zeros_like(param.data, device=device)

    teacher.to(device).eval()
    student.to(device).train()

    momentum = 0.9
    accumulated_batches = 0

    for epoch in range(accumulation_epochs):
        print(f"Accumulation Epoch {epoch+1}/{accumulation_epochs}")
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            student.zero_grad()

            with torch.no_grad():
                teacher_logits = teacher(inputs)

            student_logits = student(inputs)

            # Compute DKD loss + optional CE
            distillation_loss = dkd_loss_fn(student_logits, teacher_logits, labels, alpha, beta_prob, temperature)
            ce_loss = F.cross_entropy(student_logits, labels)
            loss = alpha * distillation_loss + (1 - alpha) * ce_loss

            loss.backward()

            accumulated_batches += 1
            for name, param in student.named_parameters():
                if name in importance_scores and param.grad is not None:
                    grad_product = (param.data * param.grad).abs_()
                    if accumulated_batches == 1:
                        importance_scores[name] = grad_product
                    else:
                        importance_scores[name] = momentum * importance_scores[name] + (1 - momentum) * grad_product

    for name in importance_scores:
        importance_scores[name] /= (1 - momentum ** accumulated_batches)

    return importance_scores


In [20]:
def gradient_based_global_prune(model, importance_scores, prune_ratio=0.95):
    all_scores = torch.cat([score.flatten() for score in importance_scores.values()])
    threshold = torch.topk(all_scores, k=int(prune_ratio * all_scores.numel()), largest=False)[0][-1]

    for name, param in model.named_parameters():
        if name in importance_scores:
            mask = (importance_scores[name] > threshold).float()
            param.data.mul_(mask)

    return model


In [21]:
import torch
import torch.nn.functional as F
import torch.optim as optim

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

def retrain_with_sparsity(student, train_loader, val_loader, epochs=5, save_path="retrained_student_model.pt", patience=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)

    # 1. Store masks AND zero momentum buffers for pruned weights
    masks = {}
    for name, param in student.named_parameters():
        if 'weight' in name and param.dim() == 4:  # Consider only conv layers
            mask = (param != 0).float().to(device)
            masks[name] = mask
            # Zero momentum buffers for pruned weights
            if optimizer.state.get(param, None) and 'momentum_buffer' in optimizer.state[param]:
                optimizer.state[param]['momentum_buffer'] *= mask

    student = student.to(device)
    best_val_acc = 0.0
    best_model = None
    patience_counter = 0  # Counter for early stopping

    # 2. Add gradient clipping to prevent NaN
    max_grad_norm = 1.0

    for epoch in range(epochs):
        student.train()
        total_loss = 0.0
        correct, total = 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = student(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()

            # Apply masks to gradients
            for name, param in student.named_parameters():
                if name in masks:
                    param.grad.data *= masks[name]

            # Gradient clipping before optimizer step
            torch.nn.utils.clip_grad_norm_(student.parameters(), max_grad_norm)

            optimizer.step()

            # Reapply masks and update momentum buffers
            for name, param in student.named_parameters():
                if name in masks:
                    param.data *= masks[name]
                    if optimizer.state.get(param, None) and 'momentum_buffer' in optimizer.state[param]:
                        optimizer.state[param]['momentum_buffer'] *= masks[name]

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_loss = total_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # Validation phase
        student.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss = F.cross_entropy(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += labels.size(0)

        val_loss /= len(val_loader)
        val_acc = 100.0 * val_correct / val_total

        # Track best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = student.state_dict()
            torch.save(best_model, save_path)
            patience_counter = 0  # Reset patience counter
            print(f"New best model saved with Val Accuracy: {best_val_acc:.2f}%")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. No improvement for {patience} epochs.")
                break  # Stop training

        # Print results
        sparsity = calculate_sparsity(student)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Validation Loss: {val_loss:.4f} | Validation Acc: {val_acc:.2f}% | Sparsity: {sparsity*100:.2f}%\n")

    print(f"Best Validation Accuracy: {best_val_acc:.2f}% | Best Model Saved at: {save_path}")
    return student

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time

# KD training with DKD loss and mask-based momentum handling
def retrain_with_KD(teacher, student, train_loader, val_loader, epochs=50,
                    temperature=5.0, alpha=0.5, beta_prob=0.5, patience=5,
                    save_path="student_before_pruning.pth"):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)

    # Initialize DKD loss function
    dkd_loss_fn = DKDloss()

    # Store masks and zero momentum buffers
    masks = {}
    for name, param in student.named_parameters():
        if 'weight' in name and param.dim() == 4:
            mask = (param != 0).float().to(device)
            masks[name] = mask
            if optimizer.state.get(param, None) and 'momentum_buffer' in optimizer.state[param]:
                optimizer.state[param]['momentum_buffer'] *= mask

    teacher = teacher.to(device).eval()
    student = student.to(device)

    best_val_acc = 0.0
    best_model_state = None
    patience_counter = 0
    start_time = time.time()

    for epoch in range(epochs):
        student.train()
        total_loss, correct, total = 0.0, 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            with torch.no_grad():
                teacher_logits = teacher(inputs)

            student_logits = student(inputs)

            # DKD loss
            kd_loss = dkd_loss_fn(student_logits, teacher_logits, labels, alpha, beta_prob, temperature)
            ce_loss = F.cross_entropy(student_logits, labels)

            # Total loss (optional to keep CE blended with KD)
            loss = alpha * kd_loss + (1 - alpha) * ce_loss

            loss.backward()
            optimizer.step()

            # Reapply masks and update momentum
            for name, param in student.named_parameters():
                if name in masks:
                    param.data *= masks[name]
                    if optimizer.state.get(param, None) and 'momentum_buffer' in optimizer.state[param]:
                        optimizer.state[param]['momentum_buffer'] *= masks[name]

            total_loss += loss.item()
            _, predicted = student_logits.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_loss = total_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # Validation
        student.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss = F.cross_entropy(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += labels.size(0)

        val_loss /= len(val_loader)
        val_acc = 100.0 * val_correct / val_total
        sparsity = calculate_sparsity(student) * 100.0  # Assuming this function is defined

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Sparsity: {sparsity:.2f}%")

        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = student.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. No improvement for {patience} epochs.")
                break

    # Save best model
    student.load_state_dict(best_model_state)
    torch.save(student.state_dict(), save_path)
    print(f"Student model saved before pruning at: {save_path}")
    total_time = time.time() - start_time
    print(f"Total Training Time: {total_time // 60:.0f}m {total_time % 60:.0f}s")

    return student


In [23]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Training function with KD + CA-KLD and logits normalization
def train_kd_pruning(teacher, student, train_loader, val_loader, epochs=50, temperature=5.0, alpha=0.5,
                     beta_prob=0.5, patience=5, save_path="student_before_pruning.pth"):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)

    teacher = teacher.to(device)
    student = student.to(device)
    teacher.eval()  # Freeze teacher

    best_val_acc = 0.0
    best_model_state = None
    patience_counter = 0
    start_time = time.time()

    for epoch in range(epochs):
        student.train()
        total_loss = 0.0
        correct, total = 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            with torch.no_grad():
                teacher_logits = teacher(inputs)

            student_logits = student(inputs)

            # Temperature scaling
            teacher_logits_temp = teacher_logits / temperature
            student_logits_temp = student_logits / temperature

            # Logits normalization
            teacher_logits_temp = normalize(teacher_logits_temp)
            student_logits_temp = normalize(student_logits_temp)

            # CA-KLD loss (normalized logits)
            distillation_loss = cakld_loss(student_logits_temp, teacher_logits_temp, beta_prob) * (temperature ** 2)

            # Cross-entropy loss
            ground_truth_loss = F.cross_entropy(student_logits, labels)

            # Combined loss
            loss = alpha * distillation_loss + (1 - alpha) * ground_truth_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = student_logits.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_loss = total_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # Validation accuracy
        val_acc = evaluate(student, val_loader, device)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = student.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. No improvement for {patience} epochs.")
                break

    # Load best model state and save
    student.load_state_dict(best_model_state)
    torch.save(student.state_dict(), save_path)
    print(f"Student model saved before pruning at: {save_path}")

    total_time = time.time() - start_time
    print(f"Total Training Time: {total_time // 60:.0f}m {total_time % 60:.0f}s")

    return student

In [21]:

student = train_kd_pruning(
    teacher, student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.5,beta_prob=0.5, patience=5,save_path="student_before_pruning.pth"
)


Epoch 1/50 | Train Loss: 2.6007 | Train Acc: 83.86% | Val Acc: 88.68%
Epoch 2/50 | Train Loss: 0.8082 | Train Acc: 94.44% | Val Acc: 92.16%
Epoch 3/50 | Train Loss: 0.4497 | Train Acc: 97.41% | Val Acc: 94.04%
Epoch 4/50 | Train Loss: 0.2881 | Train Acc: 98.71% | Val Acc: 94.84%
Epoch 5/50 | Train Loss: 0.2302 | Train Acc: 99.09% | Val Acc: 95.51%
Epoch 6/50 | Train Loss: 0.1968 | Train Acc: 99.22% | Val Acc: 95.55%
Epoch 7/50 | Train Loss: 0.1769 | Train Acc: 99.28% | Val Acc: 95.57%
Epoch 8/50 | Train Loss: 0.1595 | Train Acc: 99.32% | Val Acc: 95.84%
Epoch 9/50 | Train Loss: 0.1477 | Train Acc: 99.35% | Val Acc: 96.00%
Epoch 10/50 | Train Loss: 0.1395 | Train Acc: 99.33% | Val Acc: 95.79%
Epoch 11/50 | Train Loss: 0.1308 | Train Acc: 99.36% | Val Acc: 96.12%
Epoch 12/50 | Train Loss: 0.1276 | Train Acc: 99.30% | Val Acc: 95.96%
Epoch 13/50 | Train Loss: 0.1185 | Train Acc: 99.39% | Val Acc: 96.05%
Epoch 14/50 | Train Loss: 0.1156 | Train Acc: 99.32% | Val Acc: 96.18%
Epoch 15/50 | T

In [24]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

Sparsity Before Pruning: 0.00%
Teacher Model Test Accuracy: 95.41%
Student Model Test Accuracy Before Pruning: 95.92%


## 36% Sparsity

In [25]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [26]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=3.0, alpha=0.7,beta_prob=0.5, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.3608)
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 12s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [27]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=3.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 0.3257 | Train Acc: 98.62% | Val Loss: 0.0408 | Val Acc: 98.66% | Sparsity: 36.05%
Epoch 2/50 | Train Loss: 0.2050 | Train Acc: 99.22% | Val Loss: 0.0391 | Val Acc: 98.75% | Sparsity: 36.05%
Epoch 3/50 | Train Loss: 0.1685 | Train Acc: 99.36% | Val Loss: 0.0393 | Val Acc: 98.70% | Sparsity: 36.05%
Epoch 4/50 | Train Loss: 0.1478 | Train Acc: 99.38% | Val Loss: 0.0410 | Val Acc: 98.64% | Sparsity: 36.05%
Epoch 5/50 | Train Loss: 0.1356 | Train Acc: 99.42% | Val Loss: 0.0378 | Val Acc: 98.78% | Sparsity: 36.05%
Epoch 6/50 | Train Loss: 0.1186 | Train Acc: 99.47% | Val Loss: 0.0377 | Val Acc: 98.73% | Sparsity: 36.05%
Epoch 7/50 | Train Loss: 0.1135 | Train Acc: 99.49% | Val Loss: 0.0408 | Val Acc: 98.62% | Sparsity: 36.05%
Epoch 8/50 | Train Loss: 0.1037 | Train Acc: 99.49% | Val Loss: 0.0390 | Val Acc: 98.75% | Sparsity: 36.05%
Epoch 9/50 | Train Loss: 0.0996 | Train Acc: 99.49% | Val Loss: 0.0388 | Val Acc: 98.71% | Sparsity: 36.05%
Epoch 10/50 | Train Loss: 0.

In [28]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Pruned Student Model Test Accuracy(After Retrain): {student_accuracy:.2f}%")

Pruned Student Model Test Accuracy(After Retrain): 95.92%


In [28]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Pruned Student Model Test Accuracy(After Retrain): {student_accuracy:.2f}%")

Pruned Student Model Test Accuracy(After Retrain): 95.88%


In [29]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [30]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.7,beta_prob=0.5, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.3608)
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 13s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [31]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 0.3343 | Train Acc: 98.32% | Val Loss: 0.0463 | Val Acc: 98.57% | Sparsity: 36.05%
Epoch 2/50 | Train Loss: 0.2191 | Train Acc: 99.07% | Val Loss: 0.0410 | Val Acc: 98.69% | Sparsity: 36.05%
Epoch 3/50 | Train Loss: 0.1704 | Train Acc: 99.10% | Val Loss: 0.0371 | Val Acc: 98.85% | Sparsity: 36.05%
Epoch 4/50 | Train Loss: 0.1479 | Train Acc: 99.17% | Val Loss: 0.0353 | Val Acc: 98.91% | Sparsity: 36.05%
Epoch 5/50 | Train Loss: 0.1356 | Train Acc: 99.12% | Val Loss: 0.0339 | Val Acc: 98.91% | Sparsity: 36.05%
Epoch 6/50 | Train Loss: 0.1266 | Train Acc: 99.21% | Val Loss: 0.0336 | Val Acc: 98.92% | Sparsity: 36.05%
Epoch 7/50 | Train Loss: 0.1181 | Train Acc: 99.21% | Val Loss: 0.0334 | Val Acc: 98.88% | Sparsity: 36.05%
Epoch 8/50 | Train Loss: 0.1157 | Train Acc: 99.12% | Val Loss: 0.0346 | Val Acc: 98.88% | Sparsity: 36.05%
Epoch 9/50 | Train Loss: 0.1079 | Train Acc: 99.19% | Val Loss: 0.0339 | Val Acc: 98.92% | Sparsity: 36.05%
Epoch 10/50 | Train Loss: 0.

In [32]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Pruned Student Model Test Accuracy(After Retrain): {student_accuracy:.2f}%")

Pruned Student Model Test Accuracy(After Retrain): 95.99%


# 59% of Sparsity

In [33]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [34]:
student_accuracy = evaluate(student, test_loader, device)
print(f"Student Model Test Accuracy: {student_accuracy:.2f}%")

Student Model Test Accuracy: 95.92%


In [35]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.7, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.5909)

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 13s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [36]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.7, beta_prob=0.5,patience=7,save_path="pruned_student_retrain_KD_59%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 0.3888 | Train Acc: 98.21% | Val Loss: 0.0616 | Val Acc: 98.08% | Sparsity: 59.04%
Epoch 2/50 | Train Loss: 0.2468 | Train Acc: 98.93% | Val Loss: 0.0458 | Val Acc: 98.59% | Sparsity: 59.04%
Epoch 3/50 | Train Loss: 0.1878 | Train Acc: 99.10% | Val Loss: 0.0359 | Val Acc: 98.88% | Sparsity: 59.04%
Epoch 4/50 | Train Loss: 0.1631 | Train Acc: 99.11% | Val Loss: 0.0359 | Val Acc: 98.84% | Sparsity: 59.04%
Epoch 5/50 | Train Loss: 0.1499 | Train Acc: 99.16% | Val Loss: 0.0380 | Val Acc: 98.79% | Sparsity: 59.04%
Epoch 6/50 | Train Loss: 0.1387 | Train Acc: 99.15% | Val Loss: 0.0358 | Val Acc: 98.86% | Sparsity: 59.04%
Epoch 7/50 | Train Loss: 0.1306 | Train Acc: 99.19% | Val Loss: 0.0385 | Val Acc: 98.83% | Sparsity: 59.04%
Epoch 8/50 | Train Loss: 0.1266 | Train Acc: 99.15% | Val Loss: 0.0346 | Val Acc: 98.92% | Sparsity: 59.04%
Epoch 9/50 | Train Loss: 0.1209 | Train Acc: 99.12% | Val Loss: 0.0363 | Val Acc: 98.77% | Sparsity: 59.04%
Epoch 10/50 | Train Loss: 0.

In [37]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Pruned Student Model Test Accuracy(Retrain with KD): {student_accuracy:.2f}%")

Pruned Student Model Test Accuracy(Retrain with KD): 95.74%


In [40]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [41]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=3.0, alpha=0.7, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.5909)

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 13s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [42]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=3.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_59%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 0.1116 | Train Acc: 98.67% | Val Loss: 0.0344 | Val Acc: 98.82% | Sparsity: 59.04%
Epoch 2/50 | Train Loss: 0.0722 | Train Acc: 99.25% | Val Loss: 0.0335 | Val Acc: 98.86% | Sparsity: 59.04%
Epoch 3/50 | Train Loss: 0.0621 | Train Acc: 99.37% | Val Loss: 0.0328 | Val Acc: 98.91% | Sparsity: 59.04%
Epoch 4/50 | Train Loss: 0.0568 | Train Acc: 99.38% | Val Loss: 0.0329 | Val Acc: 98.84% | Sparsity: 59.04%
Epoch 5/50 | Train Loss: 0.0531 | Train Acc: 99.42% | Val Loss: 0.0315 | Val Acc: 98.86% | Sparsity: 59.04%
Epoch 6/50 | Train Loss: 0.0512 | Train Acc: 99.44% | Val Loss: 0.0332 | Val Acc: 98.86% | Sparsity: 59.04%
Epoch 7/50 | Train Loss: 0.0493 | Train Acc: 99.40% | Val Loss: 0.0319 | Val Acc: 98.87% | Sparsity: 59.04%
Epoch 8/50 | Train Loss: 0.0472 | Train Acc: 99.41% | Val Loss: 0.0328 | Val Acc: 98.84% | Sparsity: 59.04%
Early stopping triggered at epoch 8. No improvement for 5 epochs.
Student model saved before pruning at: pruned_student_retrain_KD_59%.p

In [43]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Pruned Student Model Test Accuracy(Retrain with KD): {student_accuracy:.2f}%")

Pruned Student Model Test Accuracy(Retrain with KD): 95.86%


# 79% Sparsity

In [50]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [51]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.7, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.7907)

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 13s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [52]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 0.6026 | Train Acc: 97.14% | Val Loss: 0.1608 | Val Acc: 95.19% | Sparsity: 79.00%
Epoch 2/50 | Train Loss: 0.4047 | Train Acc: 98.20% | Val Loss: 0.0835 | Val Acc: 97.56% | Sparsity: 79.00%
Epoch 3/50 | Train Loss: 0.2879 | Train Acc: 98.81% | Val Loss: 0.0574 | Val Acc: 98.22% | Sparsity: 79.00%
Epoch 4/50 | Train Loss: 0.2247 | Train Acc: 99.06% | Val Loss: 0.0547 | Val Acc: 98.17% | Sparsity: 79.00%
Epoch 5/50 | Train Loss: 0.2027 | Train Acc: 99.11% | Val Loss: 0.0532 | Val Acc: 98.29% | Sparsity: 79.00%
Epoch 6/50 | Train Loss: 0.1794 | Train Acc: 99.12% | Val Loss: 0.0483 | Val Acc: 98.47% | Sparsity: 79.00%
Epoch 7/50 | Train Loss: 0.1676 | Train Acc: 99.16% | Val Loss: 0.0471 | Val Acc: 98.57% | Sparsity: 79.00%
Epoch 8/50 | Train Loss: 0.1598 | Train Acc: 99.14% | Val Loss: 0.0526 | Val Acc: 98.41% | Sparsity: 79.00%
Epoch 9/50 | Train Loss: 0.1503 | Train Acc: 99.13% | Val Loss: 0.0490 | Val Acc: 98.39% | Sparsity: 79.00%
Epoch 10/50 | Train Loss: 0.

In [53]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Pruned Student Model Test Accuracy(Retrain with KD): {student_accuracy:.2f}%")

Pruned Student Model Test Accuracy(Retrain with KD): 95.43%


In [54]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [55]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=3.0, alpha=0.7, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.7907)

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 14s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [56]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=3.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 0.1642 | Train Acc: 98.21% | Val Loss: 0.0512 | Val Acc: 98.47% | Sparsity: 79.00%
Epoch 2/50 | Train Loss: 0.1008 | Train Acc: 99.06% | Val Loss: 0.0448 | Val Acc: 98.56% | Sparsity: 79.00%
Epoch 3/50 | Train Loss: 0.0844 | Train Acc: 99.25% | Val Loss: 0.0434 | Val Acc: 98.66% | Sparsity: 79.00%
Epoch 4/50 | Train Loss: 0.0736 | Train Acc: 99.34% | Val Loss: 0.0422 | Val Acc: 98.60% | Sparsity: 79.00%
Epoch 5/50 | Train Loss: 0.0694 | Train Acc: 99.34% | Val Loss: 0.0406 | Val Acc: 98.59% | Sparsity: 79.00%
Epoch 6/50 | Train Loss: 0.0656 | Train Acc: 99.33% | Val Loss: 0.0411 | Val Acc: 98.59% | Sparsity: 79.00%
Epoch 7/50 | Train Loss: 0.0617 | Train Acc: 99.42% | Val Loss: 0.0411 | Val Acc: 98.63% | Sparsity: 79.00%
Epoch 8/50 | Train Loss: 0.0597 | Train Acc: 99.42% | Val Loss: 0.0410 | Val Acc: 98.65% | Sparsity: 79.00%
Early stopping triggered at epoch 8. No improvement for 5 epochs.
Student model saved before pruning at: pruned_student_retrain_KD_90%.p

In [57]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Pruned Student Model Test Accuracy(Retrain with KD): {student_accuracy:.2f}%")

Pruned Student Model Test Accuracy(Retrain with KD): 95.51%


# 90% of Sparsity

In [68]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [69]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.7, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.9008)

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 14s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [70]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 1.3333 | Train Acc: 92.86% | Val Loss: 0.2295 | Val Acc: 93.52% | Sparsity: 90.00%
Epoch 2/50 | Train Loss: 0.6924 | Train Acc: 96.58% | Val Loss: 0.1507 | Val Acc: 95.26% | Sparsity: 90.00%
Epoch 3/50 | Train Loss: 0.4653 | Train Acc: 97.99% | Val Loss: 0.1247 | Val Acc: 96.37% | Sparsity: 90.00%
Epoch 4/50 | Train Loss: 0.3618 | Train Acc: 98.66% | Val Loss: 0.1529 | Val Acc: 95.47% | Sparsity: 90.00%
Epoch 5/50 | Train Loss: 0.2994 | Train Acc: 98.92% | Val Loss: 0.0997 | Val Acc: 96.76% | Sparsity: 90.00%
Epoch 6/50 | Train Loss: 0.2598 | Train Acc: 98.99% | Val Loss: 0.0973 | Val Acc: 96.98% | Sparsity: 90.00%
Epoch 7/50 | Train Loss: 0.2344 | Train Acc: 99.10% | Val Loss: 0.0897 | Val Acc: 97.05% | Sparsity: 90.00%
Epoch 8/50 | Train Loss: 0.2190 | Train Acc: 99.10% | Val Loss: 0.0905 | Val Acc: 97.17% | Sparsity: 90.00%
Epoch 9/50 | Train Loss: 0.2068 | Train Acc: 99.14% | Val Loss: 0.0867 | Val Acc: 97.26% | Sparsity: 90.00%
Epoch 10/50 | Train Loss: 0.

In [71]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 94.83%


In [72]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [73]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=3.0, alpha=0.7, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.9008)

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")



Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 13s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [74]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=3.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 0.3225 | Train Acc: 95.89% | Val Loss: 0.1079 | Val Acc: 96.82% | Sparsity: 90.00%
Epoch 2/50 | Train Loss: 0.1664 | Train Acc: 98.38% | Val Loss: 0.0850 | Val Acc: 97.34% | Sparsity: 90.00%
Epoch 3/50 | Train Loss: 0.1255 | Train Acc: 98.95% | Val Loss: 0.0721 | Val Acc: 97.82% | Sparsity: 90.00%
Epoch 4/50 | Train Loss: 0.1072 | Train Acc: 99.17% | Val Loss: 0.0723 | Val Acc: 97.80% | Sparsity: 90.00%
Epoch 5/50 | Train Loss: 0.0966 | Train Acc: 99.25% | Val Loss: 0.0694 | Val Acc: 97.91% | Sparsity: 90.00%
Epoch 6/50 | Train Loss: 0.0908 | Train Acc: 99.24% | Val Loss: 0.0664 | Val Acc: 97.99% | Sparsity: 90.00%
Epoch 7/50 | Train Loss: 0.0858 | Train Acc: 99.33% | Val Loss: 0.0679 | Val Acc: 98.04% | Sparsity: 90.00%
Epoch 8/50 | Train Loss: 0.0811 | Train Acc: 99.28% | Val Loss: 0.0665 | Val Acc: 97.95% | Sparsity: 90.00%
Epoch 9/50 | Train Loss: 0.0772 | Train Acc: 99.33% | Val Loss: 0.0645 | Val Acc: 97.96% | Sparsity: 90.00%
Epoch 10/50 | Train Loss: 0.

In [75]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 94.62%


## 95% Sparsity

In [None]:
sparsity = calculate_sparsity(pruned_student)
print(f"Sparsity After Pruning: {sparsity * 100:.2f}%")

In [22]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [23]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.7, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.95096)

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")



Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 13s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [24]:
sparsity = calculate_sparsity(pruned_student)
print(f"Sparsity After Pruning: {sparsity * 100:.2f}%")

Sparsity After Pruning: 95.01%


In [25]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 2.2595 | Train Acc: 87.98% | Val Loss: 0.3583 | Val Acc: 90.33% | Sparsity: 95.01%
Epoch 2/50 | Train Loss: 1.1815 | Train Acc: 93.55% | Val Loss: 0.2888 | Val Acc: 92.66% | Sparsity: 95.01%
Epoch 3/50 | Train Loss: 0.8662 | Train Acc: 95.48% | Val Loss: 0.2892 | Val Acc: 92.74% | Sparsity: 95.01%
Epoch 4/50 | Train Loss: 0.6655 | Train Acc: 96.81% | Val Loss: 0.2717 | Val Acc: 93.29% | Sparsity: 95.01%
Epoch 5/50 | Train Loss: 0.5171 | Train Acc: 97.92% | Val Loss: 0.2075 | Val Acc: 94.24% | Sparsity: 95.01%
Epoch 6/50 | Train Loss: 0.4235 | Train Acc: 98.37% | Val Loss: 0.1786 | Val Acc: 95.17% | Sparsity: 95.01%
Epoch 7/50 | Train Loss: 0.3598 | Train Acc: 98.69% | Val Loss: 0.1761 | Val Acc: 95.11% | Sparsity: 95.01%
Epoch 8/50 | Train Loss: 0.3150 | Train Acc: 98.91% | Val Loss: 0.1653 | Val Acc: 95.26% | Sparsity: 95.01%
Epoch 9/50 | Train Loss: 0.2926 | Train Acc: 98.99% | Val Loss: 0.1698 | Val Acc: 95.35% | Sparsity: 95.01%
Epoch 10/50 | Train Loss: 0.

In [26]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 94.34%


In [27]:

model_path = '/kaggle/input/studentc10/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [28]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=3.0, alpha=0.7, accumulation_epochs=3
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = gradient_based_global_prune(student, importance_scores, prune_ratio=0.95096)

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")



Calculating Important Scores
Accumulation Epoch 1/3
Accumulation Epoch 2/3
Accumulation Epoch 3/3
Total Time take to calculate Important scores: 6m 12s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [29]:

start_time = time.time()
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=3.0, alpha=0.7, beta_prob=0.5,patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Retraining completed in {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")

Epoch 1/50 | Train Loss: 0.6771 | Train Acc: 90.86% | Val Loss: 0.1898 | Val Acc: 94.28% | Sparsity: 95.01%
Epoch 2/50 | Train Loss: 0.3212 | Train Acc: 95.88% | Val Loss: 0.1843 | Val Acc: 94.60% | Sparsity: 95.01%
Epoch 3/50 | Train Loss: 0.2287 | Train Acc: 97.63% | Val Loss: 0.1554 | Val Acc: 95.30% | Sparsity: 95.01%
Epoch 4/50 | Train Loss: 0.1770 | Train Acc: 98.34% | Val Loss: 0.1448 | Val Acc: 95.48% | Sparsity: 95.01%
Epoch 5/50 | Train Loss: 0.1452 | Train Acc: 98.78% | Val Loss: 0.1452 | Val Acc: 95.65% | Sparsity: 95.01%
Epoch 6/50 | Train Loss: 0.1301 | Train Acc: 99.06% | Val Loss: 0.1384 | Val Acc: 95.84% | Sparsity: 95.01%
Epoch 7/50 | Train Loss: 0.1184 | Train Acc: 99.13% | Val Loss: 0.1346 | Val Acc: 95.88% | Sparsity: 95.01%
Epoch 8/50 | Train Loss: 0.1118 | Train Acc: 99.17% | Val Loss: 0.1289 | Val Acc: 96.04% | Sparsity: 95.01%
Epoch 9/50 | Train Loss: 0.1041 | Train Acc: 99.21% | Val Loss: 0.1264 | Val Acc: 95.95% | Sparsity: 95.01%
Epoch 10/50 | Train Loss: 0.

In [30]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 93.86%
