In [1]:
import numpy as np
import os
import sys
import torch
import datasets
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from itertools import chain
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

sys.path.append('..')

### Load model

In [2]:
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)



### Prepare dataset

In [3]:
class Holder:
    def __init__(self):
        pass

In [4]:
input_seq_len = 4096
target_seq_len = 4096

num_mem_tokens = 10
input_size = 1024

batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.num_mem_tokens = num_mem_tokens
args.input_size = input_size
args.input_prefix = ''
args.block_size = None
args.task_name = 'wikitext-2-v1'

device = 'cpu'

In [5]:

raw_datasets = datasets.load_dataset('wikitext', args.task_name)
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on dataset",
)

block_size = args.input_size 
if args.num_mem_tokens is not None:
    block_size -= 2 * args.num_mem_tokens
history_size = args.input_seq_len - block_size

def group_texts(examples, block_size, history_size=None):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    if history_size is None:
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    else:
        result = {
            k: [t[max({0, i - history_size}) : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    result["labels"] = result["input_ids"].copy()
    return result

from torch.nn.utils.rnn import pad_sequence
id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
def collate_fn(batch):
    input_ids = [torch.tensor(b['input_ids'][::-1]) for b in batch]
    labels = [torch.tensor(b['labels'][::-1]) for b in batch]
    attention_mask = [torch.tensor(b['attention_mask'][::-1]) for b in batch]
    input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T.flip(1)
    labels = pad_sequence(labels, padding_value=-100).T.flip(1)
    attention_mask = pad_sequence(attention_mask, padding_value=0).T.flip(1)

    collated = {'input_ids': input_ids,
                'labels': labels, 
                'attention_mask': attention_mask}

    if input_ids.shape[1] != block_size:
        labels_mask = torch.ones_like(input_ids, dtype=bool)
        labels_mask[:, :-block_size] = False
        collated['labels_mask'] = labels_mask

    return collated


train_dataset = tokenized_datasets["train"].map(lambda x: group_texts(x, block_size, history_size), 
                                        batched=True, desc=f"Grouping train in chunks of {block_size} and history {history_size}")
valid_dataset = tokenized_datasets["validation"].map(lambda x: group_texts(x, block_size), 
                                        batched=True, desc=f"Grouping valid in chunks of {block_size}")


# shuffle train data each epoch (one loop over train_dataset)
# train_sampler = DistributedSampler(train_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=True,
#                                     drop_last=False, seed=args.seed)
# per_worker_batch_size = args.batch_size * args.gradient_accumulation_steps
# global_batch_size = per_worker_batch_size * hvd.size()

# train_sampler = RandomSampler(train_dataset)
kwargs = {'pin_memory': True}#, 'num_workers': args.data_n_workers}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, **kwargs)

Found cached dataset wikitext (/home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-389b922bfc5fe729.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-6067a66e735cfbb1.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-941845a5470f2db7.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-c59c474634171e07.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-414595b9a286c906.a

In [6]:
gen = iter(train_dataloader)
batch = next(gen)
batch = next(gen)
batch.pop('labels_mask')
batch['input_ids'].shape

torch.Size([2, 4016])

In [7]:
try: 
    out = model(**batch)
except IndexError:
    print('Error: Input size too large!')

Error: Input size too large!


### Add RMT!

In [11]:
from modeling_rmt.language_modeling import RMTDecoderLMHeadMultiSeg as RMTWrapper

In [12]:
rmt_config = {'num_mem_tokens': 10, 
              'max_n_segments': 8,
              'tokenizer': tokenizer,
              'input_size': 1024, 
            }

rmt = RMTWrapper(model, **rmt_config)

In [13]:
batch = next(gen)

In [14]:
out = rmt(**batch)

In [17]:
for k in out:
    if 'loss' in k:
        print(k, out[k].item())

loss 10.841404914855957
loss_0 4.384130001068115
loss_1 10.640803337097168
loss_2 11.2599458694458
loss_3 11.477508544921875
loss_4 10.841404914855957


### Success! 
Let's teach the model to use memory.