In [1]:
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification, T5ForConditionalGeneration
from transformers import AutoTokenizer
import datasets

import math
from matplotlib import pyplot as plt


from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

### Finetune

In [2]:
import sys
sys.path.append('..')
from transformers import AutoModelForSequenceClassification, AutoModelForCausalLM
# from modeling_rmt
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from modeling_rmt.language_modeling import *
# from modeling_rmt_enc_dec import RMTEncoderDecoderForConditionalGeneration

In [142]:
num_segments = 16
num_mem_tokens = 2
# device = torch.device(3)
device = 'cpu'

In [143]:
model_name = 'gpt2'

tokenizer = AutoTokenizer.from_pretrained(model_name)

rmt_config = {'num_mem_tokens': num_mem_tokens, 
                'max_n_segments': num_segments,
                'tokenizer': tokenizer,
                'segment_ordering': 'regular',
                'input_size': 128, 
                'bptt_depth': -1, 
                'sum_loss': False,
             }

base_model = AutoModelForCausalLM.from_pretrained(model_name)
rmt = RMTDecoderLMHeadMultiSeg(base_model, **rmt_config)

rmt.to(device)
pass



In [144]:
cpt = torch.load('../../runs/arxiv/gpt2/linear_adamw_wd1e-03_620-128-5x128_mem2_bs32_regular_bptt-5_from_cpt_4-5/run_4/model_best.pth', map_location='cpu')

In [145]:
rmt.load_state_dict(cpt['model_state_dict'])

<All keys matched successfully>

### load dataset 

In [146]:
class Holder:
    def __init__(self):
        pass

In [147]:
input_seq_len = 128
target_seq_len = 128
batch_size = 1

args = Holder
args.batch_size = batch_size
args.num_mem_tokens = 2
args.max_n_segments = 16
args.input_size = 128
args.input_seq_len = (args.input_size - 2*args.num_mem_tokens) * args.max_n_segments
args.target_seq_len = args.input_seq_len
args.input_prefix = ''
device = 'cpu'

In [148]:
import random
from datasets import Dataset, load_dataset, load_from_disk

In [149]:
block_size = args.input_size 
if args.num_mem_tokens is not None:
    block_size -= 2 * args.num_mem_tokens
# if args.xl_cache_size is not None:
#     block_size -= args.xl_cache_size
history_size = args.input_seq_len - block_size

class segmentDataLoaderOTF(DataLoader):
    def __init__(self, dataset, block_size, history_size, max_samples=None, shuffle=False, *args, **kwargs):
        super().__init__(dataset, *args, **kwargs)
        self.block_size = block_size
        self.history_size = history_size
        self.max_samples = max_samples
        self.shuffle = shuffle
            
    def get_samples(self, document):
        input_ids, attention_mask = document['input_ids'], document['attention_mask']
        samples = [input_ids[max({0, start - self.history_size}): start + self.block_size] for start in range(0, len(input_ids), self.block_size)]
        return samples
    
    def __iter__(self):
        inds = list(range(len(self.dataset)))

        if self.max_samples is not None:
            inds = inds[:self.max_samples]

        if self.shuffle: 
            random.shuffle(inds)

        doc_ind = 0
        samples = []

        while True:
            if doc_ind >= len(inds):
                break
            try: 
                while len(samples) < self.batch_size:
                    document = self.dataset[inds[doc_ind]]
                    doc_ind += 1
                    samples += self.get_samples(document)
                    if doc_ind >= len(inds):
                        raise(StopIteration)
            except(StopIteration):
                pass

            batch, samples = samples[:self.batch_size], samples[self.batch_size:]
            yield self.collate_fn(batch)
        

from torch.nn.utils.rnn import pad_sequence
id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
def collate_fn(batch):
    input_ids = labels = [torch.tensor(b[::-1]) for b in batch]
    attention_mask = [torch.ones_like(b, dtype=int) for b in input_ids]
    input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T.flip(1)
    labels = pad_sequence(labels, padding_value=-100).T.flip(1)
    attention_mask = pad_sequence(attention_mask, padding_value=0).T.flip(1)

    collated = {'input_ids': input_ids,
                'labels': labels, 
                'attention_mask': attention_mask}
    
    if input_ids.shape[1] != block_size:
        labels_mask = torch.ones_like(input_ids, dtype=bool)
        labels_mask[:, :-block_size] = False
        collated['labels_mask'] = labels_mask
    
    return collated

# valid_dataset = load_from_disk('/home/bulatov/bulatov/datasets/arxiv_pile/tokenized/valid')
train_dataset = load_from_disk('/home/bulatov/bulatov/datasets/arxiv_pile/tokenized/train')
kwargs = {'pin_memory': True}#, 'num_workers': args.data_n_workers}
dataloader = segmentDataLoaderOTF(valid_dataset, batch_size=3, 
                                #   sampler=train_sampler,
                                block_size=block_size, 
                                history_size=history_size, 
                                shuffle=True,
                                collate_fn=collate_fn, **kwargs)

In [150]:
gen = iter(dataloader)
for _ in range(16):
    batch = next(gen)

In [151]:
batch['input_ids'].shape

torch.Size([3, 1984])

In [155]:
out = rmt(**batch, output_hidden_states=True)

In [156]:
len(out.hidden_states)

13

In [168]:
for _ in range(16):
    batch = next(gen)

out = rmt(**batch, output_hidden_states=True)

In [169]:
data = {}

for i in range(16):
    data[f'read_memory_{i}'] = out[f'hidden_states_{i}'][-1][:, :2].detach().numpy().tolist()
    data[f'write_memory_{i}'] = out[f'hidden_states_{i}'][-1][:, -2:].detach().numpy().tolist()

In [170]:
for k in out:
    if 'loss' in k:
        data[k] = out[k].item()

In [171]:
data['input_text'] = tokenizer.batch_decode(batch['input_ids'])
data['input_ids'] = batch['input_ids'].numpy().tolist()

In [172]:
data.keys()

dict_keys(['read_memory_0', 'write_memory_0', 'read_memory_1', 'write_memory_1', 'read_memory_2', 'write_memory_2', 'read_memory_3', 'write_memory_3', 'read_memory_4', 'write_memory_4', 'read_memory_5', 'write_memory_5', 'read_memory_6', 'write_memory_6', 'read_memory_7', 'write_memory_7', 'read_memory_8', 'write_memory_8', 'read_memory_9', 'write_memory_9', 'read_memory_10', 'write_memory_10', 'read_memory_11', 'write_memory_11', 'read_memory_12', 'write_memory_12', 'read_memory_13', 'write_memory_13', 'read_memory_14', 'write_memory_14', 'read_memory_15', 'write_memory_15', 'loss', 'loss_0', 'loss_1', 'loss_2', 'loss_3', 'loss_4', 'loss_5', 'loss_6', 'loss_7', 'loss_8', 'loss_9', 'loss_10', 'loss_11', 'loss_12', 'loss_13', 'loss_14', 'loss_15', 'input_text', 'input_ids'])

In [165]:
batch['input_ids']

tensor([[39280,  4663,    61,  ...,    13,  2102,    11],
        [17641,   326,   612,  ..., 21054,    92, 38016],
        [21017,   317,   309,  ...,    92, 38016,  1797]])

In [173]:
batch['input_ids']

tensor([[ 262, 4006,  287,  ...,   13, 1849,   56],
        [  12, 1671, 7305,  ...,   11,  564,  246],
        [ 360,   12, 1671,  ..., 1849,   50, 2788]])

In [174]:
import json
with open('data/memories_16seg_s2.json', 'w') as f:
    json.dump(data, f)

In [178]:
ls data/memories_16seg_s2.json

data/memories_16seg_s2.json


In [177]:
pwd

'/cephfs/home/bulatov/bulatov/RMT_light/framework/notebooks'

In [180]:
ls /cephfs/home/bulatov/bulatov/RMT_light/framework/notebooks/data/memories_16seg_s2.json

/cephfs/home/bulatov/bulatov/RMT_light/framework/notebooks/data/memories_16seg_s2.json


In [57]:
for k in out:
    if 'loss' in k:
        print(k, out[k].item())

loss 2.6600282192230225
loss_0 2.9729745388031006
loss_1 2.677720069885254
loss_2 2.739013910293579
loss_3 2.523183822631836
loss_4 2.5851669311523438
loss_5 2.6139843463897705
loss_6 3.1057839393615723
loss_7 2.6600282192230225


In [118]:
b1 = dict(**batch)
b2 = dict(**batch)

In [119]:
for k in b1:
    b1[k] = b1[k][:1]

for k in b2:
    b2[k] = b2[k][:1]

In [128]:
# def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
#                 inputs_embeds=None, labels=None, labels_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None):
#         kwargs = {'attention_mask': attention_mask, 'token_type_ids': token_type_ids,
#                   'position_ids': position_ids, 'inputs_embeds': inputs_embeds,
#                   'labels_mask': labels_mask, #'pos_weight': pos_weight,
#                   'labels': labels, 'output_attentions': output_attentions,
#                   'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
#                   }


self = rmt
kwargs = dict(**b1)
input_ids = kwargs.pop('input_ids')
labels = kwargs['labels']

memory = self.set_memory(input_ids.shape)
segmented = self.pad_and_segment(input_ids, labels)

inp_memories = []
base_model_outputs = []
for seg_num, segment in enumerate(zip(*segmented)):
    

    seg_kwargs, non_empty_mask = self.prepare_kwargs(segment, memory, kwargs)
    if sum(non_empty_mask) == 0:
        continue
    
    if self.detach_memory(seg_num):
        memory = memory.detach()

    inp_memories.append(memory)

    seg_kwargs['inputs_embeds'][:, self.read_memory_position] = memory[non_empty_mask]
    seg_kwargs['inputs_embeds'][:, self.write_memory_position] = memory[non_empty_mask]
    out = self.model(**seg_kwargs)
    base_model_outputs.append(out)
    
    memory[non_empty_mask] = out.hidden_states[-1][:, self.write_memory_position]

self.memory_state = memory

out1 = self.process_outputs(base_model_outputs, kwargs)
# return out

In [129]:
for k in out1:
    if 'loss' in k:
        print(k, out1[k].item())

loss 2.3547987937927246
loss_0 3.582904100418091
loss_1 2.427748203277588
loss_2 2.5168142318725586
loss_3 2.9418818950653076
loss_4 2.3158202171325684
loss_5 1.8513410091400146
loss_6 3.7193567752838135
loss_7 2.3547987937927246


In [None]:
len(inp_memories)

8

In [130]:
# with true memories 

self = rmt
kwargs = dict(**b2)
input_ids = kwargs.pop('input_ids')
labels = kwargs['labels']

memory = self.set_memory(input_ids.shape)
segmented = self.pad_and_segment(input_ids, labels)

# inp_memories = []
base_model_outputs = []
for seg_num, segment in enumerate(zip(*segmented)):
    

    seg_kwargs, non_empty_mask = self.prepare_kwargs(segment, memory, kwargs)
    if sum(non_empty_mask) == 0:
        continue
    
    if self.detach_memory(seg_num):
        memory = memory.detach()

    # inp_memories.append(memory)
    # memory = inp_memories[seg_num]

    seg_kwargs['inputs_embeds'][:, self.read_memory_position] = memory[non_empty_mask]
    seg_kwargs['inputs_embeds'][:, self.write_memory_position] = memory[non_empty_mask]
    out = self.model(**seg_kwargs)
    base_model_outputs.append(out)
    
    memory[non_empty_mask] = out.hidden_states[-1][:, self.write_memory_position]

self.memory_state = memory

out2 = self.process_outputs(base_model_outputs, kwargs)
for k in out2:
    if 'loss' in k:
        print(k, out2[k].item())

loss 2.3547987937927246
loss_0 3.582904100418091
loss_1 2.427748203277588
loss_2 2.5168142318725586
loss_3 2.9418818950653076
loss_4 2.3158202171325684
loss_5 1.8513410091400146
loss_6 3.7193567752838135
loss_7 2.3547987937927246


In [131]:
# with fake memories 

self = rmt
kwargs = dict(**b2)
input_ids = kwargs.pop('input_ids')
labels = kwargs['labels']

memory = self.set_memory(input_ids.shape)
segmented = self.pad_and_segment(input_ids, labels)

# inp_memories = []
base_model_outputs = []
for seg_num, segment in enumerate(zip(*segmented)):
    

    seg_kwargs, non_empty_mask = self.prepare_kwargs(segment, memory, kwargs)
    if sum(non_empty_mask) == 0:
        continue
    
    if self.detach_memory(seg_num):
        memory = memory.detach()

    # inp_memories.append(memory)
    memory = inp_memories[seg_num]

    seg_kwargs['inputs_embeds'][:, self.read_memory_position] = memory[non_empty_mask]
    seg_kwargs['inputs_embeds'][:, self.write_memory_position] = memory[non_empty_mask]
    out = self.model(**seg_kwargs)
    base_model_outputs.append(out)
    
    memory[non_empty_mask] = out.hidden_states[-1][:, self.write_memory_position]

self.memory_state = memory

out2f = self.process_outputs(base_model_outputs, kwargs)
for k in out2f:
    if 'loss' in k:
        print(k, out2f[k].item())

loss 2.3726823329925537
loss_0 3.3955602645874023
loss_1 2.281308889389038
loss_2 2.4400248527526855
loss_3 2.8076670169830322
loss_4 2.3518617153167725
loss_5 1.9259015321731567
loss_6 3.7944767475128174
loss_7 2.3726823329925537


In [135]:
print('key - loss - fake memory loss')
for k in out2f:
    if 'loss' in k:
        print(k, round(out2[k].item(), 2), round(out2f[k].item(), 2))

key - loss - fake memory loss
loss 2.35 2.37
loss_0 3.58 3.4
loss_1 2.43 2.28
loss_2 2.52 2.44
loss_3 2.94 2.81
loss_4 2.32 2.35
loss_5 1.85 1.93
loss_6 3.72 3.79
loss_7 2.35 2.37
