In [2]:
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import datasets

import math
from matplotlib import pyplot as plt


from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

### Finetune

In [4]:
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')

### load dataset 

In [5]:
class Holder:
    def __init__(self):
        pass

In [6]:
input_seq_len = 512
target_seq_len = 512

num_mem_tokens = 2
input_size = 128

batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.num_mem_tokens = num_mem_tokens
args.input_size = input_size
args.input_prefix = ''
args.block_size = None
args.task_name = 'wikitext-2-v1'

device = 'cpu'

In [7]:
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)



In [8]:
from itertools import chain

raw_datasets = datasets.load_dataset('wikitext', args.task_name)
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on dataset",
)

block_size = args.input_size 
if args.num_mem_tokens is not None:
    block_size -= 2 * args.num_mem_tokens
history_size = args.input_seq_len - block_size

def group_texts(examples, block_size, history_size=None):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    if history_size is None:
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    else:
        result = {
            k: [t[max({0, i - history_size}) : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    result["labels"] = result["input_ids"].copy()
    return result

Found cached dataset wikitext (/home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-389b922bfc5fe729.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-6067a66e735cfbb1.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-941845a5470f2db7.arrow


In [None]:
id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

In [None]:
block_size, history_size

(124, 388)

In [None]:
from torch.nn.utils.rnn import pad_sequence

# def collate_fn(batch):
#     input_ids = [torch.tensor(b['input_ids'][::-1]) for b in batch]
#     labels = [torch.tensor(b['labels'][::-1]) for b in batch]
#     attention_mask = [torch.tensor(b['attention_mask'][::-1]) for b in batch]
#     input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T.flip(1)
#     labels = pad_sequence(labels, padding_value=-100).T.flip(1)
#     attention_mask = pad_sequence(attention_mask, padding_value=0).T.flip(1)

#     collated = {'input_ids': input_ids,
#                 'labels': labels, 
#                 'attention_mask': attention_mask}

#     if input_ids.shape[1] != block_size:
#         labels_mask = torch.ones_like(input_ids, dtype=bool)
#         labels_mask[:, :-block_size] = False
#         collated['labels_mask'] = labels_mask

#     return collated


def collate_fn(batch):
    input_ids = [torch.tensor(b['input_ids']) for b in batch]
    labels = [torch.tensor(b['labels']) for b in batch]
    labels_mask = [torch.ones_like(l, dtype=bool) for l in labels]
    attention_mask = [torch.tensor(b['attention_mask']) for b in batch]

    input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T
    labels = pad_sequence(labels, padding_value=-100).T
    labels_mask = pad_sequence(labels_mask, padding_value=False).T
    attention_mask = pad_sequence(attention_mask, padding_value=0).T

    collated = {'input_ids': input_ids,
                'labels': labels, 
                'labels_mask': labels_mask,
                'attention_mask': attention_mask}

    # if args.vary_n_segments:
    #     n_segments = np.random.randint(1, args.max_n_segments + 1)
    #     n_tokens = n_segments * block_size
    #     for k in collated:
    #         collated[k] = collated[k][:, -n_tokens:]

    return collated


train_dataset = tokenized_datasets["train"].map(lambda x: group_texts(x, block_size, history_size), 
                                        batched=True, desc=f"Grouping train in chunks of {block_size} and history {history_size}")
valid_dataset = tokenized_datasets["validation"].map(lambda x: group_texts(x, block_size), 
                                        batched=True, desc=f"Grouping valid in chunks of {block_size}")


# shuffle train data each epoch (one loop over train_dataset)
# train_sampler = DistributedStrain_dataset[i] for i in range(4)ampler(train_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=True,
#                                     drop_last=False, seed=args.seed)
# per_worker_batch_size = args.batch_size * args.gradient_accumulation_steps
# global_batch_size = per_worker_batch_size * hvd.size()

# train_sampler = RandomSampler(train_dataset)
kwargs = {'pin_memory': True}#, 'num_workers': args.data_n_workers}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, **kwargs)

Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-1504f9373e317eca.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-c6da793e710ea6d8.arrow


In [None]:
b = [train_dataset[i] for i in range(4)]

In [None]:
for k in b[0]:
    b[0][k] = b[0][k][:124]

In [None]:
batch = collate_fn(b)

In [None]:
gen = iter(train_dataloader)
batch = next(gen)
batch = next(gen)
batch = next(gen)
batch['input_ids'].shape

torch.Size([2, 512])

In [None]:
# raw_datasets['train'][1]

### Model

In [None]:
import math
import torch
from torch.nn import CrossEntropyLoss

In [None]:
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

class MemoryCell(torch.nn.Module):
    def __init__(self, base_model, num_mem_tokens):
        super().__init__()
        self.model = base_model
        self.create_memory(num_mem_tokens)

    def create_memory(self, num_mem_tokens):
        self.num_mem_tokens = num_mem_tokens
        embeddings = self.model.get_input_embeddings()
        memory_weights = torch.randn((num_mem_tokens, self.model.config.n_embd)) * embeddings.weight.data.std()
        self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))

        self.read_memory_position = range(num_mem_tokens)
        self.write_memory_position = range(-num_mem_tokens, 0)

    def set_memory(self, input_shape):
        memory = self.memory.repeat(input_shape[0], 1, 1)
        return memory

    def forward(self, input_ids, memory_state=None, **kwargs):
        if memory_state is None:
            memory_state = self.set_memory(input_ids.shape)

        seg_kwargs = self.process_input(input_ids, memory_state, **kwargs)
        out = self.model(**seg_kwargs)
        out, new_memory_state = self.process_output(out, **kwargs)

        return out, new_memory_state
    
    def process_input(self, input_ids, memory_state, **kwargs):
        seg_kwargs = dict(**kwargs)

        inputs_embeds = kwargs.get('inputs_embeds')
        if inputs_embeds is None:
            inputs_embeds = self.model.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([memory_state, inputs_embeds, memory_state], dim=1)

        seg_kwargs['input_ids'] = None
        seg_kwargs['inputs_embeds'] = inputs_embeds
        if kwargs.get('attention_mask') is not None:
            seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], inputs_embeds.shape)
        seg_kwargs['output_hidden_states'] = True
        return seg_kwargs
    
    def pad_attention_mask(self, attention_mask, shape):
        if self.num_mem_tokens in {0, None}:
            return attention_mask
        else:
            mask = torch.ones(*shape[:2], dtype=torch.int64).to(attention_mask.device)
            mask[:, self.num_mem_tokens:-self.num_mem_tokens] = attention_mask
            return mask
    
    def process_output(self, model_outputs, **kwargs):
        if self.num_mem_tokens not in {0, None}:
            out = CausalLMOutputWithCrossAttentions()
            memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens:]
            out['logits'] = model_outputs.logits[:, self.num_mem_tokens:-self.num_mem_tokens]
            
            if kwargs.get('output_hidden_states'):
                out['hidden_states'] = [lh[:, self.num_mem_tokens:-self.num_mem_tokens] for lh in model_outputs.hidden_states]
            if kwargs.get('output_attentions'):
                out['attentions'] = model_outputs['attentions']
        else:
            memory_state = None
            out = model_outputs
            
        return out, memory_state 


In [None]:
class RecurrentWrapper(torch.nn.Module):
    def __init__(self, memory_cell, **rmt_kwargs):
        super().__init__()
        self.memory_cell = memory_cell
        self.rmt_config = rmt_kwargs

    def forward(self, input_ids, labels=None, labels_mask=None, inputs_embeds=None, attention_mask=None, output_attentions=None, output_hidden_states=None):
        memory_state = None
        segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask)

        cell_outputs = []
        for seg_num, segment in enumerate(segmented):
            cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)
            cell_outputs.append(cell_out)
            self.manage_gradients(memory_state, seg_num)

        out = self.process_outputs(cell_outputs, labels=labels, 
                                   labels_mask=labels_mask,
                                   output_attentions=output_attentions, 
                                   output_hidden_states=output_hidden_states)
        return out

    def segment(self, **kwargs):
        segments = []
        for k, tensor in kwargs.items():
            if tensor is not None:
                k_segments = self.split_tensor(tensor)
                for s, k_seg in enumerate(k_segments):
                    if s < len(segments):
                        segments[s][k] = k_seg
                    else:
                        segments.append({k: k_seg})

        return segments
    
    def split_tensor(self, tensor):
        align = self.rmt_config.get('segment_alignment')
        segment_size = self.rmt_config.get('segment_size')
        if align in {'left', None}:
            split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]]
            segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])]
        elif align in {'right', None}:
            split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1]
            segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])]
        elif align == 'center':
            n_seg = math.ceil(tensor.shape[1] / segment_size)
            segments = torch.chunk(tensor, n_seg, dim=1)
        else:
            raise NotImplementedError
        return segments

    def process_outputs(self, cell_outputs, **kwargs):
        out = CausalLMOutputWithCrossAttentions()
        full_logits = torch.cat([o.logits for o in cell_outputs], dim=1)
        full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*[o.hidden_states for o in cell_outputs])])

        labels = kwargs.get('labels')
        if labels is not None:
            shift_labels = labels[..., 1:].contiguous()
            shift_logits = full_logits[..., :-1, :].contiguous()
            flat_labels = shift_labels.view(-1)
            flat_logits = shift_logits.view(-1, shift_logits.size(-1))
            
            loss_fct = CrossEntropyLoss()
            labels_mask = kwargs.get('labels_mask')
            if labels_mask is not None:
                shift_mask = labels_mask[..., :-1].contiguous()

                flat_labels = flat_labels[shift_mask.view(-1)]
                flat_logits = flat_logits[shift_mask.view(-1)]
                
            out['loss'] = loss_fct(flat_logits, flat_labels)

        out['logits'] = full_logits
        segment_keys = ['loss', 'logits']
        if kwargs.get('output_attentions'):
            segment_keys.append('attentions')
        if kwargs.get('output_hidden_states'):
            segment_keys.append('hidden_states')
            out['hidden_states'] = full_hidden_states

        # for seg_num, o in enumerate(cell_outputs):
        #     for key, value in o.items():
        #         if any([sk in key for sk in segment_keys]):
        #             out[f'{key}_{seg_num}'] = value

        return out 
        
    def manage_gradients(self, memory_state, seg_num):
        k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments')
        if seg_num == 0 \
            or k2 in {-1, None} \
            or seg_num + k2 > max_n_segments:
                return True
        
        memory_state = memory_state.detach()
        return False

In [None]:

num_mem_tokens = 10
# device = torch.device(3)
device = 'cpu'

rmt_config = {'num_mem_tokens': 10, 
               #  'max_n_segments': 1,
               #  'segment_alignment': 'right',
               #  'tokenizer': tokenizer,
               #  'memory_layers': 'all', 
               #  'share_memory_layers': True,
               #  'reconstruction_loss_coef': 0.1,
               #  'k1': -1, 'k2': 3,
               #  'segment_ordering': 'regular',
               #  'input_size': 1024, 
               #  'bptt_depth': -1, 
               #  'sum_loss': False,
             }

base_model = AutoModelForCausalLM.from_pretrained(model_name)
cell = MemoryCell(base_model, num_mem_tokens=2)



In [None]:
rmt = RecurrentWrapper(cell, max_n_segments=5, segment_size=124, segment_alignment='center')

In [None]:
gen = iter(train_dataloader)
batch = next(gen)
batch = next(gen)
batch = next(gen)
batch.pop('labels_mask')
# batch.pop('labels')
1

1

In [None]:
rmt_out = rmt(**batch)

In [None]:
rmt_out.keys()

odict_keys(['loss', 'logits'])

# Contract-NLI

In [None]:
from datasets import Dataset

class CNLIDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            self.samples = json.load(f)

    def __getitem__(self, item):
        return self.samples[item]
    
    def __len__(self):
        return len(self.samples)

In [None]:
dataset = CNLIDataset('/cephfs/home/bulatov/bulatov/datasets/contract_nli/contract-nli/test_processed.json')

[Parallel(n_jobs=12)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done   1 tasks      | elapsed:    0.0s
[Parallel(n_jobs=12)]: Batch computation too fast (0.0032s.) Setting batch_size=2.
[Parallel(n_jobs=12)]: Done   8 tasks      | elapsed:    0.0s
[Parallel(n_jobs=12)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=12)]: Batch computation too fast (0.0160s.) Setting batch_size=4.
[Parallel(n_jobs=12)]: Done  28 tasks      | elapsed:    0.0s
[Parallel(n_jobs=12)]: Done  50 tasks      | elapsed:    0.0s
[Parallel(n_jobs=12)]: Done 123 out of 123 | elapsed:    0.1s finished


In [None]:
args.input_seq_len

512

In [None]:
from torch.nn.utils.rnn import pad_sequence

tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens('[GEN]', special_tokens=True)
gen_token = tokenizer.encode('[GEN]')[0]
id_pad_value = tokenizer.eos_token_id

def collate_fn(batch):
    inputs = [b['context'][:args.input_seq_len * 10] for b in batch]
    labels = [b['answer'][:args.input_seq_len * 10] for b in batch]

    collated = {}
    inputs = tokenizer.batch_encode_plus(list(inputs), padding=False)
    labels = tokenizer.batch_encode_plus(list(labels), padding=False)

    full_inputs, labels_mask = [], []
    for inp, lab in zip(inputs['input_ids'], labels['input_ids']):
        inp = inp[:args.input_seq_len - len(lab) - 1]
        full_inputs.append(torch.tensor(inp + [gen_token] + lab))
        labels_mask.append(torch.tensor([False] * len(inp) + [True] * (len(lab) + 1)))

    full_inputs = pad_sequence(full_inputs, padding_value=id_pad_value).T
    labels_mask = pad_sequence(labels_mask, padding_value=False).T

    collated['input_ids'] = collated['labels'] = full_inputs
    collated['labels_mask'] = labels_mask
    collated['attention_mask'] = collated['input_ids'] != id_pad_value
    return collated

In [None]:
batch = collate_fn([dataset[i] for i in range(4)])
# batch = [dataset[i] for i in range(4)]

In [None]:
segmented = rmt.segment(**batch)

In [None]:
[s['input_ids'].shape for s in segmented]

[torch.Size([4, 103]),
 torch.Size([4, 103]),
 torch.Size([4, 103]),
 torch.Size([4, 103]),
 torch.Size([4, 100])]

In [None]:
labels = batch.pop('labels')
rmt_out = rmt(**batch)

In [None]:
rmt_out.logits.shape

torch.Size([4, 512, 50258])

In [140]:
rmt_out.logits[batch['labels_mask']].shape

torch.Size([20, 50258])

In [142]:
labels[batch['labels_mask']].shape

torch.Size([20])

In [143]:
tokenizer.decode([ 3673,    44,  1463,   276, 50256])

'NotMentioned<|endoftext|>'

In [144]:
tokenizer.decode([14539,   603,   434])

'Entailment'

In [None]:
# rmt_out = rmt(**batch)

self = rmt
input_ids = batch['input_ids']

memory_state = None
segmented = self.segment(input_ids=input_ids, inputs_embeds=None, attention_mask=None)

cell_outputs = []
for seg_num, segment in enumerate(segmented):
    cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)#**batch)
    cell_outputs.append(cell_out)
    self.manage_gradients(memory_state, seg_num)

# out = self.process_outputs(cell_outputs)

In [None]:
input_ids.shape

torch.Size([2, 248])

In [None]:
torch.chunk(torch.ones(10), 3)

(tensor([1., 1., 1., 1.]), tensor([1., 1., 1., 1.]), tensor([1., 1.]))

### seq2seq

In [28]:
from huggingface_hub import hf_hub_download
# scrolls_metric_path = hf_hub_download(repo_id="datasets/tau/scrolls", filename="metrics/scrolls.py")tokenizer.add_tokens('[GEN]', special_tokens=True)
gen_token = tokenizer.encode('[GEN]')[0]

rmt.memory_cell.model.resize_token_embeddings(len(tokenizer))

# gen_token = tokenizer.eos_token_id

Embedding(50258, 768)

In [29]:
tokenizer.special_tokens_map

{'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>'}

In [30]:
global_attention_first_token = False  # should be True for LED
encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}
# generate_kwargs = {'max_length': args.targettokenizer.add_tokens('[GEN]', special_tokens=True)
# gen_token = tokenizer.encode('[GEN]')[0]

# rmt.memory_cell.model.resize_token_embeddings(len(tokenizer))_seq_len, 'min_length': args.target_seq_len}
generate_kwargs = {}

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    labels = [b['output'][:args.input_seq_len * 10] for b in batch]

    collated = {}
    inputs = tokenizer.batch_encode_plus(list(inputs), padding=False)
    labels = tokenizer.batch_encode_plus(list(labels), padding=False)

    full_inputs = [torch.tensor(i[:input_size - len(l) - 1] + [gen_token] + l) for i, l in zip(inputs['input_ids'], labels['input_ids'])]
    full_inputs = pad_sequence(full_inputs, padding_value=tokenizer.pad_token_id).T
    
    labels_mask = torch.zeros_like(full_inputs).bool()
    for i, l in enumerate(labels['input_ids']):
        labels_mask[i, -len(l) -1:] = True
from huggingface_hub import hf_hub_download
# scrolls_metric_path = hf_hub_download(repo_id="datasets/tau/scrolls", filename="metrics/scrolls.py")
    collated['input_ids'] = collated['labels'] = full_inputs
    collated['labels_mask'] = labels_mask
    collated['attention_mask'] = collated['input_ids'] != tokenizer.pad_token_id

    return collated
    

In [31]:
seq2seq_task_name = 'quality'
dataset = datasets.load_dataset('tau/scrolls', seq2seq_task_name)
train_dataset = dataset['train']
# shuffle train data each epoch (one loop over train_dataset)
# train_sampler = DistributedSampler(train_dataset, shuffle=True, drop_last=False, seed=args.seed)
kwargs = {'pin_memory': True}
train_dataloader = DataLoader(train_dataset, batch_size=2,
                                collate_fn=collate_fn, **kwargs)

Found cached dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/quality/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
gen = iter(train_dataloader)
batch = next(gen)

In [None]:
batch.keys()

dict_keys(['input_ids', 'labels', 'labels_mask', 'attention_mask'])

In [None]:
batch['input_ids'].shape

torch.Size([2, 128])

In [None]:
batch['labels'].shape

torch.Size([2, 128])

In [None]:
tokenizer.batch_decode(batch['input_ids'])

['Why is Si retirement so significant to the Space Exploration Team? \n\n (A) There aren’t enough working people in the world. They won’t be able to find a replacement.\n (B) As one of two remaining spacemen, it would likely mean the defunding and shut down of the Space Exploration Team.\n (C) Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.\n (D) His retirement may inspire others to[GEN]Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.',
 'What makes Gubelin an outlier in the present day?\n\n (A) He is much older than the rest of the population.\n (B) He refuses new operations that could improve his health.\n (C) His mind is still active, and he values hard work.\n (D) He still wears glasses and value objects like the gold watch given to Si.\n\n\nSPACEMAN ON A SPREE\n\n\n\n\n   BY MACK REYNOLDS\n\n\n\n\n   Illustrated by Nodel\n\n\n\n\n[GEN]He still wears glasses and value objects like the g

In [None]:
tokenizer.batch_decode([c[i] for c, i in zip(batch['input_ids'], batch['labels_mask'])])

['[GEN]Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.',
 '[GEN]He still wears glasses and value objects like the gold watch given to Si.']

In [None]:
tokenizer.batch_decode([c[i] for c, i in zip(batch['input_ids'], batch['attention_mask'])])

['Why is Si retirement so significant to the Space Exploration Team? \n\n (A) There aren’t enough working people in the world. They won’t be able to find a replacement.\n (B) As one of two remaining spacemen, it would likely mean the defunding and shut down of the Space Exploration Team.\n (C) Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.\n (D) His retirement may inspire others toTraining new spacemen is costly and time consuming. They won’t have anyone else ready after him.',
 'What makes Gubelin an outlier in the present day?\n\n (A) He is much older than the rest of the population.\n (B) He refuses new operations that could improve his health.\n (C) His mind is still active, and he values hard work.\n (D) He still wears glasses and value objects like the gold watch given to Si.\n\n\nSPACEMAN ON A SPREE\n\n\n\n\n   BY MACK REYNOLDS\n\n\n\n\n   Illustrated by Nodel\n\n\n\n\nHe still wears glasses and value objects like the gold watch 

In [None]:
out = rmt(**batch)

In [None]:
out.loss

tensor(60.2067, grad_fn=<NllLossBackward>)

In [None]:
out.loss

tensor(61.4377, grad_fn=<NllLossBackward>)