In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [None]:
!git clone https://github.com/NVIDIA/apex
!cd apex; pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

In [None]:
!pip install performer-pytorch

## Performer Model

In [None]:
# Try enwik dataset
from performer_pytorch import PerformerLM
# Calculates loss
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper

import random
#import tqdm
from tqdm.notebook import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# Mixed Precision in PyTorch: - Throws Apex Error (or not?)
from torch.cuda.amp import autocast, GradScaler

# constants
NUM_BATCHES = 10#int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 2048
SEQ_LEN = 4096

# helpers
def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    """
    chr: returns character from num; e.g. chr(97)) -> a; chr of <=32 -> whitespace
    """
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))


# instantiate model

model = PerformerLM(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 8,
    causal = True,
    reversible = True,
    nb_features = 256,
    use_scalenorm = True,
    local_attn_heads = (8, 8, 8, 6, 4, 2) # Attention Heads per layer
)

model = AutoregressiveWrapper(model)
model.cuda()


## Data Preparation

In [None]:
# prepare enwik8 data

with open("../input/enwikidataset/enwikinews-dataset.txt") as file:
    
    file = file.read()
    X = np.fromstring(file, dtype=np.uint8)
    trX, vaX = np.split(X, [int(len(file)*0.9)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
    
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len
    
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# optimizer
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler()

## Training

In [None]:

for i in tqdm(range(NUM_BATCHES), desc='training'):
    model.train()
    
    ###
    ##if i % 500 == 0:
    #    print(i)

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        with autocast():
            loss = model(next(train_loader), return_loss = True)
        #loss.backward()
        scaler.scale(loss).backward()
        
    print(f'training loss: {loss.item()}')

    scaler.unscale_(optim)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    scaler.step(optim)
    scaler.update()
    #optim.step() # TMP
    
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0 and i != 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

In [None]:
# W/ AMP, 60secs

training loss: 5.719467639923096
validation loss: 5.679171562194824
training loss: 5.688153266906738
training loss: 5.64174747467041
training loss: 5.6018266677856445
training loss: 5.5692853927612305
training loss: 5.528021335601807
training loss: 5.4963812828063965
training loss: 5.4679107666015625
training loss: 5.419412136077881
training loss: 5.399522304534912


# W/o AMP, 57 secs

training loss: 5.67160701751709
validation loss: 5.638372421264648
training loss: 5.621777057647705
training loss: 5.606544494628906
training loss: 5.550500869750977
training loss: 5.5118408203125
training loss: 5.481593132019043
training loss: 5.444958686828613
training loss: 5.410432815551758
training loss: 5.354680061340332
training loss: 5.319333553314209