In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0
import os, sys, time
sys.path.insert(0, '..')
import lib

import math
import numpy as np
from copy import deepcopy
import torch, torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('seaborn-darkgrid')
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

# For reproducibility
import random
seed = random.randint(0, 2**32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(seed)

# Setting

In [None]:
model_type = 'lstm'

# Language Model
emb_size = 128
sequence_length = 100
hidden_size = 256

# Dataset 
data_dir = './data'
train_batch_size = 128
valid_batch_size = 128
test_batch_size = 128

loss_function = F.nll_loss
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# MAML
max_steps = 3000
inner_loop_steps_in_epoch = 200
inner_loop_epochs = 1
inner_loop_steps = inner_loop_steps_in_epoch * inner_loop_epochs
meta_grad_clip = 10.

loss_interval = 40
first_val_step = 40

assert (inner_loop_steps - first_val_step) % loss_interval == 0
validation_steps = int((inner_loop_steps - first_val_step) / loss_interval + 1)


# Inner optimizer
inner_optimizer_type='adam'
learning_rate = 0.001
inner_optimizer_kwargs = dict(
    lr=learning_rate, 
    betas=(0.9, 0.999), 
    weight_decay=0.0,
    eps=1e-8
)

# Meta optimizer
meta_learning_rate = 0.0003
meta_betas = (0.9, 0.997)

checkpoint_steps = 6
recovery_step = None

kwargs = dict(
    first_valid_step=first_val_step,
    valid_loss_interval=loss_interval,
)

In [None]:
exp_name = f"LSTM_LM_{model_type}_PTb_{inner_optimizer_type}" + \
           f"_steps{inner_loop_steps}_interval{loss_interval}" + \
           f"_tr_bs{train_batch_size}_val_bs{valid_batch_size}_seed_{seed}"

print("Experiment name: ", exp_name)

logs_path = "./logs/{}".format(exp_name)
assert recovery_step is not None or not os.path.exists(logs_path)
# !rm -rf {logs_path}

## Prepare PennTreebank

In [None]:
import torchtext

data_dir = 'data/'

PTb_TEXT = torchtext.data.Field(lower=False, tokenize=list)
PTb_train, PTb_valid, PTb_test = torchtext.datasets.PennTreebank.splits(PTb_TEXT, root=data_dir)
PTb_TEXT.build_vocab(PTb_train)

PTb_voc_size = len(PTb_TEXT.vocab)

PTb_train_loader = torchtext.data.BPTTIterator(PTb_train, train_batch_size, sequence_length, 
                                           train=True, device=device, repeat=True, shuffle=True)
PTb_valid_loader = torchtext.data.BPTTIterator(PTb_valid, valid_batch_size, sequence_length, 
                                           train=True, device=device, repeat=True, shuffle=True)

PTb_test_ids = list(map(PTb_TEXT.vocab.stoi.get, PTb_test.examples[0].text))
PTb_test_ids = list(map(lambda x: x if x is not None else 0, PTb_test_ids))
full_PTb_test_ids = torch.tensor(PTb_test_ids)
part_PTb_test_ids = torch.tensor(PTb_test_ids)[:len(PTb_test_ids) // 10]

### Utils

In [None]:
def samples_batches(dataloader, num_batches):
    x_batches, y_batches = [], []
    for batch_i, batch in enumerate(dataloader):
        if batch_i >= num_batches: break
        x_batches.append(batch.text.t())
        y_batches.append(batch.target.t()) 
    return x_batches, y_batches


def compute_loss(logp_next, batch_targets, **kwargs):
    xent = F.nll_loss(logp_next.reshape(-1, logp_next.shape[-1]), 
                      batch_targets.reshape(-1), reduction='none')
    return xent.mean()


@torch.no_grad()
def compute_test_loss(model, loss_function, test_ids, **kwargs):
    logp_next = model(test_ids[:-1][None].to(device))
    loss = loss_function(logp_next, test_ids[1:][None].to(device))
    return loss

## Create the model and meta-optimizer

In [None]:
optimizer = lib.make_inner_optimizer(inner_optimizer_type, **inner_optimizer_kwargs)
model = lib.models.language_model.LanguageModel(PTb_voc_size, emb_size, hid_size=hidden_size)
maml = lib.MAML(model, model_type, optimizer=optimizer, 
    checkpoint_steps=checkpoint_steps,
    loss_function=compute_loss
).to(device)

## Trainer

In [None]:
class TrainerLM(lib.Trainer):
    def train_on_batch(self, train_loader, valid_loader, prefix='train/', **kwargs):
        """ Performs a single gradient update and reports metrics """
        # Prepare train data
        x_batches, y_batches = samples_batches(train_loader, inner_loop_steps_in_epoch)
        
        # Due to a little amount of validation data, 
        # validation batches are sampled from both remaining train and valid sets
        x_val_batches, y_val_batches = samples_batches(train_loader, validation_steps - 1)
        x_tmp_batches, y_tmp_batches = samples_batches(valid_loader, 1)
        x_val_batches.extend(x_tmp_batches)
        y_val_batches.extend(y_tmp_batches)
        
        # Perform step
        self.meta_optimizer.zero_grad()
        with lib.training_mode(maml, is_train=True):
            maml.resample_parameters()
            updated_model, train_loss_history, valid_loss_history, *etc = \
                maml.forward(x_batches, y_batches, x_val_batches, y_val_batches, **kwargs)  
            train_loss = torch.cat(train_loss_history).mean()
            valid_loss = torch.cat(valid_loss_history).mean()
        valid_loss.backward()
        
        # Check gradients        
        grad_norm = lib.utils.total_norm_frobenius(self.maml.initializers.parameters())
        self.writer.add_scalar(prefix + "grad_norm", grad_norm, self.total_steps)
        bad_grad = not math.isfinite(grad_norm)

        if not bad_grad and self.meta_grad_clip:
            nn.utils.clip_grad_norm_(list(self.maml.initializers.parameters()), self.meta_grad_clip)
        else:
            print("Fix bad grad. Loss {} | Grad {}".format(train_loss.item(), grad_norm))
            for param in self.maml.initializers.parameters():
                param.grad = torch.where(torch.isfinite(param.grad), 
                                         param.grad, torch.zeros_like(param.grad))
        self.meta_optimizer.step()
        return self.record(train_loss=train_loss.item(),
                           valid_loss=valid_loss.item(), prefix=prefix)
        
    def evaluate_metrics(self, train_loader, test_loader, prefix='val/', **kwargs):
        """ Predicts and evaluates metrics over the entire dataset """
        torch.cuda.empty_cache()
        
        print('Baseline')
        self.maml.resample_parameters(initializers=self.maml.untrained_initializers, is_final=True)
        base_model = deepcopy(self.maml.model)    
        base_train_loss_history, base_test_loss_history = eval_model(base_model, train_loader, test_loader,
                                                                     device=self.device, **kwargs)
        print('DIMAML')
        self.maml.resample_parameters(is_final=True)
        maml_model = deepcopy(self.maml.model)
        maml_train_loss_history, maml_test_loss_history = eval_model(maml_model, train_loader, test_loader, 
                                                                     device=self.device,  **kwargs)
        lib.utils.lm_draw_plots(base_train_loss_history, base_test_loss_history, 
                                maml_train_loss_history, maml_test_loss_history)
        
        self.writer.add_scalar(prefix + "train_AUC", sum(maml_train_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_AUC", sum(maml_test_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_loss", maml_test_loss_history[-1], self.total_steps)

In [None]:
##################
# Eval functions #
##################

def adjust_learning_rate(optimizer, epoch, milestone=80, **kwargs):
    """decrease the learning rate at 80 epoch"""
    if milestone <= epoch:
        lr = learning_rate / 10.
    else:
        lr = learning_rate
        
    for param_group in optimizer.param_groups:        
        param_group['lr'] = lr
    return lr


def eval_model(model, train_loader, test_ids, epochs=3, 
               test_loss_interval=20, mode='train', device='cuda', **kwargs):
    optimizer = lib.optimizers.make_eval_inner_optimizer(
        maml, model, inner_optimizer_type, 
        **inner_optimizer_kwargs
    )
    # Train loop
    train_loss_history = []
    test_loss_history = []

    training_mode = model.training
    total_iters = 0
    epoch = 0
    model.train()
            
    for i, batch in enumerate(train_loader):
        adjust_learning_rate(optimizer, epoch, **kwargs)
        epoch = (total_iters + 1) // len(train_loader)
        
        optimizer.zero_grad()
        preds = model(batch.text.t())
        loss = compute_loss(preds, batch.target.t())
        loss.backward()
        optimizer.step()
        train_loss_history.append(loss.item())
        
        if (total_iters == 0) or (total_iters + 1) % test_loss_interval == 0:
            model.eval()
            test_loss = compute_test_loss(model, compute_loss, test_ids, device=device, **kwargs)
            bpc = test_loss * math.log2(math.e)
            print(f"Epoch {epoch} | Iteration {total_iters + 1} | Loss {test_loss:.4} | bpc {bpc:.4}")
            test_loss_history.append(test_loss)
            model.train()
            
        if epoch >= epochs: break
        if mode == 'train' and total_iters >= inner_loop_steps: break
        total_iters += 1
        
    model.train(training_mode)
    return train_loss_history, test_loss_history

In [None]:
train_loss_history = []
valid_loss_history = []

trainer = TrainerLM(maml, meta_lr=meta_learning_rate, 
                    meta_betas=meta_betas, meta_grad_clip=meta_grad_clip,
                    exp_name=exp_name, recovery_step=recovery_step)

## Training

In [None]:
from IPython.display import clear_output

lib.free_memory()
t0 = time.time()

while trainer.total_steps <= max_steps:
    metrics = trainer.train_on_batch(PTb_train_loader, PTb_valid_loader, **kwargs)
    
    train_loss = metrics['train_loss']
    train_loss_history.append(train_loss)
    
    valid_loss = metrics['valid_loss']
    valid_loss_history.append(valid_loss)
    
    if trainer.total_steps % 20 == 0:
        clear_output(True)
        print("Step %d | Time: %f | Train Loss %.5f | Valid loss %.5f" % 
              (trainer.total_steps, time.time()-t0, train_loss, valid_loss))
        plt.figure(figsize=[16, 5])
        plt.subplot(1,2,1)
        plt.title('Train Loss over time')
        plt.plot(lib.utils.moving_average(train_loss_history, span=50))
        plt.scatter(range(len(train_loss_history)), train_loss_history, alpha=0.1)
        plt.subplot(1,2,2)
        plt.title('Valid Loss over time')
        plt.plot(lib.utils.moving_average(valid_loss_history, span=50))
        plt.scatter(range(len(valid_loss_history)), valid_loss_history, alpha=0.1)
        plt.show()
        trainer.evaluate_metrics(PTb_train_loader, part_PTb_test_ids, 
                                 test_loss_interval=loss_interval)
        lib.utils.lm_visualize_pdf(maml)
        t0 = time.time()
    
    if trainer.total_steps % 100 == 0:
        trainer.save_model()
    trainer.total_steps += 1

In [None]:
lib.utils.lm_visualize_pdf(maml)

# Evaluation

In [None]:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

def gradient_quotient(loss, params, eps=1e-5): 
    grad = torch.autograd.grad(loss, params, retain_graph=True, create_graph=True)
    prod = torch.autograd.grad(sum([(g**2).sum() / 2 for g in grad]),
                               params, retain_graph=True, create_graph=True)
    out = sum([((g - p) / (g + eps * (2*(g >= 0).float() - 1).detach()) - 1).abs().sum() 
               for g, p in zip(grad, prod)])
    return out / sum([p.data.nelement() for p in params])

              
def metainit(model, criterion, x_size, y_size, lr=0.1, momentum=0.9, steps=150, eps=1e-5):
    model.eval()
    params = [p for p in model.parameters() 
              if p.requires_grad and len(p.size()) >= 2]
    memory = [0] * len(params)
    for i in range(steps):
        sequences = torch.randint(0, y_size, torch.Size([x_size[0], x_size[1] + 1])).cuda()
        input, target = sequences[:, :-1], sequences[:, 1:]
        loss = criterion(model(input), target)
        gq = gradient_quotient(loss, [p for p in model.parameters() 
              if p.requires_grad], eps)
        
        grad = torch.autograd.grad(gq, params)
        for j, (p, g_all) in enumerate(zip(params, grad)):
            norm = p.data.norm().item()
            g = torch.sign((p.data * g_all).sum() / norm) 
            memory[j] = momentum * memory[j] - lr * g.item() 
            new_norm = norm + memory[j]
            p.data.mul_(new_norm / norm)
        print("%d/GQ = %.2f" % (i, gq.item()))
              
              
def lstm_orthogonal(cell, gain=1):
    cell.reset_parameters()

    # orthogonal initialization of recurrent weights
    _, hh, _, _ = list(cell.parameters())
    for i in range(0, hh.size(0), cell.hidden_size):
         torch.nn.init.orthogonal_(hh[i:i + cell.hidden_size], gain=gain)

## Evalulation on WikiText2

In [None]:
WikiTEXT = torchtext.data.Field(lower=False, tokenize=list)

# load corpora, each dataset only contains one long "example" with all text in that example
wikitext2_train, wikitext2_valid, wikitext2_test = torchtext.datasets.WikiText2.splits(WikiTEXT, root=data_dir, 
                                                         train='wiki.train.raw',
                                                         validation='wiki.valid.raw',
                                                         test='wiki.test.raw')
WikiTEXT.build_vocab(wikitext2_train)
wikitext2_voc_size = len(WikiTEXT.vocab)

wikitext2_train_loader = torchtext.data.BPTTIterator(wikitext2_train, train_batch_size, sequence_length, 
                                           train=True, device=device, repeat=True, shuffle=True)
wikitext2_valid_loader = torchtext.data.BPTTIterator(wikitext2_valid, valid_batch_size, sequence_length, 
                                           train=False, device=device, repeat=False, shuffle=True)

wikitext2_test_ids = list(map(WikiTEXT.vocab.stoi.get, wikitext2_test.examples[0].text))
wikitext2_test_ids = list(map(lambda x: x if x is not None else 0, wikitext2_test_ids))
wikitext2_test_ids = torch.tensor(wikitext2_test_ids)

# For MetaInit
batch = next(iter(wikitext2_train_loader))
text, target = batch.text.t(), batch.target.t()

# Tune voc_size
maml.model.emb_vectors = nn.Embedding(wikitext2_voc_size, emb_size).to(device)
maml.model.logits = nn.Linear(hidden_size, wikitext2_voc_size).to(device)
maml.model.init_weights()

In [None]:
num_reruns = 10
wikitext2_batches_in_epoch = len(wikitext2_train_loader)

reruns_base, reruns_orthogonal = [], []
reruns_metainit, reruns_dimaml = [], []

for rerun_id in range(num_reruns):
    print(f"Rerun #{rerun_id}")
    print('DIMAML')
    maml.resample_parameters(is_final=True)
    maml_model = deepcopy(maml.model)
    maml_train_loss_history, maml_test_loss_history = eval_model(maml_model, wikitext2_train_loader,
                                                                 wikitext2_test_ids, epochs=100, 
                                                                 device=device, mode='eval',
                                                                 test_loss_interval=10*wikitext2_batches_in_epoch)
    
    reruns_dimaml.append(maml_test_loss_history)
    
    print('Baseline')
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    base_model = deepcopy(maml.model)    
    base_train_loss_history, base_test_loss_history = eval_model(base_model, wikitext2_train_loader, 
                                                                 wikitext2_test_ids, epochs=100, 
                                                                 device=device, mode='eval',
                                                                 test_loss_interval=10*wikitext2_batches_in_epoch)
    reruns_base.append(base_test_loss_history)
    
    print("Orthogonal")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    orthogonal_model = deepcopy(maml.model)
    lstm_orthogonal(orthogonal_model.lstm1)
    lstm_orthogonal(orthogonal_model.lstm2)
    orthogonal_train_loss_history, orthogonal_test_loss_history = eval_model(orthogonal_model, wikitext2_train_loader,
                                                                  wikitext2_test_ids, epochs=100, 
                                                                  device=device, mode='eval',
                                                                  test_loss_interval=10*wikitext2_batches_in_epoch)
    reruns_orthogonal.append(orthogonal_test_loss_history)
    
    print("Metainit")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    metainit_model = deepcopy(maml.model)
    metainit(metainit_model, compute_loss, text.shape, wikitext2_voc_size)
    metainit_train_loss_history, metainit_test_loss_history = eval_model(metainit_model, wikitext2_train_loader, 
                                                                         wikitext2_test_ids, epochs=100,
                                                                         device=device, mode='eval', 
                                                                         test_loss_interval=10*wikitext2_batches_in_epoch)
    reruns_metainit.append(metainit_test_loss_history)
    
reruns_base = np.array(reruns_base) * math.log2(math.e)
reruns_dimaml = np.array(reruns_dimaml) * math.log2(math.e)
reruns_metainit = np.array(reruns_metainit) * math.log2(math.e)
reruns_orthogonal = np.array(reruns_orthogonal) * math.log2(math.e)

In [None]:
print("Baseline 10 epoch: ", reruns_base.mean(0)[1], reruns_base.std(0, ddof=1)[1])
print("Baseline 50 epoch: ", reruns_base.mean(0)[5], reruns_base.std(0, ddof=1)[5])
print("Baseline 100 epoch: ", reruns_base.mean(0)[10], reruns_base.std(0, ddof=1)[10])
print()
print("DIMAML 10 epoch: ", reruns_dimaml.mean(0)[1], reruns_dimaml.std(0, ddof=1)[1])
print("DIMAML 50 epoch: ", reruns_dimaml.mean(0)[5], reruns_dimaml.std(0, ddof=1)[5])
print("DIMAML 100 epoch: ", reruns_dimaml.mean(0)[10], reruns_dimaml.std(0, ddof=1)[10])
print()
print("MetaInit 10 epoch: ", fixed_reruns_metainit.mean(0)[1], fixed_reruns_metainit.std(0, ddof=1)[1])
print("MetaInit 50 epoch: ", fixed_reruns_metainit.mean(0)[5], fixed_reruns_metainit.std(0, ddof=1)[5])
print("MetaInit 100 epoch: ", fixed_reruns_metainit.mean(0)[10], fixed_reruns_metainit.std(0, ddof=1)[10])
print()
print("Orthogonal 10 epoch: ", reruns_orthogonal.mean(0)[1], reruns_orthogonal.std(0, ddof=1)[1])
print("Orthogonal 50 epoch: ", reruns_orthogonal.mean(0)[5], reruns_orthogonal.std(0, ddof=1)[5])
print("Orthogonal 100 epoch: ", reruns_orthogonal.mean(0)[10], reruns_orthogonal.std(0, ddof=1)[10])