In [2]:
import torch
import torchvision 
import torchvision.transforms as transforms 
import torchvision.models as models 
import torch.optim as optim 
import torch.nn as nn 
import torch.nn.functional as F

In [6]:
def train_and_evaluate_scratch(trainloader, testloader, model, optimizer, scheduler, criterion, num_epochs, model_path, device): 
    lowest_test_loss = 1000.0 
    for epoch in range(num_epochs): 
        running_loss, train_corrects, train_total = 0.0, 0, 0
        model.train() 
        for inputs, labels in trainloader: 
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad() 
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward() 
            optimizer.step()

            running_loss += loss.item()
            predicted_class = outputs.data.max(1, keepdim = True)[1]
            train_corrects += predicted_class.eq(labels.data.view_as(predicted_class)).cpu().sum()
            train_total += labels.size(0)
    
        # Evaluation 
        test_corrects, test_total, test_running_loss = evaluate(model, testloader, device)

        scheduler.step()
        if test_running_loss/test_total < lowest_test_loss: 
            torch.save(model.state_dict(), model_path)
            lowest_test_loss = test_running_loss/test_total
 
        print(f'[{epoch + 1}], train_loss: {running_loss/train_total:.4f}, test_loss: {test_running_loss/test_total:.4f}, train_accuracy: {train_corrects*100/train_total:.2f} %, test_accuracy: {test_corrects*100/test_total:.2f} %')

def evaluate(model, testloader, device):   
    criterion = nn.CrossEntropyLoss() 
    test_running_loss, test_corrects, test_total = 0.0, 0, 0 
    model.eval() 
    with torch.no_grad(): 
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device) 

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_running_loss += loss.item() 
            predicted_class = outputs.data.max(1, keepdim = True)[1]
            test_corrects += predicted_class.eq(labels.data.view_as(predicted_class)).cpu().sum()
            test_total += labels.size(0)

    return test_corrects, test_total, test_running_loss

def get_generalist_model(pretrained_generalist_path, trainloader, testloader): 

    generalist_model_name = 'cifar100_' + 'resnet20'
    generalist_model = torch.hub.load("chenyaofo/pytorch-cifar-models", generalist_model_name, pretrained = True)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    generalist_model.to(device)
    
    if pretrained_generalist_path is None: 
        print("No pretrained path available. Training a new model as the generalist")
        optimizer = optim.SGD(generalist_model.parameters(), lr = 0.001, momentum = 0.9, nesterov = True, weight_decay = 5e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 20, gamma = 0.1)
        criterion = nn.CrossEntropyLoss() 
        model_path = "teacher_model_cifar100.pth"
        pretrained_generalist_path = model_path 
        num_epochs = 10 
        train_and_evaluate_scratch(trainloader, testloader, generalist_model, optimizer, scheduler, criterion, num_epochs, model_path, device)
    
    generalist_state_dict = torch.load(pretrained_generalist_path, map_location = device)
    generalist_model.load_state_dict(generalist_state_dict)
    
    return generalist_model 

In [7]:
transformation = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
    
trainset = torchvision.datasets.CIFAR100(root = './data', train = True, download = True, transform = transformation)
testset = torchvision.datasets.CIFAR100(root = './data', train = False, download = True, transform = transformation)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 128, shuffle = True)
testloader = torch.utils.data.DataLoader(testset, batch_size = 1, shuffle = False)

generalist_model = get_generalist_model(None, trainloader, testloader)

Files already downloaded and verified
Files already downloaded and verified


Using cache found in C:\Users\yeewenli/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master


No pretrained path available. Training a new model as the generalist
[1], train_loss: 0.0026, test_loss: 1.1691, train_accuracy: 90.96 %, test_accuracy: 68.60 %
[2], train_loss: 0.0021, test_loss: 1.1714, train_accuracy: 93.13 %, test_accuracy: 68.72 %
[3], train_loss: 0.0020, test_loss: 1.1802, train_accuracy: 94.07 %, test_accuracy: 68.68 %
[4], train_loss: 0.0019, test_loss: 1.1843, train_accuracy: 94.71 %, test_accuracy: 68.62 %
[5], train_loss: 0.0018, test_loss: 1.1840, train_accuracy: 95.27 %, test_accuracy: 68.45 %
[6], train_loss: 0.0017, test_loss: 1.1960, train_accuracy: 95.73 %, test_accuracy: 68.37 %
[7], train_loss: 0.0016, test_loss: 1.1967, train_accuracy: 95.94 %, test_accuracy: 68.48 %
[8], train_loss: 0.0015, test_loss: 1.1977, train_accuracy: 96.48 %, test_accuracy: 68.08 %
[9], train_loss: 0.0015, test_loss: 1.2068, train_accuracy: 96.69 %, test_accuracy: 68.17 %
[10], train_loss: 0.0014, test_loss: 1.2207, train_accuracy: 97.02 %, test_accuracy: 67.92 %
