In [1]:
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification, T5ForConditionalGeneration
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import datasets

import math
from matplotlib import pyplot as plt


from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

### Finetune

In [2]:
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')

In [3]:
import math
import torch
import torch.nn.functional as F

class RMTBaseModel(torch.nn.Module):
    def __init__(self, base_model, **rmt_kwargs):
        super().__init__()
        self.model = base_model
        self.set_params(**rmt_kwargs)

    def set_params(self, num_mem_tokens, tokenizer, **rmt_config):
        self.rmt_config = rmt_config
        self.extract_special_tokens(tokenizer)
        self.extend_word_embeddings(num_mem_tokens, tokenizer)

        self.segment_size = rmt_config['input_size'] - num_mem_tokens - tokenizer.num_special_tokens_to_add()
        if 'sep_token' in tokenizer.special_tokens_map:
            self.segment_size -= 1

    def set_memory(self, input_shape):
        memory = self.model.embeddings(self.mem_token_ids)
        memory = memory.repeat(input_shape[0], 1, 1)
        return memory

    def extract_special_tokens(self, tokenizer):
        self.pad_token_id = tokenizer.pad_token_id
        self.special_token_ids = [tokenizer.pad_token_id]
        for token in ['cls_token', 'sep_token', 'eos_token', 'bos_token']:
            token_id = getattr(tokenizer, f'{token}_id')
            if token_id is not None:
                self.register_buffer(token, torch.tensor([token_id]))
                self.special_token_ids.append(token_id)
            else:
                setattr(self, token, None)

    def extend_word_embeddings(self, num_mem_tokens, tokenizer):
            
        vocab_size = self.model.config.vocab_size
        extended_vocab_size = vocab_size + num_mem_tokens
        self.num_mem_tokens = num_mem_tokens
        self.register_buffer('mem_token_ids', torch.arange(vocab_size, vocab_size + num_mem_tokens))
        self.model.resize_token_embeddings(extended_vocab_size)

        special_tokens = tokenizer.special_tokens_map
        mem_start_ind = int('cls_token' in special_tokens or 'bos_token' in special_tokens)
        self.memory_position = range(mem_start_ind, mem_start_ind + num_mem_tokens)
        self.model.embeddings = self.model.get_input_embeddings()

    def forward(self, **kwargs):
       raise NotImplementedError

    def pad_and_segment(self, input_ids):
        segmented_batch = []
        for seq in input_ids:
            drop_mask = torch.any(torch.stack([seq == t for t in self.special_token_ids if t is not None]), dim=0)
            seq = seq[~drop_mask]
            seq = seq[:self.segment_size * self.rmt_config['max_n_segments']]

            align = self.rmt_config.get('segment_alignment')
            if align in {'right', None}:
                split_inds = (list(range(len(seq), 0, -self.segment_size)) + [0])[::-1]
            elif align == 'left':
                split_inds = list(range(0, len(seq), self.segment_size)) + [len(seq)]
            elif align == 'center':
                n_seg = math.ceil(len(seq) / self.segment_size)
                split_inds = list(range(0, len(seq), math.ceil(len(seq) / n_seg))) + [len(seq)]
            else:
                raise NotImplementedError

            input_segments = [seq[start:end] for (start, end) in zip(split_inds, split_inds[1:])]
            input_segments = [self.pad_add_special_tokens(t, self.rmt_config['input_size']) for t in input_segments]

            # add empty segment markers if needed
            n_empty_segments = self.rmt_config['max_n_segments'] - len(input_segments)
            input_segments = [None] * n_empty_segments + input_segments

            segmented_batch.append(input_segments)

        segmented_batch = [[sample[seg_num] for sample in segmented_batch] \
                            for seg_num in range(self.rmt_config['max_n_segments'])]
        return segmented_batch

    def pad_add_special_tokens(self, **kwargs):
        raise NotImplementedError

    def prepare_kwargs(self, segment_input_ids, kwargs):
        seg_kwargs = dict(**kwargs)
        non_empty_mask = [s is not None for s in segment_input_ids]
        if sum(non_empty_mask) == 0:
            return None, non_empty_mask
            
        input_ids = torch.stack([s for s in segment_input_ids if s is not None])
        inputs_embeds = self.model.embeddings(input_ids)

        seg_kwargs['input_ids'] = None
        seg_kwargs['inputs_embeds'] = inputs_embeds
        if seg_kwargs.get('labels') is not None:
            seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]
        seg_kwargs['attention_mask'] = self.get_attention_mask(input_ids)
        if seg_kwargs.get('token_type_ids') is not None:
            seg_kwargs['token_type_ids'] = self.get_token_type_ids(input_ids)
        seg_kwargs['output_hidden_states'] = True

        return seg_kwargs, non_empty_mask

    def process_outputs(self, model_outputs, output_attentions, output_hidden_states):
        rmt_out = model_outputs[-1]

        segment_keys = ['loss']
        if output_attentions:
            segment_keys.append('attentions')
        if output_hidden_states:
            segment_keys.append('hidden_states')

        extracted = {}
        for seg_num, out in enumerate(model_outputs):
            for key, value in out.items():
                if any([sk in key for sk in segment_keys]):
                    extracted[f'{key}_{seg_num}'] = value

        if self.rmt_config['sum_loss']:
            losses = [out['loss'] for out in model_outputs]
            extracted['loss'] = torch.stack(losses).mean(dim=0)

        for key, value in extracted.items():
            rmt_out[key] = value
        
        # drop unnecessary hiddens to save memory
        if not output_hidden_states:
            for key in rmt_out.keys():
                if 'hidden_state' in key:
                    rmt_out[key] = None

        return rmt_out 
        
    def get_token_type_ids(self, tensor):
        return torch.zeros_like(tensor)

    def get_attention_mask(self, tensor):
        mask = torch.ones_like(tensor)
        mask[tensor == self.pad_token_id] = 0
        return mask

In [4]:
import math
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

class RMTDecoderLMHeadMultiSeg(RMTBaseModel):
    def set_params(self, num_mem_tokens, tokenizer, **rmt_config):
        self.rmt_config = rmt_config
        self.extract_special_tokens(tokenizer)
        self.create_memory(num_mem_tokens)

        self.segment_size = rmt_config['input_size'] - 2 * num_mem_tokens - tokenizer.num_special_tokens_to_add()
        if 'sep_token' in tokenizer.special_tokens_map:
            self.segment_size -= 1

    def create_memory(self, num_mem_tokens):
        self.num_mem_tokens = num_mem_tokens
        embeddings = self.model.get_input_embeddings()
        memory_weights = torch.randn((num_mem_tokens, self.model.config.n_embd)) * embeddings.weight.data.std()
        self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))

        self.read_memory_position = range(num_mem_tokens)
        self.write_memory_position = range(-num_mem_tokens, 0)

    def set_memory(self, input_shape):
        create_memory = self.training \
                        or self.rmt_config.get('reinit_mem_each_fwd') \
                        or not hasattr(self, 'memory_state') \
                        or self.rmt_config['max_n_segments'] == 1 
        if create_memory:
            memory = self.memory.repeat(input_shape[0], 1, 1)
        else:
            memory = self.memory_state[:input_shape[0]]
        return memory
    
    def detach_memory(self, seg_num):
        k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments')
        if seg_num == 0 \
            or k2 in {-1, None} \
            or seg_num + k2 > max_n_segments:
            return False
        return True
    
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                inputs_embeds=None, labels=None, labels_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        kwargs = {'attention_mask': attention_mask, 'token_type_ids': token_type_ids,
                  'position_ids': position_ids, 'inputs_embeds': inputs_embeds,
                  'labels_mask': labels_mask, #'pos_weight': pos_weight,
                  'labels': labels, 'output_attentions': output_attentions,
                  'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
                  }

        memory = self.set_memory(input_ids.shape)
        segmented = self.pad_and_segment(input_ids, labels)

        base_model_outputs = []
        for seg_num, segment in enumerate(zip(*segmented)):

            seg_kwargs, non_empty_mask = self.prepare_kwargs(segment, memory, kwargs)
            if sum(non_empty_mask) == 0:
                continue
            
            if self.detach_memory(seg_num):
                memory = memory.detach()

            seg_kwargs['inputs_embeds'][:, self.read_memory_position] = memory[non_empty_mask]
            seg_kwargs['inputs_embeds'][:, self.write_memory_position] = memory[non_empty_mask]
            out = self.model(**seg_kwargs)
            base_model_outputs.append(out)
            
            memory[non_empty_mask] = out.hidden_states[-1][:, self.write_memory_position]

        self.memory_state = memory
        # return (base_model_outputs, kwargs)
        out = self.process_outputs(base_model_outputs, kwargs)
        return out
    
    def pad_and_segment(self, input_ids, labels=None):
        segmented_batch = []
        segmented_batch_labels = []

        if labels is None:
            labels = [None] * input_ids.shape[0]
        batch_labels = labels
        for seq, labels in zip(input_ids, batch_labels):

            align = self.rmt_config.get('segment_alignment')
            if align in {'right', None}:
                split_inds = (list(range(len(seq), 0, -self.segment_size)) + [0])[::-1]
            elif align == 'left':
                split_inds = list(range(0, len(seq), self.segment_size)) + [len(seq)]
            elif align == 'center':
                n_seg = math.ceil(len(seq) / self.segment_size)
                split_inds = list(range(0, len(seq), math.ceil(len(seq) / n_seg))) + [len(seq)]
            else:
                raise NotImplementedError
            input_segments = [seq[start:end] for (start, end) in zip(split_inds, split_inds[1:])]
            # add empty segment markers if needed
            n_empty_segments = self.rmt_config['max_n_segments'] - len(input_segments)
            input_segments = [None] * n_empty_segments + input_segments
            segmented_batch.append(input_segments)

            if labels is not None:
                labels_segments = [labels[start:end] for (start, end) in zip(split_inds, split_inds[1:])]
                labels_segments = [None] * n_empty_segments + labels_segments
                segmented_batch_labels.append(labels_segments)

        segmented_batch = [[sample[seg_num] for sample in segmented_batch]
                           for seg_num in range(self.rmt_config['max_n_segments'])]
        segmented_batch_labels = [[sample[seg_num] for sample in segmented_batch_labels]
                                  for seg_num in range(self.rmt_config['max_n_segments'])]

        return segmented_batch, segmented_batch_labels

    def prepare_kwargs(self, segment, memory, kwargs):
        segment_input_ids, segment_labels = segment
        seg_kwargs = dict(**kwargs)
        non_empty_mask = [s is not None for s in segment_input_ids]
        if sum(non_empty_mask) == 0:
            return None, non_empty_mask

        input_ids = torch.stack([s for s in segment_input_ids if s is not None])
        inputs_embeds = self.model.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([memory, inputs_embeds, memory], dim=1)

        seg_kwargs['input_ids'] = None
        seg_kwargs['inputs_embeds'] = inputs_embeds
        seg_kwargs['attention_mask'] = self.get_attention_mask(inputs_embeds)
        seg_kwargs['output_hidden_states'] = True
        if seg_kwargs['labels'] is not None:
            labels = torch.stack([el for el, m in zip(segment_labels, non_empty_mask) if m])
            memory_labels = torch.ones((labels.shape[0], self.num_mem_tokens), dtype=labels.dtype, device=labels.device) * -100
            seg_kwargs['labels'] = torch.cat((memory_labels, labels, memory_labels), dim=1)
            seg_kwargs['labels'][:, self.num_mem_tokens] = -100
        seg_kwargs.pop('labels_mask')

        return seg_kwargs, non_empty_mask
    
    def get_attention_mask(self, tensor):
        mask = torch.ones(*tensor.shape[:2], dtype=torch.int64).to(tensor.device)
        mask[tensor == self.pad_token_id] = 0
        return mask
    
    def process_outputs(self, model_outputs, kwargs):
        if self.num_mem_tokens in {0, None}:
            full_logits = torch.cat([o.logits for o in model_outputs], dim=1)
            truncated_hs = [[lh for lh in o.hidden_states] for o in model_outputs]
        else:    
            full_logits = torch.cat([o.logits[:, self.num_mem_tokens:-self.num_mem_tokens] for o in model_outputs], dim=1)
            truncated_hs = [[lh[:, self.num_mem_tokens:-self.num_mem_tokens] for lh in o.hidden_states] for o in model_outputs]
        full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*truncated_hs)])

        rmt_out = CausalLMOutputWithCrossAttentions()
        full_labels = kwargs.get('labels')
        if full_labels is not None:
            shift_labels = full_labels[..., 1:].contiguous()
            shift_logits = full_logits[..., :-1, :].contiguous()
            flat_labels = shift_labels.view(-1)
            flat_logits = shift_logits.view(-1, shift_logits.size(-1))
            
            loss_fct = CrossEntropyLoss()
            labels_mask = kwargs.get('labels_mask')
            if labels_mask is not None:
                shift_mask = labels_mask[..., :-1].contiguous()
                flat_labels = flat_labels[shift_mask.view(-1)]
                flat_logits = flat_logits[shift_mask.view(-1)]
                
            rmt_out['loss'] = loss_fct(flat_logits, flat_labels)

        rmt_out['logits'] = full_logits
        segment_keys = ['loss', 'logits']
        if kwargs.get('output_attentions'):
            segment_keys.append('attentions')
        if kwargs.get('output_hidden_states'):
            segment_keys.append('hidden_states')
            rmt_out['hidden_states'] = full_hidden_states

        for seg_num, out in enumerate(model_outputs):
            for key, value in out.items():
                if any([sk in key for sk in segment_keys]):
                    rmt_out[f'{key}_{seg_num}'] = value

        return rmt_out 

In [5]:
import torch

from typing import List, Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

def xl_memory_forward(self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        rmt_parent = None
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    if input_ids is not None and inputs_embeds is not None:
        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    elif input_ids is not None:
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])
        batch_size = input_ids.shape[0]
    elif inputs_embeds is not None:
        input_shape = inputs_embeds.size()[:-1]
        batch_size = inputs_embeds.shape[0]
    else:
        raise ValueError("You have to specify either input_ids or inputs_embeds")

    device = input_ids.device if input_ids is not None else inputs_embeds.device

    if token_type_ids is not None:
        token_type_ids = token_type_ids.view(-1, input_shape[-1])
    if position_ids is not None:
        position_ids = position_ids.view(-1, input_shape[-1])

    if past_key_values is None:
        past_length = 0
        past_key_values = tuple([None] * len(self.h))
    else:
        past_length = past_key_values[0][0].size(-2)
    if position_ids is None:
        position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

    # GPT2Attention mask.
    if attention_mask is not None:
        if batch_size <= 0:
            raise ValueError("batch_size has to be defined and > 0")
        attention_mask = attention_mask.view(batch_size, -1)
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        attention_mask = attention_mask[:, None, None, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        attention_mask = (1.0 - attention_mask) * -10000.0

    # If a 2D or 3D attention mask is provided for the cross-attention
    # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
    if self.config.add_cross_attention and encoder_hidden_states is not None:
        encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
        encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
        if encoder_attention_mask is None:
            encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
        encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
    else:
        encoder_attention_mask = None

    # Prepare head mask if needed
    # 1.0 in head_mask indicate we keep the head
    # attention_probs has shape bsz x n_heads x N x N
    # head_mask has shape n_layer x batch x n_heads x N x N
    head_mask = self.get_head_mask(head_mask, self.config.n_layer)

    if inputs_embeds is None:
        inputs_embeds = self.wte(input_ids)
    position_embeds = self.wpe(position_ids)
    hidden_states = inputs_embeds + position_embeds

    if token_type_ids is not None:
        token_type_embeds = self.wte(token_type_ids)
        hidden_states = hidden_states + token_type_embeds

    hidden_states = self.drop(hidden_states)

    output_shape = input_shape + (hidden_states.size(-1),)

    presents = () if use_cache else None
    all_self_attentions = () if output_attentions else None
    all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
    all_hidden_states = () if output_hidden_states else None
    
    # init xl cache
    xl_cache = rmt_parent.memory_storage['xl_cache']
    xl_cache_size = rmt_parent.rmt_config['xl_cache_size']

    if 0 in xl_cache and attention_mask is not None:
        layer_attention_mask = torch.cat((attention_mask[:, :, :, :xl_cache[0].shape[1]], attention_mask), dim=-1)
    else:
        layer_attention_mask = attention_mask
    for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):

        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(hidden_states.device)
            # Ensure layer_past is on same device as hidden_states (might not be correct)
            if layer_past is not None:
                layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
            # Ensure that attention_mask is always on the same device as hidden_states
            if attention_mask is not None:
                attention_mask = attention_mask.to(hidden_states.device)
            if isinstance(head_mask, torch.Tensor):
                head_mask = head_mask.to(hidden_states.device)
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if self.gradient_checkpointing and self.training:

            # if use_cache:
            #     logger.warning(
            #         "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
            #     )
            #     use_cache = False

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    # None for past_key_value
                    return module(*inputs, use_cache, output_attentions)

                return custom_forward

            outputs = torch.utils.checkpoint.checkpoint(
                create_custom_forward(block),
                hidden_states,
                None,
                attention_mask,
                head_mask[i],
                encoder_hidden_states,
                encoder_attention_mask,
            )
        else:
            if i in xl_cache:
                layer_cache = xl_cache[i]
                non_empty_mask = rmt_parent.memory_storage['non_empty_mask']
                if non_empty_mask is not None:
                    layer_cache = layer_cache[non_empty_mask]
                hidden_states = torch.cat((layer_cache, hidden_states), dim=1)

            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=layer_attention_mask,
                head_mask=head_mask[i],
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )

        hidden_states = outputs[0]

        if i in xl_cache:
            hidden_states = hidden_states[:, xl_cache[i].shape[1]:]

        # update layer cache
        xl_cache[i] = hidden_states
        if rmt_parent.num_mem_tokens not in {0, None}:
            xl_cache[i] = xl_cache[i][:, :-rmt_parent.num_mem_tokens]
        xl_cache[i] = xl_cache[i][:, -xl_cache_size:].detach()

        if use_cache is True:
            presents = presents + (outputs[1],)

        if output_attentions:
            all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
            if self.config.add_cross_attention:
                all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)

        # Model Parallel: If it's the last layer for that device, put things on the next device
        if self.model_parallel:
            for k, v in self.device_map.items():
                if i == v[-1] and "cuda:" + str(k) != self.last_device:
                    hidden_states = hidden_states.to("cuda:" + str(k + 1))

    hidden_states = self.ln_f(hidden_states)

    hidden_states = hidden_states.view(output_shape)
    # Add last hidden state
    if output_hidden_states:
        all_hidden_states = all_hidden_states + (hidden_states,)

    if not return_dict:
        return tuple(
            v
            for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
            if v is not None
        )

    return BaseModelOutputWithPastAndCrossAttentions(
        last_hidden_state=hidden_states,
        past_key_values=presents,
        hidden_states=all_hidden_states,
        attentions=all_self_attentions,
        cross_attentions=all_cross_attentions,
    )

In [6]:
import types
import copy
import re
class RMTDecoderXLCache(RMTDecoderLMHeadMultiSeg):
    def set_params(self, num_mem_tokens, tokenizer, **rmt_config):
        super().set_params(num_mem_tokens, tokenizer, **rmt_config)
        self.override_encoder_forward(rmt_config.get('xl_forward_func'))
        if rmt_config.get('xl_cache_size'):
            self.segment_size -= rmt_config['xl_cache_size']

    def override_encoder_forward(self, xl_forward_func):
        if xl_forward_func is None:
            from rmt_utils.decoder.transformer_xl import xl_forward
            xl_forward_func = xl_forward
        new_forward = lambda *args, **kwargs: xl_forward_func(*args, **kwargs, rmt_parent=self)
        self.model.base_model.forward = types.MethodType(new_forward, self.model.base_model)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                inputs_embeds=None, labels=None, labels_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        kwargs = {'attention_mask': attention_mask, 'token_type_ids': token_type_ids,
                  'position_ids': position_ids, 'inputs_embeds': inputs_embeds,
                  'labels_mask': labels_mask, #'pos_weight': pos_weight,
                  'labels': labels, 'output_attentions': output_attentions,
                  'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
                  }

        memory = self.set_memory(input_ids.shape)
        segmented = self.pad_and_segment(input_ids, labels)

        base_model_outputs = []
        self.memory_storage = {'xl_cache': dict(), 'non_empty_mask': None}
        for seg_num, segment in enumerate(zip(*segmented)):

            seg_kwargs, non_empty_mask = self.prepare_kwargs(segment, memory, kwargs)
            if sum(non_empty_mask) == 0:
                continue
            
            if self.detach_memory(seg_num):
                memory = memory.detach()

            seg_kwargs['inputs_embeds'][:, self.read_memory_position] = memory[non_empty_mask]
            seg_kwargs['inputs_embeds'][:, self.write_memory_position] = memory[non_empty_mask]
            out = self.model(**seg_kwargs)
            base_model_outputs.append(out)
            
            self.memory_storage['non_empty_mask'] = non_empty_mask
            memory[non_empty_mask] = out.hidden_states[-1][:, self.write_memory_position]

        self.memory_state = memory

        out = self.process_outputs(base_model_outputs, kwargs)
        return out


# -----------------------------------------------------------------------------------

In [7]:
# num_segments = 1
# num_mem_tokens = 2
# device = torch.device(3)
device = 'cpu'

In [36]:
from transformers import AutoModel
model_name = 'gpt2'

tokenizer = AutoTokenizer.from_pretrained(model_name)

rmt_config = {'num_mem_tokens': 0, 
                'max_n_segments': 1,
               #  'segment_alignment': 'right',
                'tokenizer': tokenizer,
               #  'memory_layers': 'all', 
               #  'share_memory_layers': True,
               #  'reconstruction_loss_coef': 0.1,
               #  'offset_position_ids': True,
                'xl_cache_size': 64,
                'xl_forward_func': xl_memory_forward,
                'k1': -1, 'k2': 3,
                'segment_ordering': 'regular',
                'input_size': 128, 
                'bptt_depth': -1, 
                'sum_loss': False,
             }

# base_model = AutoModelForCausalLM.from_pretrained(model_name)
# rmt = RMTDecoderLMHeadMultiSeg(base_model, **rmt_config)
rmt = RMTDecoderXLCache(base_model, **rmt_config)




# イ

In [None]:
for n, p in rmt.named_parameters():
    if 'memory' in n:
        print(n)

memory


In [35]:
rmt.eval()
# rmt.train()
1

1

In [27]:
input_ids = batch['input_ids'].clone()#[:, :512]
labels = batch['labels'].clone()#[:, :512]

labels_mask = batch['labels_mask']
# rmt_out_0 = rmt(input_ids, labels=labels, labels_mask=labels_mask, output_hidden_states=True, output_attentions=True)
# rmt_out.loss
kwargs = dict(**batch)

In [28]:
segmented = rmt.pad_and_segment(input_ids, labels)

In [37]:
len(segmented[0]), segmented[0][0][0].shape, segmented[0][1][0].shape, segmented[0][2][0].shape

AttributeError: 'NoneType' object has no attribute 'shape'

In [38]:
# rmt.memory_storage['xl_cache'][0].shape

In [39]:
# rmt.memory_storage['xl_cache'][11].shape

In [30]:
# rmt_out = rmt(**kwargs, output_hidden_states=True)
base_model_outputs, kw = rmt(**kwargs, output_hidden_states=True)

In [31]:
len(base_model_outputs)

4

In [32]:
kwargs['labels'].shape

torch.Size([2, 256])

In [33]:
sum([l.logits.shape[1] for l in base_model_outputs])

256

In [34]:
sum([l.logits.shape[1] for l in base_model_outputs])

256

In [160]:
sum([l.hidden_states[0].shape[1] for l in base_model_outputs])

312

In [161]:
[l.hidden_states[0].shape[1] for l in base_model_outputs]

[56, 64, 64, 64, 64]

In [157]:
base_model_outputs[0].hidden_states[0].shape

torch.Size([2, 56, 768])

In [153]:
sum([l.logits.shape[1] for l in base_model_outputs])

312

In [None]:
kwarg

In [118]:
base_model_outputs[0].logits.shape

torch.Size([2, 28, 50257])

In [119]:
base_model_outputs[1].logits.shape

torch.Size([2, 68, 50257])

In [120]:
base_model_outputs[-1].logits.shape

torch.Size([2, 68, 50257])

In [111]:
rmt_out.keys()

odict_keys(['loss', 'logits', 'hidden_states', 'loss_0', 'logits_0', 'hidden_states_0', 'loss_1', 'logits_1', 'hidden_states_1', 'loss_2', 'logits_2', 'hidden_states_2', 'loss_3', 'logits_3', 'hidden_states_3', 'loss_4', 'logits_4', 'hidden_states_4'])

In [114]:
rmt.memory_storage['xl_cache'][0].shape

torch.Size([2, 10, 768])

In [43]:
len(rmt_out.hidden_states), rmt_out.hidden_states[0].shape

(13, torch.Size([2, 504, 768]))

In [49]:
len(rmt_out.hidden_states_0), rmt_out.hidden_states_0[0].shape

(13, torch.Size([2, 12, 768]))

In [51]:
xl_cache_size = 8

In [31]:
[(key, rmt_out[key]) for key in rmt_out if 'loss' in key]

[('loss', tensor(14.0483, grad_fn=<NllLossBackward>)),
 ('loss_0', tensor(5.7048, grad_fn=<NllLossBackward>)),
 ('loss_1', tensor(11.0391, grad_fn=<NllLossBackward>)),
 ('loss_2', tensor(11.2370, grad_fn=<NllLossBackward>)),
 ('loss_3', tensor(13.3765, grad_fn=<NllLossBackward>)),
 ('loss_4', tensor(14.0265, grad_fn=<NllLossBackward>))]

In [None]:
[(key, rmt_out[key]) for key in rmt_out if 'loss' in key]

[('loss', tensor(11.3409, grad_fn=<NllLossBackward>)),
 ('loss_0', tensor(2.9731, grad_fn=<NllLossBackward>)),
 ('loss_1', tensor(2.6483, grad_fn=<NllLossBackward>)),
 ('loss_2', tensor(2.8106, grad_fn=<NllLossBackward>)),
 ('loss_3', tensor(2.6017, grad_fn=<NllLossBackward>))]

In [186]:
cpt_path = "../../runs/lm_long/wikitext-2-v1/gpt2/lr5e-05_linear_adamw_wd1e-03_630-128-5x128_mem1_bs32_regular_bptt-5_from_cpt_4-5/run_1/model_best.pth"
cpt = torch.load(cpt_path, map_location='cpu')
rmt.load_state_dict(cpt['model_state_dict'])

<All keys matched successfully>

In [115]:
rmt_out.logits.shape, len(rmt_out.hidden_states), rmt_out.hidden_states[0].shape

(torch.Size([1, 51, 50257]), 13, torch.Size([1, 51, 768]))

### load dataset 

In [10]:
class Holder:
    def __init__(self):
        pass
    
max_n_segments = 4
input_size = 128
num_mem_tokens = 32

input_seq_len = target_seq_len = max_n_segments * (input_size - 2 * num_mem_tokens)
batch_size = 2

args = Holder
args.max_n_segments = max_n_segments
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''
args.block_size = None

device = 'cpu'

### Decoder

In [11]:
from itertools import chain
def get_lm_datasets(raw_datasets, tokenizer, block_size):
    # copy-pasted from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py

    column_names = raw_datasets["train"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples):
        return tokenizer(examples[text_column_name])

    # with accelerator.main_process_first():
    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        # num_proc=args.preprocessing_num_workers,
        remove_columns=column_names,
        # load_from_cache_file=not args.overwrite_cache,
        desc="Running tokenizer on dataset",
    )

    # if args.block_size is None:
    # block_size = tokenizer.model_max_length
    # if block_size > 1024:
    #     logger.warning(
    #         "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
    #         " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
    #         " override this default with `--block_size xxx`."
    #     )
    # block_size = 1024
    # else:
    #     if args.block_size > tokenizer.model_max_length:
    #         logger.warning(
    #             f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
    #             f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
    #         )
    #     block_size = min(args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    # with accelerator.main_process_first():
    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        # num_proc=args.preprocessing_num_workers,
        # load_from_cache_file=not args.overwrite_cache,
        desc=f"Grouping texts in chunks of {block_size}",
    )

    train_dataset = lm_datasets["train"]
    valid_dataset = lm_datasets["validation"]

    return train_dataset, valid_dataset

In [12]:
# from lm_experiments_tools.lm_datasets import load_dataset
raw_datasets = datasets.load_dataset('wikitext', 'wikitext-2-v1')
# train_dataset, _ = get_lm_datasets(raw_datasets, tokenizer, block_size=args.input_seq_len)
# _, valid_dataset = get_lm_datasets(raw_datasets, tokenizer, block_size=input_size - 2 * num_mem_tokens)

Found cached dataset wikitext (/home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

In [13]:
block_size = input_size - 2 * num_mem_tokens
history_size = args.input_seq_len - block_size

column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on dataset",
)

def group_texts(examples, block_size, history_size=None):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    if history_size is None:
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    else:
        result = {
            k: [t[max({0, i - history_size}) : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    result["labels"] = result["input_ids"].copy()
    return result


train_dataset = tokenized_datasets["train"].map(lambda x: group_texts(x, block_size, history_size), 
                                        batched=True, desc=f"Grouping train in chunks of {block_size} and history {history_size}")
valid_dataset = tokenized_datasets["validation"].map(lambda x: group_texts(x, block_size), 
                                        batched=True, desc=f"Grouping valid in chunks of {block_size}")
test_dataset = tokenized_datasets["test"].map(lambda x: group_texts(x, block_size), 
                                        batched=True, desc=f"Grouping test in chunks of {block_size}")

Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-389b922bfc5fe729.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-6067a66e735cfbb1.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-941845a5470f2db7.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-867d7b23bfc3f054.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-fdadcdfe8a04e3dc.a

In [14]:
len(train_dataset)

37869

In [15]:
len(train_dataset[0]['input_ids']), len(valid_dataset[0]['input_ids']), len(test_dataset[0]['input_ids'])

(64, 64, 64)

In [16]:
len(train_dataset[1]['input_ids']), len(valid_dataset[1]['input_ids']), len(test_dataset[1]['input_ids'])

(128, 64, 64)

In [17]:
len(train_dataset[2]['input_ids']), len(valid_dataset[2]['input_ids']), len(test_dataset[2]['input_ids'])

(192, 64, 64)

In [18]:
len(train_dataset[-1]['input_ids']), len(valid_dataset[-1]['input_ids']), len(test_dataset[-1]['input_ids'])

(219, 51, 58)

In [19]:
from torch.nn.utils.rnn import pad_sequence
id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
def collate_fn(batch):
    input_ids = [torch.tensor(b['input_ids'][::-1]) for b in batch]
    labels = [torch.tensor(b['labels'][::-1]) for b in batch]
    attention_mask = [torch.tensor(b['attention_mask'][::-1]) for b in batch]
    input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T.flip(1)
    labels = pad_sequence(labels, padding_value=-100).T.flip(1)
    attention_mask = pad_sequence(attention_mask, padding_value=0).T.flip(1)

    collated = {'input_ids': input_ids,
                'labels': labels, 
                'attention_mask': attention_mask}
    
    if input_ids.shape[1] != block_size:
        labels_mask = torch.ones_like(input_ids, dtype=bool)
        labels_mask[:, :-block_size] = False
        collated['labels_mask'] = labels_mask
    
    return collated

In [20]:
# dataloader for RMT
# batch sample i is a continuation of sample i of the previous batch
class alignedDataLoader(DataLoader):
    def __iter__(self):
        # all_inds = np.arange(len(self.dataset))
        all_inds = np.arange(len(self.dataset) // self.batch_size * batch_size)
        all_inds = all_inds.reshape(batch_size, -1)
        for batch_ind in range(all_inds.shape[1]):
            batch = [self.dataset[int(ind)] for ind in all_inds[:, batch_ind]]
            yield self.collate_fn(batch)


kwargs = {'pin_memory': True, 'num_workers': 1}

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, 
                                # sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataloader = alignedDataLoader(valid_dataset, batch_size=batch_size, 
                            # sampler=valid_sampler,
                                collate_fn=collate_fn, drop_last=True, **kwargs)

In [21]:
src = [train_dataset[i] for i in range(2)]

In [22]:
gen = iter(train_dataloader)
batch = next(gen)

In [23]:
batch = next(gen)
batch = next(gen)
batch = next(gen)


In [24]:
valid_gen = iter(valid_dataloader)
valid_batch = next(valid_gen)

In [25]:
valid_batch['input_ids'].shape

torch.Size([2, 64])

In [26]:
batch['labels']

tensor([[  319,   569, 18354,  7496, 17740,  2873,   764,  2893,   340, 17383,
           262,  3210,  3033,   286,   262,  2168,   837,   340,   635, 25289,
          3294, 16895,   837,   884,   355,  1642,   262,   983,   517,  1279,
          2954,    29,   329,  2168, 29661,   764, 15684, 11915,  1279,  2954,
            29,  8835,    73,   280,   290, 26777,  7286, 13704, 13231, 43354,
          1111,  4504,   422,  2180, 12784,   837,  1863,   351,   569, 18354,
          7496, 17740,  2873,  3437, 33687,  5303, 18024,  6909,   764,   317,
          1588,  1074,   286,  8786, 12118,   262,  4226,   764,   383,   983,
           705,    82,  4756,  7505,   373, 23568,   416,  1737,   705,    77,
           764,   220,   198,   632,  1138,   351,  3967,  4200,   287,  2869,
           837,   290,   373, 15342,   416,  1111,  4960,   290,  8830,  9188,
           764,  2293,  2650,   837,   340,  2722, 41496,  2695,   837,  1863,
           351,   281,  9902,  8313,   287,  3389,  

In [27]:
batch['labels'][:, 1004:]

tensor([[  796,   569, 18354,  ...,   416,  5609,   511],
        [ 8686,  4282,   764,  ...,  2478,  2233,   284]])

In [25]:
batch['labels_mask'][:, 1004:]

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [119]:
tokenizer.decode(batch['input_ids'][:, 1004:][0])

' = Valkyria Chronicles III = \n Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3, lit. Valkyria of the Battlefield 3 ), commonly referred to as Valkyria Chronicles III outside Japan, is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable. Released in January 2011 in Japan, it is the third game in the Valkyria series. <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " Nameless ", a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven ". \n The game began development in 2010, carrying over a large portion of the work done on Valkyria Chronicles II. While it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more <unk> for series newcomers. Character designer <

In [53]:
batch = next(gen)


In [54]:
batch['input_ids']

tensor([[8686, 4282,  764,  ...,  262, 2656,  837],
        [ 262, 3859, 1445,  ..., 3670,  373, 7867]])

In [162]:
gen = iter(valid_dataloader)
batch = next(gen)

In [163]:
tokenizer.decode(batch['input_ids'][0][-50:])

'-@ eastern Atlantic Ocean from northern Norway to the Azores and Morocco, not including the Baltic Sea. It is also present in most of the Mediterranean Sea, only missing from the section east of Crete, and along only the north @-@'

In [164]:
tokenizer.decode(batch['input_ids'][1][-50:])

' the <unk> and <unk> powers of nearly every one, <unk> their opponents... \n The side left <unk> and travelled north where they played <unk> in Masterton. The match was won 10 – 8, and'

In [165]:
batch = next(gen)

In [166]:
tokenizer.decode(batch['input_ids'][0][:50])

' west coast of the Black Sea. The <unk> populations are found in the Norwegian <unk> <unk> and <unk>, inside the Arctic Circle. \n The species can be divided into four genetically distinct populations, one widespread population,'

In [168]:
tokenizer.decode(batch['input_ids'][1][:50])

' the next day they faced Wellington, who they also defeated. The fixture against Wellington was nearly abandoned because Scott and the Wellington Rugby Union could not agree on a venue ; the match went ahead only when the Wellington officials agreed to cede the <unk>'

In [21]:
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [127]:
batch['input_ids'][0][:10]

tensor([  796,   569, 18354,  7496, 17740,  6711,   796,   220,   198,  2311])

In [128]:
batch['labels'][0][:10]

tensor([  796,   569, 18354,  7496, 17740,  6711,   796,   220,   198,  2311])

In [130]:
train_dataset[0].keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [131]:
train_dataset[0]['input_ids'][:10]

[796, 569, 18354, 7496, 17740, 6711, 796, 220, 198, 2311]

In [132]:
train_dataset[0]['labels'][:10]

[796, 569, 18354, 7496, 17740, 6711, 796, 220, 198, 2311]

In [18]:
batch['input_ids'].shape

torch.Size([2, 2008])

In [19]:
tokenizer.decode(batch['input_ids'][0])

' = Valkyria Chronicles III = \n Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3, lit. Valkyria of the Battlefield 3 ), commonly referred to as Valkyria Chronicles III outside Japan, is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable. Released in January 2011 in Japan, it is the third game in the Valkyria series. <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " Nameless ", a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven ". \n The game began development in 2010, carrying over a large portion of the work done on Valkyria Chronicles II. While it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more <unk> for series newcomers. Character designer <

In [77]:
for batch in valid_dataloader:
    if batch['input_ids'].shape[0] != 2:
        print(batch['input_ids'].shape)

In [13]:
batch.keys(), batch['input_ids'].shape

(dict_keys(['input_ids', 'attention_mask', 'labels']), torch.Size([2, 1014]))

In [21]:
out = base_model(**batch)

KeyboardInterrupt: 