In [1]:
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification, T5ForConditionalGeneration
from transformers import AutoTokenizer
import datasets

import math
from matplotlib import pyplot as plt


from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

### Finetune

In [2]:
from transformers import AutoModelForSequenceClassification, T5ForConditionalGeneration
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')
from modeling_rmt import RMTEncoderForSequenceClassification
from modeling_rmt import RMTEncoderDecoderForConditionalGeneration

In [3]:
from transformers import AutoTokenizer, AutoModelForTokenClassification

In [58]:
num_segments = 1
num_mem_tokens = 0
# device = torch.device(3)
device = 'cpu'

In [74]:
# model_name = 'bert-base-cased'
model_name = 't5-base'

tokenizer = AutoTokenizer.from_pretrained(model_name)

rmt_config = {'num_mem_tokens': num_mem_tokens, 
                'max_n_segments': num_segments,
                'tokenizer': tokenizer,
               #  'memory_layers': 'all', 
               #  'share_memory_layers': True,
               #  'reconstruction_loss_coef': 0.1,
                'segment_ordering': 'regular',
                'input_size': 512, 
                'bptt_depth': -1, 
                'sum_loss': False,
             }

# base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
base_model = T5ForConditionalGeneration.from_pretrained(model_name)
# rmt = RMTEncoderDecoderForConditionalGeneration(base_model, **rmt_config)
# rmt = RMTEncoderForSequenceClassification(base_model, **rmt_config)
# rmt = RMTEncoderTBPTT(base_model, **rmt_config)

In [60]:
rmt_out = rmt(sample_input_ids, **kwargs)
rmt_out.loss

tensor(1.2175, grad_fn=<NllLossBackward>)

In [None]:
base_out = base_model(sample_input_ids, **kwargs)
base_out['loss']

In [75]:
cpt_path = "../../runs/framework/qasper/t5-base/lr5e-05_constant_with_warmup_adamw_wd1e-03_512-1024-{1}seg_memNA_bs32_iters5000_regular/run_10/"
# cpt_path = '../../runs/framework/contract_nli/bert-base-cased/lr1e-05_linear_adamw_wd1e-03_1024-512-{1}seg_mem0_bs32_iters5000_regular/run_4/'
model_cpt = os.path.join(cpt_path, "model_best.pth")
cpt = torch.load(model_cpt, map_location='cpu')
base_model.load_state_dict(cpt['model_state_dict'])
# rmt.load_state_dict(cpt['model_state_dict'])

<All keys matched successfully>

In [64]:
rmt_out = rmt(sample_input_ids, **kwargs)
rmt_out.loss

tensor(1.4465, grad_fn=<NllLossBackward>)

In [82]:
base_out = base_model(sample_input_ids, **kwargs)
base_out['loss']

tensor(0.4951, grad_fn=<NllLossBackward>)

In [28]:
base_model.config.vocab_size

32128

In [19]:
base_model.tie_weights()

In [20]:
base_model.resize_token_embeddings(base_model.config.vocab_size + 10)

Embedding(32138, 768)

In [22]:
base_model.tie_weights()

In [23]:
base_out = base_model(sample_input_ids, **kwargs)
base_out['loss']

tensor(9.3226, grad_fn=<NllLossBackward>)

In [285]:
rmt = RMTEncoderDecoderForConditionalGeneration(base_model, **rmt_config)

In [286]:
base_out = rmt.model(sample_input_ids, **kwargs)
base_out['loss']

tensor(6.2180, grad_fn=<NllLossBackward>)

In [287]:
base_out = base_model(sample_input_ids, **kwargs)
base_out['loss']

tensor(6.2180, grad_fn=<NllLossBackward>)

In [57]:
out = rmt(sample_input_ids, **kwargs, output_hidden_states=False, output_attentions=False)
out.loss

tensor(1.4465, grad_fn=<NllLossBackward>)

In [250]:
out['loss_0'], out['loss_1'], out['loss_2']

(tensor(16.0130, grad_fn=<NllLossBackward>),
 tensor(15.4830, grad_fn=<NllLossBackward>),
 tensor(14.7046, grad_fn=<NllLossBackward>))

### load dataset 

In [77]:
class Holder:
    def __init__(self):
        pass

In [78]:
input_seq_len = 512
target_seq_len = 1024
batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''
device = 'cpu'

### Encoder-decoder

In [79]:
global_attention_first_token = False  # should be True for LED
encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}
# generate_kwargs = {'max_length': args.target_seq_len, 'min_length': args.target_seq_len}
generate_kwargs = {}

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    if 'outputs' in batch[0]:
        # if we have more than 1 label per example (only in valid) take only one of them
        # to compute loss on valid
        labels = [b['outputs'][0][:args.target_seq_len * 10] for b in batch]
    else:
        labels = [b['output'][:args.target_seq_len * 10] for b in batch]
    if args.input_prefix:
        inputs = [args.input_prefix + inp for inp in inputs]
    features = tokenizer.batch_encode_plus(list(inputs), max_length=args.input_seq_len, return_tensors='pt',
                                           **encode_plus_kwargs)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer.batch_encode_plus(list(labels), max_length=args.target_seq_len, return_tensors='pt',
                                             **encode_plus_kwargs).input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    features['labels'] = labels
    features['id'] = [b['id'] for b in batch]
    if 'outputs' in batch[0]:
        features['target_text'] = [b['outputs'] for b in batch]
    else:
        features['target_text'] = [b['output'] for b in batch]
    if 'global_attention_mask' in features:
        raise RuntimeError('What global attention mask for Longformer and LongformerEncoder-Decoder should be?')
    return features

In [80]:
task_name = 'qasper'
dataset = datasets.load_dataset('tau/scrolls', task_name)
train_dataset = dataset['train']

train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataset = dataset['validation']
valid_sampler = RandomSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/qasper/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [81]:
gen = iter(train_dataloader)
sample = next(gen)

if 'id' in sample:
    id = sample.pop('id')
if 'target_text' in sample:
    tgt_text = sample.pop('target_text')

for k in sample:
    sample[k] = sample[k].to(device)
    
sample_input_ids = sample.pop('input_ids').to(device)
kwargs = sample

### Encoder

In [53]:
input_seq_len = 1024
target_seq_len = 512
batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''
device = 'cpu'

In [54]:
encode_plus_kwargs = {'max_length': args.input_seq_len,
                        'truncation': True,
                        'padding': 'longest',
                        'pad_to_multiple_of': 1}
generate_kwargs = {}
labels_map = {'Contradiction': 0, 'Entailment': 1, 'Not mentioned': 2}
num_labels = len(labels_map)

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    labels = [b['output'][:args.target_seq_len * 10] for b in batch]
    if args.input_prefix:
        inputs = [args.input_prefix + inp for inp in inputs]
    features = tokenizer.batch_encode_plus(list(inputs), return_tensors='pt', **encode_plus_kwargs)
    labels = np.array([labels_map[t] for t in labels])
    features['labels'] = torch.from_numpy(labels)
    return features

In [55]:
task_name = 'contract_nli'
dataset = datasets.load_dataset('tau/scrolls', task_name)
train_dataset = dataset['train']

train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataset = dataset['validation']
valid_sampler = RandomSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/contract_nli/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [56]:
gen = iter(train_dataloader)
sample = next(gen)

if 'id' in sample:
    id = sample.pop('id')
if 'target_text' in sample:
    tgt_text = sample.pop('target_text')

rmt.to(device)
for k in sample:
    sample[k] = sample[k].to(device)
    
sample_input_ids = sample.pop('input_ids').to(device)
kwargs = sample

In [30]:
out = rmt(sample_input_ids, **kwargs, output_hidden_states=True, output_attentions = True)
out.keys()

ValueError: Expected input batch_size (2) to match target batch_size (12).

# toy T-BPTT example from source
https://discuss.pytorch.org/t/truncated-backprop-data-clarification/34854

In [None]:
class TBPTT():
    def __init__(self, one_step_module, loss_module, k1, k2, optimizer):
        self.one_step_module = one_step_module
        self.loss_module = loss_module
        self.k1 = k1 # update period
        self.k2 = k2 # steps to use in bptt
        self.retain_graph = k1 < k2
        # You can also remove all the optimizer code here, and the
        # train function will just accumulate all the gradients in
        # one_step_module parameters
        self.optimizer = optimizer

    def train(self, input_sequence, init_state):
        states = [(None, init_state)]
        for j, (inp, target) in enumerate(input_sequence):

            state = states[-1][1].detach()
            state.requires_grad=True
            output, new_state = self.one_step_module(inp, state)
            states.append((state, new_state))

            while len(states) > self.k2:
                # Delete stuff that is too old
                del states[0]

            if (j+1)%self.k1 == 0:
                loss = self.loss_module(output, target)

                optimizer.zero_grad()
                # backprop last module (keep graph only if they ever overlap)
                start = time.time()
                loss.backward(retain_graph=self.retain_graph)
                for i in range(self.k2-1):
                    # if we get all the way back to the "init_state", stop
                    if states[-i-2][0] is None:
                        break
                    curr_grad = states[-i-1][0].grad
                    states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
                print("bw: {}".format(time.time()-start))
                optimizer.step()



seq_len = 20
layer_size = 50

idx = 0

class MyMod(nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.lin = nn.Linear(2*layer_size, 2*layer_size)

    def forward(self, inp, state):
        global idx
        full_out = self.lin(torch.cat([inp, state], 1))
        # out, new_state = full_out.chunk(2, dim=1)
        out = full_out.narrow(1, 0, layer_size)
        new_state = full_out.narrow(1, layer_size, layer_size)
        def get_pr(idx_val):
            def pr(*args):
                print("doing backward {}".format(idx_val))
            return pr
        new_state.register_hook(get_pr(idx))
        out.register_hook(get_pr(idx))
        print("doing fw {}".format(idx))
        idx += 1
        return out, new_state


one_step_module = MyMod()
loss_module = nn.MSELoss()
input_sequence = [(torch.rand(200, layer_size), torch.rand(200, layer_size))] * seq_len

optimizer = torch.optim.SGD(one_step_module.parameters(), lr=1e-3)

runner = TBPTT(one_step_module, loss_module, 5, 7, optimizer)

runner.train(input_sequence, torch.zeros(200, layer_size))
print("done")