__This notebook__ trains resnet18 from scratch on CIFAR10 dataset.

In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE
import os, sys, time
sys.path.insert(0, '..')
import lib

import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

import random
random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

import time
from resnet import ResNet18
device = 'cuda' if torch.cuda.is_available() else 'cpu'

experiment_name = 'editable_layer3'
experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])
print(experiment_name)
print("PyTorch version:", torch.__version__)

env: CUDA_VISIBLE_DEVICES=1
editable_layer3_2019.09.19_23:06:14
PyTorch version: 1.1.0


In [2]:
from torchvision import transforms, datasets

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
X_test, y_test = map(torch.cat, zip(*list(testloader)))

Files already downloaded and verified
Files already downloaded and verified


In [3]:
model = lib.Editable(
    module=ResNet18(), loss_function=lib.contrastive_cross_entropy,
    get_editable_parameters=lambda module: module.layer3.parameters(),
    optimizer=lib.IngraphRMSProp(
        learning_rate=1e-3, beta=nn.Parameter(torch.tensor(0.5, dtype=torch.float32)), 
    ), max_steps=10,

).to(device)

trainer = lib.EditableTrainer(model, F.cross_entropy, experiment_name=experiment_name, max_norm=10)
trainer.writer.add_text("trainer", repr(trainer).replace('\n', '<br>'))

In [None]:
from tqdm import tqdm_notebook, tnrange
from IPython.display import clear_output

val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))
min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']
early_stopping_epochs = 500
number_of_epochs_without_improvement = 0

def edit_generator():
    while True:
        for xb, yb in torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2):
            yield xb.to(device), torch.randint_like(yb, low=0, high=len(classes), device=device)

edit_generator = edit_generator()


while True:
    for x_batch, y_batch in tqdm_notebook(trainloader):
        trainer.step(x_batch.to(device), y_batch.to(device), *next(edit_generator))
        
    val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))
    clear_output(True)
    
    error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']
    
    number_of_epochs_without_improvement += 1
    
    
    if error_rate < min_error:
        trainer.save_checkpoint(tag='best_val_error')
        min_error = error_rate
        number_of_epochs_without_improvement = 0
        
    if drawdown < min_drawdown:
        trainer.save_checkpoint(tag='best_drawdown')
        min_drawdown = drawdown
        number_of_epochs_without_improvement = 0
    
    trainer.save_checkpoint()
    trainer.remove_old_temp_checkpoints()

    if number_of_epochs_without_improvement > early_stopping_epochs:
        break

In [None]:
from lib import evaluate_quality

np.random.seed(9)
indices = np.random.permutation(len(X_test))[:1000]
X_edit = X_test[indices].clone().to(device)
y_edit = torch.tensor(np.random.randint(0, 10, size=y_test[indices].shape), device=device)
metrics = evaluate_quality(editable_model, X_test, y_test, X_edit, y_edit, batch_size=512)
for key in sorted(metrics.keys()):
    print('{}\t:{:.5}'.format(key, metrics[key]))