In [1]:
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification, T5ForConditionalGeneration
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import datasets

import math
from matplotlib import pyplot as plt


from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

### Finetune

In [2]:
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')

In [3]:
import math
import torch
import torch.nn.functional as F

class RMTBaseModel(torch.nn.Module):
    def __init__(self, base_model, **rmt_kwargs):
        super().__init__()
        self.model = base_model
        self.set_params(**rmt_kwargs)

    def set_params(self, num_mem_tokens, tokenizer, **rmt_config):
        self.rmt_config = rmt_config
        self.extract_special_tokens(tokenizer)
        self.extend_word_embeddings(num_mem_tokens, tokenizer)

        self.segment_size = rmt_config['input_size'] - num_mem_tokens - tokenizer.num_special_tokens_to_add()
        if 'sep_token' in tokenizer.special_tokens_map:
            self.segment_size -= 1

    def set_memory(self, input_shape):
        memory = self.model.embeddings(self.mem_token_ids)
        memory = memory.repeat(input_shape[0], 1, 1)
        return memory

    def extract_special_tokens(self, tokenizer):
        self.pad_token_id = tokenizer.pad_token_id
        self.special_token_ids = [tokenizer.pad_token_id]
        for token in ['cls_token', 'sep_token', 'eos_token', 'bos_token']:
            token_id = getattr(tokenizer, f'{token}_id')
            if token_id is not None:
                self.register_buffer(token, torch.tensor([token_id]))
                self.special_token_ids.append(token_id)
            else:
                setattr(self, token, None)

    def extend_word_embeddings(self, num_mem_tokens, tokenizer):
            
        vocab_size = self.model.config.vocab_size
        extended_vocab_size = vocab_size + num_mem_tokens
        self.num_mem_tokens = num_mem_tokens
        self.register_buffer('mem_token_ids', torch.arange(vocab_size, vocab_size + num_mem_tokens))
        self.model.resize_token_embeddings(extended_vocab_size)

        special_tokens = tokenizer.special_tokens_map
        mem_start_ind = int('cls_token' in special_tokens or 'bos_token' in special_tokens)
        self.memory_position = range(mem_start_ind, mem_start_ind + num_mem_tokens)
        self.model.embeddings = self.model.get_input_embeddings()

    def forward(self, **kwargs):
       raise NotImplementedError

    def pad_and_segment(self, input_ids):
        segmented_batch = []
        for seq in input_ids:
            drop_mask = torch.any(torch.stack([seq == t for t in self.special_token_ids if t is not None]), dim=0)
            seq = seq[~drop_mask]
            seq = seq[:self.segment_size * self.rmt_config['max_n_segments']]

            align = self.rmt_config.get('segment_alignment')
            if align in {'right', None}:
                split_inds = (list(range(len(seq), 0, -self.segment_size)) + [0])[::-1]
            elif align == 'left':
                split_inds = list(range(0, len(seq), self.segment_size)) + [len(seq)]
            elif align == 'center':
                n_seg = math.ceil(len(seq) / self.segment_size)
                split_inds = list(range(0, len(seq), math.ceil(len(seq) / n_seg))) + [len(seq)]
            else:
                raise NotImplementedError

            input_segments = [seq[start:end] for (start, end) in zip(split_inds, split_inds[1:])]
            input_segments = [self.pad_add_special_tokens(t, self.rmt_config['input_size']) for t in input_segments]

            # add empty segment markers if needed
            n_empty_segments = self.rmt_config['max_n_segments'] - len(input_segments)
            input_segments = [None] * n_empty_segments + input_segments

            segmented_batch.append(input_segments)

        segmented_batch = [[sample[seg_num] for sample in segmented_batch] \
                            for seg_num in range(self.rmt_config['max_n_segments'])]
        return segmented_batch

    def pad_add_special_tokens(self, **kwargs):
        raise NotImplementedError

    def prepare_kwargs(self, segment_input_ids, kwargs):
        seg_kwargs = dict(**kwargs)
        non_empty_mask = [s is not None for s in segment_input_ids]
        if sum(non_empty_mask) == 0:
            return None, non_empty_mask
            
        input_ids = torch.stack([s for s in segment_input_ids if s is not None])
        inputs_embeds = self.model.embeddings(input_ids)

        seg_kwargs['input_ids'] = None
        seg_kwargs['inputs_embeds'] = inputs_embeds
        if seg_kwargs.get('labels') is not None:
            seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]
        seg_kwargs['attention_mask'] = self.get_attention_mask(input_ids)
        if seg_kwargs.get('token_type_ids') is not None:
            seg_kwargs['token_type_ids'] = self.get_token_type_ids(input_ids)
        seg_kwargs['output_hidden_states'] = True

        return seg_kwargs, non_empty_mask

    def process_outputs(self, model_outputs, output_attentions, output_hidden_states):
        rmt_out = model_outputs[-1]

        segment_keys = ['loss']
        if output_attentions:
            segment_keys.append('attentions')
        if output_hidden_states:
            segment_keys.append('hidden_states')

        extracted = {}
        for seg_num, out in enumerate(model_outputs):
            for key, value in out.items():
                if any([sk in key for sk in segment_keys]):
                    extracted[f'{key}_{seg_num}'] = value

        if self.rmt_config['sum_loss']:
            losses = [out['loss'] for out in model_outputs]
            extracted['loss'] = torch.stack(losses).mean(dim=0)

        for key, value in extracted.items():
            rmt_out[key] = value
        
        # drop unnecessary hiddens to save memory
        if not output_hidden_states:
            for key in rmt_out.keys():
                if 'hidden_state' in key:
                    rmt_out[key] = None

        return rmt_out 
        
    def get_token_type_ids(self, tensor):
        return torch.zeros_like(tensor)

    def get_attention_mask(self, tensor):
        mask = torch.ones_like(tensor)
        mask[tensor == self.pad_token_id] = 0
        return mask

In [4]:
# class RMTDecoderLMHead(RMTBaseModel):
#     def set_params(self, num_mem_tokens, tokenizer, **rmt_config):
#         self.rmt_config = rmt_config
#         self.extract_special_tokens(tokenizer)
#         self.create_memory(num_mem_tokens)

#         self.segment_size = rmt_config['input_size'] - 2 * num_mem_tokens - tokenizer.num_special_tokens_to_add()
#         if 'sep_token' in tokenizer.special_tokens_map:
#             self.segment_size -= 1

#     def create_memory(self, num_mem_tokens):
#         self.num_mem_tokens = num_mem_tokens
#         embeddings = self.model.get_input_embeddings()
#         memory_weights = torch.randn((num_mem_tokens, self.model.config.n_embd)) * embeddings.weight.data.std()
#         self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))

#         self.read_memory_position = range(num_mem_tokens)
#         self.write_memory_position = range(-num_mem_tokens, 0)

#     def set_memory(self, input_shape):
#         memory = self.memory.repeat(input_shape[0], 1, 1)
#         return memory

#     def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
#                 inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None):
#         kwargs = {'attention_mask': attention_mask, 'token_type_ids': token_type_ids,
#                   'position_ids': position_ids, 'inputs_embeds': inputs_embeds,
#                   'labels': labels, 'output_attentions': output_attentions,
#                   'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
#                   }

#         if not hasattr(self, 'memory_states') or self.memory_states is None:
#             init_memory = self.set_memory(input_ids.shape)
#             self.memory_states = [(None, init_memory)]
        
#         memory = self.memory_states[-1][1].detach()#.to(input_ids.device)
#         memory.requires_grad = True

#         segment_input_ids = self.pad_and_segment(input_ids)[0]
#         seg_kwargs, non_empty_mask = self.prepare_kwargs(segment_input_ids, memory, kwargs)
        
#         labels = seg_kwargs.pop('labels')
#         out = self.model(**seg_kwargs)
        
#         new_memory = out.hidden_states[-1][:, self.write_memory_position]
#         self.memory_states.append((memory, new_memory))
#         self.trim_memory_states()

#         ### Calculate loss excluding memory 
#         lm_logits = out.logits[:, self.num_mem_tokens:-self.num_mem_tokens]
#         # Shift so that tokens < n predict n
#         shift_logits = lm_logits[..., :-1, :].contiguous()
#         shift_labels = labels[..., 1:].contiguous()
#         # Flatten the tokens
#         loss_fct = CrossEntropyLoss()
#         out['loss'] = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

#         return out

#     def pad_add_special_tokens(self, tensor, segment_size):
#         # pad_size = segment_size - tensor.shape[0]
#         # if pad_size > 0:
#         #     tensor = F.pad(tensor, (0, pad_size))
#         return tensor
    
#     def prepare_kwargs(self, segment_input_ids, memory, kwargs):
#         seg_kwargs = dict(**kwargs)
#         non_empty_mask = [s is not None for s in segment_input_ids]
#         if sum(non_empty_mask) == 0:
#             return None, non_empty_mask
            
#         input_ids = torch.stack([s for s in segment_input_ids if s is not None])
#         inputs_embeds = self.model.get_input_embeddings()(input_ids)
#         inputs_embeds = torch.cat([memory, inputs_embeds, memory], dim=1)

#         seg_kwargs['input_ids'] = None
#         seg_kwargs['inputs_embeds'] = inputs_embeds
#         if seg_kwargs.get('labels') is not None:
#             seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]
#         seg_kwargs['attention_mask'] = self.get_attention_mask(inputs_embeds)
#         # if seg_kwargs.get('token_type_ids') is not None:
#         #     seg_kwargs['token_type_ids'] = self.get_token_type_ids(inputs_embeds)
#         seg_kwargs['output_hidden_states'] = True

#         return seg_kwargs, non_empty_mask
    
#     def get_attention_mask(self, tensor):
#         mask = torch.ones(*tensor.shape[:2], dtype=torch.int64).to(tensor.device)
#         mask[tensor == self.pad_token_id] = 0
#         return mask

#     def train(self, *args, **kwargs):
#         self.memory_states = None
#         super().train(*args, **kwargs)

#     def eval(self, *args, **kwargs):
#         self.memory_states = None
#         super().eval(*args, **kwargs)

#     def trim_memory_states(self):
#         k2 = self.rmt_config.get('k2')
#         if not k2 or k2 == -1:
#             return 
#         while len(self.memory_states) > k2:
#             del self.memory_states[0]

#     def truncated_backward(self, k1, k2):
#         memory_states = self.memory_states
#         if k1 != -1:
#             raise NotImplementedError
        
#         for i in range(k2 - 1 if k2 != -1 else len(memory_states)):
#             curr_grad = memory_states[-i-1][0].grad
#             memory_states[-i-2][1].backward(curr_grad, retain_graph=False)#k2>2)

#             # if we get all the way back to the "init_memory", stop
#             if memory_states[-i-2][0] is None:
#                 break

In [5]:
# from torch.nn import CrossEntropyLoss
# from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

# class RMTDecoderLMHeadMultiSeg(RMTBaseModel):
#     def set_params(self, num_mem_tokens, tokenizer, **rmt_config):
#         self.rmt_config = rmt_config
#         self.extract_special_tokens(tokenizer)
#         self.create_memory(num_mem_tokens)

#         self.segment_size = rmt_config['input_size'] - 2 * num_mem_tokens - tokenizer.num_special_tokens_to_add()
#         if 'sep_token' in tokenizer.special_tokens_map:
#             self.segment_size -= 1

#     def create_memory(self, num_mem_tokens):
#         self.num_mem_tokens = num_mem_tokens
#         embeddings = self.model.get_input_embeddings()
#         memory_weights = torch.randn((num_mem_tokens, self.model.config.n_embd)) * embeddings.weight.data.std()
#         self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))

#         self.read_memory_position = range(num_mem_tokens)
#         self.write_memory_position = range(-num_mem_tokens, 0)

#     def set_memory(self, input_shape):
#         if self.training or not hasattr(self, 'memory_state'):
#             memory = self.memory.repeat(input_shape[0], 1, 1)
#         else:
#             memory = self.memory_state[:input_shape[0]]
#         return memory
    
#     def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
#                 inputs_embeds=None, labels=None, labels_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None):
#         kwargs = {'attention_mask': attention_mask, 'token_type_ids': token_type_ids,
#                   'position_ids': position_ids, 'inputs_embeds': inputs_embeds,
#                   'labels_mask': labels_mask, #'pos_weight': pos_weight,
#                   'labels': labels, 'output_attentions': output_attentions,
#                   'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
#                   }

#         memory = self.set_memory(input_ids.shape)
#         segmented = self.pad_and_segment(input_ids, labels)

#         base_model_outputs = []
#         for seg_num, segment in enumerate(zip(*segmented)):                
#             if self.rmt_config['bptt_depth'] != -1:
#                 raise NotImplementedError

#             seg_kwargs, non_empty_mask = self.prepare_kwargs(segment, memory, kwargs)
#             if sum(non_empty_mask) == 0:
#                 continue
            
#             seg_kwargs['inputs_embeds'][:, self.read_memory_position] = memory[non_empty_mask]
#             seg_kwargs['inputs_embeds'][:, self.write_memory_position] = memory[non_empty_mask]
#             out = self.model(**seg_kwargs)
#             base_model_outputs.append(out)
            
#             memory[non_empty_mask] = out.hidden_states[-1][:, self.write_memory_position]

#         self.memory_state = memory
#         if self.training:
#             del self.memory_state
#         return base_model_outputs
#         out = self.process_outputs(base_model_outputs, kwargs)
#         return out
    
#     def pad_and_segment(self, input_ids, labels=None):
#         segmented_batch = []
#         segmented_batch_labels = []

#         if labels is None:
#             labels = [None] * input_ids.shape[0]
#         batch_labels = labels

#         for seq, labels in zip(input_ids, batch_labels):
#             # if seq.shape[0] != self.segment_size * self.rmt_config['max_n_segments']:
#             #     raise(ValueError(f"Inputs shape {input_ids.shape} does not match {self.segment_size}*{self.rmt_config['max_n_segments']}"))
#             # if labels is not None:
#             #     labels = labels[:self.segment_size * self.rmt_config['max_n_segments']]

#             align = self.rmt_config.get('segment_alignment')
#             if align in {'right', None}:
#                 split_inds = (list(range(len(seq), 0, -self.segment_size)) + [0])[::-1]
#             elif align == 'left':
#                 split_inds = list(range(0, len(seq), self.segment_size)) + [len(seq)]
#             elif align == 'center':
#                 n_seg = math.ceil(len(seq) / self.segment_size)
#                 split_inds = list(range(0, len(seq), math.ceil(len(seq) / n_seg))) + [len(seq)]
#             else:
#                 raise NotImplementedError

#             input_segments = [seq[start:end] for (start, end) in zip(split_inds, split_inds[1:])]
#             # add empty segment markers if needed
#             n_empty_segments = self.rmt_config['max_n_segments'] - len(input_segments)
#             input_segments = [None] * n_empty_segments + input_segments
#             segmented_batch.append(input_segments)

#             if labels is not None:
#                 labels_segments = [labels[start:end] for (start, end) in zip(split_inds, split_inds[1:])]
#                 labels_segments = [None] * n_empty_segments + labels_segments
#                 segmented_batch_labels.append(labels_segments)

#         segmented_batch = [[sample[seg_num] for sample in segmented_batch]
#                            for seg_num in range(self.rmt_config['max_n_segments'])]
#         segmented_batch_labels = [[sample[seg_num] for sample in segmented_batch_labels]
#                                   for seg_num in range(self.rmt_config['max_n_segments'])]

#         return segmented_batch, segmented_batch_labels

#     def prepare_kwargs(self, segment, memory, kwargs):
#         segment_input_ids, segment_labels = segment
#         seg_kwargs = dict(**kwargs)
#         non_empty_mask = [s is not None for s in segment_input_ids]
#         if sum(non_empty_mask) == 0:
#             return None, non_empty_mask

#         input_ids = torch.stack([s for s in segment_input_ids if s is not None])
#         inputs_embeds = self.model.get_input_embeddings()(input_ids)
#         inputs_embeds = torch.cat([memory, inputs_embeds, memory], dim=1)

#         seg_kwargs['input_ids'] = None
#         seg_kwargs['inputs_embeds'] = inputs_embeds
#         seg_kwargs['attention_mask'] = self.get_attention_mask(inputs_embeds)
#         seg_kwargs['output_hidden_states'] = True
#         if seg_kwargs['labels'] is not None:
#             labels = torch.stack([el for el, m in zip(segment_labels, non_empty_mask) if m])
#             memory_labels = torch.ones((labels.shape[0], self.num_mem_tokens), dtype=labels.dtype, device=labels.device) * -100
#             seg_kwargs['labels'] = torch.cat((memory_labels, labels, memory_labels), dim=1)
#             seg_kwargs['labels'][:, self.num_mem_tokens] = -100
#         seg_kwargs.pop('labels_mask')

#         return seg_kwargs, non_empty_mask
    
#     def get_attention_mask(self, tensor):
#         mask = torch.ones(*tensor.shape[:2], dtype=torch.int64).to(tensor.device)
#         mask[tensor == self.pad_token_id] = 0
#         return mask
    
#     def process_outputs(self, model_outputs, kwargs):
#         full_logits = torch.cat([o.logits[:, self.num_mem_tokens:-self.num_mem_tokens] for o in model_outputs], dim=1)
#         truncated_hs = [[lh[:, self.num_mem_tokens:-self.num_mem_tokens] for lh in o.hidden_states] for o in model_outputs]
#         full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*truncated_hs)])
#         full_labels = kwargs.get('labels')

#         rmt_out = CausalLMOutputWithCrossAttentions()
#         if kwargs.get('labels') is not None:
#             # Shift so that tokens < n predict n
#             shift_logits = full_logits[..., :-1, :].contiguous()
#             shift_labels = full_labels[..., 1:].contiguous()

#             # if full_labels.shape[1] != self.segment_size * self.rmt_config['max_n_segments']:
#             #     raise(ValueError(f"Labels shape {full_labels.shape} does not match {self.segment_size}*{self.rmt_config['max_n_segments']}"))
#             # Flatten the tokens
#             loss_fct = CrossEntropyLoss(reduction='none')
#             loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
#             labels_mask = kwargs.get('labels_mask')
#             if labels_mask is None:
#                 rmt_out['loss'] = loss.mean()
#             else:
#                 shift_mask = labels_mask[..., 1:].contiguous()
#                 rmt_out['loss'] = loss[shift_mask.view(-1)].mean()

#         rmt_out['logits'] = full_logits
#         segment_keys = ['loss', 'logits']
#         if kwargs.get('output_attentions'):
#             segment_keys.append('attentions')
#         if kwargs.get('output_hidden_states'):
#             segment_keys.append('hidden_states')
#             rmt_out['hidden_states'] = full_hidden_states

#         for seg_num, out in enumerate(model_outputs):
#             for key, value in out.items():
#                 if any([sk in key for sk in segment_keys]):
#                     rmt_out[f'{key}_{seg_num}'] = value

#         return rmt_out 

In [193]:
import math
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

class RMTDecoderLMHeadMultiSeg(RMTBaseModel):
    def set_params(self, num_mem_tokens, tokenizer, **rmt_config):
        self.rmt_config = rmt_config
        self.extract_special_tokens(tokenizer)
        self.create_memory(num_mem_tokens)

        self.segment_size = rmt_config['input_size'] - 2 * num_mem_tokens - tokenizer.num_special_tokens_to_add()
        if 'sep_token' in tokenizer.special_tokens_map:
            self.segment_size -= 1

    def create_memory(self, num_mem_tokens):
        self.num_mem_tokens = num_mem_tokens
        embeddings = self.model.get_input_embeddings()
        memory_weights = torch.randn((num_mem_tokens, self.model.config.n_embd)) * embeddings.weight.data.std()
        self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))

        self.read_memory_position = range(num_mem_tokens)
        self.write_memory_position = range(-num_mem_tokens, 0)

    def set_memory(self, input_shape):
        create_memory = self.training \
                        or self.rmt_config.get('reinit_mem_each_fwd') \
                        or not hasattr(self, 'memory_state') \
                        or self.rmt_config['max_n_segments'] == 1 
        if create_memory:
            memory = self.memory.repeat(input_shape[0], 1, 1)
        else:
            memory = self.memory_state[:input_shape[0]]
        return memory
    
    def detach_memory(self, seg_num):
        k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments')
        if seg_num == 0 \
            or k2 in {-1, None} \
            or seg_num + k2 > max_n_segments:
            return False
        return True
    
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                inputs_embeds=None, labels=None, labels_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        kwargs = {'attention_mask': attention_mask, 'token_type_ids': token_type_ids,
                  'position_ids': position_ids, 'inputs_embeds': inputs_embeds,
                  'labels_mask': labels_mask, #'pos_weight': pos_weight,
                  'labels': labels, 'output_attentions': output_attentions,
                  'output_hidden_states': output_hidden_states, 'return_dict': return_dict,
                  }

        memory = self.set_memory(input_ids.shape)
        segmented = self.pad_and_segment(input_ids, labels)

        base_model_outputs = []
        for seg_num, segment in enumerate(zip(*segmented)):

            seg_kwargs, non_empty_mask = self.prepare_kwargs(segment, memory, kwargs)
            if sum(non_empty_mask) == 0:
                continue
            
            if self.detach_memory(seg_num):
                memory = memory.detach()

            seg_kwargs['inputs_embeds'][:, self.read_memory_position] = memory[non_empty_mask]
            seg_kwargs['inputs_embeds'][:, self.write_memory_position] = memory[non_empty_mask]
            out = self.model(**seg_kwargs)
            base_model_outputs.append(out)
            
            memory[non_empty_mask] = out.hidden_states[-1][:, self.write_memory_position]

        self.memory_state = memory
        # return (base_model_outputs, kwargs)
        out = self.process_outputs(base_model_outputs, kwargs)
        return out
    
    def pad_and_segment(self, input_ids, labels=None):
        segmented_batch = []
        segmented_batch_labels = []

        if labels is None:
            labels = [None] * input_ids.shape[0]
        batch_labels = labels
        for seq, labels in zip(input_ids, batch_labels):

            align = self.rmt_config.get('segment_alignment')
            if align in {'right', None}:
                split_inds = (list(range(len(seq), 0, -self.segment_size)) + [0])[::-1]
            elif align == 'left':
                split_inds = list(range(0, len(seq), self.segment_size)) + [len(seq)]
            elif align == 'center':
                n_seg = math.ceil(len(seq) / self.segment_size)
                split_inds = list(range(0, len(seq), math.ceil(len(seq) / n_seg))) + [len(seq)]
            else:
                raise NotImplementedError
            input_segments = [seq[start:end] for (start, end) in zip(split_inds, split_inds[1:])]
            # add empty segment markers if needed
            n_empty_segments = self.rmt_config['max_n_segments'] - len(input_segments)
            input_segments = [None] * n_empty_segments + input_segments
            segmented_batch.append(input_segments)

            if labels is not None:
                labels_segments = [labels[start:end] for (start, end) in zip(split_inds, split_inds[1:])]
                labels_segments = [None] * n_empty_segments + labels_segments
                segmented_batch_labels.append(labels_segments)

        segmented_batch = [[sample[seg_num] for sample in segmented_batch]
                           for seg_num in range(self.rmt_config['max_n_segments'])]
        segmented_batch_labels = [[sample[seg_num] for sample in segmented_batch_labels]
                                  for seg_num in range(self.rmt_config['max_n_segments'])]

        return segmented_batch, segmented_batch_labels

    def prepare_kwargs(self, segment, memory, kwargs):
        segment_input_ids, segment_labels = segment
        seg_kwargs = dict(**kwargs)
        non_empty_mask = [s is not None for s in segment_input_ids]
        if sum(non_empty_mask) == 0:
            return None, non_empty_mask

        input_ids = torch.stack([s for s in segment_input_ids if s is not None])
        inputs_embeds = self.model.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([memory, inputs_embeds, memory], dim=1)

        seg_kwargs['input_ids'] = None
        seg_kwargs['inputs_embeds'] = inputs_embeds
        seg_kwargs['attention_mask'] = self.get_attention_mask(inputs_embeds)
        seg_kwargs['output_hidden_states'] = True
        if seg_kwargs['labels'] is not None:
            labels = torch.stack([el for el, m in zip(segment_labels, non_empty_mask) if m])
            memory_labels = torch.ones((labels.shape[0], self.num_mem_tokens), dtype=labels.dtype, device=labels.device) * -100
            seg_kwargs['labels'] = torch.cat((memory_labels, labels, memory_labels), dim=1)
            seg_kwargs['labels'][:, self.num_mem_tokens] = -100
        seg_kwargs.pop('labels_mask')

        return seg_kwargs, non_empty_mask
    
    def get_attention_mask(self, tensor):
        mask = torch.ones(*tensor.shape[:2], dtype=torch.int64).to(tensor.device)
        mask[tensor == self.pad_token_id] = 0
        return mask
    
    def process_outputs(self, model_outputs, kwargs):
        full_logits = torch.cat([o.logits[:, self.num_mem_tokens:-self.num_mem_tokens] for o in model_outputs], dim=1)
        truncated_hs = [[lh[:, self.num_mem_tokens:-self.num_mem_tokens] for lh in o.hidden_states] for o in model_outputs]
        full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*truncated_hs)])

        rmt_out = CausalLMOutputWithCrossAttentions()
        full_labels = kwargs.get('labels')
        if full_labels is not None:
            shift_labels = full_labels[..., 1:].contiguous()
            shift_logits = full_logits[..., :-1, :].contiguous()
            flat_labels = shift_labels.view(-1)
            flat_logits = shift_logits.view(-1, shift_logits.size(-1))
            
            loss_fct = CrossEntropyLoss()
            labels_mask = kwargs.get('labels_mask')
            if labels_mask is not None:
                shift_mask = labels_mask[..., :-1].contiguous()
                flat_labels = flat_labels[shift_mask.view(-1)]
                flat_logits = flat_logits[shift_mask.view(-1)]
                
            rmt_out['loss'] = loss_fct(flat_logits, flat_labels)

        rmt_out['logits'] = full_logits
        segment_keys = ['loss', 'logits']
        if kwargs.get('output_attentions'):
            segment_keys.append('attentions')
        if kwargs.get('output_hidden_states'):
            segment_keys.append('hidden_states')
            rmt_out['hidden_states'] = full_hidden_states

        for seg_num, out in enumerate(model_outputs):
            for key, value in out.items():
                if any([sk in key for sk in segment_keys]):
                    rmt_out[f'{key}_{seg_num}'] = value

        return rmt_out 

# -----------------------------------------------------------------------------------

In [194]:
# num_segments = 1
# num_mem_tokens = 2
# device = torch.device(3)
device = 'cpu'

In [195]:
from transformers import AutoModel
model_name = 'gpt2'

tokenizer = AutoTokenizer.from_pretrained(model_name)

rmt_config = {'num_mem_tokens': 1, 
                'max_n_segments': 5,
               #  'segment_alignment': 'right',
                'tokenizer': tokenizer,
               #  'memory_layers': 'all', 
               #  'share_memory_layers': True,
               #  'reconstruction_loss_coef': 0.1,
                'offset_position_ids': True,
                'k1': -1, 'k2': 3,
                'segment_ordering': 'regular',
                'input_size': 128, 
                'bptt_depth': -1, 
                'sum_loss': False,
             }

# base_model = AutoModelForCausalLM.from_pretrained(model_name)
rmt = RMTDecoderLMHeadMultiSeg(base_model, **rmt_config)
# rmt2 = RMTEncoderForTokenClassification(base_model, **rmt_config)

# base_model3 = AutoModelForCausalLM.from_pretrained(model_name)
# rmt_offset = RMTDecoderLMHead(base_model3, **rmt_config)


# 籴

In [167]:
for n, p in rmt.named_parameters():
    if 'memory' in n:
        print(n)

memory


In [168]:
rmt.eval()
rmt.train()
1

1

In [240]:
input_ids = batch['input_ids'].clone()#[:, :512]
labels = batch['labels'].clone()#[:, :512]

labels_mask = batch['labels_mask']
# rmt_out_0 = rmt(input_ids, labels=labels, labels_mask=labels_mask, output_hidden_states=True, output_attentions=True)
# rmt_out.loss
kwargs = dict(**batch)

In [242]:
rmt_out = rmt(**kwargs)

In [243]:
[(key, rmt_out[key]) for key in rmt_out if 'loss' in key]

[('loss', tensor(2.7514, grad_fn=<NllLossBackward>)),
 ('loss_0', tensor(2.5994, grad_fn=<NllLossBackward>)),
 ('loss_1', tensor(2.7056, grad_fn=<NllLossBackward>)),
 ('loss_2', tensor(2.5862, grad_fn=<NllLossBackward>)),
 ('loss_3', tensor(2.7514, grad_fn=<NllLossBackward>))]

In [None]:
[(key, rmt_out[key]) for key in rmt_out if 'loss' in key]

[('loss', tensor(11.3409, grad_fn=<NllLossBackward>)),
 ('loss_0', tensor(2.9731, grad_fn=<NllLossBackward>)),
 ('loss_1', tensor(2.6483, grad_fn=<NllLossBackward>)),
 ('loss_2', tensor(2.8106, grad_fn=<NllLossBackward>)),
 ('loss_3', tensor(2.6017, grad_fn=<NllLossBackward>))]

In [186]:
cpt_path = "../../runs/lm_long/wikitext-2-v1/gpt2/lr5e-05_linear_adamw_wd1e-03_630-128-5x128_mem1_bs32_regular_bptt-5_from_cpt_4-5/run_1/model_best.pth"
cpt = torch.load(cpt_path, map_location='cpu')
rmt.load_state_dict(cpt['model_state_dict'])

<All keys matched successfully>

In [154]:
# input_ids = valid_batch['input_ids'].clone()#[:, :512]
# labels = valid_batch['labels'].clone()#[:, :512]
# labels_mask = None
# kwargs = dict(**valid_batch)
# kwargs['labels_mask'] = None

In [155]:
input_ids.shape

torch.Size([2, 432])

In [170]:
def detach_memory(self, seg_num):
    k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments')
    if seg_num == 0 \
        or k2 in {-1, None} \
        or seg_num + k2 > max_n_segments:
        return False
    return True

In [171]:
self = rmt
memory = self.set_memory(input_ids.shape)
segmented = self.pad_and_segment(input_ids, labels)

base_model_outputs = []
for seg_num, segment in enumerate(zip(*segmented)):                

    seg_kwargs, non_empty_mask = self.prepare_kwargs(segment, memory, kwargs)
    if sum(non_empty_mask) == 0:
        # raise(ValueError('nonemptymask is zero'))
        continue
    
    print('seg_num, k2, detach_memory', seg_num, self.rmt_config.get('k2'), detach_memory(self, seg_num))
    if detach_memory(self, seg_num):
        memory = memory.detach()

    seg_kwargs['inputs_embeds'][:, self.read_memory_position] = memory[non_empty_mask]
    seg_kwargs['inputs_embeds'][:, self.write_memory_position] = memory[non_empty_mask]
    out = self.model(**seg_kwargs)
    base_model_outputs.append(out)
    
    memory[non_empty_mask] = out.hidden_states[-1][:, self.write_memory_position]

self.memory_state = memory
# return (base_model_outputs, kwargs)
out = self.process_outputs(base_model_outputs, kwargs)
# return out

seg_num, k2, detach_memory 0 3 False
seg_num, k2, detach_memory 1 3 True
seg_num, k2, detach_memory 2 3 False
seg_num, k2, detach_memory 3 3 False


In [172]:
out.loss.backward()

In [53]:
memory.shape

torch.Size([2, 10, 768])

In [175]:
base_model_outputs[0].loss, base_model_outputs[1].loss, base_model_outputs[2].loss, base_model_outputs[3].loss

(tensor(4.9269, grad_fn=<NllLossBackward>),
 tensor(10.3476, grad_fn=<NllLossBackward>),
 tensor(10.0047, grad_fn=<NllLossBackward>),
 tensor(11.6276, grad_fn=<NllLossBackward>))

In [127]:
segments = [torch.ones(8) * i for i in range(1, 4)]
memory = torch.nn.Parameter(torch.ones(2) * 5, requires_grad=True)
ws = [torch.nn.Parameter(torch.ones(10) * 2, requires_grad=True)] * 3
# w = torch.nn.Parameter(torch.ones(10) * 2, requires_grad=True)

In [128]:
mem = memory.repeat(1)

outputs = []
losses = []
k2 = 1
for i, (seg, w) in enumerate(zip(segments, ws)):
    print('i, k2, max_n_segments', i,  k2, len(segments),  i + k2 <= len(segments) and i != 0)
    if i + k2 <= len(segments) and i != 0:
        mem = mem.detach()
        # inp = torch.cat((mem.detach(), seg))
    # else:
        # inp = torch.cat((mem, seg))
    inp = torch.cat((mem, seg))
    out = inp.T + w
    mem = out[:2]
    loss = out.mean()
    
    outputs.append(out)
    losses.append(loss)

total_loss = torch.cat(outputs)[-10:].mean()

i, k2, max_n_segments 0 1 3 False
i, k2, max_n_segments 1 1 3 True
i, k2, max_n_segments 2 1 3 True


In [129]:
[w.grad for w in ws]

[None, None, None]

In [130]:
total_loss.backward()

In [120]:
# k2 = 3
[w.grad for w in ws], memory.grad

([tensor([0.3000, 0.3000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000]),
  tensor([0.3000, 0.3000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000]),
  tensor([0.3000, 0.3000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000])],
 tensor([0.1000, 0.1000]))

In [126]:
# k2 = 2
[w.grad for w in ws], memory.grad

([tensor([0.2000, 0.2000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000]),
  tensor([0.2000, 0.2000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000]),
  tensor([0.2000, 0.2000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000])],
 tensor([0., 0.]))

In [131]:
# k2 = 1
[w.grad for w in ws], memory.grad

([tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000]),
  tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000]),
  tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000])],
 tensor([0., 0.]))

In [248]:
# # k2 = 0
# [w.grad for w in ws]

In [196]:
mem.grad, memory.grad

  mem.grad, memory.grad


(None, tensor([0.1000, 0.1000]))

In [162]:
segments

[tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([2., 2., 2., 2., 2., 2., 2., 2.]),
 tensor([3., 3., 3., 3., 3., 3., 3., 3.])]

In [159]:
ws[0].shape

torch.Size([10])

In [114]:
rmt_out.keys()

odict_keys(['loss', 'logits', 'hidden_states', 'loss_0', 'logits_0', 'hidden_states_0', 'attentions_0'])

In [115]:
rmt_out.logits.shape, len(rmt_out.hidden_states), rmt_out.hidden_states[0].shape

(torch.Size([1, 51, 50257]), 13, torch.Size([1, 51, 768]))

### load dataset 

In [234]:
class Holder:
    def __init__(self):
        pass
    
max_n_segments = 4
input_size = 128
num_mem_tokens = 1

input_seq_len = target_seq_len = max_n_segments * (input_size - 2 * num_mem_tokens)
batch_size = 2

args = Holder
args.max_n_segments = max_n_segments
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.input_prefix = ''
args.block_size = None

device = 'cpu'

### Decoder

In [235]:
from lm_experiments_tools.lm_datasets import *
raw_datasets = datasets.load_dataset('wikitext', 'wikitext-2-v1')
# train_dataset, _ = get_lm_datasets(raw_datasets, tokenizer, block_size=args.input_seq_len)
# _, valid_dataset = get_lm_datasets(raw_datasets, tokenizer, block_size=input_size - 2 * num_mem_tokens)

Found cached dataset wikitext (/home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

In [236]:
block_size = input_size - 2 * num_mem_tokens
history_size = args.input_seq_len - block_size

column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on dataset",
)

def group_texts(examples, block_size, history_size=None):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    if history_size is None:
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    else:
        result = {
            k: [t[max({0, i - history_size}) : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    result["labels"] = result["input_ids"].copy()
    return result


train_dataset = tokenized_datasets["train"].map(lambda x: group_texts(x, block_size, history_size), 
                                        batched=True, desc=f"Grouping train in chunks of {block_size} and history {history_size}")
valid_dataset = tokenized_datasets["validation"].map(lambda x: group_texts(x, block_size), 
                                        batched=True, desc=f"Grouping valid in chunks of {block_size}")
test_dataset = tokenized_datasets["test"].map(lambda x: group_texts(x, block_size), 
                                        batched=True, desc=f"Grouping test in chunks of {block_size}")

Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-389b922bfc5fe729.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-6067a66e735cfbb1.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-941845a5470f2db7.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-b4e1ff576225fa59.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-3f80d227394578e9.a

In [237]:
len(train_dataset)

19244

In [215]:
len(train_dataset[0]['input_ids']), len(valid_dataset[0]['input_ids']), len(test_dataset[0]['input_ids'])

(126, 126, 126)

In [216]:
len(train_dataset[1]['input_ids']), len(valid_dataset[1]['input_ids']), len(test_dataset[1]['input_ids'])

(252, 126, 126)

In [217]:
len(train_dataset[2]['input_ids']), len(valid_dataset[2]['input_ids']), len(test_dataset[2]['input_ids'])

(378, 126, 126)

In [218]:
len(train_dataset[-1]['input_ids']), len(valid_dataset[-1]['input_ids']), len(test_dataset[-1]['input_ids'])

(385, 37, 8)

In [219]:
from torch.nn.utils.rnn import pad_sequence
id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
def collate_fn(batch):
    input_ids = [torch.tensor(b['input_ids'][::-1]) for b in batch]
    labels = [torch.tensor(b['labels'][::-1]) for b in batch]
    attention_mask = [torch.tensor(b['attention_mask'][::-1]) for b in batch]
    input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T.flip(1)
    labels = pad_sequence(labels, padding_value=-100).T.flip(1)
    attention_mask = pad_sequence(attention_mask, padding_value=0).T.flip(1)

    collated = {'input_ids': input_ids,
                'labels': labels, 
                'attention_mask': attention_mask}
    
    if input_ids.shape[1] != block_size:
        labels_mask = torch.ones_like(input_ids, dtype=bool)
        labels_mask[:, :-block_size] = False
        collated['labels_mask'] = labels_mask
    
    return collated

In [220]:
# dataloader for RMT
# batch sample i is a continuation of sample i of the previous batch
class alignedDataLoader(DataLoader):
    def __iter__(self):
        # all_inds = np.arange(len(self.dataset))
        all_inds = np.arange(len(self.dataset) // self.batch_size * batch_size)
        all_inds = all_inds.reshape(batch_size, -1)
        for batch_ind in range(all_inds.shape[1]):
            batch = [self.dataset[int(ind)] for ind in all_inds[:, batch_ind]]
            yield self.collate_fn(batch)


kwargs = {'pin_memory': True, 'num_workers': 1}

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, 
                                # sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

valid_dataloader = alignedDataLoader(valid_dataset, batch_size=batch_size, 
                            # sampler=valid_sampler,
                                collate_fn=collate_fn, drop_last=True, **kwargs)

In [221]:
src = [train_dataset[i] for i in range(2)]

In [223]:
gen = iter(train_dataloader)
batch = next(gen)

In [238]:
batch = next(gen)
batch = next(gen)
batch = next(gen)


In [31]:
valid_gen = iter(valid_dataloader)
valid_batch = next(valid_gen)

In [36]:
valid_batch['input_ids'].shape

torch.Size([2, 108])

In [26]:
batch['labels']

tensor([[ -100,  -100,  -100,  ...,   416,  5609,   511],
        [  796,   569, 18354,  ...,  2478,  2233,   284]])

In [27]:
batch['labels'][:, 1004:]

tensor([[  796,   569, 18354,  ...,   416,  5609,   511],
        [ 8686,  4282,   764,  ...,  2478,  2233,   284]])

In [25]:
batch['labels_mask'][:, 1004:]

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [119]:
tokenizer.decode(batch['input_ids'][:, 1004:][0])

' = Valkyria Chronicles III = \n Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3, lit. Valkyria of the Battlefield 3 ), commonly referred to as Valkyria Chronicles III outside Japan, is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable. Released in January 2011 in Japan, it is the third game in the Valkyria series. <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " Nameless ", a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven ". \n The game began development in 2010, carrying over a large portion of the work done on Valkyria Chronicles II. While it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more <unk> for series newcomers. Character designer <

In [53]:
batch = next(gen)


In [54]:
batch['input_ids']

tensor([[8686, 4282,  764,  ...,  262, 2656,  837],
        [ 262, 3859, 1445,  ..., 3670,  373, 7867]])

In [162]:
gen = iter(valid_dataloader)
batch = next(gen)

In [163]:
tokenizer.decode(batch['input_ids'][0][-50:])

'-@ eastern Atlantic Ocean from northern Norway to the Azores and Morocco, not including the Baltic Sea. It is also present in most of the Mediterranean Sea, only missing from the section east of Crete, and along only the north @-@'

In [164]:
tokenizer.decode(batch['input_ids'][1][-50:])

' the <unk> and <unk> powers of nearly every one, <unk> their opponents... \n The side left <unk> and travelled north where they played <unk> in Masterton. The match was won 10 – 8, and'

In [165]:
batch = next(gen)

In [166]:
tokenizer.decode(batch['input_ids'][0][:50])

' west coast of the Black Sea. The <unk> populations are found in the Norwegian <unk> <unk> and <unk>, inside the Arctic Circle. \n The species can be divided into four genetically distinct populations, one widespread population,'

In [168]:
tokenizer.decode(batch['input_ids'][1][:50])

' the next day they faced Wellington, who they also defeated. The fixture against Wellington was nearly abandoned because Scott and the Wellington Rugby Union could not agree on a venue ; the match went ahead only when the Wellington officials agreed to cede the <unk>'

In [21]:
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [127]:
batch['input_ids'][0][:10]

tensor([  796,   569, 18354,  7496, 17740,  6711,   796,   220,   198,  2311])

In [128]:
batch['labels'][0][:10]

tensor([  796,   569, 18354,  7496, 17740,  6711,   796,   220,   198,  2311])

In [130]:
train_dataset[0].keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [131]:
train_dataset[0]['input_ids'][:10]

[796, 569, 18354, 7496, 17740, 6711, 796, 220, 198, 2311]

In [132]:
train_dataset[0]['labels'][:10]

[796, 569, 18354, 7496, 17740, 6711, 796, 220, 198, 2311]

In [18]:
batch['input_ids'].shape

torch.Size([2, 2008])

In [19]:
tokenizer.decode(batch['input_ids'][0])

' = Valkyria Chronicles III = \n Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3, lit. Valkyria of the Battlefield 3 ), commonly referred to as Valkyria Chronicles III outside Japan, is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable. Released in January 2011 in Japan, it is the third game in the Valkyria series. <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " Nameless ", a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven ". \n The game began development in 2010, carrying over a large portion of the work done on Valkyria Chronicles II. While it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more <unk> for series newcomers. Character designer <

In [77]:
for batch in valid_dataloader:
    if batch['input_ids'].shape[0] != 2:
        print(batch['input_ids'].shape)

In [13]:
batch.keys(), batch['input_ids'].shape

(dict_keys(['input_ids', 'attention_mask', 'labels']), torch.Size([2, 1014]))

In [21]:
out = base_model(**batch)

KeyboardInterrupt: 