In [1]:
import torch
import torchvision 
import torchvision.transforms as transforms 
import torchvision.models as models 
import torch.optim as optim 
import torch.nn as nn 
import torch.nn.functional as F

In [12]:
def train_and_evaluate_scratch(trainloader, testloader, model, optimizer, scheduler, criterion, num_epochs, model_path): 
    lowest_test_loss = 1000.0 
    for epoch in range(num_epochs): 
        running_loss, train_corrects, train_total = 0.0, 0, 0
        model.train() 
        for i, data in enumerate(trainloader, 0): 
            inputs, labels = data 
            if torch.cuda.is_available(): 
                inputs, labels = inputs.cuda(), labels.cuda()
            
            optimizer.zero_grad() 
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward() 
            optimizer.step()

            running_loss += loss.item()
            predicted_class = outputs.data.max(1, keepdim = True)[1]
            train_corrects += predicted_class.eq(labels.data.view_as(predicted_class)).cpu().sum()
            train_total += labels.size(0)
    
        # Evaluation 
        test_corrects, test_total, test_running_loss = evaluate(model, testloader)

        scheduler.step()
        if test_running_loss/test_total < lowest_test_loss: 
            torch.save(model.state_dict, model_path)
            lowest_test_loss = test_running_loss/test_total
 
        print('[{}], train_loss: {}, test_loss: {}, train_accuracy: {} %, test_accuracy: {} %'.format(epoch+1, running_loss/train_total, test_running_loss/test_total, 
        train_corrects*100/train_total, test_corrects*100/test_total))


def evaluate(model, testloader):   
    criterion = nn.CrossEntropyLoss() 
    test_running_loss, test_corrects, test_total = 0.0, 0, 0 
    model.eval() 
    with torch.no_grad(): 
        for data in testloader:
            inputs, labels = data
            if torch.cuda.is_available(): 
                inputs, labels = inputs.cuda(), labels.cuda()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_running_loss += loss.item() 
            predicted_class = outputs.data.max(1, keepdim = True)[1]
            test_corrects += predicted_class.eq(labels.data.view_as(predicted_class)).cpu().sum()
            test_total += labels.size(0)

    return test_corrects, test_total, test_running_loss


def get_generalist_model(pretrained_generalist_path, trainloader, testloader): 

    generalist_model_name = 'cifar100_' + 'resnet20'
    generalist_model = torch.hub.load("chenyaofo/pytorch-cifar-models", generalist_model_name, pretrained = True)

    if pretrained_generalist_path is not None: 
        print("Using pre-trained generalist model")
        generalist_state_dict = torch.load(pretrained_generalist_path, map_location = torch.device('cpu'))
        generalist_model.load_state_dict(generalist_state_dict)
        if torch.cuda.is_available():
            generalist_model = generalist_model.cuda()
    else: 
        print("No pretrained path available. Training a new model as the generalist")
        if torch.cuda.is_available(): 
            generalist_model = generalist_model.cuda() 
        optimizer = optim.SGD(generalist_model.parameters(), lr = 0.001, momentum = 0.9, nesterov = True, weight_decay = 5e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 20, gamma = 0.1)
        criterion = nn.CrossEntropyLoss() 
        model_path = 'teacher_model_cifar100.pth'
        train_and_evaluate_scratch(trainloader, testloader, generalist_model, optimizer, scheduler, criterion, 10, model_path)
    
    return generalist_model 

In [13]:
transformation = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
    
trainset = torchvision.datasets.CIFAR100(root = './data', train = True, download = True, transform = transformation)
testset = torchvision.datasets.CIFAR100(root = './data', train = False, download = True, transform = transformation)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 128, shuffle = True)
testloader = torch.utils.data.DataLoader(testset, batch_size = 1, shuffle = False)

generalist_model = get_generalist_model(None, trainloader, testloader)

Files already downloaded and verified
Files already downloaded and verified


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


No pretrained path available. Training a new model as the generalist
[1], train_loss: 0.0025461179679632188, test_loss: 1.1700654629195906, train_accuracy: 91.08200073242188 %, test_accuracy: 69.02999877929688 %
[2], train_loss: 0.0021299282842874526, test_loss: 1.174413313458801, train_accuracy: 93.16400146484375 %, test_accuracy: 68.51000213623047 %
[3], train_loss: 0.0019759627851843836, test_loss: 1.1756201090892617, train_accuracy: 93.97000122070312 %, test_accuracy: 68.56999969482422 %
[4], train_loss: 0.001867857117652893, test_loss: 1.1808805552328463, train_accuracy: 94.6259994506836 %, test_accuracy: 68.51000213623047 %
[5], train_loss: 0.0017691494822502135, test_loss: 1.1918881003540096, train_accuracy: 95.1520004272461 %, test_accuracy: 68.45999908447266 %
[6], train_loss: 0.0016891899898648262, test_loss: 1.1948901089004291, train_accuracy: 95.6780014038086 %, test_accuracy: 68.37000274658203 %
[7], train_loss: 0.0016106324490904808, test_loss: 1.200008534161732, train_ac