<a href="https://colab.research.google.com/github/xSakix/AI_colab_notebooks/blob/master/reformer_pytorch_cuda.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install torch
!pip install reformer_pytorch
!pip install transformers



In [0]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Fri Feb  7 17:49:49 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.48.02    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   63C    P0    35W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
# load model file and epoch
import os
import re
files = [f for f in os.listdir('/content/drive/My Drive/model_saves') if f.startswith('epoch')]
last_model_file = None
epochs_run = 0
if len(files) > 0:
  files.sort(reverse=True)
  last_model_file = os.path.join('/content/drive/My Drive/model_saves',files[0])
  print(last_model_file)
  epochs = re.findall(r'\d+',files[0])
  epochs_run = 0
  if len(epochs) == 1:
    epochs_run = int(epochs[0])
  print('number of epochs run:',epochs_run)


/content/drive/My Drive/model_saves/epoch-2000.pt
number of epochs run: 2000


In [0]:
from reformer_pytorch import ReformerLM

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import os
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 8
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def get_top_p(logits, top_p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = float('-inf')
    return logits

def sample_next_token(logits, top_p=0.9, temperature = 1.0):
    logits = logits[0, -1, :] / temperature
    filtered_logits = get_top_p(logits, top_p=top_p)

    probs = F.softmax(filtered_logits, dim=-1)
    return torch.multinomial(probs, 1)

def decode_token(token):
    return str(chr(token))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# instantiate model

model = ReformerLM(
    dim = 512,
    depth = 6,
    max_seq_len = SEQ_LEN,
    num_tokens = 256,
    heads = 8,
    bucket_size = 64,
    n_hashes = 8,
    ff_chunks = 10,
    lsh_dropout = 0.1,
    weight_tie = True,
    causal = True,
    use_full_attn = False # set this to true for comparison with full attention
)

# model = ReformerLM(
#     dim = 512,
#     depth = 6,
#     max_seq_len = SEQ_LEN,
#     num_tokens = 256,
#     heads = 8,
#     bucket_size = 64,
#     n_hashes = 4,
#     ff_chunks = 10,
#     lsh_dropout = 0.1,
#     weight_tie = True,
#     causal = True,
#     use_full_attn = False # set this to true for comparison with full attention
# )

if last_model_file is not None:
  model.load_state_dict(torch.load(last_model_file ))

model.cuda()


# prepare enwik8 data

with gzip.open('/content/drive/My Drive/model_data/merged.gz') as file:
    X = np.array([int(c) for c in file.read()])
    si = int(len(X)-len(X)*0.2)
    trX, vaX = np.split(X, [si])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq[0:-1].cuda(), full_seq[1:].cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

print(len(train_dataset))
print(len(val_dataset))

# optimizer
# optimizer.load_state_dict(torch.load('optimizer.pt'))
# scheduler.load_state_dict(torch.load('scheduler.pt'))

optim = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,amsgrad=True)

if os.path.exists('/content/drive/My Drive/model_saves/optim.pt'):
  optim.load_state_dict(torch.load('/content/drive/My Drive/model_saves/optim.pt'))

#scheduler

# scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=VALIDATE_EVERY, gamma=0.1)

scheduler = get_linear_schedule_with_warmup(
            optim,
            num_warmup_steps=0,
            num_training_steps=len(train_dataset) // GRADIENT_ACCUMULATE_EVERY * NUM_BATCHES
        )

if os.path.exists('/content/drive/My Drive/model_saves/scheduler.pt'):
  scheduler.load_state_dict(torch.load('/content/drive/My Drive/model_saves/scheduler.pt'))

# training

def get_batch_loss(model, data):
    x, y = data
    pred = model(x)
    return F.cross_entropy(pred.transpose(1, 2), y, reduction='mean')

for i in tqdm.tqdm(range(epochs_run, NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = get_batch_loss(model, next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()
    scheduler.step()

    if i % VALIDATE_EVERY == 0:
        torch.save(model.state_dict(), os.path.join('/content/drive/My Drive/model_saves', 'epoch-{}.pt'.format(i)))
        torch.save(optim.state_dict(),'/content/drive/My Drive/model_saves/optim.pt')
        torch.save(scheduler.state_dict(),'/content/drive/My Drive/model_saves/scheduler.pt')
        model.eval()
        with torch.no_grad():
            loss = get_batch_loss(model, next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            inp, _ = random.choice(val_dataset)
            output_str = ''
            prime = decode_tokens(inp)

            # print(f'%s \n\n %s', (prime, '*' * 100))
            print(prime)
            print('*'*100)

            for _ in tqdm.tqdm(range(GENERATE_LENGTH), desc='generating'):
                logits = model(inp[None, :])
                next_token = sample_next_token(logits)
                output_str += decode_token(next_token)
                inp = torch.cat((inp[1:], next_token), dim=0)

            print(output_str)


training:   0%|          | 0/98000 [00:00<?, ?it/s][A

18140
4535
training loss: 2.463907241821289




generating:   0%|          | 0/512 [00:00<?, ?it/s][A[A

validation loss: 2.53680682182312
romazdenie ako natlakovu akciu namierenu proti vlastnikovi
spolocnosti. Sme presvedceni o tom, ze organizatori podujatia
zneuzivaju dovercivost ludi, siria nepravdive informacie a ich
motivaciou nie je pomoc, ale osobny financny prospech, co napokon priznali
aj v mediach, uviedla v stanovisku.
Spolocnost ECO-INVEST ako aj majitel opakuju, ze sa tymto a ani
ziadnym inym sposobom nenechaju tlacit k rokovaniam o mimosudnych
dohodach a vyplateni roznych sum. Greksova upozornuje, ze privatizacnu
povinnost poskytnut podiel v papiernach zamestnancom, ktoru ziadaju
aktivisti, si akcionar splnil uz pred viac ako patnastimi rokmi. "Preto
nevidime dovod vobec s panom Stefanom Gavornikom (predseda obcianskeho
zdruzenia  pozn. redakcie) a Pavlom Korytarom (pravny zastupca
obcianskeho zdruzenia  pozn. redakcie) o tejto vykonstruovanej
poziadavke diskutovat a ani ju politizovat, " dodala zastupkyna
spolocnosti.
Greksova taktiez doplnila, ze aktivisti, 



generating:   0%|          | 1/512 [00:00<01:46,  4.80it/s][A[A

generating:   0%|          | 2/512 [00:00<01:47,  4.74it/s][A[A

generating:   1%|          | 3/512 [00:00<01:47,  4.75it/s][A[A

generating:   1%|          | 4/512 [00:00<01:47,  4.74it/s][A[A

generating:   1%|          | 5/512 [00:01<01:46,  4.76it/s][A[A

generating:   1%|          | 6/512 [00:01<01:46,  4.74it/s][A[A

generating:   1%|▏         | 7/512 [00:01<01:46,  4.73it/s][A[A

generating:   2%|▏         | 8/512 [00:01<01:47,  4.70it/s][A[A

generating:   2%|▏         | 9/512 [00:01<01:46,  4.72it/s][A[A

generating:   2%|▏         | 10/512 [00:02<01:46,  4.73it/s][A[A

generating:   2%|▏         | 11/512 [00:02<01:45,  4.73it/s][A[A

generating:   2%|▏         | 12/512 [00:02<01:45,  4.75it/s][A[A

generating:   3%|▎         | 13/512 [00:02<01:44,  4.76it/s][A[A

generating:   3%|▎         | 14/512 [00:02<01:44,  4.75it/s][A[A

generating:   3%|▎         | 15/512 [00:03<01:44,  4.75