In [5]:
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import torchvision.models as models
import time

In [6]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Data augmentation for training
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.RandomCrop(32, padding=4),  # Randomly crop the image
    transforms.Resize(224),  # Resize to 224x224 for ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# No augmentation for validation and test
val_test_transform = transforms.Compose([
    transforms.Resize(224),  # Resize to 224x224 for ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
train_val_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_test_transform)

# Split train_val_dataset into train and validation sets (80% train, 20% validation)
train_size = int(0.8 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size
train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])

# Apply val_test_transform to the validation set
val_dataset.dataset.transform = val_test_transform

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 32.6MB/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [8]:
# Load pretrained ResNet-50 (Teacher Model)
teacher = models.resnet50(pretrained=True)

# Modify the final fully connected layer for 10 classes (CIFAR-10)
teacher.fc = nn.Linear(teacher.fc.in_features, 10)
# Move models to device
teacher = teacher.to(device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 187MB/s]


In [9]:

model_path = '/kaggle/input/best_teacher_res50/pytorch/default/1/Best_Teacher.pth'
# Load the model weights
teacher.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  teacher.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [10]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(pretrained=True)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 189MB/s]


In [11]:
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total
    

In [12]:
def calculate_sparsity(model):
    total_zeros = 0
    total_params = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            total_zeros += torch.sum(param == 0).item()
            total_params += param.numel()
    return total_zeros / total_params

In [13]:
def measure_inference_time(model, test_loader, device, num_runs=5):
    model.eval()
    model.to(device)
    
    # Warm-up (to avoid initial overhead)
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.to(device)
            _ = model(inputs)
            break  # Only one batch for warm-up
    
    # Measure inference time
    total_time = 0
    with torch.no_grad():
        for _ in range(num_runs):
            for inputs, _ in test_loader:
                inputs = inputs.to(device)
                start_time = time.time()  # Start timer
                _ = model(inputs)
                end_time = time.time()  # End timer
                total_time += (end_time - start_time)
    
    # Average inference time per batch
    avg_time_per_batch = total_time / (num_runs * len(test_loader))
    return avg_time_per_batch

In [14]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def calculate_model_size(model, filename="temp.pth"):
    torch.save(model.state_dict(), filename)
    size = os.path.getsize(filename) / (1024 * 1024)  # Size in MB
    os.remove(filename)
    return size

def compare_model_sizes(teacher, student, pruned_student):
    # Count parameters
    teacher_params = count_parameters(teacher)
    student_params = count_parameters(student)
    pruned_params = count_parameters(pruned_student)
    
    # Calculate disk size
    teacher_size = calculate_model_size(teacher, "teacher.pth")
    student_size = calculate_model_size(student, "student.pth")
    pruned_size = calculate_model_size(pruned_student, "pruned_student.pth")
    
    # Print comparison
    print("\n--- Model Size Comparison ---")
    print(f"Teacher Model: {teacher_params} parameters, {teacher_size:.2f} MB")
    print(f"Student Model (Before Pruning): {student_params} parameters, {student_size:.2f} MB")
    print(f"Student Model (After Pruning): {pruned_params} parameters, {pruned_size:.2f} MB")
    
    # Calculate compression ratio
    compression_ratio = student_size / pruned_size
    print(f"\nCompression Ratio: {compression_ratio:.2f}x")

In [15]:
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, patience=3):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    
    best_val_accuracy = 0.0
    best_model_state = None
    patience_counter = 0  # Counter for early stopping
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Evaluate on the validation set
        val_accuracy = evaluate(model, val_loader, device)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {running_loss/len(train_loader):.4f} | Val Accuracy: {val_accuracy:.2f}%")
        
        # Early stopping logic
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_state = model.state_dict()
            patience_counter = 0  # Reset patience counter
            torch.save(model.state_dict(), 'best_teacher_model.pth')  # Save the best model
            print(f" New best model saved with validation accuracy: {best_val_accuracy:.2f}%")
        else:
            patience_counter += 1
            print(f" No improvement in validation accuracy ({patience_counter}/{patience})")
            
            # Stop training if no improvement for 'patience' epochs
            if patience_counter >= patience:
                print(f"\nEarly stopping triggered! No improvement for {patience} epochs.")
                break
    
    # Load the best model state
    model.load_state_dict(torch.load('best_teacher_model.pth'))
    print("\nLoading the best model for final evaluation.")
    
    # Evaluate on the test set
    test_accuracy = evaluate(model, test_loader, device)
    print(f"Test Accuracy with Best Model: {test_accuracy:.2f}%")
    
    return model



In [None]:
# Fine-tune the teacher model
teacher = train_model(teacher, train_loader, val_loader, epochs=200, lr=0.001, patience=5)

In [16]:

def compute_gradient_importance(teacher, student, data_loader, device, temperature=4.0, alpha=0.5):
    importance_scores = {}
    
    # Initialize importance storage for all Conv2d weights
    for name, param in student.named_parameters():
        if 'weight' in name and isinstance(param, nn.Parameter) and len(param.shape) == 4:
            importance_scores[name] = torch.zeros_like(param.data, device=device)
    
    teacher.to(device)
    student.to(device)
    
    for inputs, labels in data_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        student.zero_grad()
        
        with torch.no_grad():
            # Apply temperature scaling to teacher logits
            teacher_logits = teacher(inputs)
            teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
            
        # Apply temperature scaling to student logits
        student_logits = student(inputs)
        student_probs = F.softmax(student_logits / temperature, dim=1)
        
        # Compute KL divergence loss (distillation loss)
        kl_loss = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=1),
            teacher_probs,
            reduction='batchmean'
        ) * (temperature ** 2)  # Scale by temperature squared
        
        # Compute log-likelihood loss (ground truth alignment)
        log_likelihood_loss = F.cross_entropy(student_logits, labels)
        
        # Combine losses
        loss = alpha * kl_loss + (1 - alpha) * log_likelihood_loss
        loss.backward()
        
        # Accumulate absolute gradients for weights
        for name, param in student.named_parameters():
            if name in importance_scores and param.grad is not None:
                importance_scores[name] += param.grad.abs().detach()
    
    # Average across batches
    for name in importance_scores:
        importance_scores[name] /= len(data_loader)
        
    return importance_scores

In [17]:
def gradient_based_unstructured_prune(model, importance_scores, prune_ratio=0.3):
    """
    True unstructured pruning using per-weight importance scores
    """
    for name, param in model.named_parameters():
        if name in importance_scores:
            scores = importance_scores[name]
            n_prune = int(prune_ratio * scores.numel())
            
            if n_prune > 0:
                # Flatten and find threshold
                flat_scores = scores.flatten()
                threshold = torch.topk(flat_scores, k=n_prune, largest=False)[0][-1]
                
                # Create mask and apply pruning
                mask = (scores > threshold).float()
                param.data.mul_(mask)
    
    return model

def prune_model(model, importance_scores, prune_ratio=0.3, pruning_type='structured'):
    """
    Prune the model using either structured or unstructured pruning.
    """
    if pruning_type == 'structured':
        return channel_prune(model, importance_scores, prune_ratio)
    elif pruning_type == 'unstructured':
        return gradient_based_unstructured_prune(model, importance_scores, prune_ratio)
    else:
        raise ValueError("Invalid pruning type. Choose 'structured' or 'unstructured'.")

In [18]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

def train_kd_pruning(teacher, student, train_loader, val_loader, epochs=50, pruning_type='unstructured', temperature=5.0, alpha=0.5, patience=5, save_path="student_before_pruning.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)
    
    teacher = teacher.to(device)
    student = student.to(device)
    
    best_val_acc = 0.0
    best_model_state = None
    patience_counter = 0
    start_time = time.time()
    
    # Training loop
    for epoch in range(epochs):
        student.train()
        total_loss = 0.0
        correct, total = 0, 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            with torch.no_grad():
                # Apply temperature scaling to teacher logits
                teacher_logits = teacher(inputs)
                teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
            
            # Apply temperature scaling to student logits
            student_logits = student(inputs)
            student_probs = F.log_softmax(student_logits / temperature, dim=1)
            
            # Compute distillation loss (KL divergence)
            distillation_loss = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
            
            # Compute ground truth loss (cross-entropy)
            ground_truth_loss = F.cross_entropy(student_logits, labels)
            
            # Combine losses using alpha
            loss = alpha * distillation_loss + (1 - alpha) * ground_truth_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            _, predicted = student_logits.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
        
        train_loss = total_loss / len(train_loader)
        train_acc = 100.0 * correct / total
        
        # Validation
        val_acc = evaluate(student, val_loader, device)
        
        # Print metrics
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Val Acc: {val_acc:.2f}%")
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = student.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. No improvement for {patience} epochs.")
                break
    
    # Load the best model state
    student.load_state_dict(best_model_state)
    
    # Save the student model before pruning
    torch.save(student.state_dict(), save_path)
    print(f"Student model saved before pruning at: {save_path}")
    # Print total training time
    total_time = time.time() - start_time
    print(f"Total Training Time: {total_time // 60:.0f}m {total_time % 60:.0f}s")
    return student

In [19]:
# # Train the student model using KD and pruning
# pruned_student_unstructured = train_kd_pruning(teacher, student, train_loader, test_loader,epochs=5, pruning_type='unstructured')
student = train_kd_pruning(
    teacher, student, train_loader, val_loader,
    epochs=50, pruning_type='unstructured', temperature=5.0, alpha=0.5, patience=5,save_path="student_before_pruning.pth"
)


Epoch 1/50 | Train Loss: 1.0725 | Train Acc: 89.02% | Val Acc: 94.53%
Epoch 2/50 | Train Loss: 0.3656 | Train Acc: 96.88% | Val Acc: 95.38%
Epoch 3/50 | Train Loss: 0.2222 | Train Acc: 98.80% | Val Acc: 96.00%
Epoch 4/50 | Train Loss: 0.1594 | Train Acc: 99.40% | Val Acc: 96.40%
Epoch 5/50 | Train Loss: 0.1279 | Train Acc: 99.61% | Val Acc: 96.42%
Epoch 6/50 | Train Loss: 0.1094 | Train Acc: 99.58% | Val Acc: 96.29%
Epoch 7/50 | Train Loss: 0.0987 | Train Acc: 99.61% | Val Acc: 96.42%
Epoch 8/50 | Train Loss: 0.0893 | Train Acc: 99.66% | Val Acc: 96.44%
Epoch 9/50 | Train Loss: 0.0832 | Train Acc: 99.65% | Val Acc: 96.61%
Epoch 10/50 | Train Loss: 0.0783 | Train Acc: 99.67% | Val Acc: 96.50%
Epoch 11/50 | Train Loss: 0.0743 | Train Acc: 99.66% | Val Acc: 96.51%
Epoch 12/50 | Train Loss: 0.0712 | Train Acc: 99.64% | Val Acc: 96.47%
Epoch 13/50 | Train Loss: 0.0678 | Train Acc: 99.67% | Val Acc: 96.54%
Epoch 14/50 | Train Loss: 0.0656 | Train Acc: 99.68% | Val Acc: 96.31%
Early stopping 

In [17]:
# Measure inference times
teacher_inference_time = measure_inference_time(teacher, test_loader, device)
student_inference_time = measure_inference_time(student, test_loader, device)
print(f"Teacher Model Inference Time: {teacher_inference_time * 1000:.2f} ms per batch")
print(f"Student Model Inference Time(Before Pruning): {student_inference_time * 1000:.2f} ms per batch")

Teacher Model Inference Time: 7.38 ms per batch
Student Model Inference Time(Before Pruning): 3.58 ms per batch


In [19]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

Sparsity Before Pruning: 0.00%
Teacher Model Test Accuracy: 95.43%
Student Model Test Accuracy Before Pruning: 95.67%


In [20]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = prune_model(student, importance_scores, prune_ratio=0.9,pruning_type='unstructured')
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 1m 56s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [22]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Student Model Test Accuracy After Pruning: {student_accuracy:.2f}%")
sparsity = calculate_sparsity(pruned_student)
print(f"Sparsity After Pruning: {sparsity * 100:.2f}%")

Student Model Test Accuracy After Pruning: 10.00%
Sparsity After Pruning: 89.92%


In [23]:
# Measure inference times
pruned_student_inference_time = measure_inference_time(pruned_student, test_loader, device)
print(f"Student Model Inference Time(After Pruning): {pruned_student_inference_time * 1000:.2f} ms per batch")

Student Model Inference Time(After Pruning): 3.50 ms per batch


In [24]:
torch.save(pruned_student.state_dict(), "pruned_student_unstructured_90%.pth")
print("Model saved as pruned_student_unstructured.pth")


Model saved as pruned_student_unstructured.pth


In [20]:
import torch
import torch.nn.functional as F
import torch.optim as optim

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

def retrain_with_sparsity(student, train_loader, val_loader, epochs=5, save_path="retrained_student_model.pt", patience=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)

    # 1. Store masks AND zero momentum buffers for pruned weights
    masks = {}
    for name, param in student.named_parameters():
        if 'weight' in name and param.dim() == 4:  # Consider only conv layers
            mask = (param != 0).float().to(device)
            masks[name] = mask
            # Zero momentum buffers for pruned weights
            if optimizer.state.get(param, None) and 'momentum_buffer' in optimizer.state[param]:
                optimizer.state[param]['momentum_buffer'] *= mask

    student = student.to(device)
    best_val_acc = 0.0
    best_model = None
    patience_counter = 0  # Counter for early stopping

    # 2. Add gradient clipping to prevent NaN
    max_grad_norm = 1.0

    for epoch in range(epochs):
        student.train()
        total_loss = 0.0
        correct, total = 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = student(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()

            # Apply masks to gradients
            for name, param in student.named_parameters():
                if name in masks:
                    param.grad.data *= masks[name]

            # Gradient clipping before optimizer step
            torch.nn.utils.clip_grad_norm_(student.parameters(), max_grad_norm)

            optimizer.step()

            # Reapply masks and update momentum buffers
            for name, param in student.named_parameters():
                if name in masks:
                    param.data *= masks[name]
                    if optimizer.state.get(param, None) and 'momentum_buffer' in optimizer.state[param]:
                        optimizer.state[param]['momentum_buffer'] *= masks[name]

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_loss = total_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # Validation phase
        student.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss = F.cross_entropy(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += labels.size(0)

        val_loss /= len(val_loader)
        val_acc = 100.0 * val_correct / val_total

        # Track best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = student.state_dict()
            torch.save(best_model, save_path)
            patience_counter = 0  # Reset patience counter
            print(f"New best model saved with Val Accuracy: {best_val_acc:.2f}%")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. No improvement for {patience} epochs.")
                break  # Stop training

        # Print results
        sparsity = calculate_sparsity(student)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Validation Loss: {val_loss:.4f} | Validation Acc: {val_acc:.2f}% | Sparsity: {sparsity*100:.2f}%\n")

    print(f"Best Validation Accuracy: {best_val_acc:.2f}% | Best Model Saved at: {save_path}")
    return student

# 90% of Sparsity

In [27]:
student = retrain_with_sparsity(
    pruned_student, train_loader, val_loader,
    epochs=200,  save_path='/kaggle/working/retrained_student_model.pt',patience=20
)

New best model saved with Val Accuracy: 40.53%
Epoch 1/200 | Train Loss: 1.8917 | Train Acc: 27.86%
Validation Loss: 1.7769 | Validation Acc: 40.53% | Sparsity: 89.92%

New best model saved with Val Accuracy: 55.63%
Epoch 2/200 | Train Loss: 1.3447 | Train Acc: 49.77%
Validation Loss: 1.2166 | Validation Acc: 55.63% | Sparsity: 89.92%

New best model saved with Val Accuracy: 59.49%
Epoch 3/200 | Train Loss: 1.0961 | Train Acc: 60.29%
Validation Loss: 1.1460 | Validation Acc: 59.49% | Sparsity: 89.92%

New best model saved with Val Accuracy: 64.64%
Epoch 4/200 | Train Loss: 0.9631 | Train Acc: 65.70%
Validation Loss: 0.9741 | Validation Acc: 64.64% | Sparsity: 89.92%

New best model saved with Val Accuracy: 67.98%
Epoch 5/200 | Train Loss: 0.8731 | Train Acc: 69.03%
Validation Loss: 0.8931 | Validation Acc: 67.98% | Sparsity: 89.92%

Epoch 6/200 | Train Loss: 0.7990 | Train Acc: 71.61%
Validation Loss: 0.9894 | Validation Acc: 65.41% | Sparsity: 89.92%

New best model saved with Val Acc

In [29]:
student_accuracy = evaluate(student, test_loader, device)
print(f"Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

Pruned Student Model Test Accuracy: 78.14%


In [31]:
sparsity = calculate_sparsity(student)
print(f"Sparsity: {sparsity * 100:.2f}%")

Sparsity: 89.92%


In [33]:

pruned_student_inference_time = measure_inference_time(student, test_loader, device)
print(f"Pruned Student Model Inference Time: {pruned_student_inference_time * 1000:.2f} ms per batch")


Pruned Student Model Inference Time: 3.59 ms per batch


In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time

def retrain_with_KD(teacher, student, train_loader, val_loader, epochs=50, temperature=5.0, alpha=0.5, patience=5, save_path="student_before_pruning.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)
     # 1. Store masks AND zero momentum buffers for pruned weights
    masks = {}
    for name, param in student.named_parameters():
        if 'weight' in name and param.dim() == 4:  # Consider only conv layers
            mask = (param != 0).float().to(device)
            masks[name] = mask
            # Zero momentum buffers for pruned weights
            if optimizer.state.get(param, None) and 'momentum_buffer' in optimizer.state[param]:
                optimizer.state[param]['momentum_buffer'] *= mask
    
    teacher = teacher.to(device)
    student = student.to(device)
    
    best_val_acc = 0.0
    best_val_loss = float("inf")
    best_model_state = None
    patience_counter = 0
    start_time = time.time()
    
    # Training loop
    for epoch in range(epochs):
        student.train()
        total_loss, correct, total = 0.0, 0, 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            with torch.no_grad():
                # Apply temperature scaling to teacher logits
                teacher_logits = teacher(inputs)
                teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
            
            # Apply temperature scaling to student logits
            student_logits = student(inputs)
            student_probs = F.log_softmax(student_logits / temperature, dim=1)
            
            # Compute distillation loss (KL divergence)
            distillation_loss = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
            
            # Compute ground truth loss (cross-entropy)
            ground_truth_loss = F.cross_entropy(student_logits, labels)
            
            # Combine losses using alpha
            loss = alpha * distillation_loss + (1 - alpha) * ground_truth_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Reapply masks and update momentum buffers
            for name, param in student.named_parameters():
                if name in masks:
                    param.data *= masks[name]
                    if optimizer.state.get(param, None) and 'momentum_buffer' in optimizer.state[param]:
                        optimizer.state[param]['momentum_buffer'] *= masks[name]
            
            total_loss += loss.item()
            _, predicted = student_logits.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
        
        train_loss = total_loss / len(train_loader)
        train_acc = 100.0 * correct / total
        
        # Validation
        student.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss = F.cross_entropy(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += labels.size(0)
        
        val_loss /= len(val_loader)
        val_acc = 100.0 * val_correct / val_total
        sparsity = calculate_sparsity(student)*100.0  # Function assumed to be defined
        
        # Print metrics
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Sparsity: {sparsity:.2f}%")
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_loss = val_loss
            best_model_state = student.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. No improvement for {patience} epochs.")
                break
    
    # Load the best model state
    student.load_state_dict(best_model_state)
    
    # Save the student model before pruning
    torch.save(student.state_dict(), save_path)
    print(f"Student model saved before pruning at: {save_path}")
    
    # Print total training time
    total_time = time.time() - start_time
    print(f"Total Training Time: {total_time // 60:.0f}m {total_time % 60:.0f}s")
    
    return student


# Retrained with KD

In [22]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(pretrained=True)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

In [23]:

model_path = '/kaggle/working/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [29]:
student_accuracy = evaluate(student, test_loader, device)
print(f"Student Model Test Accuracy: {student_accuracy:.2f}%")

Student Model Test Accuracy: 95.57%


In [24]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()

pruned_student = prune_model(student, importance_scores, prune_ratio=0.9,pruning_type='unstructured')

total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 2m 3s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [25]:
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.5, patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)

Epoch 1/50 | Train Loss: 6.3299 | Train Acc: 26.78% | Val Loss: 3.1997 | Val Acc: 28.57% | Sparsity: 89.92%
Epoch 2/50 | Train Loss: 3.9530 | Train Acc: 55.78% | Val Loss: 1.4134 | Val Acc: 58.39% | Sparsity: 89.92%
Epoch 3/50 | Train Loss: 2.7632 | Train Acc: 69.99% | Val Loss: 1.0868 | Val Acc: 69.74% | Sparsity: 89.92%
Epoch 4/50 | Train Loss: 2.0934 | Train Acc: 77.49% | Val Loss: 1.0235 | Val Acc: 73.26% | Sparsity: 89.92%
Epoch 5/50 | Train Loss: 1.6975 | Train Acc: 81.78% | Val Loss: 0.7131 | Val Acc: 79.18% | Sparsity: 89.92%
Epoch 6/50 | Train Loss: 1.4217 | Train Acc: 84.96% | Val Loss: 0.6641 | Val Acc: 80.76% | Sparsity: 89.92%
Epoch 7/50 | Train Loss: 1.2235 | Train Acc: 87.29% | Val Loss: 0.6359 | Val Acc: 82.20% | Sparsity: 89.92%
Epoch 8/50 | Train Loss: 1.0331 | Train Acc: 89.29% | Val Loss: 0.6695 | Val Acc: 82.39% | Sparsity: 89.92%
Epoch 9/50 | Train Loss: 0.8772 | Train Acc: 91.58% | Val Loss: 0.5704 | Val Acc: 83.87% | Sparsity: 89.92%
Epoch 10/50 | Train Loss: 0.

In [26]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 86.54%


# 80% of Sparsity

In [40]:

model_path = '/kaggle/input/best_student_resnet18/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [41]:
# Measure inference times
teacher_inference_time = measure_inference_time(teacher, test_loader, device)
student_inference_time = measure_inference_time(student, test_loader, device)
print(f"Teacher Model Inference Time: {teacher_inference_time * 1000:.2f} ms per batch")
print(f"Student Model Inference Time(Before Pruning): {student_inference_time * 1000:.2f} ms per batch")

Teacher Model Inference Time: 0.06 ms per batch
Student Model Inference Time(Before Pruning): 0.03 ms per batch


In [42]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

Sparsity Before Pruning: 0.00%
Teacher Model Test Accuracy: 95.43%
Student Model Test Accuracy Before Pruning: 95.67%


In [43]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = prune_model(student, importance_scores, prune_ratio=0.81,pruning_type='unstructured')
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 1m 57s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [44]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Student Model Test Accuracy After Pruning: {student_accuracy:.2f}%")
sparsity = calculate_sparsity(pruned_student)
print(f"Sparsity After Pruning: {sparsity * 100:.2f}%")

Student Model Test Accuracy After Pruning: 10.00%
Sparsity After Pruning: 80.93%


In [45]:
# Measure inference times
pruned_student_inference_time = measure_inference_time(pruned_student, test_loader, device)
print(f"Student Model Inference Time(After Pruning): {pruned_student_inference_time * 1000:.2f} ms per batch")

Student Model Inference Time(After Pruning): 0.03 ms per batch


In [46]:
retrained_student = retrain_with_sparsity(
    pruned_student, train_loader, val_loader,
    epochs=200,  save_path='/kaggle/working/retrained_student_model.pt',patience=20
)

New best model saved with Val Accuracy: 33.91%
Epoch 1/200 | Train Loss: 1.8738 | Train Acc: 28.42%
Validation Loss: 1.7155 | Validation Acc: 33.91% | Sparsity: 80.93%

New best model saved with Val Accuracy: 56.33%
Epoch 2/200 | Train Loss: 1.3000 | Train Acc: 51.73%
Validation Loss: 1.2148 | Validation Acc: 56.33% | Sparsity: 80.93%

New best model saved with Val Accuracy: 58.58%
Epoch 3/200 | Train Loss: 1.0094 | Train Acc: 63.38%
Validation Loss: 1.2865 | Validation Acc: 58.58% | Sparsity: 80.93%

New best model saved with Val Accuracy: 69.31%
Epoch 4/200 | Train Loss: 0.8325 | Train Acc: 70.33%
Validation Loss: 0.8858 | Validation Acc: 69.31% | Sparsity: 80.93%

New best model saved with Val Accuracy: 71.70%
Epoch 5/200 | Train Loss: 0.7150 | Train Acc: 74.51%
Validation Loss: 0.8166 | Validation Acc: 71.70% | Sparsity: 80.93%

New best model saved with Val Accuracy: 74.47%
Epoch 6/200 | Train Loss: 0.6213 | Train Acc: 78.24%
Validation Loss: 0.7517 | Validation Acc: 74.47% | Spar

In [47]:
student_accuracy = evaluate(retrained_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 84.97%


In [49]:
# Measure inference times
pruned_student_inference_time = measure_inference_time(retrained_student, test_loader, device)
print(f" Retrained pruned Student Model Inference Time: {pruned_student_inference_time * 1000:.2f} ms per batch")

 Retrained pruned Student Model Inference Time: 3.60 ms per batch


In [50]:
torch.save(retrained_student.state_dict(), "pruned_retrained_student_unstructured_80%.pth")
print("Model saved as pruned_student_unstructured.pth")


Model saved as pruned_student_unstructured.pth


# Retrain with KD

In [35]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(pretrained=True)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

In [36]:

model_path = '/kaggle/working/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [37]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = prune_model(student, importance_scores, prune_ratio=0.81,pruning_type='unstructured')
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 2m 3s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [38]:
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.5, patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)

Epoch 1/50 | Train Loss: 6.1550 | Train Acc: 28.46% | Val Loss: 1.6843 | Val Acc: 47.54% | Sparsity: 80.93%
Epoch 2/50 | Train Loss: 3.0885 | Train Acc: 66.45% | Val Loss: 1.0557 | Val Acc: 71.67% | Sparsity: 80.93%
Epoch 3/50 | Train Loss: 1.6390 | Train Acc: 82.29% | Val Loss: 0.6846 | Val Acc: 80.70% | Sparsity: 80.93%
Epoch 4/50 | Train Loss: 1.1037 | Train Acc: 88.38% | Val Loss: 0.5155 | Val Acc: 85.17% | Sparsity: 80.93%
Epoch 5/50 | Train Loss: 0.7999 | Train Acc: 92.11% | Val Loss: 0.4193 | Val Acc: 87.88% | Sparsity: 80.93%
Epoch 6/50 | Train Loss: 0.5911 | Train Acc: 94.84% | Val Loss: 0.4706 | Val Acc: 87.34% | Sparsity: 80.93%
Epoch 7/50 | Train Loss: 0.4457 | Train Acc: 96.77% | Val Loss: 0.4302 | Val Acc: 88.77% | Sparsity: 80.93%
Epoch 8/50 | Train Loss: 0.3355 | Train Acc: 98.09% | Val Loss: 0.3366 | Val Acc: 90.30% | Sparsity: 80.93%
Epoch 9/50 | Train Loss: 0.2710 | Train Acc: 98.86% | Val Loss: 0.3163 | Val Acc: 90.94% | Sparsity: 80.93%
Epoch 10/50 | Train Loss: 0.

In [39]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 91.02%


# 70% Sparsity

In [17]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

In [18]:

model_path = '/kaggle/input/best_student_resnet18/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [19]:
# Measure inference times
teacher_inference_time = measure_inference_time(teacher, test_loader, device)
student_inference_time = measure_inference_time(student, test_loader, device)
print(f"Teacher Model Inference Time: {teacher_inference_time * 1000:.2f} ms per batch")
print(f"Student Model Inference Time(Before Pruning): {student_inference_time * 1000:.2f} ms per batch")

Teacher Model Inference Time: 7.26 ms per batch
Student Model Inference Time(Before Pruning): 3.59 ms per batch


In [20]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

Sparsity Before Pruning: 0.00%
Teacher Model Test Accuracy: 95.41%
Student Model Test Accuracy Before Pruning: 95.67%


In [22]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = prune_model(student, importance_scores, prune_ratio=0.707,pruning_type='unstructured')
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 1m 57s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [23]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Student Model Test Accuracy After Pruning: {student_accuracy:.2f}%")
sparsity = calculate_sparsity(pruned_student)
print(f"Sparsity After Pruning: {sparsity * 100:.2f}%")

Student Model Test Accuracy After Pruning: 10.00%
Sparsity After Pruning: 70.64%


In [24]:
# Measure inference times
pruned_student_inference_time = measure_inference_time(pruned_student, test_loader, device)
print(f"Student Model Inference Time(After Pruning): {pruned_student_inference_time * 1000:.2f} ms per batch")

Student Model Inference Time(After Pruning): 3.59 ms per batch


In [27]:
retrained_student = retrain_with_sparsity(
    pruned_student, train_loader, val_loader,
    epochs=200,  save_path='/kaggle/working/retrained_student_model.pt',patience=20
)

New best model saved with Val Accuracy: 49.47%
Epoch 1/200 | Train Loss: 1.7996 | Train Acc: 33.07%
Validation Loss: 1.3904 | Validation Acc: 49.47% | Sparsity: 70.64%

New best model saved with Val Accuracy: 84.40%
Epoch 2/200 | Train Loss: 0.7987 | Train Acc: 71.81%
Validation Loss: 0.4583 | Validation Acc: 84.40% | Sparsity: 70.64%

New best model saved with Val Accuracy: 89.49%
Epoch 3/200 | Train Loss: 0.3213 | Train Acc: 88.86%
Validation Loss: 0.3136 | Validation Acc: 89.49% | Sparsity: 70.64%

New best model saved with Val Accuracy: 90.77%
Epoch 4/200 | Train Loss: 0.1862 | Train Acc: 93.76%
Validation Loss: 0.2806 | Validation Acc: 90.77% | Sparsity: 70.64%

New best model saved with Val Accuracy: 90.94%
Epoch 5/200 | Train Loss: 0.1086 | Train Acc: 96.47%
Validation Loss: 0.2844 | Validation Acc: 90.94% | Sparsity: 70.64%

New best model saved with Val Accuracy: 90.99%
Epoch 6/200 | Train Loss: 0.0692 | Train Acc: 97.80%
Validation Loss: 0.3126 | Validation Acc: 90.99% | Spar

In [31]:
student_accuracy = evaluate(retrained_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 92.82%


In [32]:
# Measure inference times
pruned_student_inference_time = measure_inference_time(retrained_student, test_loader, device)
print(f" Retrained pruned Student Model Inference Time: {pruned_student_inference_time * 1000:.2f} ms per batch")

 Retrained pruned Student Model Inference Time: 3.55 ms per batch


In [33]:
torch.save(retrained_student.state_dict(), "pruned_retrained_student_unstructured_80%.pth")
print("Model saved as pruned_student_unstructured.pth")


Model saved as pruned_student_unstructured.pth


# Retrain with KD

In [40]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(pretrained=True)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

In [41]:

model_path = '/kaggle/working/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [42]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = prune_model(student, importance_scores, prune_ratio=0.7,pruning_type='unstructured')
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 2m 3s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [43]:
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.5, patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)

Epoch 1/50 | Train Loss: 4.3070 | Train Acc: 50.84% | Val Loss: 0.4773 | Val Acc: 86.18% | Sparsity: 69.94%
Epoch 2/50 | Train Loss: 0.8262 | Train Acc: 91.53% | Val Loss: 0.2946 | Val Acc: 90.96% | Sparsity: 69.94%
Epoch 3/50 | Train Loss: 0.4553 | Train Acc: 96.08% | Val Loss: 0.2525 | Val Acc: 92.12% | Sparsity: 69.94%
Epoch 4/50 | Train Loss: 0.3031 | Train Acc: 98.11% | Val Loss: 0.2116 | Val Acc: 93.69% | Sparsity: 69.94%
Epoch 5/50 | Train Loss: 0.2198 | Train Acc: 99.04% | Val Loss: 0.1828 | Val Acc: 94.35% | Sparsity: 69.94%
Epoch 6/50 | Train Loss: 0.1791 | Train Acc: 99.34% | Val Loss: 0.1847 | Val Acc: 94.25% | Sparsity: 69.94%
Epoch 7/50 | Train Loss: 0.1540 | Train Acc: 99.43% | Val Loss: 0.1893 | Val Acc: 94.27% | Sparsity: 69.94%
Epoch 8/50 | Train Loss: 0.1371 | Train Acc: 99.50% | Val Loss: 0.1770 | Val Acc: 94.39% | Sparsity: 69.94%
Epoch 9/50 | Train Loss: 0.1254 | Train Acc: 99.55% | Val Loss: 0.1723 | Val Acc: 94.53% | Sparsity: 69.94%
Epoch 10/50 | Train Loss: 0.

In [44]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f" Retrained Pruned Studen|t Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Studen|t Model Test Accuracy: 94.23%


# 50% Sparsity

In [27]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(pretrained=True)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

In [29]:

model_path = '/kaggle/working/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [16]:
# Measure inference times
teacher_inference_time = measure_inference_time(teacher, test_loader, device)
student_inference_time = measure_inference_time(student, test_loader, device)
print(f"Teacher Model Inference Time: {teacher_inference_time * 1000:.2f} ms per batch")
print(f"Student Model Inference Time(Before Pruning): {student_inference_time * 1000:.2f} ms per batch")

Teacher Model Inference Time: 7.42 ms per batch
Student Model Inference Time(Before Pruning): 3.67 ms per batch


In [17]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

Sparsity Before Pruning: 0.00%
Teacher Model Test Accuracy: 95.41%
Student Model Test Accuracy Before Pruning: 95.67%


In [18]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

Sparsity Before Pruning: 0.00%
Teacher Model Test Accuracy: 95.41%
Student Model Test Accuracy Before Pruning: 95.67%


In [30]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = prune_model(student, importance_scores, prune_ratio=0.505,pruning_type='unstructured')
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 2m 3s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [20]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Student Model Test Accuracy After Pruning: {student_accuracy:.2f}%")
sparsity = calculate_sparsity(pruned_student)
print(f"Sparsity After Pruning: {sparsity * 100:.2f}%")

Student Model Test Accuracy After Pruning: 10.00%
Sparsity After Pruning: 50.46%


In [21]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

Sparsity Before Pruning: 50.46%
Teacher Model Test Accuracy: 95.41%
Student Model Test Accuracy Before Pruning: 10.00%


In [23]:
# Measure inference times
pruned_student_inference_time = measure_inference_time(pruned_student, test_loader, device)
print(f"Student Model Inference Time(After Pruning): {pruned_student_inference_time * 1000:.2f} ms per batch")

Student Model Inference Time(After Pruning): 3.70 ms per batch


In [31]:
retrained_student = retrain_with_sparsity(
    pruned_student, train_loader, val_loader,
    epochs=200,  save_path='/kaggle/working/retrained_student_model_50%.pt',patience=10
)

New best model saved with Val Accuracy: 94.11%
Epoch 1/200 | Train Loss: 0.3921 | Train Acc: 86.39%
Validation Loss: 0.1804 | Validation Acc: 94.11% | Sparsity: 50.46%

Epoch 2/200 | Train Loss: 0.0599 | Train Acc: 98.13%
Validation Loss: 0.1886 | Validation Acc: 94.07% | Sparsity: 50.46%

Epoch 3/200 | Train Loss: 0.0241 | Train Acc: 99.31%
Validation Loss: 0.2069 | Validation Acc: 94.00% | Sparsity: 50.46%

New best model saved with Val Accuracy: 94.35%
Epoch 4/200 | Train Loss: 0.0117 | Train Acc: 99.72%
Validation Loss: 0.2074 | Validation Acc: 94.35% | Sparsity: 50.46%

New best model saved with Val Accuracy: 94.49%
Epoch 5/200 | Train Loss: 0.0063 | Train Acc: 99.86%
Validation Loss: 0.2083 | Validation Acc: 94.49% | Sparsity: 50.46%

Epoch 6/200 | Train Loss: 0.0043 | Train Acc: 99.91%
Validation Loss: 0.2430 | Validation Acc: 94.21% | Sparsity: 50.46%

New best model saved with Val Accuracy: 94.82%
Epoch 7/200 | Train Loss: 0.0025 | Train Acc: 99.96%
Validation Loss: 0.2095 | V

In [32]:
student_accuracy = evaluate(retrained_student, test_loader, device)
print(f" Retrained Pruned Studen|t Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Studen|t Model Test Accuracy: 94.90%


# Retrain with KD

In [45]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(pretrained=True)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

In [47]:

model_path = '/kaggle/working/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [50]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = prune_model(student, importance_scores, prune_ratio=0.505,pruning_type='unstructured')
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 2m 3s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [51]:
pruned_student = retrain_with_KD(
    teacher, pruned_student, train_loader, val_loader,
    epochs=50, temperature=5.0, alpha=0.5, patience=5,save_path="pruned_student_retrain_KD_90%.pth"
)

Epoch 1/50 | Train Loss: 0.7917 | Train Acc: 91.99% | Val Loss: 0.1938 | Val Acc: 94.12% | Sparsity: 50.46%
Epoch 2/50 | Train Loss: 0.2538 | Train Acc: 98.33% | Val Loss: 0.1604 | Val Acc: 94.54% | Sparsity: 50.46%
Epoch 3/50 | Train Loss: 0.1655 | Train Acc: 99.33% | Val Loss: 0.1389 | Val Acc: 95.47% | Sparsity: 50.46%
Epoch 4/50 | Train Loss: 0.1290 | Train Acc: 99.51% | Val Loss: 0.1383 | Val Acc: 95.28% | Sparsity: 50.46%
Epoch 5/50 | Train Loss: 0.1090 | Train Acc: 99.53% | Val Loss: 0.1360 | Val Acc: 95.49% | Sparsity: 50.46%
Epoch 6/50 | Train Loss: 0.0967 | Train Acc: 99.61% | Val Loss: 0.1358 | Val Acc: 95.51% | Sparsity: 50.46%
Epoch 7/50 | Train Loss: 0.0890 | Train Acc: 99.64% | Val Loss: 0.1376 | Val Acc: 95.44% | Sparsity: 50.46%
Epoch 8/50 | Train Loss: 0.0839 | Train Acc: 99.61% | Val Loss: 0.1305 | Val Acc: 95.74% | Sparsity: 50.46%
Epoch 9/50 | Train Loss: 0.0787 | Train Acc: 99.64% | Val Loss: 0.1319 | Val Acc: 95.56% | Sparsity: 50.46%
Epoch 10/50 | Train Loss: 0.

In [52]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f" Retrained Pruned Studen|t Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Studen|t Model Test Accuracy: 95.28%


# Retrain with 20% Training Data

In [16]:
# Load pretrained ResNet-18 (Student Model)
student = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# Modify the final fully connected layer for 10 classes (CIFAR-10)
student.fc = nn.Linear(student.fc.in_features, 10)
student = student.to(device)

In [17]:

model_path = '/kaggle/input/best_student_resnet18/pytorch/default/1/student_before_pruning.pth'
# Load the model weights
student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

  student.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


<All keys matched successfully>

In [None]:
# Measure inference times
teacher_inference_time = measure_inference_time(teacher, test_loader, device)
student_inference_time = measure_inference_time(student, test_loader, device)
print(f"Teacher Model Inference Time: {teacher_inference_time * 1000:.2f} ms per batch")
print(f"Student Model Inference Time(Before Pruning): {student_inference_time * 1000:.2f} ms per batch")

In [None]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

In [None]:
# Calculate sparsity
sparsity = calculate_sparsity(student)
print(f"Sparsity Before Pruning: {sparsity * 100:.2f}%")

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

In [18]:
# Pruning
print("Calculating Important Scores")
start_time = time.time()
importance_scores = compute_gradient_importance(
    teacher, student, train_loader, device, temperature=5.0, alpha=0.5
)
total_time = time.time() - start_time
print(f"Total Time take to calculate Important scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")

print("Pruning the model")
start_time = time.time()
pruned_student = prune_model(student, importance_scores, prune_ratio=0.9,pruning_type='unstructured')
total_time = time.time() - start_time
print(f"Total Time take to prune the model scores: {total_time // 60:.0f}m {total_time % 60:.0f}s")
student = student.to(device)


Calculating Important Scores
Total Time take to calculate Important scores: 0m 32s
Pruning the model
Total Time take to prune the model scores: 0m 0s


In [19]:
student_accuracy = evaluate(pruned_student, test_loader, device)
print(f"Student Model Test Accuracy After Pruning: {student_accuracy:.2f}%")
sparsity = calculate_sparsity(pruned_student)
print(f"Sparsity After Pruning: {sparsity * 100:.2f}%")

Student Model Test Accuracy After Pruning: 10.00%
Sparsity After Pruning: 89.92%


In [None]:

teacher_accuracy = evaluate(teacher, test_loader, device)
student_accuracy = evaluate(student, test_loader, device)
print(f"Teacher Model Test Accuracy: {teacher_accuracy:.2f}%")
print(f"Student Model Test Accuracy Before Pruning: {student_accuracy:.2f}%")

In [None]:
# Measure inference times
pruned_student_inference_time = measure_inference_time(pruned_student, test_loader, device)
print(f"Student Model Inference Time(After Pruning): {pruned_student_inference_time * 1000:.2f} ms per batch")

In [20]:
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms

# Data augmentation for training
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.RandomCrop(32, padding=4),  # Randomly crop the image
    transforms.Resize(224),  # Resize to 224x224 for ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# No augmentation for validation and test
val_test_transform = transforms.Compose([
    transforms.Resize(224),  # Resize to 224x224 for ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
train_val_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_test_transform)

# Split train_val_dataset into train (20%) and validation (20% of original)
train_size = int(0.2 * len(train_val_dataset))  # 20% of 50,000 = 10,000 samples
remaining_size = len(train_val_dataset) - train_size

# Use random_split to preserve class distribution
train_dataset, remaining_data = random_split(train_val_dataset, [train_size, remaining_size])

# Split remaining_data into validation (20% of original)
val_size = int(0.2 * len(train_val_dataset))  # 20% of 50,000 = 10,000 samples
_, val_dataset = random_split(remaining_data, [remaining_size - val_size, val_size])

# Apply val_test_transform to the validation set
val_dataset.dataset.transform = val_test_transform

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Print dataset sizes
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Files already downloaded and verified
Files already downloaded and verified
Train dataset size: 10000
Validation dataset size: 10000
Test dataset size: 10000


In [22]:
retrained_student = retrain_with_sparsity(
    pruned_student, train_loader, val_loader,
    epochs=200,  save_path='/kaggle/working/retrained_student_model_90%_20%_Train_Data.pt',patience=20
)

New best model saved with Val Accuracy: 19.90%
Epoch 1/200 | Train Loss: 2.1372 | Train Acc: 18.24%
Validation Loss: 2.1238 | Validation Acc: 19.90% | Sparsity: 89.92%

New best model saved with Val Accuracy: 21.76%
Epoch 2/200 | Train Loss: 1.9587 | Train Acc: 21.95%
Validation Loss: 1.9937 | Validation Acc: 21.76% | Sparsity: 89.92%

Epoch 3/200 | Train Loss: 1.9046 | Train Acc: 23.17%
Validation Loss: 2.1815 | Validation Acc: 20.42% | Sparsity: 89.92%

New best model saved with Val Accuracy: 27.65%
Epoch 4/200 | Train Loss: 1.8598 | Train Acc: 26.08%
Validation Loss: 1.8995 | Validation Acc: 27.65% | Sparsity: 89.92%

New best model saved with Val Accuracy: 30.07%
Epoch 5/200 | Train Loss: 1.7397 | Train Acc: 31.92%
Validation Loss: 1.8803 | Validation Acc: 30.07% | Sparsity: 89.92%

Epoch 6/200 | Train Loss: 1.6891 | Train Acc: 34.55%
Validation Loss: 2.2614 | Validation Acc: 25.26% | Sparsity: 89.92%

New best model saved with Val Accuracy: 32.22%
Epoch 7/200 | Train Loss: 1.6415 

In [24]:
student_accuracy = evaluate(retrained_student, test_loader, device)
print(f" Retrained Pruned Student Model Test Accuracy: {student_accuracy:.2f}%")

 Retrained Pruned Student Model Test Accuracy: 75.18%
