<a href="https://colab.research.google.com/github/xSakix/AI_colab_notebooks/blob/master/reformer_pytorch_cuda.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch
!pip install reformer_pytorch
!pip install transformers

Collecting reformer_pytorch
  Downloading https://files.pythonhosted.org/packages/c7/76/e16c3f0904011223e8c4a853d3b08a300db74c4a90a4a983f1a7d934fd63/reformer_pytorch-0.12.7.tar.gz
Collecting revtorch>=0.2.4
  Downloading https://files.pythonhosted.org/packages/7b/7f/6b2247e5ce4b8969dedfcaec064c59ce0417cddbe638bfa6169ff586eaea/revtorch-0.2.4.tar.gz
Building wheels for collected packages: reformer-pytorch, revtorch
  Building wheel for reformer-pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for reformer-pytorch: filename=reformer_pytorch-0.12.7-cp36-none-any.whl size=8720 sha256=52665f37f9e968fc527845fabe1a43680e12aa72078735ba10c4a3c146628af5
  Stored in directory: /root/.cache/pip/wheels/61/b8/d4/a72dab74c922c6cb6544a50f5853b548071e1cb33eb76fda13
  Building wheel for revtorch (setup.py) ... [?25l[?25hdone
  Created wheel for revtorch: filename=revtorch-0.2.4-cp36-none-any.whl size=5750 sha256=ace4657a8c417356555591e23f3073000098a35d4026e1dd880f119724b167cf
  Stored in directo

In [7]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Wed Feb 12 08:08:33 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.48.02    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    32W / 250W |  13261MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
+-------

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
# load model file and epoch
import os
import re
files = [f for f in os.listdir('/content/drive/My Drive/model_saves') if f.startswith('epoch')]
last_model_file = None
epochs_run = 0
if len(files) > 0:
  files.sort(reverse=True)
  last_model_file = os.path.join('/content/drive/My Drive/model_saves',files[0])
  print(last_model_file)
  epochs = re.findall(r'\d+',files[0])
  epochs_run = 0
  if len(epochs) == 1:
    epochs_run = int(epochs[0])
  print('number of epochs run:',epochs_run)


/content/drive/My Drive/model_saves/epoch-17700.pt
number of epochs run: 17700


In [0]:
from reformer_pytorch import ReformerLM

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import os
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 8
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def get_top_p(logits, top_p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = float('-inf')
    return logits

def sample_next_token(logits, top_p=0.9, temperature = 1.0):
    logits = logits[0, -1, :] / temperature
    filtered_logits = get_top_p(logits, top_p=top_p)

    probs = F.softmax(filtered_logits, dim=-1)
    return torch.multinomial(probs, 1)

def decode_token(token):
    return str(chr(token))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# instantiate model

model = ReformerLM(
    dim = 512,
    depth = 6,
    max_seq_len = SEQ_LEN,
    num_tokens = 256,
    heads = 8,
    bucket_size = 64,
    n_hashes = 8,
    ff_chunks = 10,
    lsh_dropout = 0.1,
    weight_tie = True,
    causal = True,
    use_full_attn = False # set this to true for comparison with full attention
)

# model = ReformerLM(
#     dim = 512,
#     depth = 6,
#     max_seq_len = SEQ_LEN,
#     num_tokens = 256,
#     heads = 8,
#     bucket_size = 64,
#     n_hashes = 4,
#     ff_chunks = 10,
#     lsh_dropout = 0.1,
#     weight_tie = True,
#     causal = True,
#     use_full_attn = False # set this to true for comparison with full attention
# )

if last_model_file is not None:
  model.load_state_dict(torch.load(last_model_file ))

model.cuda()


# prepare enwik8 data

with gzip.open('/content/drive/My Drive/model_data/merged.gz') as file:
    X = np.array([int(c) for c in file.read()])
    si = int(len(X)-len(X)*0.2)
    trX, vaX = np.split(X, [si])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq[0:-1].cuda(), full_seq[1:].cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

print(len(train_dataset))
print(len(val_dataset))

# optimizer
# optimizer.load_state_dict(torch.load('optimizer.pt'))
# scheduler.load_state_dict(torch.load('scheduler.pt'))

optim = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,amsgrad=True)

if os.path.exists('/content/drive/My Drive/model_saves/optim.pt'):
  optim.load_state_dict(torch.load('/content/drive/My Drive/model_saves/optim.pt'))

#scheduler

# scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=VALIDATE_EVERY, gamma=0.1)

scheduler = get_linear_schedule_with_warmup(
            optim,
            num_warmup_steps=0,
            num_training_steps=len(train_dataset) // GRADIENT_ACCUMULATE_EVERY * NUM_BATCHES
        )

if os.path.exists('/content/drive/My Drive/model_saves/scheduler.pt'):
  scheduler.load_state_dict(torch.load('/content/drive/My Drive/model_saves/scheduler.pt'))

# training

def get_batch_loss(model, data):
    x, y = data
    pred = model(x)
    return F.cross_entropy(pred.transpose(1, 2), y, reduction='mean')

for i in tqdm.tqdm(range(epochs_run, NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = get_batch_loss(model, next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()
    scheduler.step()

    if i % VALIDATE_EVERY == 0:
        torch.save(model.state_dict(), os.path.join('/content/drive/My Drive/model_saves', 'epoch-{}.pt'.format(i)))
        torch.save(optim.state_dict(),'/content/drive/My Drive/model_saves/optim.pt')
        torch.save(scheduler.state_dict(),'/content/drive/My Drive/model_saves/scheduler.pt')
        model.eval()
        with torch.no_grad():
            loss = get_batch_loss(model, next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            inp, _ = random.choice(val_dataset)
            output_str = ''
            prime = decode_tokens(inp)

            # print(f'%s \n\n %s', (prime, '*' * 100))
            print(prime)
            print('*'*100)

            for _ in tqdm.tqdm(range(GENERATE_LENGTH), desc='generating'):
                logits = model(inp[None, :])
                next_token = sample_next_token(logits)
                output_str += decode_token(next_token)
                inp = torch.cat((inp[1:], next_token), dim=0)

            print(output_str)


training:   0%|          | 0/82300 [00:00<?, ?it/s][A

39035
9758
training loss: 1.1895207166671753



training:   0%|          | 1/82300 [00:16<387:44:07, 16.96s/it][A

validation loss: 0.9503108263015747



training:   0%|          | 2/82300 [00:32<378:18:03, 16.55s/it][A

training loss: 1.060807228088379



training:   0%|          | 3/82300 [00:48<371:45:24, 16.26s/it][A

training loss: 1.1756713390350342



training:   0%|          | 4/82300 [01:03<367:05:06, 16.06s/it][A

training loss: 1.0568010807037354



training:   0%|          | 5/82300 [01:19<363:50:36, 15.92s/it][A

training loss: 1.0997706651687622



training:   0%|          | 6/82300 [01:34<361:30:13, 15.81s/it][A

training loss: 1.2273523807525635



training:   0%|          | 7/82300 [01:50<359:55:57, 15.75s/it][A

training loss: 1.4169782400131226



training:   0%|          | 8/82300 [02:06<358:46:14, 15.70s/it][A

training loss: 0.7619534730911255



training:   0%|          | 9/82300 [02:21<358:02:10, 15.66s/it][A

training loss: 1.2530237436294556



training:   0%|          | 10/82300 [02:37<357:28:16, 15.64s/it][A

training loss: 0.981792688369751



training:   0%|          | 11/82300 [02:52<357:11:26, 15.63s/it][A

training loss: 1.2237261533737183



training:   0%|          | 12/82300 [03:08<356:50:50, 15.61s/it][A

training loss: 1.235406756401062



training:   0%|          | 13/82300 [03:23<356:38:08, 15.60s/it][A

training loss: 0.9258752465248108



training:   0%|          | 14/82300 [03:39<356:29:27, 15.60s/it][A

training loss: 1.2219524383544922



training:   0%|          | 15/82300 [03:55<356:26:09, 15.59s/it][A

training loss: 1.0885881185531616



training:   0%|          | 16/82300 [04:10<356:24:25, 15.59s/it][A

training loss: 1.0124324560165405



training:   0%|          | 17/82300 [04:26<356:24:03, 15.59s/it][A

training loss: 1.3582189083099365



training:   0%|          | 18/82300 [04:41<356:22:01, 15.59s/it][A

training loss: 1.1168711185455322



training:   0%|          | 19/82300 [04:57<356:22:03, 15.59s/it][A

training loss: 1.303199291229248



training:   0%|          | 20/82300 [05:13<356:14:07, 15.59s/it][A

training loss: 0.8306032419204712



training:   0%|          | 21/82300 [05:28<356:14:02, 15.59s/it][A

training loss: 1.0929125547409058



training:   0%|          | 22/82300 [05:44<356:11:27, 15.58s/it][A

training loss: 1.0613465309143066



training:   0%|          | 23/82300 [05:59<356:18:08, 15.59s/it][A

training loss: 1.2249786853790283



training:   0%|          | 24/82300 [06:15<356:15:14, 15.59s/it][A

training loss: 1.2791032791137695



training:   0%|          | 25/82300 [06:31<356:58:04, 15.62s/it][A

training loss: 1.027601957321167



training:   0%|          | 26/82300 [06:46<356:49:30, 15.61s/it][A

training loss: 1.3565394878387451



training:   0%|          | 27/82300 [07:02<356:44:05, 15.61s/it][A

training loss: 1.2776069641113281



training:   0%|          | 28/82300 [07:17<356:32:01, 15.60s/it][A

training loss: 0.9583160877227783



training:   0%|          | 29/82300 [07:33<356:38:51, 15.61s/it][A

training loss: 1.2230029106140137



training:   0%|          | 30/82300 [07:49<356:41:22, 15.61s/it][A

training loss: 0.9350292682647705



training:   0%|          | 31/82300 [08:04<356:29:49, 15.60s/it][A

training loss: 1.1295626163482666



training:   0%|          | 32/82300 [08:20<356:24:58, 15.60s/it][A

training loss: 1.1711156368255615



training:   0%|          | 33/82300 [08:35<356:18:04, 15.59s/it][A

training loss: 1.4645940065383911



training:   0%|          | 34/82300 [08:51<356:21:32, 15.59s/it][A

training loss: 1.174204707145691



training:   0%|          | 35/82300 [09:07<356:13:14, 15.59s/it][A

training loss: 1.0782009363174438



training:   0%|          | 36/82300 [09:22<356:13:21, 15.59s/it][A

training loss: 0.9799803495407104



training:   0%|          | 37/82300 [09:38<356:10:22, 15.59s/it][A

training loss: 1.180436611175537



training:   0%|          | 38/82300 [09:53<356:11:10, 15.59s/it][A

training loss: 0.8830055594444275



training:   0%|          | 39/82300 [10:09<356:06:48, 15.58s/it][A

training loss: 1.0835990905761719



training:   0%|          | 40/82300 [10:24<356:02:50, 15.58s/it][A

training loss: 1.0766915082931519



training:   0%|          | 41/82300 [10:40<356:01:57, 15.58s/it][A

training loss: 1.2683428525924683



training:   0%|          | 42/82300 [10:56<356:06:24, 15.58s/it][A

training loss: 1.3210864067077637



training:   0%|          | 43/82300 [11:11<355:59:25, 15.58s/it][A

training loss: 1.126894474029541



training:   0%|          | 44/82300 [11:27<355:59:21, 15.58s/it][A

training loss: 1.3516286611557007



training:   0%|          | 45/82300 [11:42<355:57:31, 15.58s/it][A

training loss: 1.2718085050582886



training:   0%|          | 46/82300 [11:58<356:06:08, 15.59s/it][A

training loss: 1.5456156730651855



training:   0%|          | 47/82300 [12:14<356:01:52, 15.58s/it][A

training loss: 1.1696194410324097



training:   0%|          | 48/82300 [12:29<356:03:27, 15.58s/it][A

training loss: 1.1387717723846436



training:   0%|          | 49/82300 [12:45<356:01:29, 15.58s/it][A

training loss: 1.1308152675628662
