# Import Libraries

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use("ggplot")

import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data.sampler import WeightedRandomSampler

from torchvision.models import resnet18

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

Running on device: CUDA


# Accuracy Metric

In [2]:
def accuracy(net, loader):
    """Return accuracy on a dataset given by the data loader."""
    correct = 0
    total = 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return correct / total

# Training

In [3]:
def training(net, train_set): 
    epochs = 32
    data_loader = DataLoader(
        train_set, batch_size=128, shuffle=True)
    
    optimizer = optim.SGD(net.parameters(), lr=0.05, momentum=0.9, weig`ht_decay=4e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    net.train()
    for epoch in range(epochs):
        for i, (inputs, targets) in enumerate(data_loader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            targets = torch.nn.functional.one_hot(targets, num_classes=10).float()
            
            optimizer.zero_grad()
            output = net(inputs)
            loss = -(torch.log_softmax(output,dim=-1).view(-1,1,10) @ targets.view(-1,10,1)).mean()

            loss.backward()
            optimizer.step()
            
        scheduler.step()
    net.eval()
    return net

In [4]:
from load_cifar_script import get_cifar10_data

data_loaders = get_cifar10_data()

net = resnet18(pretrained=False, num_classes=10)
net.to(DEVICE)
net = training(net, ConcatDataset([(data_loaders["retain_train"]).dataset, (data_loaders["compact_train"]).dataset]))
torch.save(net.state_dict(), './tmp/checkpoint.pth')

retain_accuracy = accuracy(net, data_loaders["retain_train"])
test_accuracy = accuracy(net, data_loaders["retain_test"])

Files already downloaded and verified
Files already downloaded and verified
