In [1]:
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification, T5ForConditionalGeneration
from transformers import AutoTokenizer
import datasets

import math
from matplotlib import pyplot as plt


from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

### toy example


In [25]:
class RMT(torch.nn.Module):
    def __init__(self, mem_size=10, input_size=512, dim=100, out_size=2, 
                #  k1=None, # backward is done every k1 segments
                #  k2=None, # backward uses k2 last segments
                 ):
        super().__init__()
        self.embedding = torch.nn.Embedding(input_size, dim)
        self.weight = torch.nn.Linear(dim, dim)
        self.cls = torch.nn.Linear(dim, out_size)
        
        memory = torch.nn.Parameter(torch.randn(mem_size, dim), requires_grad=True)
        self.register_parameter('memory', memory)
        
        self.dim = dim
        self.mem_size = mem_size
        self.out_size = out_size
        # self.k1 = k1
        # self.k2 = k2
    
    def _forward(self, embedded, labels):
        hiddens = self.weight(embedded)        
        processed_hiddens = hiddens.mean(dim=1)
        logits = self.cls(processed_hiddens)
                
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits.softmax(dim=1), labels.squeeze(1))
        
        return loss, hiddens

    def forward(self, segments, labels):
        
        init_memory = self.memory.unsqueeze(0).repeat(segments[0].shape[0], 1, 1)

        memory_states = [(None, init_memory)]
        losses = []
        for seg_num, X in enumerate(segments):
            memory = memory_states[-1][1].detach()
            memory.requires_grad = True

            embedded = self.embedding(X)
            embedded[:, :self.mem_size] = memory
            
            loss, hiddens = self._forward(embedded, labels)
            new_memory = hiddens[:, :self.mem_size]

            losses.append(loss)

            memory_states.append((memory, new_memory))

        self.memory_states = memory_states
        
        return losses

In [26]:
num_classes = 2
batch_size = 2

n_segments = 5
mem_size = 10

input_size = 512
# out_size = 1
dim = 100

labels = torch.randint(num_classes, (batch_size, 1))


segments = [torch.randint(100, (batch_size, input_size)) for _ in range(n_segments)]

rmt = RMT(mem_size, input_size=input_size, dim=dim, out_size=num_classes)#, k1=n_segments, k2=1)

In [27]:
losses = rmt(segments, labels)

In [28]:
k1 = n_segments
k2 = 1

In [None]:
def truncated_backward(losses, memory_states,
                        k1=-1, # update is done every k1 segments
                        k2=-1, # backward uses k2 last segments
                        ):
    losses[-1].backward(retain_graph=False)
    for i in range(k2 - 1 if k2 != -1 else len(rmt.memory_states) - 1):
        # if we get all the way back to the "init_memory", stop
        if rmt.memory_states[-i-2][0] is None:
            break
        curr_grad = rmt.memory_states[-i-1][0].grad
        memory_states[-i-2][1].backward(curr_grad, retain_graph=False)

    if k1 != -1:
        raise NotImplementedError
        # if (j+1)%self.k1 == 0:
        #     loss = self.loss_module(output, target)

        #     optimizer.zero_grad()
        #     # backprop last module (keep graph only if they ever overlap)
        #     start = time.time()
        #     loss.backward(retain_graph=self.retain_graph)
        #     for i in range(self.k2-1):
        #         # if we get all the way back to the "init_state", stop
        #         if states[-i-2][0] is None:
        #             break
        #         curr_grad = states[-i-1][0].grad
        #         states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
        #     optimizer.step()

In [244]:
rmt.memory.grad

In [224]:
losses = rmt(segments, labels)
losses[0].backward()
print('memory gradient: {}\nweight grad: {}\ncls grad: {}'.format(rmt.memory.grad.mean(), rmt.weight.weight.grad.mean(), rmt.cls.weight.grad.mean()))

AttributeError: 'NoneType' object has no attribute 'mean'

In [210]:
rmt.zero_grad()
grad_from_loss = []
for i, seg in enumerate(segments):
    losses = rmt(segments, labels)
    losses[i].backward()
    grad_from_loss.append(torch.clone(rmt.memory.grad))

In [217]:
(grad_from_loss[0] + grad_from_loss[1] - grad_from_loss[2]).abs().mean(), grad_from_loss[0].abs().mean()

(tensor(2.0285e-05), tensor(2.1588e-05))

In [70]:
rmt.zero_grad()
print('memory gradient: {}\nweight grad: {}\ncls grad: {}'.format(rmt.memory.grad.mean(), rmt.weight.weight.grad.mean(), rmt.cls.weight.grad.mean()))

memory gradient: 0.0
weight grad: 0.0
cls grad: 0.0


In [197]:
losses = rmt(segments, labels)
losses[-1].backward()
print('memory gradient: {}\nweight grad: {}\ncls grad: {}'.format(rmt.memory.grad.mean(), rmt.weight.weight.grad.mean(), rmt.cls.weight.grad.mean()))

memory gradient: 9.23935294849798e-06
weight grad: -0.00039320866926573217
cls grad: -1.19209286886246e-09


In [198]:
losses

[tensor(0.6648, grad_fn=<NllLossBackward>),
 tensor(0.6615, grad_fn=<NllLossBackward>),
 tensor(0.6628, grad_fn=<NllLossBackward>),
 tensor(0.6644, grad_fn=<NllLossBackward>),
 tensor(0.6562, grad_fn=<NllLossBackward>)]

In [15]:
rmt = RMT(mem_size, input_size=input_size, dim=dim, out_size=num_classes)
losses = rmt(segments, labels)
losses[0].backward()
print('memory gradient: {}\nweight grad: {}\ncls grad: {}'.format(rmt.memory.grad.mean(), rmt.weight.weight.grad.mean(), rmt.cls.weight.grad.mean()))

memory gradient: -9.052487257577013e-07
weight grad: 5.410353696788661e-05
cls grad: 3.725290215195187e-11


In [None]:
rmt = RMT(mem_size, input_size=input_size, dim=dim, out_size=num_classes)
losses = rmt(segments, labels)
losses[0].backward()
print('memory gradient: {}\nweight grad: {}\ncls grad: {}'.format(rmt.memory.grad.mean(), rmt.weight.weight.grad.mean(), rmt.cls.weight.grad.mean()))

memory gradient: -9.052487257577013e-07
weight grad: 5.410353696788661e-05
cls grad: 3.725290215195187e-11


In [16]:
rmt.zero_grad()

In [17]:
print('memory gradient: {}\nweight grad: {}\ncls grad: {}'.format(rmt.memory.grad.mean(), rmt.weight.weight.grad.mean(), rmt.cls.weight.grad.mean()))

memory gradient: -9.052487257577013e-07
weight grad: 0.0
cls grad: 0.0


### Finetune

In [2]:
from transformers import AutoModelForSequenceClassification, T5ForConditionalGeneration
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')
from modeling_rmt import RMTEncoderForSequenceClassification
from modeling_rmt import RMTEncoderDecoderForConditionalGeneration

In [3]:
from transformers import AutoTokenizer, AutoModelForTokenClassification

In [4]:
import torch
import torch.nn.functional as F
from modeling_rmt.base import RMTBaseModel
class RMTEncoderTBPTT(RMTEncoderForSequenceClassification):

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        kwargs = {'attention_mask': attention_mask, 'token_type_ids': token_type_ids,
                  'position_ids': position_ids, 'inputs_embeds': inputs_embeds,
                  'labels': labels, 'output_attentions': output_attentions,
                  'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
                  }

        init_memory = self.set_memory(input_ids.shape)
        segmented = self.pad_and_segment(input_ids)
        if self.num_mem_tokens == 0:
            segmented = segmented[-1:]

        memory_states = [(None, init_memory)]
        base_model_outputs = []
        for seg_num, segment_input_ids in enumerate(segmented):
            seg_kwargs, non_empty_mask = self.prepare_kwargs(segment_input_ids, kwargs)
            if sum(non_empty_mask) < 1:
                raise NotImplementedError
            
            if sum(non_empty_mask) == 0:
                continue
            
            memory = memory_states[-1][1].detach()
            memory.requires_grad = True
            # new_memory = memory
            # new_memory.requires_grad = True, False
            # print('memory', memory.mean(), memory.std(), memory)
            seg_kwargs['inputs_embeds'][:, self.memory_position] = memory[non_empty_mask]
            out = self.model(**seg_kwargs)
            base_model_outputs.append(out)
            
            self.memory_states = memory_states
            new_memory = out.hidden_states[-1][:, self.memory_position]
            memory_states.append((memory, new_memory))
        
        self.memory_states = memory_states

        out = self.process_outputs(base_model_outputs, output_attentions, output_hidden_states)
        return out
    
    def truncated_backward(self, k1, k2):
        memory_states = self.memory_states
        if k1 != -1:
            raise NotImplementedError
        
        for i in range(k2 - 1 if k2 != -1 else len(memory_states)):
            curr_grad = memory_states[-i-1][0].grad
            memory_states[-i-2][1].backward(curr_grad, retain_graph=False)

            # if we get all the way back to the "init_memory", stop
            if memory_states[-i-2][0] is None:
                break

In [5]:
for _ in range(0):
    print(12)

In [6]:
num_segments = 2
num_mem_tokens = 10
# device = torch.device(3)
device = 'cpu'

In [7]:
model_name = 'bert-base-cased'
# model_name = 't5-base'

tokenizer = AutoTokenizer.from_pretrained(model_name)

rmt_config = {'num_mem_tokens': num_mem_tokens, 
                'max_n_segments': num_segments,
                'tokenizer': tokenizer,
               #  'memory_layers': 'all', 
               #  'share_memory_layers': True,
               #  'reconstruction_loss_coef': 0.1,
                'segment_ordering': 'regular',
                'input_size': 512, 
                'bptt_depth': -1, 
                'sum_loss': False,
             }

base_model1 = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
base_model2 = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
# base_model = T5ForConditionalGeneration.from_pretrained(model_name)
# rmt = RMTEncoderDecoderForConditionalGeneration(base_model, **rmt_config)
rmt1 = RMTEncoderForSequenceClassification(base_model1, **rmt_config)
# rmt2 = RMTEncoderForSequenceClassification(base_model2, **rmt_config)
rmt2 = RMTEncoderTBPTT(base_model2, **rmt_config)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [44]:
# rmt2.memory_states[0]

In [45]:
rmt1.model.embeddings.weight.data = rmt2.model.embeddings.weight.data
rmt1.model.classifier.weight.data = rmt2.model.classifier.weight.data

In [46]:
# cpt_path = "../../runs/framework/qasper/t5-base/lr5e-05_constant_with_warmup_adamw_wd1e-03_512-1024-{1}seg_memNA_bs32_iters5000_regular/run_10/"
# model_cpt = os.path.join(cpt_path, "model_best.pth")
# cpt = torch.load(model_cpt, map_location='cpu')
# base_model.load_state_dict(cpt['model_state_dict'])

In [47]:
# for (n1, p1), (n2, p2) in zip(rmt1.named_parameters(), rmt2.named_parameters()):
#     # if getattr(rmt2, n) != p:
#     #     print(n)
#     if (p1 == p2).float().mean() < 1:
#         print(n1)

In [55]:
out1 = rmt1(sample_input_ids, **kwargs, output_hidden_states=False, output_attentions=False)
out2 = rmt2(sample_input_ids, **kwargs, output_hidden_states=False, output_attentions=False)
out1.loss, out2.loss

(tensor(0.9533, grad_fn=<NllLossBackward>),
 tensor(0.9533, grad_fn=<NllLossBackward>))

In [35]:
rmt2.memory_states[-2][1].grad

  rmt2.memory_states[-2][1].grad


In [49]:
# out = rmt(sample_input_ids, **kwargs, output_hidden_states=False, output_attentions=False)
# out.keys()

In [50]:
out1.loss.backward()

In [58]:
out2.loss.backward()

In [52]:
rmt1.model.embeddings.weight.grad[rmt1.mem_token_ids] - rmt2.model.embeddings.weight.grad[rmt1.mem_token_ids]

tensor([[-1.9897e-03, -2.3091e-03,  1.8419e-03,  ...,  2.9464e-03,
         -6.8678e-04, -5.5236e-06],
        [ 4.1039e-04, -2.2907e-03, -4.0704e-03,  ..., -2.6411e-03,
         -4.5289e-04,  1.3317e-03],
        [-4.1274e-04, -3.8551e-04, -2.8217e-03,  ...,  4.2580e-04,
         -1.1218e-03, -1.6782e-03],
        ...,
        [ 4.0664e-05, -2.9215e-03,  1.5561e-03,  ...,  8.5625e-04,
          3.1595e-04, -1.4923e-03],
        [ 2.7264e-03,  4.9132e-04, -6.6655e-03,  ...,  4.7209e-03,
         -1.2776e-03, -4.0603e-03],
        [ 3.0761e-03, -3.3016e-03, -6.4506e-03,  ..., -3.9267e-03,
         -7.8163e-03, -9.2225e-04]])

In [59]:
rmt2.truncated_backward(k1=-1, k2=2)

In [60]:
rmt1.model.embeddings.weight.grad[rmt1.mem_token_ids] - rmt2.model.embeddings.weight.grad[rmt1.mem_token_ids]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [66]:
len(memory_states)

4

In [64]:
# rmt.truncated_backward(k1=-1, k2=1)

k1 = -1
k2 = 1
self = rmt
memory_states = self.memory_states
if k1 != -1:
    raise NotImplementedError


In [74]:
i = 1

In [75]:
memory_states[-i-2][0].grad

In [77]:

curr_grad = memory_states[-i-1][0].grad

memory_states[-i-2][1].backward(curr_grad, retain_graph=False)

In [78]:
memory_states[-i-2][0].grad

tensor([[[ 8.7065e-04,  1.6680e-04, -1.2168e-03,  ..., -1.7849e-04,
           5.4806e-04,  2.0493e-03],
         [-1.5455e-03, -1.3318e-03, -6.9536e-03,  ..., -2.6797e-03,
          -6.9923e-03,  1.3776e-03],
         [ 7.5854e-03,  3.5826e-03, -2.6303e-03,  ..., -2.0050e-03,
           1.8138e-03, -3.5395e-03],
         ...,
         [-2.2143e-03, -1.2456e-03,  3.8501e-03,  ...,  2.7255e-03,
          -1.9535e-03,  2.2969e-04],
         [ 6.5699e-03,  3.4641e-03, -2.5996e-03,  ..., -3.7634e-03,
           4.2661e-03, -3.9519e-03],
         [ 5.8664e-03, -5.6410e-03,  4.6935e-03,  ...,  1.3502e-04,
          -4.6932e-03, -3.7547e-03]],

        [[-1.1527e-03, -1.1127e-06,  2.6598e-03,  ..., -7.1365e-03,
           2.9748e-04,  5.6316e-03],
         [ 1.7341e-03,  5.5398e-03,  6.3473e-03,  ...,  1.0567e-03,
           2.5097e-03,  4.1832e-03],
         [-2.7789e-03, -1.8359e-03,  2.4206e-03,  ...,  1.8355e-03,
           9.1351e-04,  1.0039e-02],
         ...,
         [-1.6219e-02, -1

In [None]:

for i in range(k2 - 1 if k2 != -1 else len(memory_states)):
    print()
    curr_grad = memory_states[-i-1][0].grad
    memory_states[-i-2][1].backward(curr_grad, retain_graph=False)

    # if we get all the way back to the "init_memory", stop
    if memory_states[-i-2][0] is None:
        break

In [53]:
rmt.model.embeddings.weight.shape

torch.Size([29006, 768])

In [57]:
rmt.model.embeddings.weight.grad[rmt.mem_token_ids]#.abs().sum()

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [60]:
rmt.model.embeddings.weight.grad[sample_input_ids]#.abs().sum()

tensor([[[-7.0348e-04, -5.1936e-03, -1.7022e-03,  ..., -4.8533e-04,
          -1.4215e-03,  2.2574e-03],
         [-4.5974e-03,  5.1353e-03, -5.1554e-05,  ...,  9.8471e-03,
          -8.8535e-04,  6.8188e-03],
         [-1.9443e-03,  7.3353e-04, -3.4847e-03,  ...,  3.4695e-04,
           1.9930e-03, -3.3250e-03],
         ...,
         [ 1.4958e-03, -1.0761e-02,  2.7553e-02,  ...,  1.3110e-02,
           8.6204e-03,  1.2122e-02],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-4.0009e-03,  5.6785e-03, -3.3421e-03,  ..., -1.6336e-02,
          -6.6896e-04,  8.5934e-03]],

        [[-7.0348e-04, -5.1936e-03, -1.7022e-03,  ..., -4.8533e-04,
          -1.4215e-03,  2.2574e-03],
         [-8.5180e-04,  4.4820e-04, -3.6088e-04,  ..., -7.8554e-05,
          -2.9245e-05, -6.2273e-04],
         [-1.1675e-03,  9.3001e-04, -1.3836e-04,  ...,  5.5215e-04,
           1.0142e-03,  3.8922e-04],
         ...,
         [-1.2561e-03, -3

In [None]:
rmt.model.embeddings.weight.grad.abs().sum()

tensor(251.0535)

In [250]:
out['loss_0'], out['loss_1'], out['loss_2']

(tensor(16.0130, grad_fn=<NllLossBackward>),
 tensor(15.4830, grad_fn=<NllLossBackward>),
 tensor(14.7046, grad_fn=<NllLossBackward>))

### load dataset 

In [7]:
class Holder:
    def __init__(self):
        pass

In [8]:
input_seq_len = 512
target_seq_len = 1024
batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''
device = 'cpu'

### Encoder-decoder

In [9]:
global_attention_first_token = False  # should be True for LED
encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}
# generate_kwargs = {'max_length': args.target_seq_len, 'min_length': args.target_seq_len}
generate_kwargs = {}

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    if 'outputs' in batch[0]:
        # if we have more than 1 label per example (only in valid) take only one of them
        # to compute loss on valid
        labels = [b['outputs'][0][:args.target_seq_len * 10] for b in batch]
    else:
        labels = [b['output'][:args.target_seq_len * 10] for b in batch]
    if args.input_prefix:
        inputs = [args.input_prefix + inp for inp in inputs]
    features = tokenizer.batch_encode_plus(list(inputs), max_length=args.input_seq_len, return_tensors='pt',
                                           **encode_plus_kwargs)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer.batch_encode_plus(list(labels), max_length=args.target_seq_len, return_tensors='pt',
                                             **encode_plus_kwargs).input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    features['labels'] = labels
    features['id'] = [b['id'] for b in batch]
    if 'outputs' in batch[0]:
        features['target_text'] = [b['outputs'] for b in batch]
    else:
        features['target_text'] = [b['output'] for b in batch]
    if 'global_attention_mask' in features:
        raise RuntimeError('What global attention mask for Longformer and LongformerEncoder-Decoder should be?')
    return features

In [16]:
task_name = 'qasper'
dataset = datasets.load_dataset('tau/scrolls', task_name)
train_dataset = dataset['train']

train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataset = dataset['validation']
valid_sampler = RandomSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/qasper/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
gen = iter(train_dataloader)
sample = next(gen)

if 'id' in sample:
    id = sample.pop('id')
if 'target_text' in sample:
    tgt_text = sample.pop('target_text')

for k in sample:
    sample[k] = sample[k].to(device)
    
sample_input_ids = sample.pop('input_ids').to(device)
kwargs = sample

### Encoder

In [19]:
input_seq_len = 1536
target_seq_len = 3
batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''
device = 'cpu'

In [20]:
encode_plus_kwargs = {'max_length': args.input_seq_len,
                        'truncation': True,
                        'padding': 'longest',
                        'pad_to_multiple_of': 1}
generate_kwargs = {}
labels_map = {'Contradiction': 0, 'Entailment': 1, 'Not mentioned': 2}
num_labels = len(labels_map)

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    labels = [b['output'][:args.target_seq_len * 10] for b in batch]
    if args.input_prefix:
        inputs = [args.input_prefix + inp for inp in inputs]
    features = tokenizer.batch_encode_plus(list(inputs), return_tensors='pt', **encode_plus_kwargs)
    labels = np.array([labels_map[t] for t in labels])
    features['labels'] = torch.from_numpy(labels)
    return features

In [21]:
import datasets
task_name = 'contract_nli'
dataset = datasets.load_dataset('tau/scrolls', task_name)
train_dataset = dataset['train']

train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataset = dataset['validation']
valid_sampler = RandomSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/contract_nli/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [22]:
gen = iter(train_dataloader)
sample = next(gen)

if 'id' in sample:
    id = sample.pop('id')
if 'target_text' in sample:
    tgt_text = sample.pop('target_text')

# rmt.to(device)
for k in sample:
    sample[k] = sample[k].to(device)
    
sample_input_ids = sample.pop('input_ids').to(device)
kwargs = sample

In [30]:
out = rmt(sample_input_ids, **kwargs, output_hidden_states=True, output_attentions = True)
out.keys()

ValueError: Expected input batch_size (2) to match target batch_size (12).

# toy T-BPTT example from source
https://discuss.pytorch.org/t/truncated-backprop-data-clarification/34854

In [24]:
class TBPTT():
    def __init__(self, one_step_module, loss_module, k1, k2, optimizer):
        self.one_step_module = one_step_module
        self.loss_module = loss_module
        self.k1 = k1 # update period
        self.k2 = k2 # steps to use in bptt
        self.retain_graph = k1 < k2
        # You can also remove all the optimizer code here, and the
        # train function will just accumulate all the gradients in
        # one_step_module parameters
        self.optimizer = optimizer

    def train(self, input_sequence, init_state):
        states = [(None, init_state)]
        for j, (inp, target) in enumerate(input_sequence):

            state = states[-1][1].detach()
            state.requires_grad=True
            output, new_state = self.one_step_module(inp, state)
            states.append((state, new_state))

            while len(states) > self.k2:
                # Delete stuff that is too old
                del states[0]

            if (j+1)%self.k1 == 0:
                loss = self.loss_module(output, target)

                optimizer.zero_grad()
                # backprop last module (keep graph only if they ever overlap)
                start = time.time()
                loss.backward(retain_graph=self.retain_graph)
                for i in range(self.k2-1):
                    # if we get all the way back to the "init_state", stop
                    if states[-i-2][0] is None:
                        break
                    curr_grad = states[-i-1][0].grad
                    states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
                print("bw: {}".format(time.time()-start))
                optimizer.step()



seq_len = 20
layer_size = 50

idx = 0

class MyMod(nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.lin = nn.Linear(2*layer_size, 2*layer_size)

    def forward(self, inp, state):
        global idx
        full_out = self.lin(torch.cat([inp, state], 1))
        # out, new_state = full_out.chunk(2, dim=1)
        out = full_out.narrow(1, 0, layer_size)
        new_state = full_out.narrow(1, layer_size, layer_size)
        def get_pr(idx_val):
            def pr(*args):
                print("doing backward {}".format(idx_val))
            return pr
        new_state.register_hook(get_pr(idx))
        out.register_hook(get_pr(idx))
        print("doing fw {}".format(idx))
        idx += 1
        return out, new_state


one_step_module = MyMod()
loss_module = nn.MSELoss()
input_sequence = [(torch.rand(200, layer_size), torch.rand(200, layer_size))] * seq_len

optimizer = torch.optim.SGD(one_step_module.parameters(), lr=1e-3)

runner = TBPTT(one_step_module, loss_module, 5, 7, optimizer)

runner.train(input_sequence, torch.zeros(200, layer_size))
print("done")

NameError: name 'nn' is not defined