In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler
#from torch.utils.tensorboard import SummaryWriter

import numpy as np
from reformer_pytorch import Reformer, ReformerLM
from transformers import BertTokenizer, AdamW, PreTrainedTokenizer

import re
import os
from tqdm import tqdm, tqdm_notebook
from glob import glob

import json
import pickle
import shutil

## Shorter max length to test faster on current hardware

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
tokenizer.max_len = 128

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




In [3]:
model = ReformerLM(
    num_tokens = tokenizer.vocab_size,
    dim = 512,
    depth = 6,
    heads = 8,
    max_seq_len = tokenizer.max_len,
    causal = True
)

In [4]:
test = 'Hello, my dog is cute'

In [5]:
tok = tokenizer.encode(test, max_length=tokenizer.max_len, add_special_tokens=True)
tok = torch.tensor(tok, dtype=torch.long)
tok.shape

torch.Size([8])

In [6]:
tokenizer.decode(tok)

'[CLS] Hello, my dog is cute [SEP]'

In [7]:
def mask_tokens(inputs: torch.Tensor, tokenizer, mlm_probability=0.15, pad=True):
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
    labels = inputs.clone()
    # mlm_probability defaults to 0.15 in Bert
    probability_matrix = torch.full(labels.shape, mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    if pad:
        input_pads = tokenizer.max_len - inputs.shape[-1]
        label_pads = tokenizer.max_len - labels.shape[-1]
        
        inputs = F.pad(inputs, pad=(0,input_pads), value=tokenizer.pad_token_id)
        labels = F.pad(labels, pad=(0,label_pads), value=tokenizer.pad_token_id)
    
    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels


In [8]:
inputs, labels = mask_tokens(tok.unsqueeze(0), tokenizer, pad=True)

In [9]:
tokenizer.decode(inputs.squeeze(0))

'[CLS] Hello, my dog is cute [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [10]:
tokenizer.decode(labels.squeeze(0))

'[UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

# Predictions in shape [batch_size, max_seq_len, vocab_size]

In [11]:
pred = model(inputs)
pred.shape

torch.Size([1, 128, 28996])

In [12]:
tokenizer.decode(torch.argmax(pred, dim=-1).squeeze(0))

'Shuả touching Tonga Streetsj Parachute Sinatra dove Tim Hollandesthisky Mina varied crude trousers climbsleadingrta Betsy dove ≠ Christiminhed surreal McGill posters ／hs Kazan doveов hostess favoritesyland Slater ：rta rootedporated surreal trousers residencesmada Angelina HC Yu possessing Providence entire Freeway arrives grayÈ nm surreal Yesest Guangdongu fabricated ： ballotshipsitureread residences dryly Holland Unity sought trouserspathianchemistry penetrate pursuit mountainous sailsreadnamshū obsession 2001 ecosystem certain山ibrating obligations top surrealhalesrgoamo Ā HC HC joins Cope Dolores fabricated groups joins Ā surreal Depression chick Entry surrealiturerta surreal Garcia blogے lenses surreal Upper badge Layton Ally șest HC addressrta muster'

In [13]:
loss_fn = nn.CrossEntropyLoss()  # -100 index = padding token

In [14]:
masked_lm_loss = loss_fn(pred.view(-1, tokenizer.vocab_size), labels.view(-1))
masked_lm_loss

tensor(12.1266, grad_fn=<NllLossBackward>)

# Number of Wiki Files

In [15]:
wiki_filepath = './data/enwiki'
wikifiles = []

for root, dirs, files in os.walk(wiki_filepath):
    for file in files:
        wikifiles.append(os.path.join(root, file))
print(f'Total Files: {len(wikifiles)}')

Total Files: 13638


## Moving all of the files to a common directory

It looks like all of the files are named similarly, so running this now overwrites files and only gives you
100 total files. This is good for testing, but we need a better solution long-term

for wk in tqdm(wikifiles):
    filename = wk.split('\\')[-1]
    dirname = wk.split('\\')[-2]
    filename = f'{dirname}_{filename}'
    
    shutil.move(wk, f'D:/data/enwiki/{filename}')

In [273]:
tmp_work_dir = 'D:/data/enwiki/tmp_work_dir'

In [274]:
tmp_files = [file for file in glob(f'{tmp_work_dir}/*')]

In [16]:
class WikiDataset(Dataset):

    def __init__(self, path="", prefix="train"):
        
        assert os.path.isdir(path)

        self.documents = []
        filename_list = os.listdir(path)
        for file in filename_list:
            path_to_file = os.path.join(path, file)
            if not os.path.isfile(path_to_file):
                continue
            self.documents.append(path_to_file)

    def __len__(self):
        """ Returns the number of documents. """
        return len(self.documents)

    def __getitem__(self, idx):
        document_path = self.documents[idx]
        document_name = document_path.split("/")[-1]
        
        items = []
        
        with open(document_path, encoding="utf-8") as source:
            raw_text = source.readlines()
            for obj in raw_text:
                text = json.loads(obj)['text']
                text = re.sub('\\n', ' ', text)
                text = re.sub('\s+', ' ', text)
                items.append(text)
        
        return items

In [17]:
dataset = WikiDataset(path=wiki_filepath)
dataloader = DataLoader(dataset, shuffle=True, batch_size=4)

In [18]:
dataset_len = len(dataset)
train_len = int(dataset_len * 0.1)
eval_len = dataset_len - train_len

train_dataset, eval_dataset = torch.utils.data.random_split(dataset, (train_len, eval_len))

In [19]:
from math import sqrt
lr = lambda t: min(1e-2, 1 / sqrt(t))

In [285]:
train = train.dataset
test = test.dataset

In [253]:
# dataloader loads all wiki files, we them need to iterate over them
# by a batch size, so we need to create a new dataloader

# in order to calculate the masked_lm_loss, we need to identify where
# the masked tokens are in the inputs, and where the non-masked tokens
# are in the labels. These indices should match, so we'll assert that
# to double check that our mask_tokens function works.

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
epochs = 1
total_steps = 0
total_loss = 0.0
step_loss = 0.0
losses = []

model.train()
model.to(device)
optimizer = AdamW(params=model.parameters())

writer = SummaryWriter(log_dir='./tb_logs')

for epoch in range(epochs):
    for step, batch in tqdm(enumerate(dataloader), desc='DataLoader'):
        for data in batch:

            # tokenizing
            inputs = torch.cat(
                    [
                        tokenizer.encode(
                            data[i], 
                            add_special_tokens=True, 
                            max_length=tokenizer.max_len, 
                            pad_to_max_length=True, 
                            return_tensors='pt') \
                        for i in range(len(data))
                    ]
                )

            inputs, labels = mask_tokens(inputs, tokenizer)
            inputs, labels = inputs.to(device), labels.to(device)
            output = model(inputs)

            # calculating loss
            loss_mx = labels != -100

            output = output[loss_mx].view(-1, tokenizer.vocab_size)
            labels = labels[loss_mx].view(-1)

            loss = loss_fn(output, labels)
            losses.append(loss.item())
            step_loss += loss.item()
            total_steps += 1
        
        writer.add_scalar(tag='Train/Loss', scalar_value=loss.item(), global_step=total_steps)
        
        total_loss += step_loss
        writer.add_scalar(tag='[Avg] Train/Loss', scalar_value=total_loss / total_steps, global_step=total_steps)
    writer.close()

















DataLoader: 0it [00:00, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A















DataLoader: 1it [01:17, 77.89s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A

KeyboardInterrupt: 

# DeepSpeed experimentation

In [28]:
import deepspeed

In [39]:
deepspeed_config = {
    "train_batch_size": 8,
    "gradient_accumulation_steps": 1,
    "steps_per_print": 1,
    "zero_optimization": True,
    "disable_allgather": True,
    "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.00015,
      "max_grad_norm": 1.0
        }
    },

    "fp16": {
        "enabled": True,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "min_loss_scale": 1
    }
}

In [40]:
model_engine, optimizer, _, _ = deepspeed.initialize(args=deepspeed_config,
                                                     model=model,
                                                     model_parameters=params)

DeepSpeed info: version=0.1.0, git-hash=bf2689a, git-branch=master


AssertionError: DeepSpeed requires integer command line parameter --local_rank

In [35]:
torch.distributed.get_rank()

AssertionError: Default process group is not initialized