In [None]:
import os
import torch
import torchvision
import numpy as np
import data
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import random
from IRE import AdmWIRE
from model import Transformer
from copy import deepcopy

%load_ext autoreload
%autoreload 2

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
        
setup_seed(41)

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 1 or 2 gpus
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
corpus = data.Corpus('wikitext-2')
scaler = torch.cuda.amp.GradScaler()

In [None]:
def batchify(data, bsz, device):
    n = corpus.train.size()[0]
    data1 = data[:n]
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data1.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data1 = data1.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data1 = data1.view(bsz, -1).t().contiguous()
    return data1.to(device)
        
def get_batch(source, i, bptt):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i: i+seq_len]
    target = source[i+1: i+1+seq_len].view(-1)
    return data, target

def get_lr(it, min_lr, learning_rate, warmup_iters, lr_decay_iters):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [None]:
d_input = len(corpus.dictionary)
max_seq_length = 1000
epochs = 20000
max_iters = 100000
d_output = d_input
d_model = 128
num_heads = 8
num_layers = 2
d_ff = 512
dropout = 0.0
batch_size = 32
decay_lr = True
max_lr = 6e-4
min_lr = max_lr / 20
warmup_iters = 3000
weight_decay = 5e-4
beta1 = 0.9
beta2 = 0.95
###### hyperparametyers for IRE
rank = 0.01  # 0.1, 0.01, 0.001
prog = 5.0  # 1, 2, 5, ...
beta = 0.0  # momentum of Fisher estimate, 0.0.
prog_decay = True  # true or false: prog cos decay
######

log_interval = 10

model = Transformer(d_input, d_output, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

base_optimizer = torch.optim.AdamW
optimizer = AdmWIRE(model, base_optimizer, rank=rank, prog=prog, beta=beta, prog_decay=prog_decay, 
                     lr=max_lr, betas=(beta1,beta2), weight_decay=weight_decay, eps=1e-16)

train_data = batchify(corpus.train, batch_size, device)

In [None]:
loss_train = []
lrs = []

iter = 1
epoch = 0

while iter < max_iters:
    
    model.train()
    t0 = time.time()
    
    for batch, i in enumerate(range(0, train_data.size(0) - 1, max_seq_length)):
        
        data, targets = get_batch(train_data, i, max_seq_length)
        
        if iter >= warmup_iters and iter % 10 == 1:
            # after warm-up phase, estimate the projection each 10 iters
            optimizer.zero_grad(set_to_none=True)
            with torch.enable_grad():
                logits = model(data.t()).view(-1, d_input)
                samp_dist = torch.distributions.Categorical(logits=logits)
                y_sample = samp_dist.sample()
                loss = F.cross_entropy(logits, y_sample, ignore_index=-1)
                loss.backward()
                optimizer.update_mask()
                optimizer.zero_grad(set_to_none=True)

        lr = get_lr(iter, min_lr, max_lr, warmup_iters, max_iters)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        lrs.append(lr)
            
        output = model(data.t()).view(-1, d_input)
        loss = F.cross_entropy(output, targets, ignore_index=-1)
        loss.backward()
        optimizer.descent_step(lr, max_lr)
        optimizer.zero_grad(set_to_none=True)

        iter += 1 
        
    t1 = time.time()
    dt = t1 - t0

    lossf = loss.item()
    loss_train.append(lossf) 
    
    if epoch % log_interval == 0:
        print(f"epoch {epoch}: loss {lossf:.4f}, lr {lr:.2e}, time {dt*1000:.2f}ms")

    epoch += 1

In [None]:
import matplotlib.pyplot as plt
iters = range(1, len(loss_train)+1)
plt.plot(iters, loss_train, color ='C0', label='train loss')

In [None]:
iters = range(1, len(lrs)+1)
plt.plot(iters, lrs, color ='C0', label='train loss')