In [1]:
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification, T5ForConditionalGeneration
from transformers import AutoTokenizer
import datasets

import math
from matplotlib import pyplot as plt


from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

### Finetune

In [2]:
from transformers import AutoModelForSequenceClassification
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')
# from modeling_rmt import RMTEncoderForSequenceClassification
# from modeling_rmt_enc_dec import RMTEncoderDecoderForConditionalGeneration

In [4]:
# import torch
# import torch.nn.functional as F
# # from .base import RMTBaseModel
# sys.path.append('..')

# from modeling_rmt.conditional_generation import RMTEncoderDecoderForConditionalGeneration

# class RMTEncoderDecoderMemoryOutput(RMTEncoderDecoderForConditionalGeneration):
#     def forward(self, input_ids, attention_mask=None, position_ids=None, head_mask=None,
#                 inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None):
#         kwargs = {'attention_mask': attention_mask,
#                 #   'position_ids': position_ids, 
#                   'inputs_embeds': inputs_embeds,
#                   'labels': labels, 'output_attentions': output_attentions,
#                   'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
#                   }

#         memory = self.set_memory(input_ids.shape)
#         segmented = self.pad_and_segment(input_ids)

#         memories = []
#         base_model_outputs = []
#         for seg_num, segment_input_ids in enumerate(segmented):                
#             if self.rmt_config['bptt_depth'] != -1:
#                 raise NotImplementedError

#             seg_kwargs, non_empty_mask = self.prepare_kwargs(segment_input_ids, kwargs)
#             if sum(non_empty_mask) == 0:
#                 continue
            
#             seg_kwargs['inputs_embeds'][:, self.memory_position] = memory[non_empty_mask]
#             out = self.model(**seg_kwargs)
#             base_model_outputs.append(out)
            
#             memory[non_empty_mask] = out.encoder_hidden_states[-1][:, self.memory_position]
#             memories.append(torch.clone(memory))

#             if seg_num == len(segmented) - 1:
#                 memories = torch.cat(memories, dim=1)
#                 decoder_input_ids = self.model._shift_right(labels)

#                 out = self.model.decoder(input_ids=decoder_input_ids, encoder_hidden_states=memories)
#                 base_model_outputs.append(out)

#         out = self.process_outputs(base_model_outputs, output_attentions, output_hidden_states)
#         return out

#     def generate(self, input_ids, attention_mask=None, position_ids=None, head_mask=None,
#                 inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None,
#                 min_length=None, max_length=None):
#         kwargs = {'attention_mask': attention_mask,
#                   'inputs_embeds': inputs_embeds,
#                   'output_attentions': output_attentions,
#                   'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
#                   'min_length': min_length, 'max_length': max_length
#                   }

#         memory = self.set_memory(input_ids.shape)
#         segmented = self.pad_and_segment(input_ids)

#         memories = []
#         for seg_num, segment_input_ids in enumerate(segmented):                
#             if self.rmt_config['bptt_depth'] != -1:
#                 raise NotImplementedError

#             seg_kwargs, non_empty_mask = self.prepare_kwargs(segment_input_ids, kwargs)
#             seg_kwargs['inputs_embeds'][:, self.memory_position] = memory[non_empty_mask]
#             if sum(non_empty_mask) == 0:
#                 continue

#             for param in ['min_length', 'max_length']:
#                 if param in seg_kwargs:
#                     seg_kwargs.pop(param)
                    
#             out = self.model.encoder(**seg_kwargs)
#             memory[non_empty_mask] = out.last_hidden_state[:, self.memory_position]
#             memories.append(torch.clone(memory))

#             if seg_num == len(segmented) - 1:                
#                 memories = torch.cat(memories, dim=1)
#                 out = self.model.generate(**seg_kwargs, encoder_hidden_states=memories)

#         return out

In [54]:
from modeling_rmt.conditional_generation import *
from torch.nn import CrossEntropyLoss
class RMTEncoderDecoderFullMemoryLastSeg(RMTEncoderDecoderMemoryLayers):
    def forward(self, input_ids, attention_mask=None, position_ids=None, head_mask=None,
                inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        kwargs = {'attention_mask': attention_mask,
                #   'position_ids': position_ids, 
                  'inputs_embeds': inputs_embeds,
                  'labels': labels, 'output_attentions': output_attentions,
                  'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
                  }

        memory = self.set_memory(input_ids.shape)
        segmented = self.pad_and_segment(input_ids)

        memories = []
        base_model_outputs = []
        for seg_num, segment_input_ids in enumerate(segmented):                
            if self.rmt_config['bptt_depth'] != -1:
                raise NotImplementedError

            seg_kwargs, non_empty_mask = self.prepare_kwargs(segment_input_ids, kwargs)
            if sum(non_empty_mask) == 0:
                continue
            
            seg_kwargs['inputs_embeds'][:, self.memory_position] = memory[non_empty_mask]
            out = self.model(**seg_kwargs)
            base_model_outputs.append(out)
            
            memory[non_empty_mask] = out.encoder_hidden_states[-1][:, self.memory_position]
            memories.append(torch.clone(memory))

        hidden_states = torch.cat(memories[:-1] + [out.encoder_hidden_states[-1]], dim=1)
        decoder_input_ids = self.model._shift_right(labels)
        decoder_outputs = self.model.decoder(input_ids=decoder_input_ids, 
                                             encoder_hidden_states=hidden_states, 
                                             output_hidden_states=output_hidden_states, 
                                             output_attentions=output_attentions)
        base_model_outputs.append(decoder_outputs)

        out = self.process_outputs(base_model_outputs, output_attentions, output_hidden_states)

        sequence_output = decoder_outputs[0]
        # Set device for model parallelism
        if self.model.model_parallel:
            torch.cuda.set_device(self.model.encoder.first_device)
            self.model.lm_head = self.model.lm_head.to(self.model.encoder.first_device)
            sequence_output = sequence_output.to(self.model.lm_head.weight.device)

        if self.model.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model.model_dim**-0.5)

        lm_logits = self.model.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
        
        out['loss'] = loss

        return out

    def generate(self, input_ids, attention_mask=None, position_ids=None, head_mask=None,
                inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None,
                min_length=None, max_length=None):
        kwargs = {'attention_mask': attention_mask,
                  'inputs_embeds': inputs_embeds,
                  'output_attentions': output_attentions,
                  'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
                  'min_length': min_length, 'max_length': max_length
                  }

        memory = self.set_memory(input_ids.shape)
        segmented = self.pad_and_segment(input_ids)

        memories = []
        for seg_num, segment_input_ids in enumerate(segmented):                
            if self.rmt_config['bptt_depth'] != -1:
                raise NotImplementedError

            seg_kwargs, non_empty_mask = self.prepare_kwargs(segment_input_ids, kwargs)
            if sum(non_empty_mask) == 0:
                continue
            seg_kwargs['inputs_embeds'][:, self.memory_position] = memory[non_empty_mask]

            for param in ['min_length', 'max_length']:
                if param in seg_kwargs:
                    seg_kwargs.pop(param)
                    
            out = self.model.encoder(**seg_kwargs)
            memory[non_empty_mask] = out.last_hidden_state[:, self.memory_position]
            memories.append(torch.clone(memory))

        hidden_states = torch.cat(memories[:-1] + [out.last_hidden_state], dim=1)
        out = self.model.generate(**seg_kwargs, encoder_hidden_states=hidden_states)

        return out

In [55]:
num_segments = 2
num_mem_tokens = 10
# device = torch.device(3)
device = 'cpu'

In [56]:
from transformers import AutoTokenizer, AutoModelForTokenClassification

In [57]:
from rmt_utils.encoder_decoder.horizontal_memory import horizontal_memory_forward as memory_forward_func
# model_name = "facebook/bart-base"
model_name = 't5-small'

tokenizer = AutoTokenizer.from_pretrained(model_name)

rmt_config = {'num_mem_tokens': 10, 
                'max_n_segments': 3,
               #  'segment_alignment': 'right',
                'tokenizer': tokenizer,
                'memory_layers': 'all', 
               #  'memory_forward_func': memory_layers_func,
                'share_memory_layers': True,
                'reconstruction_loss_coef': 0.1,
                'segment_ordering': 'regular',
                'input_size': 512, 
                'bptt_depth': -1, 
                'sum_loss': False,
             }

base_model = T5ForConditionalGeneration.from_pretrained(model_name)

# rmt = RMTEncoderDecoderMemoryOutput(base_model, **rmt_config)
rmt = RMTEncoderDecoderFullMemoryLastSeg(base_model, **rmt_config)
# rmt = RMTEncoderDecoderForConditionalGeneration(base_model, **rmt_config)
# rmt = RMTEncoderDecoderMemoryLoss(base_model, **rmt_config) # does not work
# rmt = RMTEncoderDecoderMemoryLayers(base_model, **rmt_config)
# rmt = RMTEncoderDecoderHorizontalMemory(base_model, **rmt_config)
# rmt.to(device)



In [58]:
out = rmt(sample_input_ids, **kwargs)#, output_hidden_states=True, output_attentions=True)
out.keys()

odict_keys(['last_hidden_state', 'past_key_values', 'hidden_states', 'attentions', 'cross_attentions', 'loss_0', 'decoder_hidden_states_0', 'decoder_attentions_0', 'cross_attentions_0', 'encoder_hidden_states_0', 'encoder_attentions_0', 'loss_1', 'decoder_hidden_states_1', 'decoder_attentions_1', 'cross_attentions_1', 'encoder_hidden_states_1', 'encoder_attentions_1', 'loss_2', 'decoder_hidden_states_2', 'decoder_attentions_2', 'cross_attentions_2', 'encoder_hidden_states_2', 'encoder_attentions_2', 'hidden_states_3', 'attentions_3', 'cross_attentions_3', 'loss'])

In [51]:
out.encoder_last_hidden_state.shape

torch.Size([2, 512, 512])

In [61]:
out.cross_attentions_3[-1].shape

torch.Size([2, 8, 5, 532])

In [59]:
out.encoder_attentions[-1].shape, out.cross_attentions[-1].shape, out.decoder_attentions[-1].shape

AttributeError: 'BaseModelOutputWithPastAndCrossAttentions' object has no attribute 'encoder_attentions'

In [52]:
out.keys()

odict_keys(['loss', 'logits', 'past_key_values', 'decoder_hidden_states', 'decoder_attentions', 'cross_attentions', 'encoder_last_hidden_state', 'encoder_hidden_states', 'encoder_attentions', 'loss_0', 'decoder_hidden_states_0', 'decoder_attentions_0', 'cross_attentions_0', 'encoder_hidden_states_0', 'encoder_attentions_0', 'loss_1', 'decoder_hidden_states_1', 'decoder_attentions_1', 'cross_attentions_1', 'encoder_hidden_states_1', 'encoder_attentions_1', 'loss_2', 'decoder_hidden_states_2', 'decoder_attentions_2', 'cross_attentions_2', 'encoder_hidden_states_2', 'encoder_attentions_2'])

In [43]:
self = rmt
input_ids = sample_input_ids.clone()
labels = sample['labels']
output_attentions = True
output_hidden_states=True

kwargs = dict(**sample)
kwargs['output_attentions'] = True
kwargs['output_hidden_states'] = True

memory = self.set_memory(input_ids.shape)
segmented = self.pad_and_segment(input_ids)

memories = []
base_model_outputs = []
for seg_num, segment_input_ids in enumerate(segmented):                
    if self.rmt_config['bptt_depth'] != -1:
        raise NotImplementedError

    seg_kwargs, non_empty_mask = self.prepare_kwargs(segment_input_ids, kwargs)
    if sum(non_empty_mask) == 0:
        continue
    
    seg_kwargs['inputs_embeds'][:, self.memory_position] = memory[non_empty_mask]
    out = self.model(**seg_kwargs)
    base_model_outputs.append(out)
    
    memory[non_empty_mask] = out.encoder_hidden_states[-1][:, self.memory_position]
    memories.append(torch.clone(memory))

hidden_states = torch.cat(memories[:-1] + [out.encoder_hidden_states[-1]], dim=1)
decoder_input_ids = self.model._shift_right(labels)
decoder_outputs = self.model.decoder(input_ids=decoder_input_ids, encoder_hidden_states=hidden_states, output_hidden_states=True, output_attentions=True)
base_model_outputs.append(decoder_outputs)

out = self.process_outputs(base_model_outputs, output_attentions, output_hidden_states)

sequence_output = decoder_outputs[0]
# Set device for model parallelism
if self.model.model_parallel:
    torch.cuda.set_device(self.model.encoder.first_device)
    self.model.lm_head = self.model.lm_head.to(self.model.encoder.first_device)
    sequence_output = sequence_output.to(self.model.lm_head.weight.device)

if self.model.config.tie_word_embeddings:
    # Rescale output before projecting on vocab
    # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
    sequence_output = sequence_output * (self.model.model_dim**-0.5)

lm_logits = self.model.lm_head(sequence_output)

loss = None
if labels is not None:
    loss_fct = CrossEntropyLoss(ignore_index=-100)
    loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

out['loss'] = loss

In [41]:
decoder_outputs.keys()

odict_keys(['last_hidden_state', 'past_key_values', 'hidden_states', 'attentions', 'cross_attentions'])

In [42]:
decoder_outputs.cross_attentions[-1].shape

torch.Size([2, 8, 5, 532])

In [34]:
hidden_states.shape

torch.Size([2, 532, 512])

### load dataset 

In [10]:
class Holder:
    def __init__(self):
        pass

In [11]:
input_seq_len = 1536
target_seq_len = 1024
batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''
device = 'cpu'

### Encoder-decoder

In [12]:
global_attention_first_token = False  # should be True for LED
encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}
# generate_kwargs = {'max_length': args.target_seq_len, 'min_length': args.target_seq_len}
generate_kwargs = {}

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    if 'outputs' in batch[0]:
        # if we have more than 1 label per example (only in valid) take only one of them
        # to compute loss on valid
        labels = [b['outputs'][0][:args.target_seq_len * 10] for b in batch]
    else:
        labels = [b['output'][:args.target_seq_len * 10] for b in batch]
    if args.input_prefix:
        inputs = [args.input_prefix + inp for inp in inputs]
    features = tokenizer.batch_encode_plus(list(inputs), max_length=args.input_seq_len, return_tensors='pt',
                                           **encode_plus_kwargs)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer.batch_encode_plus(list(labels), max_length=args.target_seq_len, return_tensors='pt',
                                             **encode_plus_kwargs).input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    features['labels'] = labels
    features['id'] = [b['id'] for b in batch]
    if 'outputs' in batch[0]:
        features['target_text'] = [b['outputs'] for b in batch]
    else:
        features['target_text'] = [b['output'] for b in batch]
    if 'global_attention_mask' in features:
        raise RuntimeError('What global attention mask for Longformer and LongformerEncoder-Decoder should be?')
    return features

In [13]:
task_name = 'qasper'
dataset = datasets.load_dataset('tau/scrolls', task_name)
train_dataset = dataset['train']

train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataset = dataset['validation']
valid_sampler = RandomSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/qasper/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [14]:
gen = iter(train_dataloader)
sample = next(gen)

if 'id' in sample:
    id = sample.pop('id')
if 'target_text' in sample:
    tgt_text = sample.pop('target_text')

rmt.to(device)
for k in sample:
    sample[k] = sample[k].to(device)
    
sample_input_ids = sample.pop('input_ids').to(device)
kwargs = sample

### Encoder

In [31]:
input_seq_len = 1536
target_seq_len = 3
batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''
device = 'cpu'

In [32]:
encode_plus_kwargs = {'max_length': args.input_seq_len,
                        'truncation': True,
                        'padding': 'longest',
                        'pad_to_multiple_of': 1}
generate_kwargs = {}
labels_map = {'Contradiction': 0, 'Entailment': 1, 'Not mentioned': 2}
num_labels = len(labels_map)

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    labels = [b['output'][:args.target_seq_len * 10] for b in batch]
    if args.input_prefix:
        inputs = [args.input_prefix + inp for inp in inputs]
    features = tokenizer.batch_encode_plus(list(inputs), return_tensors='pt', **encode_plus_kwargs)
    labels = np.array([labels_map[t] for t in labels])
    features['labels'] = torch.from_numpy(labels)
    return features

In [33]:
task_name = 'contract_nli'
dataset = datasets.load_dataset('tau/scrolls', task_name)
train_dataset = dataset['train']

train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataset = dataset['validation']
valid_sampler = RandomSampler(valid_dataset)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/contract_nli/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [34]:
gen = iter(train_dataloader)
sample = next(gen)

if 'id' in sample:
    id = sample.pop('id')
if 'target_text' in sample:
    tgt_text = sample.pop('target_text')

rmt.to(device)
for k in sample:
    sample[k] = sample[k].to(device)
    
sample_input_ids = sample.pop('input_ids').to(device)
kwargs = sample

In [24]:
out = rmt(sample_input_ids, **kwargs, output_hidden_states=True, output_attentions = True)
out.keys()

ModuleAttributeError: 'BertForSequenceClassification' object has no attribute 'encoder'

### replace forward signature

In [None]:
import inspect 
from functools import wraps

model_name = "google/bert_uncased_L-4_H-256_A-4"

def decorate(func, source):
    @wraps(source)
    def decorated(*args, **kwargs):
        return func(*args, **kwargs)
    return decorated

class RMT(torch.nn.Module):
    def __init__(self, base_model, **rmt_kwargs):
        super().__init__()
        self.model = base_model
        self.forward = decorate(self.forward, self.model.forward)

    def forward(self, new_rmt_arg, input_ids, **kwargs):
        pass

base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
rmt = RMT(base_model, **rmt_config)
inspect.signature(rmt.forward)

In [49]:
from functools import wraps

def wrap_func(method):
    @wraps(method)
    def _impl(self, *method_args, **method_kwargs):
        method_output = method(self, *method_args, **method_kwargs)
        return method_output
    return _impl
    # return 

class RMTEncoderForSequenceClassification(RMTBaseModel):
    def __init__(self, base_model, **rmt_kwargs):
        super().__init__(base_model, **rmt_kwargs)

    @wrap_func
    def forward(self, input_ids, **kwargs):
        memory = self.set_memory()
        memory = memory.repeat(input_ids.shape[0], 1, 1)
        segmented = self.pad_and_segment(input_ids)

        losses = []
        for seg_num, segment_input_ids in enumerate(segmented):                
            if (self.rmt_config['bptt_depth'] > -1) and (len(segmented) - seg_num > self.rmt_config['bptt_depth']): 
                memory = memory.detach()

            seg_kwargs = dict(**kwargs)
            seg_kwargs['output_hidden_states'] = True

            non_empty_mask = [s is not None for s in segment_input_ids]
            if sum(non_empty_mask) == 0:
                continue
            input_ids = torch.stack([s for s in segment_input_ids if s is not None])
            attention_mask = self.get_attention_mask(input_ids)
            token_type_ids = self.get_token_type_ids(input_ids)
            seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]

            inputs_embeds = self.model.embeddings(input_ids)
            inputs_embeds[:, self.memory_position] = memory[non_empty_mask]

            seg_kwargs['input_ids'] = None
            seg_kwargs['inputs_embeds'] = inputs_embeds
            seg_kwargs['attention_mask'] = attention_mask
            seg_kwargs['token_type_ids'] = token_type_ids

            out = self.model(**seg_kwargs)
            memory[non_empty_mask] = out.hidden_states[-1][:, self.memory_position]

            losses.append(out['loss'])

        # drop unnecessary hiddens to save memory
        if not kwargs.get('output_hidden_states'):
            for key in out.keys():
                if 'hidden_state' in key:
                    out[key] = None

        for i, l in enumerate(losses):
            out[f'loss_{i}'] = l.mean()

        if self.rmt_config['sum_loss']:
            out['loss'] = torch.stack(losses).sum(dim=0)

        return out
        
    def pad_add_special_tokens(self, tensor, segment_size):
        input_elements = []
        input_elements += [self.cls_token, self.mem_token_ids, self.sep_token, tensor, self.sep_token]
        tensor = torch.cat(input_elements)

        pad_size = segment_size - tensor.shape[0]
        if pad_size > 0:
            tensor = F.pad(tensor, (0, pad_size))
        return tensor
    
    def get_token_type_ids(self, tensor):
        return torch.zeros_like(tensor)


num_segments = 2
num_mem_tokens = 10
# device = torch.device(3)
device = 'cpu'


from rmt_utils.encoder.memory_layers import memory_layers_forward as memory_layers_func
# from rmt_utils.encoder.memory_layers import deberta_memory_layers_forward as memory_layers_func

model_name = "google/bert_uncased_L-4_H-256_A-4"
# model_name = "microsoft/deberta-v3-base"
# model_name = 'google/electra-base-discriminator'
# model_name = "google/bert_uncased_L-4_H-256_A-4"

tokenizer = AutoTokenizer.from_pretrained(model_name)

rmt_config = {'num_mem_tokens': 5, 
                'max_n_segments': 3,
               #  'segment_alignment': 'right',
                'tokenizer': tokenizer,
                'memory_layers': 'all', 
                'memory_forward_func': memory_layers_func,
                'share_memory_layers': True,
                'reconstruction_loss_coef': 1,
                'segment_ordering': 'regular',
                'input_size': 512, 
                'bptt_depth': -1, 
                'sum_loss': False,
             }

base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
rmt = RMTEncoderForSequenceClassification(base_model, **rmt_config)
# rmt = RMTEncoderMemoryLayers(base_model, **rmt_config)
# rmt = RMTEncoderMLMMemLoss(base_model, **rmt_config)
# rmt.to(device)

import inspect
inspect.signature(rmt.forward)

Some weights of the model checkpoint at google/bert_uncased_L-4_H-256_A-4 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w

<Signature (input_ids, **kwargs)>

### MLM head for input decoding 

In [140]:
self.model.bert.encoder.layer[-1]

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=256, out_features=256, bias=True)
      (key): Linear(in_features=256, out_features=256, bias=True)
      (value): Linear(in_features=256, out_features=256, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=256, out_features=256, bias=True)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=256, out_features=1024, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=1024, out_features=256, bias=True)
    (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [247]:
self.rec_attn = copy.deepcopy(self.model.base_model.encoder.layer[-1])
self.rec_cls = torch.nn.Linear(self.model.config.hidden_size, self.model.config.vocab_size)

In [146]:
rec_kwargs = dict(**kwargs)
rec_kwargs.pop('labels')
# rec_kwargs.pop('token_type')

tensor([2, 1])

In [269]:
mlm_prob = 0.15

segmented = self.pad_and_segment(sample_input_ids)
previous_input_ids = segmented[0]

inputs = torch.stack(previous_input_ids)
input_embeddings = self.model.embeddings(inputs)

out = self.model(inputs_embeds=input_embeddings, output_hidden_states=True)
memory_outputs = out['hidden_states'][-1][:, self.memory_position]


In [271]:
def reconstruction_forward(self, memory_outputs, previous_input_ids):
    
    inputs = torch.stack(previous_input_ids)
    input_embeddings = self.model.embeddings(inputs)
    input_embeddings[:, self.memory_position] = memory_outputs

    token_inds = list(range(self.num_mem_tokens + 2, input_embeddings.shape[1] - 1))
    mask_inds = np.random.choice(token_inds, round(len(token_inds) * mlm_prob))
    attention_mask = torch.ones(input_embeddings.shape[1])
    attention_mask[mask_inds] = 0

    rec_attn_out = self.rec_attn(input_embeddings)
    rec_logits = self.rec_cls(rec_attn_out[0])

    loss_fct = CrossEntropyLoss(ignore_index=-100)
    reconstruction_loss = loss_fct(rec_logits.view(-1, rec_logits.size(-1)), inputs.view(-1))
    
    return reconstruction_loss

In [272]:
reconstruction_forward(self, memory_outputs=memory_outputs, previous_input_ids=segmented[0])

tensor(10.4209, grad_fn=<NllLossBackward>)

In [254]:
token_inds = list(range(self.num_mem_tokens + 2, input_embeddings.shape[1] - 1))
mask_inds = np.random.choice(token_inds, round(len(token_inds) * mlm_prob))

In [266]:
attention_mask = torch.ones(input_embeddings.shape[1])
attention_mask[mask_inds] = 0


In [267]:
rec_attn_out = self.rec_attn(input_embeddings)
rec_logits = self.rec_cls(rec_attn_out[0])
rec_logits.shape

torch.Size([2, 512, 30527])

In [268]:
loss_fct = CrossEntropyLoss(ignore_index=-100)
reconstruction_loss = loss_fct(rec_logits.view(-1, rec_logits.size(-1)), inputs.view(-1))
reconstruction_loss

tensor(10.4209, grad_fn=<NllLossBackward>)

In [207]:
input_embeddings.shape

torch.Size([2, 512, 256])

In [239]:
token_inds = list(range(self.num_mem_tokens + 2, input_embeddings.shape[1] - 1))

In [240]:
np.random.shuffle(token_inds)
mask_inds = token_inds[: round(len(token_inds) * mlm_prob) ]

In [None]:
input_a

In [None]:
random_mask_inds = torch.randa


In [203]:
input_embeddings

tensor([[[-0.0159,  0.0027,  0.0078,  ...,  0.0175, -0.0240,  0.0109],
         [ 0.0315, -0.0055,  0.0017,  ...,  0.0132, -0.0214, -0.0158],
         [ 0.0016, -0.0042, -0.0412,  ...,  0.0192,  0.0082, -0.0007],
         ...,
         [-0.0986,  0.0014, -0.0430,  ..., -0.0016, -0.0158, -0.0046],
         [-0.0986,  0.0014, -0.0430,  ..., -0.0016, -0.0158, -0.0046],
         [-0.0986,  0.0014, -0.0430,  ..., -0.0016, -0.0158, -0.0046]],

        [[-0.0159,  0.0027,  0.0078,  ...,  0.0175, -0.0240,  0.0109],
         [ 0.0315, -0.0055,  0.0017,  ...,  0.0132, -0.0214, -0.0158],
         [ 0.0016, -0.0042, -0.0412,  ...,  0.0192,  0.0082, -0.0007],
         ...,
         [-0.0226,  0.0497,  0.0308,  ..., -0.0470, -0.0116,  0.0216],
         [-0.0381, -0.0252,  0.0037,  ...,  0.0464,  0.0336,  0.0329],
         [-0.0637, -0.0239,  0.0430,  ..., -0.0894,  0.0181,  0.0181]]],
       grad_fn=<EmbeddingBackward>)

In [148]:
rec_kwargs['token_type_ids'].shapea

torch.Size([2, 1536])

In [142]:
sample_input_ids

tensor([[ 101, 4909, 2283,  ...,    0,    0,    0],
        [ 101, 4909, 2283,  ..., 2023, 3820,  102]])

In [201]:
# from torch.nn import CrossEntropyLoss
# # def segment_reconstruction_forward(self, segmented, hidden_states):

# hidden_states = rec_kwargs['inputs_embeds']
# previous_input_ids = segmented[-2]
# non_empty_mask = [s is not None for s in previous_input_ids]
# if sum(non_empty_mask) == 0:
#     raise ValueError

# previous_input_ids = torch.stack(previous_input_ids)[non_empty_mask]
# reconstructor_input = hidden_states[non_empty_mask]

# rec_attn_out = self.rec_attn(reconstructor_input)
# rec_logits = self.rec_cls(rec_attn_out[0])

# loss_fct = CrossEntropyLoss(ignore_index=-100)
# reconstruction_loss = loss_fct(rec_logits.view(-1, rec_logits.size(-1)), previous_input_ids.view(-1))

### segment to memory attribution

In [18]:
self = rmt
input_ids = sample_input_ids

In [19]:
# memory = self.set_memory()
# memory = memory.repeat(input_ids.shape[0], 1, 1)
# segmented = self.pad_and_segment(input_ids)

# losses = {}
# memories = []
# inputs = []
# non_memory_position = [i for i in range(self.rmt_config['input_size']) if i not in self.memory_position]

# for seg_num, segment_input_ids in enumerate(segmented):
#     if (self.rmt_config['bptt_depth'] > -1) and (len(segmented) - seg_num > self.rmt_config['bptt_depth']): 
#         memory = memory.detach()

#     seg_kwargs = dict(**kwargs)
#     seg_kwargs['output_hidden_states'] = True
    
#     non_empty_mask = [s is not None for s in segment_input_ids]
#     if sum(non_empty_mask) == 0:
#         continue
#     input_ids = torch.stack(segment_input_ids)[non_empty_mask]
#     attention_mask = self.get_attention_mask(input_ids)
#     token_type_ids = self.get_token_type_ids(input_ids)
#     seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]

#     inputs_embeds = self.embeddings(input_ids)
#     inputs_embeds[:, self.memory_position] = memory[non_empty_mask]

#     seg_kwargs['inputs_embeds'] = inputs_embeds
#     seg_kwargs['attention_mask'] = attention_mask
        
#     out = self.model.forward(**seg_kwargs)

#     memory[non_empty_mask] = out.encoder_hidden_states[-1][:, self.memory_position]
    
#     memories.append(torch.clone(memory[non_empty_mask].detach()))
#     inputs.append(out.encoder_hidden_states[-1][:, non_memory_position])

#     losses[f'loss_{seg_num}'] = out['loss']

# memory_out = out.encoder_last_hidden_state[:, self.memory_position]
# reconstruction_loss = self.segment_reconstruction_forward(segmented, memory_out)
# out['reconstruction_loss'] = reconstruction_loss

# # drop unnecessary hiddens to save memory
# # if not kwargs.get('output_hidden_states'):
# #     for key in out.keys():
# #         if 'hidden_state' in key:
# #             out[key] = None
            
# for k, loss in losses.items():
#     out[k] = loss

# if self.rmt_config['sum_loss']:
#     out['loss'] = torch.stack(losses).sum(dim=0)

# rec_coef = self.rmt_config['reconstruction_loss_coef']
# out['loss'] = reconstruction_loss * rec_coef + out['loss'] * (1 - rec_coef)