In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0
import os, sys, time
sys.path.insert(0, '..')
import lib

import math
import numpy as np
from copy import deepcopy
import torch, torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('seaborn-darkgrid')
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

# For reproducibility
import random
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(seed)

# Setting

In [None]:
model_type = 'fixup_resnet'

# Dataset 
data_dir = './data'
train_batch_size = 128
valid_batch_size = 128
test_batch_size = 64
num_workers = 3
pin_memory = True

num_classes = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

loss_function = F.cross_entropy

# DIMAML initializers params
num_quantiles = 100

# MAML
max_steps = 3000
inner_loop_steps = 200

loss_kwargs={'reduction':'mean'}

first_val_step = 40
loss_interval = 40

assert (inner_loop_steps - first_val_step) % loss_interval == 0
validation_steps = int((inner_loop_steps - first_val_step) / loss_interval + 1)

# Optimizer
learning_rate=0.1
inner_optimizer_type='momentum'
inner_optimizer_kwargs = dict(
    lr=learning_rate, momentum=0.9, 
    nesterov=True, weight_decay=0.0005
)

# Meta optimizer
meta_betas = (0.9, 0.997)
meta_learning_rate = 0.001
meta_grad_clip = 10.

checkpoint_steps = 3
recovery_step = None

kwargs = dict(
    first_valid_step=first_val_step,
    valid_loss_interval=loss_interval, 
    loss_kwargs=loss_kwargs, 
)

In [None]:
exp_name = f"PLIF_FixupResNet18_CIFAR100_{inner_optimizer_type}"
exp_name += f"_steps{inner_loop_steps}_interval{loss_interval}"
exp_name += f"_tr_bs{train_batch_size}_val_bs{valid_batch_size}_seed_{seed}"
print("Experiment name: ", exp_name)

logs_path = "./logs/{}".format(exp_name)
assert recovery_step is not None or not os.path.exists(logs_path)
# !rm -rf {logs_path}

## Prepare CIFAR100

In [None]:
from torchvision import transforms, datasets
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import TensorDataset, DataLoader

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])

train_dataset = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
valid_dataset = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=eval_transform)
test_set = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=eval_transform)

num_train = len(train_dataset)
indices = list(range(num_train))
split = 40000
    
np.random.shuffle(indices)
train_idx, valid_idx = indices[:split], indices[split:]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, sampler=train_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=valid_batch_size, sampler=valid_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
)

test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=test_batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=pin_memory
)

## Create the model and meta-optimizer

In [None]:
optimizer = lib.make_inner_optimizer(inner_optimizer_type, **inner_optimizer_kwargs)
model = lib.models.fixup_resnet.FixupResNet18(num_classes=num_classes)
maml = lib.PLIF_MAML(model, model_type, optimizer=optimizer, 
    checkpoint_steps=checkpoint_steps,
    loss_function=loss_function,
    num_quantiles=num_quantiles
).to(device)

## Trainer

In [None]:
def samples_batches(dataloader, num_batches):
    x_batches, y_batches = [], []
    for batch_i, (x_batch, y_batch) in enumerate(dataloader):
        if batch_i >= num_batches: break
        x_batches.append(x_batch)
        y_batches.append(y_batch) 
    return x_batches, y_batches


class TrainerResNet(lib.Trainer):
    def train_on_batch(self, train_loader, valid_loader, prefix='train/', **kwargs):
        """ Performs a single gradient update and reports metrics """
        # Sample train and val batches
        x_batches, y_batches = samples_batches(train_loader, inner_loop_steps)
        x_val_batches, y_val_batches = samples_batches(valid_loader, validation_steps)

        # Perform a meta training step
        self.meta_optimizer.zero_grad()
        with lib.training_mode(self.maml, is_train=True):
            self.maml.resample_parameters()
            updated_model, train_loss_history, valid_loss_history, *etc = \
                self.maml.forward(x_batches, y_batches, x_val_batches, y_val_batches, **kwargs)  
            train_loss = torch.cat(train_loss_history).mean()
            valid_loss = torch.cat(valid_loss_history).mean() if len(valid_loss_history) > 0 else torch.zeros(1)
            valid_loss.backward()
        
        # Check gradients        
        grad_norm = lib.utils.total_norm_frobenius(self.maml.initializers.parameters())
        self.writer.add_scalar(prefix + "grad_norm", grad_norm, self.total_steps)
        bad_grad = not math.isfinite(grad_norm)

        if not bad_grad and self.meta_grad_clip:
            nn.utils.clip_grad_norm_(list(self.maml.initializers.parameters()), self.meta_grad_clip)
        else:
            print("Fix bad grad. Loss {} | Grad {}".format(train_loss.item(), grad_norm))
            for param in self.maml.initializers.parameters():
                param.grad = torch.where(torch.isfinite(param.grad), 
                                         param.grad, torch.zeros_like(param.grad))
        self.meta_optimizer.step()
        
        return self.record(train_loss=train_loss.item(),
                           valid_loss=valid_loss.item(), prefix=prefix)
        
    def evaluate_metrics(self, train_loader, test_loader, prefix='val/', **kwargs):
        """ Predicts and evaluates metrics over the entire dataset """
        torch.cuda.empty_cache()
        
        print('Baseline')
        self.maml.resample_parameters(initializers=self.maml.untrained_initializers, is_final=True)
        base_model = deepcopy(self.maml.model)    
        base_train_loss_history, base_test_loss_history, base_test_error_history = \
            eval_model(base_model, train_loader, test_loader, epochs=1, device=self.device)
            
        print('Ours')
        self.maml.resample_parameters(is_final=True)
        maml_model = deepcopy(self.maml.model)
        maml_train_loss_history, maml_test_loss_history, maml_test_error_history = \
            eval_model(maml_model, train_loader, test_loader, epochs=1, device=self.device)
        
        lib.utils.resnet_draw_plots(base_train_loss_history, base_test_loss_history, 
                                    base_test_error_history, maml_train_loss_history, 
                                    maml_test_loss_history, maml_test_error_history)
        
        self.writer.add_scalar(prefix + "train_AUC", sum(maml_train_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_AUC", sum(maml_test_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_loss", maml_test_loss_history[-1], self.total_steps)
        self.writer.add_scalar(prefix + "test_cls_error", maml_test_error_history[-1], self.total_steps) 

In [None]:
########################
# Generate Train Batch #
########################
            
def generate_train_batches(train_loader, batches_in_epoch=150):
    x_batches, y_batches = [], []
    for batch_i, (x_batch, y_batch) in enumerate(train_loader):
        if batch_i >= batches_in_epoch: break
        x_batches.append(x_batch)
        y_batches.append(y_batch)

    assert len(x_batches) == len(y_batches) == batches_in_epoch

    local_x = torch.cat(x_batches, dim=0)
    local_y = torch.cat(y_batches, dim=0)
    local_dataset = TensorDataset(local_x, local_y)
    local_dataloader = DataLoader(local_dataset, batch_size=train_batch_size, 
                                  shuffle=True, num_workers=num_workers)
    return local_dataloader
        

##################
# Eval functions #
##################

def adjust_learning_rate(optimizer, epoch, milestones=[30, 50]):
    """decrease the learning rate at 30 and 50 epoch"""
    lr = learning_rate
    if epoch >= milestones[0]: 
        lr /= 10
    if epoch >= milestones[1]: 
        lr /= 10
    for param_group in optimizer.param_groups:
        if param_group['initial_lr'] == learning_rate:
            param_group['lr'] = lr
        else:
            if epoch < milestones[0]:
                param_group['lr'] = param_group['initial_lr']
            elif epoch < milestones[1]:
                param_group['lr'] = param_group['initial_lr'] / 10.
            else:
                param_group['lr'] = param_group['initial_lr'] / 100.
    return lr


@torch.no_grad()
def compute_test_loss(model, loss_function, test_loader, device='cuda'):
    model.eval()   
    test_loss, cls_error = 0., 0.
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        preds = model(x_test)
        test_loss += loss_function(preds, y_test) * x_test.shape[0]
        cls_error += 1. * (y_test != preds.argmax(axis=-1)).sum()
    test_loss /= len(test_loader.dataset)
    cls_error /= len(test_loader.dataset)
    model.train()
    return test_loss.item(), cls_error.item()


def eval_model(model, train_loader, test_loader, epochs=3, test_loss_interval=40, device='cuda'):
    optimizer = lib.optimizers.make_eval_inner_optimizer(
        maml, model, inner_optimizer_type, 
        **inner_optimizer_kwargs
    )
    for param_group in optimizer.param_groups:
        param_group['initial_lr'] = learning_rate
        
    # Train loop
    train_loss_history = []
    test_loss_history = []
    test_error_history = []

    training_mode = model.training
    
    total_iters = 0
    for epoch in range(epochs):
        model.train()
        lr = adjust_learning_rate(optimizer, epoch)
        for i, (x_batch, y_batch) in enumerate(train_loader):
            optimizer.zero_grad()
            preds = model(x_batch.to(device))
            loss = loss_function(preds, y_batch.to(device))
            loss.backward()
            
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 4)
            optimizer.step()
            
            if (total_iters == 0) or (total_iters + 1) % test_loss_interval == 0:
                train_loss_history.append(loss.item())
                model.eval()
                test_loss, test_error = compute_test_loss(model, loss_function, test_loader, device=device)
                print("Epoch {} | Train Loss {:.4f} | Test Loss {:.4f} | Classification Error {:.4f}"\
                      .format(epoch, loss.item(), test_loss, test_error))
                test_loss_history.append(test_loss)
                test_error_history.append(test_error)
                model.train()
            
            total_iters += 1
    
    model.train(training_mode)
    return train_loss_history, test_loss_history, test_error_history                                       

In [None]:
train_loss_history = []
valid_loss_history = []

trainer = TrainerResNet(maml, meta_lr=meta_learning_rate, 
                        meta_betas=meta_betas, meta_grad_clip=meta_grad_clip,
                        exp_name=exp_name, recovery_step=recovery_step)

## Training

In [None]:
from IPython.display import clear_output

t0 = time.time()

while trainer.total_steps <= max_steps:
    lib.free_memory()
    metrics = trainer.train_on_batch(
        train_loader, valid_loader, **kwargs
    )
    train_loss = metrics['train_loss']
    train_loss_history.append(train_loss)## Training
    
    valid_loss = metrics['valid_loss']
    valid_loss_history.append(valid_loss)
    
    if trainer.total_steps % 10 == 0:
        clear_output(True)
        print("Step: %d | Time: %f | Train Loss %.5f | Valid loss %.5f" 
              % (trainer.total_steps, time.time()-t0, train_loss, valid_loss))
        plt.figure(figsize=[16, 5])
        plt.subplot(1,2,1)
        plt.title('Train Loss over time')
        plt.plot(lib.utils.moving_average(train_loss_history, span=50))
        plt.scatter(range(len(train_loss_history)), train_loss_history, alpha=0.1)
        plt.subplot(1,2,2)
        plt.title('Valid Loss over time')
        plt.plot(lib.utils.moving_average(valid_loss_history, span=50))
        plt.scatter(range(len(valid_loss_history)), valid_loss_history, alpha=0.1)
        plt.show()
        local_train_loader = generate_train_batches(train_loader, inner_loop_steps)
        trainer.evaluate_metrics(local_train_loader, test_loader, test_interval=20)
        lib.utils.resnet_visualize_quantile_functions(maml)
        t0 = time.time()
        
    if trainer.total_steps % 100 == 0:
        trainer.save_model()
        
    trainer.total_steps += 1

In [None]:
lib.utils.resnet_visualize_quantile_functions(maml)

# Evaluation

In [None]:
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
print(seed)

In [None]:
def gradient_quotient(loss, params, eps=1e-5): 
    grad = torch.autograd.grad(loss, params, retain_graph=True, create_graph=True)
    prod = torch.autograd.grad(sum([(g**2).sum() / 2 for g in grad]),
                               params, retain_graph=True, create_graph=True)
    out = sum([((g - p) / (g + eps * (2*(g >= 0).float() - 1).detach()) - 1).abs().sum() 
               for g, p in zip(grad, prod)])
    return out / sum([p.data.nelement() for p in params])


def metainit(model, criterion, x_size, y_size, lr=0.1, momentum=0.9, steps=200, eps=1e-5):
    model.eval()
    params = [p for p in model.parameters() 
              if p.requires_grad and len(p.size()) >= 2 and 
              math.isfinite(p.std().item()) and p.std().item() > 0]
    memory = [0] * len(params)
    for i in range(steps):
        input = torch.Tensor(*x_size).normal_(0, 1).cuda()
        target = torch.randint(0, y_size, (x_size[0],)).cuda()
        loss = criterion(model(input), target)
        gq = gradient_quotient(loss, list(model.parameters()), eps)
        
        grad = torch.autograd.grad(gq, params)
        for j, (p, g_all) in enumerate(zip(params, grad)):
            norm = p.data.norm().item()
            g = torch.sign((p.data * g_all).sum() / norm) 
            memory[j] = momentum * memory[j] - lr * g.item() 
            new_norm = norm + memory[j]
            p.data.mul_(new_norm / (norm + eps))
        print("%d/GQ = %.2f" % (i, gq.item()))

In [None]:
def genOrthgonal(dim):
    a = torch.zeros((dim, dim)).normal_(0, 1)
    q, r = torch.qr(a)
    d = torch.diag(r, 0).sign()
    diag_size = d.size(0)
    d_exp = d.view(1, diag_size).expand(diag_size, diag_size)
    q.mul_(d_exp)
    return q

def makeDeltaOrthogonal(weights, gain):
    rows = weights.size(0)
    cols = weights.size(1)
    if rows < cols:
        print("In_filters should not be greater than out_filters.")
    weights.data.fill_(0)
    dim = max(rows, cols)
    q = genOrthgonal(dim)
    mid1 = weights.size(2) // 2
    mid2 = weights.size(3) // 2
    with torch.no_grad():
        weights[:, :, mid1, mid2] = q[:weights.size(0), :weights.size(1)]
        weights.mul_(gain)

## Eval TinyImageNet

In [None]:
data_dir = 'data/tiny-imagenet-200/'
num_workers = {'train': 0, 'val': 0,'test': 0}
data_transforms = {
    'train': transforms.Compose([
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ])
}
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) 
                  for x in ['train', 'val','test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=128, shuffle=True, num_workers=num_workers[x])
                  for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}

In [None]:
ti_batches_in_epoch = len(dataloaders['train'])
assert ti_batches_in_epoch == 782
num_reruns = 10

reruns_base_test_loss_history = []
reruns_base_test_error_history = []

reruns_metainit_test_loss_history = []
reruns_metainit_test_error_history = []
    
reruns_maml_test_loss_history = []
reruns_maml_test_error_history = []

reruns_deltaorthogonal_test_loss_history = []
reruns_deltaorthogonal_test_error_history = []

for i in range(num_reruns):
    print(f"Rerun {i}")
    
    print("Baseline")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    base_model = deepcopy(maml.model)

    base_model.fc = nn.Linear(in_features=512, out_features=200, bias=True).to(device)
    nn.init.constant_(base_model.fc.weight, 0)
    nn.init.constant_(base_model.fc.bias, 0)

    base_train_loss_history, base_test_loss_history, base_test_error_history = \
        eval_model(base_model, dataloaders['train'], dataloaders['test'], 
                   epochs=70, test_loss_interval=ti_batches_in_epoch, device=device)

    reruns_base_test_loss_history.append(base_test_loss_history)
    reruns_base_test_error_history.append(base_test_error_history)
    
    print("DIMAML")
    maml.resample_parameters(is_final=True)
    maml_model = deepcopy(maml.model)

    maml_model.fc = nn.Linear(in_features=512, out_features=200, bias=True).to(device)
    nn.init.constant_(maml_model.fc.weight, 0)
    nn.init.constant_(maml_model.fc.bias, 0)

    maml_train_loss_history, maml_test_loss_history, maml_test_error_history = \
         eval_model(maml_model, dataloaders['train'], dataloaders['test'],  
                    epochs=70, test_loss_interval=ti_batches_in_epoch, device=device)
    
    reruns_maml_test_loss_history.append(maml_test_loss_history)
    reruns_maml_test_error_history.append(maml_test_error_history)
    
    print("MetaInit")
    batch_x, _ = next(iter(dataloaders['train']))
    batch_x = batch_x[:64]
    metainit_model = lib.models.FixupResNet18(num_classes=200, default_init=True).to(device)
    metainit(metainit_model, loss_function, batch_x.shape, 200)
    
    metainit_train_loss_history, metainit_test_loss_history, metainit_test_error_history = \
        eval_model(metainit_model, dataloaders['train'], dataloaders['test'], 
                   epochs=70, test_loss_interval=ti_batches_in_epoch, device=device)
    
    reruns_metainit_test_loss_history.append(metainit_test_loss_history)
    reruns_metainit_test_error_history.append(metainit_test_error_history)
    
    print("DeltaOrthogonal")
    deltaorthogonal_model = lib.models.FixupResNet18(num_classes=200).to(device)
    for param in deltaorthogonal_model.parameters():
        if len(param.size()) >= 4:
            makeDeltaOrthogonal(param, nn.init.calculate_gain('leaky_relu'))
    
    deltaorthogonal_train_loss_history, deltaorthogonal_test_loss_history, deltaorthogonal_test_error_history = \
        eval_model(deltaorthogonal_model, dataloaders['train'], dataloaders['test'], 
                   epochs=70, test_loss_interval=ti_batches_in_epoch, device=device)
    
    reruns_deltaorthogonal_test_loss_history.append(deltaorthogonal_test_loss_history)
    reruns_deltaorthogonal_test_error_history.append(deltaorthogonal_test_error_history)

In [None]:
base_mean = np.array(reruns_base_test_error_history).mean(0)
base_std = np.array(reruns_base_test_error_history).std(0, ddof=1)

maml_mean = np.array(reruns_maml_test_error_history).mean(0)
maml_std = np.array(reruns_maml_test_error_history).std(0, ddof=1)

metainit_mean = np.array(reruns_metainit_test_error_history).mean(0)
metainit_std = np.array(reruns_metainit_test_error_history).std(0, ddof=1)

deltaorthogonal_mean = np.array(reruns_deltaorthogonal_test_error_history).mean(0)
deltaorthogonal_std = np.array(reruns_deltaorthogonal_test_error_history).std(0, ddof=1)