In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import copy

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [27]:
TEACHER_EPOCHS = 20       # Epochs for teacher model
STUDENT_EPOCHS = 30       # Epochs for student models
BATCH_SIZE = 128
LEARNING_RATE = 0.001
TEMPERATURE = 4.0
ALPHA = 0.7

In [28]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE,
                         shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=2)

print("Data loaded successfully.")

Files already downloaded and verified
Files already downloaded and verified
Data loaded successfully.


In [29]:
def train_model(model, trainloader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}")
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        progress_bar.set_postfix(loss=running_loss / len(progress_bar))


def evaluate_model(model, testloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total


def train_distilled_student(teacher, student, trainloader, optimizer, epoch, T, alpha):
    teacher.eval()
    student.train()
    running_loss = 0.0
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1} (Distill)")
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.no_grad():
            teacher_logits = teacher(inputs)

        student_logits = student(inputs)

        soft_loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1)
        ) * (T * T)

        hard_loss = F.cross_entropy(student_logits, labels)
        total_loss = (alpha * soft_loss) + ((1 - alpha) * hard_loss)

        total_loss.backward()
        optimizer.step()
        running_loss += total_loss.item()
        progress_bar.set_postfix(loss=running_loss / len(progress_bar))

In [30]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def measure_inference_speed(model, testloader):
    model.eval()
    total_time = 0.0
    num_images = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs = inputs.to(device)
            start = time.time()
            _ = model(inputs)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            total_time += (time.time() - start)
            num_images += labels.size(0)
    return total_time / num_images

In [31]:
class TeacherModel8Layer(nn.Module):
    """
    Custom 8-Layer CNN Model (Designed for thesis)
    - 8 Convolutional layers (with BatchNorm and ReLU)
    - MaxPooling layers to downsample features
    - Fully connected classifier
    """

    def __init__(self, num_classes=10):
        super(TeacherModel8Layer, self).__init__()

        # --------------------------
        # Feature extractor (8 Conv layers)
        # --------------------------

        # Conv Block 1
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        # Conv Block 2
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        # Conv Block 3
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        # Conv Block 4
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        # Conv Block 5
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(512)

        # Conv Block 6
        self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(512)

        # Conv Block 7
        self.conv7 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(256)

        # Conv Block 8
        self.conv8 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.bn8 = nn.BatchNorm2d(128)

        # Adaptive pooling to ensure consistent output size
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))

        # --------------------------
        # Fully connected layers (classifier)
        # --------------------------
        self.fc1 = nn.Linear(128 * 4 * 4, 1024)  # Fully connected layer 1
        self.fc2 = nn.Linear(1024, 256)          # Fully connected layer 2
        self.fc3 = nn.Linear(256, num_classes)   # Output layer

        # Dropout layer for regularization
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        # --------------------------
        # Feature extraction (8 Convolutional blocks)
        # --------------------------

        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        # Reduce the max poolings
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = self.conv5(x)
        x = self.bn5(x)
        x = F.relu(x)

        x = self.conv6(x)
        x = self.bn6(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = self.conv7(x)
        x = self.bn7(x)
        x = F.relu(x)

        x = self.conv8(x)
        x = self.bn8(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        # --------------------------
        # Flatten and classification
        # --------------------------
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

In [32]:
teacher_model = TeacherModel8Layer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=LEARNING_RATE)

best_teacher_acc = 0.0
best_teacher_state = None

for epoch in range(TEACHER_EPOCHS):
    train_model(teacher_model, trainloader, optimizer_teacher, criterion, epoch)
    acc = evaluate_model(teacher_model, testloader)
    print(f"Teacher Accuracy after Epoch {epoch+1}: {acc:.2f}%")
    if acc > best_teacher_acc:
        best_teacher_acc = acc
        best_teacher_state = copy.deepcopy(teacher_model.state_dict())

torch.save(best_teacher_state, "teacher_best.pth")
teacher_model.load_state_dict(best_teacher_state)
print(f"✅ Final Teacher Accuracy: {best_teacher_acc:.2f}%")

Epoch 1: 100%|██████████| 391/391 [00:24<00:00, 16.05it/s, loss=1.81]   


Teacher Accuracy after Epoch 1: 39.39%


Epoch 2: 100%|██████████| 391/391 [00:24<00:00, 15.80it/s, loss=1.41]   


Teacher Accuracy after Epoch 2: 53.05%


Epoch 3: 100%|██████████| 391/391 [00:23<00:00, 16.45it/s, loss=1.15]   


Teacher Accuracy after Epoch 3: 62.19%


Epoch 4: 100%|██████████| 391/391 [00:23<00:00, 16.83it/s, loss=0.982]  


Teacher Accuracy after Epoch 4: 65.96%


Epoch 5: 100%|██████████| 391/391 [00:24<00:00, 16.27it/s, loss=0.864]  


Teacher Accuracy after Epoch 5: 71.93%


Epoch 6: 100%|██████████| 391/391 [00:23<00:00, 16.53it/s, loss=0.788]  


Teacher Accuracy after Epoch 6: 75.33%


Epoch 7: 100%|██████████| 391/391 [00:23<00:00, 16.62it/s, loss=0.725]  


Teacher Accuracy after Epoch 7: 77.60%


Epoch 8: 100%|██████████| 391/391 [00:23<00:00, 16.51it/s, loss=0.67]   


Teacher Accuracy after Epoch 8: 77.14%


Epoch 9: 100%|██████████| 391/391 [00:23<00:00, 16.60it/s, loss=0.628]  


Teacher Accuracy after Epoch 9: 78.37%


Epoch 10: 100%|██████████| 391/391 [00:23<00:00, 16.44it/s, loss=0.591]  


Teacher Accuracy after Epoch 10: 80.10%


Epoch 11: 100%|██████████| 391/391 [00:23<00:00, 16.59it/s, loss=0.557]  


Teacher Accuracy after Epoch 11: 79.92%


Epoch 12: 100%|██████████| 391/391 [00:24<00:00, 16.11it/s, loss=0.54]   


Teacher Accuracy after Epoch 12: 81.39%


Epoch 13: 100%|██████████| 391/391 [00:23<00:00, 16.30it/s, loss=0.509]  


Teacher Accuracy after Epoch 13: 82.14%


Epoch 14: 100%|██████████| 391/391 [00:24<00:00, 15.67it/s, loss=0.475]  


Teacher Accuracy after Epoch 14: 83.08%


Epoch 15: 100%|██████████| 391/391 [00:23<00:00, 16.62it/s, loss=0.457]   


Teacher Accuracy after Epoch 15: 82.53%


Epoch 16: 100%|██████████| 391/391 [00:23<00:00, 16.40it/s, loss=0.442]  


Teacher Accuracy after Epoch 16: 82.86%


Epoch 17: 100%|██████████| 391/391 [00:23<00:00, 16.48it/s, loss=0.42]    


Teacher Accuracy after Epoch 17: 83.39%


Epoch 18: 100%|██████████| 391/391 [00:24<00:00, 16.07it/s, loss=0.396]   


Teacher Accuracy after Epoch 18: 84.53%


Epoch 19: 100%|██████████| 391/391 [00:23<00:00, 16.79it/s, loss=0.387]  


Teacher Accuracy after Epoch 19: 84.66%


Epoch 20: 100%|██████████| 391/391 [00:23<00:00, 16.62it/s, loss=0.365]   


Teacher Accuracy after Epoch 20: 83.10%
✅ Final Teacher Accuracy: 84.66%


In [33]:
class StudentModel8Layer(nn.Module):
    """
    Custom 8-layer CNN for Student Model
    - 8 Conv layers
    - ONLY 3 max-pools to avoid shrinking to 0 size
    - Lightweight for distillation
    """

    def __init__(self, num_classes=10):
        super(StudentModel8Layer, self).__init__()

        # ---- Block 1 ----
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(32)

        # ---- Block 2 ----
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn4   = nn.BatchNorm2d(64)

        # ---- Block 3 ----
        self.conv5 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn5   = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(128, 128, 3, padding=1)
        self.bn6   = nn.BatchNorm2d(128)

        # ---- Block 4 ----
        self.conv7 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn7   = nn.BatchNorm2d(256)
        self.conv8 = nn.Conv2d(256, 256, 3, padding=1)
        self.bn8   = nn.BatchNorm2d(256)

        # Safe, fixed-size output before FC
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))

        # Classifier
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)

        self.dropout = nn.Dropout(0.4)

    def forward(self, x):

        # Block 1
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)     # 32 → 16

        # Block 2
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2)     # 16 → 8

        # Block 3
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = F.max_pool2d(x, 2)     # 8 → 4

        # Block 4 (no pooling here — safe!)
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))

        # Final features
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

In [34]:
student_conventional = StudentModel8Layer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_student_conv = optim.Adam(student_conventional.parameters(), lr=LEARNING_RATE)

best_student_conv_acc = 0.0

for epoch in range(STUDENT_EPOCHS):
    train_model(student_conventional, trainloader, optimizer_student_conv, criterion, epoch)
    acc = evaluate_model(student_conventional, testloader)
    print(f"Conventional Student Accuracy after Epoch {epoch+1}: {acc:.2f}%")
    if acc > best_student_conv_acc:
        best_student_conv_acc = acc

torch.save(student_conventional.state_dict(), "student_conventional_best.pth")
print(f"✅ Final Conventional Student Accuracy: {best_student_conv_acc:.2f}%")

Epoch 1: 100%|██████████| 391/391 [00:23<00:00, 16.29it/s, loss=1.54]  


Conventional Student Accuracy after Epoch 1: 58.25%


Epoch 2: 100%|██████████| 391/391 [00:23<00:00, 16.37it/s, loss=1.09]   


Conventional Student Accuracy after Epoch 2: 64.97%


Epoch 3: 100%|██████████| 391/391 [00:23<00:00, 16.48it/s, loss=0.917]  


Conventional Student Accuracy after Epoch 3: 63.94%


Epoch 4: 100%|██████████| 391/391 [00:24<00:00, 16.14it/s, loss=0.796]  


Conventional Student Accuracy after Epoch 4: 73.40%


Epoch 5: 100%|██████████| 391/391 [00:24<00:00, 15.93it/s, loss=0.705]  


Conventional Student Accuracy after Epoch 5: 76.07%


Epoch 6: 100%|██████████| 391/391 [00:23<00:00, 16.47it/s, loss=0.638]  


Conventional Student Accuracy after Epoch 6: 76.64%


Epoch 7: 100%|██████████| 391/391 [00:23<00:00, 16.41it/s, loss=0.59]   


Conventional Student Accuracy after Epoch 7: 79.11%


Epoch 8: 100%|██████████| 391/391 [00:23<00:00, 16.66it/s, loss=0.549]  


Conventional Student Accuracy after Epoch 8: 80.07%


Epoch 9: 100%|██████████| 391/391 [00:23<00:00, 16.44it/s, loss=0.512]  


Conventional Student Accuracy after Epoch 9: 82.19%


Epoch 10: 100%|██████████| 391/391 [00:23<00:00, 16.40it/s, loss=0.481]  


Conventional Student Accuracy after Epoch 10: 82.11%


Epoch 11: 100%|██████████| 391/391 [00:23<00:00, 16.30it/s, loss=0.454]   


Conventional Student Accuracy after Epoch 11: 82.68%


Epoch 12: 100%|██████████| 391/391 [00:25<00:00, 15.04it/s, loss=0.435]  


Conventional Student Accuracy after Epoch 12: 84.52%


Epoch 13: 100%|██████████| 391/391 [00:24<00:00, 15.83it/s, loss=0.414]   


Conventional Student Accuracy after Epoch 13: 85.70%


Epoch 14: 100%|██████████| 391/391 [00:25<00:00, 15.54it/s, loss=0.389]   


Conventional Student Accuracy after Epoch 14: 84.51%


Epoch 15: 100%|██████████| 391/391 [00:22<00:00, 17.29it/s, loss=0.368]   


Conventional Student Accuracy after Epoch 15: 84.96%


Epoch 16: 100%|██████████| 391/391 [00:24<00:00, 15.94it/s, loss=0.354]   


Conventional Student Accuracy after Epoch 16: 85.34%


Epoch 17: 100%|██████████| 391/391 [00:23<00:00, 16.41it/s, loss=0.339]   


Conventional Student Accuracy after Epoch 17: 86.11%


Epoch 18: 100%|██████████| 391/391 [00:23<00:00, 16.35it/s, loss=0.322]   


Conventional Student Accuracy after Epoch 18: 85.09%


Epoch 19: 100%|██████████| 391/391 [00:23<00:00, 16.43it/s, loss=0.312]   


Conventional Student Accuracy after Epoch 19: 85.04%


Epoch 20: 100%|██████████| 391/391 [00:24<00:00, 16.09it/s, loss=0.302]  


Conventional Student Accuracy after Epoch 20: 86.81%


Epoch 21: 100%|██████████| 391/391 [00:24<00:00, 16.21it/s, loss=0.287]   


Conventional Student Accuracy after Epoch 21: 86.26%


Epoch 22: 100%|██████████| 391/391 [00:23<00:00, 16.64it/s, loss=0.282]  


Conventional Student Accuracy after Epoch 22: 86.94%


Epoch 23: 100%|██████████| 391/391 [00:24<00:00, 15.70it/s, loss=0.268]   


Conventional Student Accuracy after Epoch 23: 86.70%


Epoch 24: 100%|██████████| 391/391 [00:24<00:00, 16.05it/s, loss=0.26]    


Conventional Student Accuracy after Epoch 24: 87.24%


Epoch 25: 100%|██████████| 391/391 [00:23<00:00, 16.29it/s, loss=0.249]   


Conventional Student Accuracy after Epoch 25: 87.51%


Epoch 26: 100%|██████████| 391/391 [00:25<00:00, 15.62it/s, loss=0.239]   


Conventional Student Accuracy after Epoch 26: 87.38%


Epoch 27: 100%|██████████| 391/391 [00:24<00:00, 15.69it/s, loss=0.231]   


Conventional Student Accuracy after Epoch 27: 87.02%


Epoch 28: 100%|██████████| 391/391 [00:24<00:00, 16.15it/s, loss=0.222]   


Conventional Student Accuracy after Epoch 28: 87.65%


Epoch 29: 100%|██████████| 391/391 [00:23<00:00, 16.31it/s, loss=0.215]   


Conventional Student Accuracy after Epoch 29: 87.42%


Epoch 30: 100%|██████████| 391/391 [00:24<00:00, 16.28it/s, loss=0.21]    


Conventional Student Accuracy after Epoch 30: 87.55%
✅ Final Conventional Student Accuracy: 87.65%


In [35]:
class DeepCustomCNN(nn.Module):
    def __init__(self, num_classes=10, base_filters=32, dropout=0.3):
        super(DeepCustomCNN, self).__init__()

        # Block 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, base_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_filters, base_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout(dropout)
        )

        # Block 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(base_filters, base_filters * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters * 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_filters * 2, base_filters * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters * 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout(dropout)
        )

        # Block 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(base_filters * 2, base_filters * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters * 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_filters * 4, base_filters * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters * 4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout(dropout)
        )

        # Block 4 — deeper layer
        self.conv4 = nn.Sequential(
            nn.Conv2d(base_filters * 4, base_filters * 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters * 8),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_filters * 8, base_filters * 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters * 8),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout(dropout)
        )

        # Fully connected layers
        self.fc1 = nn.Linear(base_filters * 8 * 2 * 2, 256)  # assuming 32x32 input (CIFAR-like)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)
        self.dropout_fc = nn.Dropout(dropout)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout_fc(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [40]:
student_distilled = DeepCustomCNN(num_classes=10, dropout=0.3).to(device)
optimizer_student_distill = optim.Adam(student_distilled.parameters(), lr=LEARNING_RATE)

best_student_distill_acc = 0.0

for epoch in range(STUDENT_EPOCHS):
    train_distilled_student(
        teacher=teacher_model,
        student=student_distilled,
        trainloader=trainloader,
        optimizer=optimizer_student_distill,
        epoch=epoch,
        T=TEMPERATURE,
        alpha=ALPHA
    )
    acc = evaluate_model(student_distilled, testloader)
    print(f"Distilled Student Accuracy after Epoch {epoch+1}: {acc:.2f}%")
    if acc > best_student_distill_acc:
        best_student_distill_acc = acc

torch.save(student_distilled.state_dict(), "student_distilled_best.pth")
print(f"✅ Final Distilled Student Accuracy: {best_student_distill_acc:.2f}%")

Epoch 1 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.77it/s, loss=4.47]  


Distilled Student Accuracy after Epoch 1: 53.87%


Epoch 2 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.84it/s, loss=2.69]   


Distilled Student Accuracy after Epoch 2: 61.09%


Epoch 3 (Distill): 100%|██████████| 391/391 [00:23<00:00, 16.48it/s, loss=2.11]   


Distilled Student Accuracy after Epoch 3: 67.63%


Epoch 4 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.91it/s, loss=1.75]   


Distilled Student Accuracy after Epoch 4: 71.88%


Epoch 5 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.06it/s, loss=1.52]   


Distilled Student Accuracy after Epoch 5: 73.11%


Epoch 6 (Distill): 100%|██████████| 391/391 [00:23<00:00, 16.53it/s, loss=1.36]   


Distilled Student Accuracy after Epoch 6: 77.77%


Epoch 7 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.09it/s, loss=1.26]   


Distilled Student Accuracy after Epoch 7: 78.58%


Epoch 8 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.18it/s, loss=1.18]   


Distilled Student Accuracy after Epoch 8: 78.92%


Epoch 9 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.81it/s, loss=1.1]    


Distilled Student Accuracy after Epoch 9: 80.37%


Epoch 10 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.75it/s, loss=1.06]   


Distilled Student Accuracy after Epoch 10: 80.83%


Epoch 11 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.78it/s, loss=1.02]   


Distilled Student Accuracy after Epoch 11: 81.55%


Epoch 12 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.69it/s, loss=0.993]  


Distilled Student Accuracy after Epoch 12: 81.64%


Epoch 13 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.07it/s, loss=0.946]  


Distilled Student Accuracy after Epoch 13: 80.91%


Epoch 14 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.89it/s, loss=0.909]  


Distilled Student Accuracy after Epoch 14: 81.61%


Epoch 15 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.20it/s, loss=0.889]  


Distilled Student Accuracy after Epoch 15: 82.11%


Epoch 16 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.18it/s, loss=0.864]  


Distilled Student Accuracy after Epoch 16: 82.94%


Epoch 17 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.15it/s, loss=0.848]  


Distilled Student Accuracy after Epoch 17: 82.48%


Epoch 18 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.07it/s, loss=0.822]  


Distilled Student Accuracy after Epoch 18: 83.02%


Epoch 19 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.22it/s, loss=0.815]  


Distilled Student Accuracy after Epoch 19: 83.27%


Epoch 20 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.12it/s, loss=0.798]  


Distilled Student Accuracy after Epoch 20: 84.03%


Epoch 21 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.79it/s, loss=0.778]  


Distilled Student Accuracy after Epoch 21: 83.38%


Epoch 22 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.75it/s, loss=0.769]  


Distilled Student Accuracy after Epoch 22: 83.93%


Epoch 23 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.71it/s, loss=0.759]  


Distilled Student Accuracy after Epoch 23: 84.00%


Epoch 24 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.96it/s, loss=0.752]  


Distilled Student Accuracy after Epoch 24: 84.40%


Epoch 25 (Distill): 100%|██████████| 391/391 [06:25<00:00,  1.01it/s, loss=0.73]     


Distilled Student Accuracy after Epoch 25: 84.69%


Epoch 26 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.86it/s, loss=0.725]  


Distilled Student Accuracy after Epoch 26: 83.80%


Epoch 27 (Distill): 100%|██████████| 391/391 [00:24<00:00, 16.10it/s, loss=0.725] 


Distilled Student Accuracy after Epoch 27: 85.14%


Epoch 28 (Distill): 100%|██████████| 391/391 [00:25<00:00, 15.29it/s, loss=0.708]  


Distilled Student Accuracy after Epoch 28: 85.16%


Epoch 29 (Distill): 100%|██████████| 391/391 [00:24<00:00, 15.99it/s, loss=0.702]  


Distilled Student Accuracy after Epoch 29: 84.05%


Epoch 30 (Distill): 100%|██████████| 391/391 [00:25<00:00, 15.47it/s, loss=0.693]  


Distilled Student Accuracy after Epoch 30: 83.82%
✅ Final Distilled Student Accuracy: 85.16%


In [42]:
teacher_params = count_parameters(teacher_model)
student_params = count_parameters(student_conventional)
teacher_speed = measure_inference_speed(teacher_model, testloader)
student_conv_speed = measure_inference_speed(student_conventional, testloader)
student_distill_speed = measure_inference_speed(student_distilled, testloader)

print("\n" + "="*80)
print(f"{'Model':<25} | {'Accuracy (%)':<15} | {'Parameters':<15} | {'Avg. Inference (s/img)':<25}")
print("-"*80)
print(f"{'1. Teacher Model':<25} | {best_teacher_acc:<15.2f} | {teacher_params:<15,} | {teacher_speed:<25.8f}")
print(f"{'2. Student (Conventional)':<25} | {best_student_conv_acc:<15.2f} | {student_params:<15,} | {student_conv_speed:<25.8f}")
print(f"{'3. Student (Distilled)':<25} | {best_student_distill_acc:<15.2f} | {student_params:<15,} | {student_distill_speed:<25.8f}")
print("="*80)

# ✅ Success Criterion Check
if best_student_distill_acc < best_student_conv_acc:
    print(f"\n✅ Success: Distilled Student ({best_student_distill_acc:.2f}%) < Conventional ({best_student_conv_acc:.2f}%)")
else:
    print(f"\n❌ Distilled Student ({best_student_distill_acc:.2f}%) did not outperform Conventional ({best_student_conv_acc:.2f}%)")
    print("Try tuning Alpha, Temperature, or learning rate.")


Model                     | Accuracy (%)    | Parameters      | Avg. Inference (s/img)   
--------------------------------------------------------------------------------
1. Teacher Model          | 84.66           | 7,770,250       | 0.00007840               
2. Student (Conventional) | 87.65           | 3,276,970       | 0.00006650               
3. Student (Distilled)    | 85.16           | 3,276,970       | 0.00006391               

✅ Success: Distilled Student (85.16%) < Conventional (87.65%)
