In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models import VGG16_Weights
from tqdm import tqdm
import os


def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def train_teacher(model, train_loader, optimizer, epoch, device):
    model.train()
    running_loss = 0.0
    progress = tqdm(train_loader, desc=f"Teacher Training Epoch {epoch+1}", leave=False)
    for inputs, labels in progress:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        progress.set_postfix(loss=f"{running_loss / len(progress):.4f}")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)



teacher_save_path = './best_teacher_vgg16_cifar100.pth'
teacher_model = torchvision.models.vgg16(weights=None, num_classes=100).to(device)

if os.path.exists(teacher_save_path):
    print(f"--- Loading existing teacher from {teacher_save_path} ---")
    teacher_model.load_state_dict(torch.load(teacher_save_path))
else:
    print("--- Fine-tuning a new teacher model ---")
    vgg16_pretrained = torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
    teacher_model.features = vgg16_pretrained.features
    teacher_model = teacher_model.to(device)
    optimizer = optim.SGD(teacher_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    best_teacher_acc = 0.0
    for epoch in range(10):
        train_teacher(teacher_model, trainloader, optimizer, epoch, device)
        acc = test(teacher_model, testloader, device)
        print(f"Teacher Epoch {epoch+1}, Accuracy: {acc:.2f}%")
        if acc > best_teacher_acc:
            best_teacher_acc = acc
            torch.save(teacher_model.state_dict(), teacher_save_path)
            print("New best accuracy. Model saved.")
    print(f"--- Finished Teacher Fine-Tuning. Best Accuracy: {best_teacher_acc:.2f}% ---")

teacher_acc = test(teacher_model, testloader, device)
print(f"Using teacher with accuracy: {teacher_acc:.2f}%")
teacher_model.eval()


def get_student_model():
    return torchvision.models.vgg11(weights=None, num_classes=100).to(device)

def loss_fn_kd(outputs, labels, teacher_outputs, alpha, temperature):
    hard_loss = F.cross_entropy(outputs, labels)
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(outputs / temperature, dim=1),
        F.softmax(teacher_outputs / temperature, dim=1)
    ) * (temperature * temperature)
    return alpha * soft_loss + (1 - alpha) * hard_loss


def train_student_lm(model, train_loader, optimizer, epoch, teacher_model, device):
    model.train()
    progress = tqdm(train_loader, desc=f"Student Epoch {epoch+1} [LM]", leave=False)
    for inputs, labels in progress:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        with torch.no_grad():
            teacher_outputs = teacher_model(inputs)
        loss = loss_fn_kd(outputs, labels, teacher_outputs, alpha=0.5, temperature=4.0)
        loss.backward()
        optimizer.step()
        progress.set_postfix(loss=f"{loss.item():.4f}")


def run_lm_experiment(teacher_model, device, num_epochs=200):
    print(f"\n--- Running Experiment: LOGIT MATCHING (LM) ---")
    student_model = get_student_model()
    
   
    student_save_path = './best_student_lm.pth'
    
    optimizer = optim.SGD(student_model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_acc = 0.0
    for epoch in range(num_epochs):
        train_student_lm(student_model, trainloader, optimizer, epoch, teacher_model, device)
        acc = test(student_model, testloader, device)
        
        if (epoch + 1) % 10 == 0 or epoch == 0 or epoch == num_epochs - 1:
            print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            print(f"New best accuracy: {best_acc:.2f}%. Saving student model to {student_save_path}")
            torch.save(student_model.state_dict(), student_save_path) # Save the best model
            
        scheduler.step()
        
    print(f"Finished Training for LM. Best Accuracy: {best_acc:.2f}%")
    print(f"Best LM student model saved to {student_save_path}")
    return best_acc

lm_accuracy = run_lm_experiment(teacher_model, device, num_epochs=200)


print("\n--- Final Result for Logit Matching ---")
print(f"Logit Matching (LM) Best Accuracy: {lm_accuracy:.2f}%")

Using device: cuda
--- Loading existing teacher from ./best_teacher_vgg16_cifar100.pth ---
Using teacher with accuracy: 69.01%

--- Running Experiment: LOGIT MATCHING (LM) ---


                                                                                    

Epoch 1/200, Accuracy: 3.03%
New best accuracy: 3.03%. Saving student model to ./best_student_lm.pth


                                                                                    

New best accuracy: 7.19%. Saving student model to ./best_student_lm.pth


                                                                                    

New best accuracy: 8.75%. Saving student model to ./best_student_lm.pth


                                                                                    

New best accuracy: 10.10%. Saving student model to ./best_student_lm.pth


                                                                                    

New best accuracy: 11.41%. Saving student model to ./best_student_lm.pth


                                                                                    

New best accuracy: 12.26%. Saving student model to ./best_student_lm.pth


                                                                                    

New best accuracy: 13.98%. Saving student model to ./best_student_lm.pth


                                                                                    

New best accuracy: 15.44%. Saving student model to ./best_student_lm.pth


                                                                                    

New best accuracy: 16.49%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 10/200, Accuracy: 16.00%


                                                                                     

New best accuracy: 17.09%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 18.32%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 18.99%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 20/200, Accuracy: 17.81%


                                                                                     

New best accuracy: 19.65%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 30/200, Accuracy: 19.21%


                                                                                     

New best accuracy: 20.03%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 22.21%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 23.23%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 40/200, Accuracy: 24.81%
New best accuracy: 24.81%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 25.14%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 26.34%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 28.84%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 50/200, Accuracy: 27.79%


                                                                                     

New best accuracy: 29.35%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 29.46%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 31.69%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 32.00%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 32.71%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 60/200, Accuracy: 35.63%
New best accuracy: 35.63%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 36.10%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 38.82%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 41.27%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 70/200, Accuracy: 42.47%
New best accuracy: 42.47%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 42.96%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 45.49%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 46.78%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 48.48%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 80/200, Accuracy: 48.28%


                                                                                     

New best accuracy: 48.94%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 49.07%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 49.40%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 50.60%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 51.27%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 51.69%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 52.53%. Saving student model to ./best_student_lm.pth


                                                                                     

Epoch 90/200, Accuracy: 52.55%
New best accuracy: 52.55%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 54.21%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 54.39%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 55.63%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 56.54%. Saving student model to ./best_student_lm.pth


                                                                                     

New best accuracy: 56.84%. Saving student model to ./best_student_lm.pth


                                                                                      

Epoch 100/200, Accuracy: 58.92%
New best accuracy: 58.92%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 59.01%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 59.76%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 60.29%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 61.39%. Saving student model to ./best_student_lm.pth


                                                                                      

Epoch 110/200, Accuracy: 60.96%


                                                                                      

New best accuracy: 61.83%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 62.90%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 63.17%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 63.41%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 64.18%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 65.37%. Saving student model to ./best_student_lm.pth


                                                                                      

Epoch 120/200, Accuracy: 64.42%


                                                                                      

New best accuracy: 65.80%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 66.29%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 66.53%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 66.72%. Saving student model to ./best_student_lm.pth


                                                                                      

Epoch 130/200, Accuracy: 66.76%
New best accuracy: 66.76%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 67.44%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 68.38%. Saving student model to ./best_student_lm.pth


                                                                                      

Epoch 140/200, Accuracy: 68.57%
New best accuracy: 68.57%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 68.77%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 69.25%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 69.34%. Saving student model to ./best_student_lm.pth


                                                                                      

Epoch 150/200, Accuracy: 69.76%
New best accuracy: 69.76%. Saving student model to ./best_student_lm.pth


                                                                                      

New best accuracy: 70.26%. Saving student model to ./best_student_lm.pth


                                                                                          

Epoch 160/200, Accuracy: 70.14%


                                                                                      

New best accuracy: 70.85%. Saving student model to ./best_student_lm.pth


                                                                                      

KeyboardInterrupt: 