<h1><span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Train" data-toc-modified-id="Train-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Train</a></span></li><li><span><a href="#Evaluate-drawdown" data-toc-modified-id="Evaluate-drawdown-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Evaluate drawdown</a></span></li></ul></div>

__This notebook__ fine-tunes a pre-trained resnet18 model with editable training.

__Prepare data:__
* Download imagenet training and dataset
* Make sure folder names are called "000", "001", ... "010", "011", ... and not "0", "1", ..., "10", "11", ...
    * rename if necessary
* Run `imagenet_preprocess_logits.ipynb` to prepare fine-tuning metadata.

__Training:__
* Set environment variables and paths in the next cell
* Run all cells :)

In [6]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE

traindir = '../../imagenet/train'  # path to train ImageFolder
valdir = '../../imagenet/val'      # path to validation ImageFolder
logits_path = './imagenet_logits/' # see imagenet_preprocess_logits

import os, sys, time
sys.path.insert(0, '..')

import lib
import src

import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

import random
random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'

experiment_name = 'code_test'
experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])
print(experiment_name)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: CUDA_VISIBLE_DEVICES=YOURDEVICEHERE
code_test_2021.03.02_18:28:06


In [2]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_dataset = lib.ImageAndLogitsFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]),
    logits_prefix = logits_path
)

batch_size = 128

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    num_workers=12, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=32, pin_memory=True)

X_test, y_test = map(torch.cat, zip(*val_loader))
X_test, y_test = X_test[::10], y_test[::10]
# !!!IMPORTANT!!!
# We use 10% of validation samples for faster validation, please use full validation set to measure "final" error rate

In [3]:
import torchvision

model = torchvision.models.resnet18(pretrained=True)

optimizer = lib.IngraphRMSProp(learning_rate=1e-4, beta=nn.Parameter(torch.as_tensor(0.5)))

model = lib.SequentialWithEditable(
    model.conv1, model.bn1, model.relu, model.maxpool,
    model.layer1, model.layer2, model.layer3, model.layer4,
    model.avgpool, lib.Flatten(),
    lib.Editable(
        lib.Residual(nn.Linear(512, 4096), nn.ELU(), nn.Linear(4096, 512)),
        loss_function=lib.contrastive_cross_entropy, 
        optimizer=optimizer, max_steps=10),

    model.fc
).to(device)

In [4]:
def classification_error(model, X_test, y_test):
    with lib.training_mode(model, is_train=False):
        return lib.classification_error(lib.Lambda(lambda x: model(x.to(device))),
                                        X_test, y_test, device='cpu', batch_size=128)

In [None]:
new_params = set(model.editable.module[0].parameters())
old_params = [param for param in model.parameters() if param not in new_params]

training_opt = lib.OptimizerList(
    torch.optim.SGD(old_params, lr=1e-5, momentum=0.9, weight_decay=1e-4),
    torch.optim.SGD(new_params, lr=1e-3, momentum=0.9, weight_decay=1e-4),
)

trainer = lib.DistillationEditableTrainer(model,
          stability_coeff=0.03, editability_coeff=0.03,
          experiment_name=experiment_name,
          error_function=classification_error,
          opt=training_opt, max_norm=10)

trainer.writer.add_text("trainer", repr(trainer).replace('\n', '<br>'))

In [None]:
from tqdm import tqdm_notebook, tnrange
from IPython.display import clear_output

# Learnign params
eval_batch_cd = 500
val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))
min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']
early_stopping_epochs = 500
number_of_epochs_without_improvement = 0
            
def edit_generator():
    while True:
        for xb, yb, lg in torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2):
            yield xb.to(device), torch.randint_like(yb, low=0, high=max(y_test) + 1, device=device)

edit_generator = edit_generator()

### Train

In [None]:
while True:
    
    for x_batch, y_batch, logits in tqdm_notebook(train_loader):
        trainer.step(x_batch.to(device), logits.to(device), *next(edit_generator))
        
        if trainer.total_steps % eval_batch_cd == 0:
            val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))
            clear_output(True)

            error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']

            number_of_epochs_without_improvement += 1

            if error_rate < min_error:
                trainer.save_checkpoint(tag='best_val_error')
                min_error = error_rate
                number_of_epochs_without_improvement = 0

            if drawdown < min_drawdown:
                trainer.save_checkpoint(tag='best_drawdown')
                min_drawdown = drawdown
                number_of_epochs_without_improvement = 0

            trainer.save_checkpoint()
            trainer.remove_old_temp_checkpoints()

            if number_of_epochs_without_improvement > early_stopping_epochs:
                break

### Evaluate drawdown

__Note:__ this code evaluates quality on 10% of the validation set. In paper we use this subset when evaluating drawdown but we measure the base error on all 50k validation samples.

In [None]:
# edit quality

from lib import evaluate_quality

np.random.seed(9)
indices = np.random.permutation(len(X_test))[:1000]
X_edit = X_test[indices].clone().to(device)
y_edit = torch.tensor(np.random.randint(0, max(y_test) + 1, size=y_test[indices].shape), device=device)
metrics = evaluate_quality(model, X_test, y_test, X_edit, y_edit, 
                           error_function=classification_error, progressbar=tqdm_notebook)

for key in sorted(metrics.keys()):
    print('{}\t:{:.5}'.format(key, metrics[key]))
