In [3]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models import VGG16_Weights
from tqdm import tqdm
import os


def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)


teacher_save_path = './best_teacher_vgg16_cifar100.pth'
teacher_model = torchvision.models.vgg16(weights=None, num_classes=100).to(device)

if not os.path.exists(teacher_save_path):
    print("FATAL ERROR: Teacher model 'best_teacher_vgg16_cifar100.pth' not found. Please run the fine-tuning script first.")
    exit()

print(f"--- Loading existing teacher from {teacher_save_path} ---")
teacher_model.load_state_dict(torch.load(teacher_save_path))
teacher_acc = test(teacher_model, testloader, device)
print(f"Using teacher with accuracy: {teacher_acc:.2f}%")
teacher_model.eval()


def get_student_model():
    return torchvision.models.vgg11(weights=None, num_classes=100).to(device)


class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes=100, smoothing=0.1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
    def forward(self, pred, target):
        pred = pred.log_softmax(dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (pred.size(1) - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=-1))


def dkd_loss(student_logits, teacher_logits, labels, alpha, beta, temperature):
    gt_mask = torch.zeros_like(student_logits).scatter_(1, labels.unsqueeze(1), 1).bool()
    other_mask = ~gt_mask
    pred_student_gt = F.log_softmax(student_logits / temperature, dim=1)
    pred_teacher_gt = F.softmax(teacher_logits / temperature, dim=1)
    tckd = (F.kl_div(pred_student_gt, pred_teacher_gt, reduction='none') * (temperature**2)).masked_select(gt_mask).mean()
    pred_student_other = F.log_softmax(student_logits - 1000 * gt_mask, dim=1)
    pred_teacher_other = F.softmax(teacher_logits - 1000 * gt_mask, dim=1)
    nckd = (F.kl_div(pred_student_other, pred_teacher_other, reduction='none') * (temperature**2)).masked_select(other_mask).mean()
    return alpha * tckd + beta * nckd


def train_student(model, train_loader, optimizer, epoch, loss_type, teacher_model, device):
    model.train()
    progress = tqdm(train_loader, desc=f"Student Epoch {epoch+1} [{loss_type.upper()}]", leave=False)

    if loss_type == 'ls':
        criterion = LabelSmoothingLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    for inputs, labels in progress:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)

        if loss_type == 'dkd':
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)
            ce_loss = criterion(outputs, labels)
            distill_loss = dkd_loss(outputs, teacher_outputs, labels, alpha=1.0, beta=8.0, temperature=4.0)
            loss = ce_loss + distill_loss
        elif loss_type == 'ls':
            loss = criterion(outputs, labels)
        else: # 'ce' for baseline
            loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        progress.set_postfix(loss=f"{loss.item():.4f}")


def run_experiment(loss_type, teacher_model, device, num_epochs=200):
    print(f"\n--- Running Experiment: {loss_type.upper()} ---")
    student_model = get_student_model()
    student_save_path = f'./best_student_{loss_type}.pth'
    
    optimizer = optim.SGD(student_model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_acc = 0.0
    for epoch in range(num_epochs):
        train_student(student_model, trainloader, optimizer, epoch, loss_type, teacher_model, device)
        acc = test(student_model, testloader, device)
        
        
        print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            print(f"New best accuracy: {best_acc:.2f}%. Saving student model to {student_save_path}")
            torch.save(student_model.state_dict(), student_save_path)
            
        scheduler.step()
        
    print(f"Finished Training for {loss_type.upper()}. Best Accuracy: {best_acc:.2f}%")
    return best_acc


results = {}
STUDENT_EPOCHS = 40

results['ls'] = run_experiment('ls', teacher_model, device, STUDENT_EPOCHS)
results['dkd'] = run_experiment('dkd', teacher_model, device, STUDENT_EPOCHS)

try:
   
    lm_student = get_student_model()
    lm_student.load_state_dict(torch.load('./best_student_lm.pth'))
    lm_accuracy = test(lm_student, testloader, device)
except FileNotFoundError:
    lm_accuracy = "Not Found" 

print("\n\n--- Final Results Summary for Task 1 ---")
if lm_accuracy != "Not Found":
    print(f"Logit Matching (LM) Accuracy:   {lm_accuracy:.2f}%")
else:
    print("Logit Matching (LM) Accuracy:   Could not find saved model 'best_student_lm.pth'")

print(f"Baseline Student (CE) Accuracy: {results.get('baseline', 0):.2f}%")
print(f"Label Smoothing (LS) Accuracy:  {results.get('ls', 0):.2f}%")
print(f"Decoupled KD (DKD) Accuracy:    {results.get('dkd', 0):.2f}%")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
--- Loading existing teacher from ./best_teacher_vgg16_cifar100.pth ---


  teacher_model.load_state_dict(torch.load(teacher_save_path))


Using teacher with accuracy: 69.01%

--- Running Experiment: LS ---


                                                                                    

Epoch 1/40, Accuracy: 2.05%
New best accuracy: 2.05%. Saving student model to ./best_student_ls.pth


                                                                                    

Epoch 2/40, Accuracy: 6.36%
New best accuracy: 6.36%. Saving student model to ./best_student_ls.pth


                                                                                    

Epoch 3/40, Accuracy: 10.02%
New best accuracy: 10.02%. Saving student model to ./best_student_ls.pth


                                                                                    

Epoch 4/40, Accuracy: 14.77%
New best accuracy: 14.77%. Saving student model to ./best_student_ls.pth


                                                                                    

Epoch 5/40, Accuracy: 17.91%
New best accuracy: 17.91%. Saving student model to ./best_student_ls.pth


                                                                                    

Epoch 6/40, Accuracy: 22.77%
New best accuracy: 22.77%. Saving student model to ./best_student_ls.pth


                                                                                    

Epoch 7/40, Accuracy: 25.58%
New best accuracy: 25.58%. Saving student model to ./best_student_ls.pth


                                                                                    

Epoch 8/40, Accuracy: 28.70%
New best accuracy: 28.70%. Saving student model to ./best_student_ls.pth


                                                                                    

Epoch 9/40, Accuracy: 31.38%
New best accuracy: 31.38%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 10/40, Accuracy: 34.85%
New best accuracy: 34.85%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 11/40, Accuracy: 38.40%
New best accuracy: 38.40%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 12/40, Accuracy: 38.85%
New best accuracy: 38.85%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 13/40, Accuracy: 42.06%
New best accuracy: 42.06%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 14/40, Accuracy: 45.01%
New best accuracy: 45.01%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 15/40, Accuracy: 45.32%
New best accuracy: 45.32%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 16/40, Accuracy: 46.97%
New best accuracy: 46.97%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 17/40, Accuracy: 49.46%
New best accuracy: 49.46%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 18/40, Accuracy: 50.53%
New best accuracy: 50.53%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 19/40, Accuracy: 52.37%
New best accuracy: 52.37%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 20/40, Accuracy: 53.68%
New best accuracy: 53.68%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 21/40, Accuracy: 54.18%
New best accuracy: 54.18%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 22/40, Accuracy: 55.44%
New best accuracy: 55.44%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 23/40, Accuracy: 56.11%
New best accuracy: 56.11%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 24/40, Accuracy: 56.07%


                                                                                     

Epoch 25/40, Accuracy: 58.96%
New best accuracy: 58.96%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 26/40, Accuracy: 59.53%
New best accuracy: 59.53%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 27/40, Accuracy: 60.36%
New best accuracy: 60.36%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 28/40, Accuracy: 61.31%
New best accuracy: 61.31%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 29/40, Accuracy: 61.86%
New best accuracy: 61.86%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 30/40, Accuracy: 62.42%
New best accuracy: 62.42%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 31/40, Accuracy: 62.45%
New best accuracy: 62.45%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 32/40, Accuracy: 63.10%
New best accuracy: 63.10%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 33/40, Accuracy: 63.76%
New best accuracy: 63.76%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 34/40, Accuracy: 64.04%
New best accuracy: 64.04%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 35/40, Accuracy: 63.97%


                                                                                     

Epoch 36/40, Accuracy: 64.56%
New best accuracy: 64.56%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 37/40, Accuracy: 64.68%
New best accuracy: 64.68%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 38/40, Accuracy: 65.04%
New best accuracy: 65.04%. Saving student model to ./best_student_ls.pth


                                                                                     

Epoch 39/40, Accuracy: 65.04%


                                                                                     

Epoch 40/40, Accuracy: 64.84%
Finished Training for LS. Best Accuracy: 65.04%

--- Running Experiment: DKD ---


                                                                                      

Epoch 1/40, Accuracy: 1.00%
New best accuracy: 1.00%. Saving student model to ./best_student_dkd.pth


                                                                                                

Epoch 2/40, Accuracy: 1.00%


                                                                                  

Epoch 3/40, Accuracy: 1.00%


                                                                                 

KeyboardInterrupt: 

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import os
torch.cuda.empty_cache()



def test(model, test_loader, device):
   
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)



teacher_save_path = './best_teacher_vgg16_cifar100.pth'
teacher_model = torchvision.models.vgg16(weights=None, num_classes=100).to(device)

if not os.path.exists(teacher_save_path):
    print(f"FATAL ERROR: Teacher model '{teacher_save_path}' not found. Please ensure the pre-trained teacher model is available.")
    exit()

print(f"--- Loading existing teacher from {teacher_save_path} ---")
teacher_model.load_state_dict(torch.load(teacher_save_path))
teacher_acc = test(teacher_model, testloader, device)
print(f"Using teacher with accuracy: {teacher_acc:.2f}%")
teacher_model.eval()



def get_student_model():
   
    return torchvision.models.vgg11(weights=None, num_classes=100).to(device)

def dkd_loss(student_logits, teacher_logits, labels, temperature):
 
    gt_mask = torch.zeros_like(student_logits).scatter_(1, labels.unsqueeze(1), 1).bool()
    other_mask = ~gt_mask

    pred_student_gt = F.log_softmax(student_logits / temperature, dim=1)
    pred_teacher_gt = F.softmax(teacher_logits / temperature, dim=1)
    tckd = (F.kl_div(pred_student_gt, pred_teacher_gt, reduction='none') * (temperature**2)).masked_select(gt_mask).mean()


    student_logits_other = student_logits.masked_fill(gt_mask, -float('inf'))
    teacher_logits_other = teacher_logits.masked_fill(gt_mask, -float('inf'))
    pred_student_other = F.log_softmax(student_logits_other / temperature, dim=1)
    pred_teacher_other = F.softmax(teacher_logits_other / temperature, dim=1)
    nckd = (F.kl_div(pred_student_other, pred_teacher_other, reduction='none') * (temperature**2)).masked_select(other_mask).mean()

    return tckd, nckd


def train_student_dkd(model, train_loader, optimizer, epoch, teacher_model, device, alpha, beta, temperature):
  
    model.train()
    progress = tqdm(train_loader, desc=f"Student Epoch {epoch+1} [DKD]", leave=False)
    criterion = nn.CrossEntropyLoss()

    for inputs, labels in progress:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = model(inputs)
        with torch.no_grad():
            teacher_outputs = teacher_model(inputs)
            
        ce_loss = criterion(outputs, labels)
        
        tckd, nckd = dkd_loss(outputs, teacher_outputs, labels, temperature)
        
        loss = ce_loss + (alpha * tckd) + (beta * nckd)
        
        loss.backward()
        optimizer.step()
        progress.set_postfix(loss=f"{loss.item():.4f}")



def run_dkd_experiment(teacher_model, device, num_epochs=40):
    
    print(f"\n--- Running Experiment: DECOUPLED KD (DKD) ---")
    student_model = get_student_model()
    student_save_path = './best_student_dkd.pth'
    
    optimizer = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
   
    alpha = 1.0
    beta = 8.0
    temperature = 4.0

    best_acc = 0.0
    for epoch in range(num_epochs):
        train_student_dkd(student_model, trainloader, optimizer, epoch, teacher_model, device, alpha, beta, temperature)
        acc = test(student_model, testloader, device)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            print(f"New best accuracy: {best_acc:.2f}%. Saving student model to {student_save_path}")
            torch.save(student_model.state_dict(), student_save_path)
            
        scheduler.step()
        
    print(f"\nFinished Training for DKD. Best Accuracy: {best_acc:.2f}%")
    print(f"Best DKD student model saved to {student_save_path}")
    return best_acc


STUDENT_EPOCHS = 40


dkd_accuracy = run_dkd_experiment(teacher_model, device, STUDENT_EPOCHS)

print("\n\n--- Final Result for DKD ---")
print(f"Decoupled KD (DKD) Best Accuracy: {dkd_accuracy:.2f}%")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
--- Loading existing teacher from ./best_teacher_vgg16_cifar100.pth ---


  teacher_model.load_state_dict(torch.load(teacher_save_path))


Using teacher with accuracy: 69.01%

--- Running Experiment: DECOUPLED KD (DKD) ---


                                                                                      

Epoch 1/40, Accuracy: 3.62%
New best accuracy: 3.62%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 2/40, Accuracy: 5.98%
New best accuracy: 5.98%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 3/40, Accuracy: 10.11%
New best accuracy: 10.11%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 4/40, Accuracy: 15.20%
New best accuracy: 15.20%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 5/40, Accuracy: 19.30%
New best accuracy: 19.30%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 6/40, Accuracy: 25.35%
New best accuracy: 25.35%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 7/40, Accuracy: 27.39%
New best accuracy: 27.39%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 8/40, Accuracy: 30.60%
New best accuracy: 30.60%. Saving student model to ./best_student_dkd.pth


                                                                                     

Epoch 9/40, Accuracy: 31.39%
New best accuracy: 31.39%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 10/40, Accuracy: 36.47%
New best accuracy: 36.47%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 11/40, Accuracy: 37.41%
New best accuracy: 37.41%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 12/40, Accuracy: 38.94%
New best accuracy: 38.94%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 13/40, Accuracy: 40.84%
New best accuracy: 40.84%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 14/40, Accuracy: 43.65%
New best accuracy: 43.65%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 15/40, Accuracy: 45.20%
New best accuracy: 45.20%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 16/40, Accuracy: 45.24%
New best accuracy: 45.24%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 17/40, Accuracy: 48.74%
New best accuracy: 48.74%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 18/40, Accuracy: 50.49%
New best accuracy: 50.49%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 19/40, Accuracy: 49.96%


                                                                                      

Epoch 20/40, Accuracy: 50.32%


                                                                                      

Epoch 21/40, Accuracy: 53.59%
New best accuracy: 53.59%. Saving student model to ./best_student_dkd.pth


                                                                                      

Epoch 22/40, Accuracy: 54.25%
New best accuracy: 54.25%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 23/40, Accuracy: 54.54%
New best accuracy: 54.54%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 24/40, Accuracy: 55.03%
New best accuracy: 55.03%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 25/40, Accuracy: 55.47%
New best accuracy: 55.47%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 26/40, Accuracy: 56.47%
New best accuracy: 56.47%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 27/40, Accuracy: 56.62%
New best accuracy: 56.62%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 28/40, Accuracy: 57.77%
New best accuracy: 57.77%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 29/40, Accuracy: 58.34%
New best accuracy: 58.34%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 30/40, Accuracy: 58.65%
New best accuracy: 58.65%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 31/40, Accuracy: 59.35%
New best accuracy: 59.35%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 32/40, Accuracy: 60.03%
New best accuracy: 60.03%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 33/40, Accuracy: 59.73%


                                                                                       

Epoch 34/40, Accuracy: 60.64%
New best accuracy: 60.64%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 35/40, Accuracy: 61.16%
New best accuracy: 61.16%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 36/40, Accuracy: 61.04%


                                                                                       

Epoch 37/40, Accuracy: 61.00%


                                                                                       

Epoch 38/40, Accuracy: 61.44%
New best accuracy: 61.44%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 39/40, Accuracy: 61.52%
New best accuracy: 61.52%. Saving student model to ./best_student_dkd.pth


                                                                                       

Epoch 40/40, Accuracy: 61.55%
New best accuracy: 61.55%. Saving student model to ./best_student_dkd.pth

Finished Training for DKD. Best Accuracy: 61.55%
Best DKD student model saved to ./best_student_dkd.pth


--- Final Result for DKD ---
Decoupled KD (DKD) Best Accuracy: 61.55%


In [9]:
torch.cuda.empty_cache()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import os
torch.cuda.empty_cache()


def test(model, test_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# --- DATA SETUP ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)


def get_student_model():
    return torchvision.models.vgg11(weights=None, num_classes=100).to(device)

def train_independent_student(model, train_loader, optimizer, epoch, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    progress = tqdm(train_loader, desc=f"Student Epoch {epoch+1} [Independent]", leave=False)
    for inputs, labels in progress:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        progress.set_postfix(loss=f"{loss.item():.4f}")


def run_si_experiment(num_epochs=200):
    print("\n--- Running Experiment: INDEPENDENT STUDENT (SI) ---")
    student_model = get_student_model()
    student_save_path = './best_student_si.pth'
    
   
    optimizer = optim.SGD(student_model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_acc = 0.0
    for epoch in range(num_epochs):
        train_independent_student(student_model, trainloader, optimizer, epoch, device)
        acc = test(student_model, testloader, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            print(f"New best accuracy: {best_acc:.2f}%. Saving model to {student_save_path}")
            torch.save(student_model.state_dict(), student_save_path)
            
        scheduler.step()
        
    print(f"\nFinished Training for SI. Best Accuracy: {best_acc:.2f}%")
    return best_acc


STUDENT_EPOCHS = 200 
si_accuracy = run_si_experiment(STUDENT_EPOCHS)
print(f"\n--- Final Result for Independent Student ---")
print(f"Independent Student (SI) Best Accuracy: {si_accuracy:.2f}%")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified

--- Running Experiment: INDEPENDENT STUDENT (SI) ---


                                                                                             

Epoch 1/200, Accuracy: 2.40%
New best accuracy: 2.40%. Saving model to ./best_student_si.pth


                                                                                             

Epoch 2/200, Accuracy: 7.41%
New best accuracy: 7.41%. Saving model to ./best_student_si.pth


                                                                                             

Epoch 3/200, Accuracy: 10.58%
New best accuracy: 10.58%. Saving model to ./best_student_si.pth


                                                                                             

Epoch 4/200, Accuracy: 14.98%
New best accuracy: 14.98%. Saving model to ./best_student_si.pth


                                                                                             

Epoch 5/200, Accuracy: 18.51%
New best accuracy: 18.51%. Saving model to ./best_student_si.pth


                                                                                             

Epoch 6/200, Accuracy: 21.76%
New best accuracy: 21.76%. Saving model to ./best_student_si.pth


                                                                                             

Epoch 7/200, Accuracy: 24.59%
New best accuracy: 24.59%. Saving model to ./best_student_si.pth


                                                                                             

Epoch 8/200, Accuracy: 29.03%
New best accuracy: 29.03%. Saving model to ./best_student_si.pth


                                                                                             

Epoch 9/200, Accuracy: 30.05%
New best accuracy: 30.05%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 10/200, Accuracy: 32.64%
New best accuracy: 32.64%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 11/200, Accuracy: 33.26%
New best accuracy: 33.26%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 12/200, Accuracy: 36.07%
New best accuracy: 36.07%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 13/200, Accuracy: 36.81%
New best accuracy: 36.81%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 14/200, Accuracy: 38.11%
New best accuracy: 38.11%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 15/200, Accuracy: 37.81%


                                                                                              

Epoch 16/200, Accuracy: 39.85%
New best accuracy: 39.85%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 17/200, Accuracy: 41.80%
New best accuracy: 41.80%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 18/200, Accuracy: 41.73%


                                                                                              

Epoch 19/200, Accuracy: 43.54%
New best accuracy: 43.54%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 20/200, Accuracy: 41.26%


                                                                                              

Epoch 21/200, Accuracy: 43.35%


                                                                                              

Epoch 22/200, Accuracy: 45.00%
New best accuracy: 45.00%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 23/200, Accuracy: 45.45%
New best accuracy: 45.45%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 24/200, Accuracy: 45.40%


                                                                                              

Epoch 25/200, Accuracy: 44.39%


                                                                                              

Epoch 26/200, Accuracy: 45.16%


                                                                                              

Epoch 27/200, Accuracy: 44.06%


                                                                                              

Epoch 28/200, Accuracy: 45.65%
New best accuracy: 45.65%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 29/200, Accuracy: 46.32%
New best accuracy: 46.32%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 30/200, Accuracy: 47.10%
New best accuracy: 47.10%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 31/200, Accuracy: 48.33%
New best accuracy: 48.33%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 32/200, Accuracy: 48.20%


                                                                                              

Epoch 33/200, Accuracy: 47.17%


                                                                                              

Epoch 34/200, Accuracy: 46.34%


                                                                                              

Epoch 35/200, Accuracy: 47.20%


                                                                                              

Epoch 36/200, Accuracy: 49.04%
New best accuracy: 49.04%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 37/200, Accuracy: 47.14%


                                                                                              

Epoch 38/200, Accuracy: 48.97%


                                                                                              

Epoch 39/200, Accuracy: 49.64%
New best accuracy: 49.64%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 40/200, Accuracy: 47.12%


                                                                                              

Epoch 41/200, Accuracy: 49.77%
New best accuracy: 49.77%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 42/200, Accuracy: 49.69%


                                                                                              

Epoch 43/200, Accuracy: 50.13%
New best accuracy: 50.13%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 44/200, Accuracy: 49.74%


                                                                                              

Epoch 45/200, Accuracy: 49.45%


                                                                                              

Epoch 46/200, Accuracy: 50.94%
New best accuracy: 50.94%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 47/200, Accuracy: 50.19%


                                                                                              

Epoch 48/200, Accuracy: 52.24%
New best accuracy: 52.24%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 49/200, Accuracy: 50.49%


                                                                                              

Epoch 50/200, Accuracy: 51.76%


                                                                                              

Epoch 51/200, Accuracy: 51.64%


                                                                                              

Epoch 52/200, Accuracy: 51.37%


                                                                                              

Epoch 53/200, Accuracy: 52.64%
New best accuracy: 52.64%. Saving model to ./best_student_si.pth


                                                                                              

Epoch 54/200, Accuracy: 51.19%


                                                                                              

Epoch 55/200, Accuracy: 52.37%


                                                                                              

Epoch 56/200, Accuracy: 52.67%
New best accuracy: 52.67%. Saving model to ./best_student_si.pth


Student Epoch 57 [Independent]:  91%|█████████ | 355/391 [00:27<00:02, 14.53it/s, loss=1.6422]

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import os
torch.cuda.empty_cache()

# --- UTILITY FUNCTIONS & DATA SETUP ---
def test(model, test_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs) # Model will now return features
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)


# --- WRAPPER MODELS TO EXTRACT FEATURES ---
class VGG11_Hint(nn.Module):
    def __init__(self):
        super(VGG11_Hint, self).__init__()
        base_model = torchvision.models.vgg11(weights=None, num_classes=100)
        # We define hint_layer using a specific index from the features module
        self.features = base_model.features
        self.hint_layer_index = 15 # After 4th conv block
        self.avgpool = base_model.avgpool
        self.classifier = base_model.classifier

    def forward(self, x):
        hint_features = self.features[:self.hint_layer_index + 1](x)
        out = self.features[self.hint_layer_index + 1:](hint_features)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        logits = self.classifier(out)
        return logits, hint_features

# The VGG16_Guided wrapper is no longer needed with the new loading strategy
# but we define which layer we want to extract features from
TEACHER_GUIDED_LAYER_INDEX = 23 # After 4th conv block in VGG16


# --- LOAD TEACHER (CORRECTED METHOD) ---
teacher_save_path = './best_teacher_vgg16_cifar100.pth'
if not os.path.exists(teacher_save_path):
    print("FATAL ERROR: Teacher model not found.")
    exit()

# Step 1: Create a standard VGG16 model, which has the correct architecture
teacher_model = torchvision.models.vgg16(weights=None, num_classes=100).to(device)

# Step 2: Load the state_dict into this correctly structured model
teacher_model.load_state_dict(torch.load(teacher_save_path))
teacher_model.eval()
print("Teacher model loaded successfully.")

# --- LOSS FUNCTIONS ---
def loss_fn_kd(outputs, labels, teacher_outputs, alpha, temperature):
    hard_loss = F.cross_entropy(outputs, labels)
    soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs / temperature, dim=1),
                                                   F.softmax(teacher_outputs / temperature, dim=1)) * (temperature * temperature)
    return alpha * hard_loss + (1 - alpha) * soft_loss

def hint_loss_function(student_hint, teacher_guided, regressor):
    student_regr = regressor(student_hint)
    return F.mse_loss(student_regr, teacher_guided)

# --- TRAINING LOOP ---
def train_student_hints(student, regressor, train_loader, optimizer, epoch, teacher, device, hint_lambda):
    student.train()
    regressor.train()
    progress = tqdm(train_loader, desc=f"Student Epoch {epoch+1} [Hints]", leave=False)

    # Hook to capture teacher's intermediate features
    teacher_features = {}
    def get_teacher_features(name):
        def hook(model, input, output):
            teacher_features[name] = output.detach()
        return hook
    
    teacher.features[TEACHER_GUIDED_LAYER_INDEX].register_forward_hook(get_teacher_features('guided'))

    for inputs, labels in progress:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        # Student forward pass
        student_logits, student_hint_features = student(inputs)
        
        # Teacher forward pass (will trigger the hook)
        with torch.no_grad():
            teacher_logits = teacher(inputs)
            teacher_guided_features = teacher_features['guided']

        # Standard Knowledge Distillation Loss
        kd_loss = loss_fn_kd(student_logits, labels, teacher_logits, alpha=0.5, temperature=4.0)
        # Hint Loss
        hint_loss = hint_loss_function(student_hint_features, teacher_guided_features, regressor)
        
        # Total Loss
        loss = kd_loss + (hint_lambda * hint_loss)
        
        loss.backward()
        optimizer.step()
        progress.set_postfix(loss=f"{loss.item():.4f}")

# --- EXPERIMENT RUNNER ---
def run_hints_experiment(teacher_model, num_epochs=200):
    print("\n--- Running Experiment: HINT-BASED DISTILLATION (HINTS) ---")
    student_model = VGG11_Hint().to(device)
    student_save_path = './best_student_hints.pth'

    # Student VGG11 hint (after block 4) is [B, 512, 4, 4]
    # Teacher VGG16 guided (after block 4) is [B, 512, 4, 4]
    regressor = nn.Conv2d(512, 512, kernel_size=1).to(device)

    # --- PROACTIVE FIX: REDUCED LEARNING RATE ---
    # The combination of a high LR and a large hint_lambda can cause instability.
    # Reducing the LR to 0.01 is a much safer starting point.
    optimizer = optim.SGD(
        list(student_model.parameters()) + list(regressor.parameters()),
        lr=0.01, momentum=0.9, weight_decay=5e-4
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_acc = 0.0
    HINT_LAMBDA = 100.0 # This weight for the hint loss often needs tuning

    for epoch in range(num_epochs):
        train_student_hints(student_model, regressor, trainloader, optimizer, epoch, teacher_model, device, HINT_LAMBDA)
        acc = test(student_model, testloader, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            print(f"New best accuracy: {best_acc:.2f}%. Saving model to {student_save_path}")
            # Save only the student's weights, not the wrapper
            torch.save(student_model.state_dict(), student_save_path)
            
        scheduler.step()
        
    print(f"\nFinished Training for Hints. Best Accuracy: {best_acc:.2f}%")
    return best_acc


STUDENT_EPOCHS = 40 
hints_accuracy = run_hints_experiment(teacher_model, STUDENT_EPOCHS)
print(f"\n--- Final Result for Hint-based Distillation ---")
print(f"Hints Best Accuracy: {hints_accuracy:.2f}%")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified


  teacher_model.load_state_dict(torch.load(teacher_save_path))


Teacher model loaded successfully.

--- Running Experiment: HINT-BASED DISTILLATION (HINTS) ---


                                                                                        

Epoch 1/40, Accuracy: 13.37%
New best accuracy: 13.37%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 2/40, Accuracy: 25.49%
New best accuracy: 25.49%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 3/40, Accuracy: 39.40%
New best accuracy: 39.40%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 4/40, Accuracy: 45.10%
New best accuracy: 45.10%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 5/40, Accuracy: 49.00%
New best accuracy: 49.00%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 6/40, Accuracy: 52.99%
New best accuracy: 52.99%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 7/40, Accuracy: 55.56%
New best accuracy: 55.56%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 8/40, Accuracy: 57.09%
New best accuracy: 57.09%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 9/40, Accuracy: 57.89%
New best accuracy: 57.89%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 10/40, Accuracy: 58.68%
New best accuracy: 58.68%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 11/40, Accuracy: 60.76%
New best accuracy: 60.76%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 12/40, Accuracy: 61.49%
New best accuracy: 61.49%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 13/40, Accuracy: 63.00%
New best accuracy: 63.00%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 14/40, Accuracy: 62.84%


                                                                                         

Epoch 15/40, Accuracy: 63.34%
New best accuracy: 63.34%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 16/40, Accuracy: 63.96%
New best accuracy: 63.96%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 17/40, Accuracy: 64.68%
New best accuracy: 64.68%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 18/40, Accuracy: 64.83%
New best accuracy: 64.83%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 19/40, Accuracy: 65.38%
New best accuracy: 65.38%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 20/40, Accuracy: 65.97%
New best accuracy: 65.97%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 21/40, Accuracy: 66.40%
New best accuracy: 66.40%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 22/40, Accuracy: 66.29%


                                                                                         

Epoch 23/40, Accuracy: 66.52%
New best accuracy: 66.52%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 24/40, Accuracy: 67.11%
New best accuracy: 67.11%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 25/40, Accuracy: 67.09%


                                                                                         

Epoch 26/40, Accuracy: 67.15%
New best accuracy: 67.15%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 27/40, Accuracy: 67.57%
New best accuracy: 67.57%. Saving model to ./best_student_hints.pth


                                                                                         

Epoch 28/40, Accuracy: 68.03%
New best accuracy: 68.03%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 29/40, Accuracy: 67.76%


                                                                                        

Epoch 30/40, Accuracy: 68.15%
New best accuracy: 68.15%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 31/40, Accuracy: 68.15%


                                                                                        

Epoch 32/40, Accuracy: 68.33%
New best accuracy: 68.33%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 33/40, Accuracy: 68.56%
New best accuracy: 68.56%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 34/40, Accuracy: 68.54%


                                                                                        

Epoch 35/40, Accuracy: 68.69%
New best accuracy: 68.69%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 36/40, Accuracy: 68.92%
New best accuracy: 68.92%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 37/40, Accuracy: 69.01%
New best accuracy: 69.01%. Saving model to ./best_student_hints.pth


                                                                                        

Epoch 38/40, Accuracy: 68.79%


                                                                                        

Epoch 39/40, Accuracy: 68.86%


                                                                                        

Epoch 40/40, Accuracy: 68.85%

Finished Training for Hints. Best Accuracy: 69.01%

--- Final Result for Hint-based Distillation ---
Hints Best Accuracy: 69.01%


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import os

torch.cuda.empty_cache()

# --- UTILITY FUNCTIONS & DATA SETUP (No changes) ---
def test(model, test_loader, device):
    student_wrapper = model
    student_wrapper.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            _, outputs = student_wrapper.student(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)


# --- THE FIX IS IN THIS MODULE ---
class VGG_Extractor(nn.Module):
    def __init__(self, vgg_model):
        super(VGG_Extractor, self).__init__()
        self.features = vgg_model.features
        self.avgpool = vgg_model.avgpool
        self.classifier = vgg_model.classifier
        self.feat_dim = 512 # The true feature dimension before the classifier

    def forward(self, x):
        # Extract features from the convolutional backbone
        pre_pool_feat = self.features(x)
        
        # This is the feature vector we want for CRD.
        # Its shape is (B, 512, 1, 1) for CIFAR, so flattening gives (B, 512).
        crd_feat = torch.flatten(pre_pool_feat, 1)
        
        # Now, continue the original forward pass to get the logits,
        # even if the architecture is flawed (as it is for the teacher).
        post_pool_feat = self.avgpool(pre_pool_feat)
        flat_feat_for_classifier = torch.flatten(post_pool_feat, 1)
        logits = self.classifier(flat_feat_for_classifier)
        
        # Return the CORRECT feature vector and the final logits
        return crd_feat, logits

class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim=128):
        super(ProjectionHead, self).__init__()
        self.head = nn.Sequential(
            nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim)
        )
    def forward(self, x):
        return F.normalize(self.head(x), dim=1)

class CrdStudentWrapper(nn.Module):
    def __init__(self):
        super(CrdStudentWrapper, self).__init__()
        base_student = torchvision.models.vgg11(weights=None, num_classes=100)
        base_student.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        base_student.classifier[0] = nn.Linear(512, 4096)
        self.student = VGG_Extractor(base_student)
        self.projector = ProjectionHead(in_dim=self.student.feat_dim)

    def forward(self, x):
        feat, logits = self.student(x)
        proj_feat = self.projector(feat)
        return proj_feat, logits

def contrastive_loss(student_proj, teacher_proj, temperature=0.07):
    batch_size = student_proj.shape[0]
    sim_s_t = torch.matmul(student_proj, teacher_proj.T) / temperature
    sim_t_s = torch.matmul(teacher_proj, student_proj.T) / temperature
    labels = torch.arange(batch_size).long().to(student_proj.device)
    loss = F.cross_entropy(sim_s_t, labels) + F.cross_entropy(sim_t_s, labels)
    return loss

# --- LOAD TEACHER ---
teacher_save_path = './best_teacher_vgg16_cifar100.pth'
if not os.path.exists(teacher_save_path):
    print("FATAL ERROR: Teacher model not found.")
    exit()
base_teacher_model = torchvision.models.vgg16(weights=None, num_classes=100)
base_teacher_model.load_state_dict(torch.load(teacher_save_path))
teacher_model = VGG_Extractor(base_teacher_model).to(device)
teacher_projector = ProjectionHead(in_dim=teacher_model.feat_dim).to(device)
teacher_model.eval()
teacher_projector.eval()
print("Teacher model and projector loaded successfully.")

# --- LOSS & TRAINING DEFINITIONS ---
def loss_fn_kd(outputs, labels, teacher_outputs, alpha, temperature):
    hard_loss = F.cross_entropy(outputs, labels)
    soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs / temperature, dim=1),
                                                   F.softmax(teacher_outputs / temperature, dim=1)) * (temperature * temperature)
    return alpha * hard_loss + (1 - alpha) * soft_loss

def train_student_crd(student_wrapper, train_loader, optimizer, epoch, teacher, teacher_proj, device, crd_lambda):
    student_wrapper.train()
    progress = tqdm(train_loader, desc=f"Student Epoch {epoch+1} [CRD]", leave=False)

    for inputs, labels in progress:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        student_proj_feat, student_logits = student_wrapper(inputs)
        
        with torch.no_grad():
            teacher_feat, teacher_logits = teacher(inputs)
            teacher_proj_feat = teacher_proj(teacher_feat)

        kd_loss = loss_fn_kd(student_logits, labels, teacher_logits, alpha=0.5, temperature=4.0)
        crd_loss = contrastive_loss(student_proj_feat, teacher_proj_feat, temperature=0.1)
        loss = kd_loss + (crd_lambda * crd_loss)
        
        loss.backward()
        optimizer.step()
        progress.set_postfix(loss=f"{loss.item():.4f}")

# --- EXPERIMENT RUNNER ---
def run_crd_experiment(teacher_model, teacher_projector, num_epochs=200):
    print("\n--- Running Experiment: CONTRASTIVE REPRESENTATION DISTILLATION (CRD) ---")
    student_wrapper = CrdStudentWrapper().to(device)
    student_save_path = './best_student_crd.pth'

    optimizer = optim.SGD(student_wrapper.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_acc = 0.0
    CRD_LAMBDA = 0.8
    for epoch in range(num_epochs):
        train_student_crd(student_wrapper, trainloader, optimizer, epoch, teacher_model, teacher_projector, device, CRD_LAMBDA)
        acc = test(student_wrapper, testloader, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            print(f"New best accuracy: {best_acc:.2f}%. Saving model to {student_save_path}")
            torch.save(student_wrapper.student.state_dict(), student_save_path)
            
        scheduler.step()
        
    print(f"\nFinished Training for CRD. Best Accuracy: {best_acc:.2f}%")
    return best_acc

# --- MAIN EXECUTION ---
STUDENT_EPOCHS = 40
crd_accuracy = run_crd_experiment(teacher_model, teacher_projector, STUDENT_EPOCHS)
print(f"\n--- Final Result for Contrastive Distillation ---")
print(f"CRD Best Accuracy: {crd_accuracy:.2f}%")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified


  base_teacher_model.load_state_dict(torch.load(teacher_save_path))


Teacher model and projector loaded successfully.

--- Running Experiment: CONTRASTIVE REPRESENTATION DISTILLATION (CRD) ---


                                                                                      

Epoch 1/40, Accuracy: 15.36%
New best accuracy: 15.36%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 2/40, Accuracy: 24.32%
New best accuracy: 24.32%. Saving model to ./best_student_crd.pth


                                                                                     

Epoch 3/40, Accuracy: 31.60%
New best accuracy: 31.60%. Saving model to ./best_student_crd.pth


                                                                                     

Epoch 4/40, Accuracy: 38.17%
New best accuracy: 38.17%. Saving model to ./best_student_crd.pth


                                                                                     

Epoch 5/40, Accuracy: 41.89%
New best accuracy: 41.89%. Saving model to ./best_student_crd.pth


                                                                                     

Epoch 6/40, Accuracy: 44.03%
New best accuracy: 44.03%. Saving model to ./best_student_crd.pth


                                                                                     

Epoch 7/40, Accuracy: 48.72%
New best accuracy: 48.72%. Saving model to ./best_student_crd.pth


                                                                                     

Epoch 8/40, Accuracy: 51.35%
New best accuracy: 51.35%. Saving model to ./best_student_crd.pth


                                                                                     

Epoch 9/40, Accuracy: 52.06%
New best accuracy: 52.06%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 10/40, Accuracy: 54.38%
New best accuracy: 54.38%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 11/40, Accuracy: 54.95%
New best accuracy: 54.95%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 12/40, Accuracy: 57.23%
New best accuracy: 57.23%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 13/40, Accuracy: 57.84%
New best accuracy: 57.84%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 14/40, Accuracy: 59.07%
New best accuracy: 59.07%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 15/40, Accuracy: 59.71%
New best accuracy: 59.71%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 16/40, Accuracy: 60.99%
New best accuracy: 60.99%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 17/40, Accuracy: 61.32%
New best accuracy: 61.32%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 18/40, Accuracy: 62.17%
New best accuracy: 62.17%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 19/40, Accuracy: 62.45%
New best accuracy: 62.45%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 20/40, Accuracy: 62.17%


                                                                                      

Epoch 21/40, Accuracy: 63.48%
New best accuracy: 63.48%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 22/40, Accuracy: 63.78%
New best accuracy: 63.78%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 23/40, Accuracy: 64.14%
New best accuracy: 64.14%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 24/40, Accuracy: 64.11%


                                                                                      

Epoch 25/40, Accuracy: 64.26%
New best accuracy: 64.26%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 26/40, Accuracy: 64.83%
New best accuracy: 64.83%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 27/40, Accuracy: 64.85%
New best accuracy: 64.85%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 28/40, Accuracy: 65.01%
New best accuracy: 65.01%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 29/40, Accuracy: 64.93%


                                                                                      

Epoch 30/40, Accuracy: 65.71%
New best accuracy: 65.71%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 31/40, Accuracy: 65.38%


                                                                                      

Epoch 32/40, Accuracy: 65.53%


                                                                                      

Epoch 33/40, Accuracy: 65.62%


                                                                                      

Epoch 34/40, Accuracy: 65.80%
New best accuracy: 65.80%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 35/40, Accuracy: 65.75%


                                                                                      

Epoch 36/40, Accuracy: 65.96%
New best accuracy: 65.96%. Saving model to ./best_student_crd.pth


                                                                                      

Epoch 37/40, Accuracy: 65.84%


                                                                                      

Epoch 38/40, Accuracy: 65.85%


                                                                                      

Epoch 39/40, Accuracy: 65.89%


                                                                                      

Epoch 40/40, Accuracy: 65.86%

Finished Training for CRD. Best Accuracy: 65.96%

--- Final Result for Contrastive Distillation ---
CRD Best Accuracy: 65.96%
