In [1]:
# from huggingface_hub import hf_hub_download
# scrolls_metric_path = hf_hub_download(repo_id="datasets/tau/scrolls", filename="metrics/scrolls.py")

In [2]:
import numpy as np
import os
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import datasets

import math
from matplotlib import pyplot as plt


from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

### Finetune

In [4]:
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')

### load dataset 

In [5]:
class Holder:
    def __init__(self):
        pass

In [6]:
input_seq_len = 512
target_seq_len = 512

num_mem_tokens = 2
input_size = 128

batch_size = 2

args = Holder
args.target_seq_len = target_seq_len
args.input_seq_len = input_seq_len
args.num_mem_tokens = num_mem_tokens
args.input_size = input_size
args.input_prefix = ''
args.block_size = None
args.task_name = 'wikitext-2-v1'

device = 'cpu'

In [7]:
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)



In [8]:
from itertools import chain

raw_datasets = datasets.load_dataset('wikitext', args.task_name)
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on dataset",
)

block_size = args.input_size 
if args.num_mem_tokens is not None:
    block_size -= 2 * args.num_mem_tokens
history_size = args.input_seq_len - block_size

def group_texts(examples, block_size, history_size=None):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    if history_size is None:
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    else:
        result = {
            k: [t[max({0, i - history_size}) : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    result["labels"] = result["input_ids"].copy()
    return result

Found cached dataset wikitext (/home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-389b922bfc5fe729.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-6067a66e735cfbb1.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-941845a5470f2db7.arrow


In [9]:
id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

In [10]:
block_size, history_size

(124, 388)

In [11]:
from torch.nn.utils.rnn import pad_sequence

# def collate_fn(batch):
#     input_ids = [torch.tensor(b['input_ids'][::-1]) for b in batch]
#     labels = [torch.tensor(b['labels'][::-1]) for b in batch]
#     attention_mask = [torch.tensor(b['attention_mask'][::-1]) for b in batch]
#     input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T.flip(1)
#     labels = pad_sequence(labels, padding_value=-100).T.flip(1)
#     attention_mask = pad_sequence(attention_mask, padding_value=0).T.flip(1)

#     collated = {'input_ids': input_ids,
#                 'labels': labels, 
#                 'attention_mask': attention_mask}

#     if input_ids.shape[1] != block_size:
#         labels_mask = torch.ones_like(input_ids, dtype=bool)
#         labels_mask[:, :-block_size] = False
#         collated['labels_mask'] = labels_mask

#     return collated


def collate_fn(batch):
    input_ids = [torch.tensor(b['input_ids']) for b in batch]
    labels = [torch.tensor(b['labels']) for b in batch]
    labels_mask = [torch.ones_like(l, dtype=bool) for l in labels]
    attention_mask = [torch.tensor(b['attention_mask']) for b in batch]

    input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T
    labels = pad_sequence(labels, padding_value=-100).T
    labels_mask = pad_sequence(labels_mask, padding_value=False).T
    attention_mask = pad_sequence(attention_mask, padding_value=0).T

    collated = {'input_ids': input_ids,
                'labels': labels, 
                'labels_mask': labels_mask,
                'attention_mask': attention_mask}

    # if args.vary_n_segments:
    #     n_segments = np.random.randint(1, args.max_n_segments + 1)
    #     n_tokens = n_segments * block_size
    #     for k in collated:
    #         collated[k] = collated[k][:, -n_tokens:]

    return collated


train_dataset = tokenized_datasets["train"].map(lambda x: group_texts(x, block_size, history_size), 
                                        batched=True, desc=f"Grouping train in chunks of {block_size} and history {history_size}")
valid_dataset = tokenized_datasets["validation"].map(lambda x: group_texts(x, block_size), 
                                        batched=True, desc=f"Grouping valid in chunks of {block_size}")


# shuffle train data each epoch (one loop over train_dataset)
# train_sampler = DistributedStrain_dataset[i] for i in range(4)ampler(train_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=True,
#                                     drop_last=False, seed=args.seed)
# per_worker_batch_size = args.batch_size * args.gradient_accumulation_steps
# global_batch_size = per_worker_batch_size * hvd.size()

# train_sampler = RandomSampler(train_dataset)
kwargs = {'pin_memory': True}#, 'num_workers': args.data_n_workers}
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, **kwargs)

Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-1504f9373e317eca.arrow
Loading cached processed dataset at /home/bulatov/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-c6da793e710ea6d8.arrow


In [12]:
b = [train_dataset[i] for i in range(4)]

In [13]:
for k in b[0]:
    b[0][k] = b[0][k][:124]

In [14]:
batch = collate_fn(b)

In [15]:
gen = iter(train_dataloader)
batch = next(gen)
batch = next(gen)
batch = next(gen)
batch['input_ids'].shape

torch.Size([2, 512])

In [16]:
# raw_datasets['train'][1]

### Model

In [17]:
import math
import torch
from torch.nn import CrossEntropyLoss

In [18]:
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

class MemoryCell(torch.nn.Module):
    def __init__(self, base_model, num_mem_tokens):
        super().__init__()
        self.model = base_model
        self.create_memory(num_mem_tokens)

    def create_memory(self, num_mem_tokens):
        self.num_mem_tokens = num_mem_tokens
        embeddings = self.model.get_input_embeddings()
        memory_weights = torch.randn((num_mem_tokens, self.model.config.n_embd)) * embeddings.weight.data.std()
        self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))

        self.read_memory_position = range(num_mem_tokens)
        self.write_memory_position = range(-num_mem_tokens, 0)

    def set_memory(self, input_shape):
        memory = self.memory.repeat(input_shape[0], 1, 1)
        return memory

    def forward(self, input_ids, memory_state=None, **kwargs):
        if memory_state is None:
            memory_state = self.set_memory(input_ids.shape)

        seg_kwargs = self.process_input(input_ids, memory_state, **kwargs)
        out = self.model(**seg_kwargs)
        out, new_memory_state = self.process_output(out, **kwargs)

        return out, new_memory_state
    
    def process_input(self, input_ids, memory_state, **kwargs):
        seg_kwargs = dict(**kwargs)

        inputs_embeds = kwargs.get('inputs_embeds')
        if inputs_embeds is None:
            inputs_embeds = self.model.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([memory_state, inputs_embeds, memory_state], dim=1)

        seg_kwargs['input_ids'] = None
        seg_kwargs['inputs_embeds'] = inputs_embeds
        if kwargs.get('attention_mask') is not None:
            seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], inputs_embeds.shape)
        seg_kwargs['output_hidden_states'] = True
        return seg_kwargs
    
    def pad_attention_mask(self, attention_mask, shape):
        if self.num_mem_tokens in {0, None}:
            return attention_mask
        else:
            mask = torch.ones(*shape[:2], dtype=torch.int64).to(attention_mask.device)
            mask[:, self.num_mem_tokens:-self.num_mem_tokens] = attention_mask
            return mask
    
    def process_output(self, model_outputs, **kwargs):
        if self.num_mem_tokens not in {0, None}:
            out = CausalLMOutputWithCrossAttentions()
            memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens:]
            out['logits'] = model_outputs.logits[:, self.num_mem_tokens:-self.num_mem_tokens]
            
            if kwargs.get('output_hidden_states'):
                out['hidden_states'] = [lh[:, self.num_mem_tokens:-self.num_mem_tokens] for lh in model_outputs.hidden_states]
            if kwargs.get('output_attentions'):
                out['attentions'] = model_outputs['attentions']
        else:
            memory_state = None
            out = model_outputs
            
        return out, memory_state 


In [19]:
class RecurrentWrapper(torch.nn.Module):
    def __init__(self, memory_cell, **rmt_kwargs):
        super().__init__()
        self.memory_cell = memory_cell
        self.rmt_config = rmt_kwargs

    def forward(self, input_ids, labels=None, labels_mask=None, inputs_embeds=None, attention_mask=None, output_attentions=None, output_hidden_states=None):
        memory_state = None
        segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask)

        cell_outputs = []
        for seg_num, segment in enumerate(segmented):
            cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)
            cell_outputs.append(cell_out)
            self.manage_gradients(memory_state, seg_num)

        out = self.process_outputs(cell_outputs, labels=labels, 
                                   labels_mask=labels_mask,
                                   output_attentions=output_attentions, 
                                   output_hidden_states=output_hidden_states)
        return out

    def segment(self, **kwargs):
        segments = []
        for k, tensor in kwargs.items():
            if tensor is not None:
                k_segments = self.split_tensor(tensor)
                for s, k_seg in enumerate(k_segments):
                    if s < len(segments):
                        segments[s][k] = k_seg
                    else:
                        segments.append({k: k_seg})

        return segments
    
    def split_tensor(self, tensor):
        align = self.rmt_config.get('segment_alignment')
        segment_size = self.rmt_config.get('segment_size')
        if align in {'left', None}:
            split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]]
            segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])]
        elif align in {'right', None}:
            split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1]
            segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])]
        elif align == 'center':
            n_seg = math.ceil(tensor.shape[1] / segment_size)
            segments = torch.chunk(tensor, n_seg, dim=1)
        else:
            raise NotImplementedError
        return segments

    def process_outputs(self, cell_outputs, **kwargs):
        out = CausalLMOutputWithCrossAttentions()
        full_logits = torch.cat([o.logits for o in cell_outputs], dim=1)
        full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*[o.hidden_states for o in cell_outputs])])

        labels = kwargs.get('labels')
        if labels is not None:
            shift_labels = labels[..., 1:].contiguous()
            shift_logits = full_logits[..., :-1, :].contiguous()
            flat_labels = shift_labels.view(-1)
            flat_logits = shift_logits.view(-1, shift_logits.size(-1))
            
            loss_fct = CrossEntropyLoss()
            labels_mask = kwargs.get('labels_mask')
            if labels_mask is not None:
                shift_mask = labels_mask[..., :-1].contiguous()

                flat_labels = flat_labels[shift_mask.view(-1)]
                flat_logits = flat_logits[shift_mask.view(-1)]
                
            out['loss'] = loss_fct(flat_logits, flat_labels)

        out['logits'] = full_logits
        segment_keys = ['loss', 'logits']
        if kwargs.get('output_attentions'):
            segment_keys.append('attentions')
        if kwargs.get('output_hidden_states'):
            segment_keys.append('hidden_states')
            out['hidden_states'] = full_hidden_states

        # for seg_num, o in enumerate(cell_outputs):
        #     for key, value in o.items():
        #         if any([sk in key for sk in segment_keys]):
        #             out[f'{key}_{seg_num}'] = value

        return out 
        
    def manage_gradients(self, memory_state, seg_num):
        k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments')
        if seg_num == 0 \
            or k2 in {-1, None} \
            or seg_num + k2 > max_n_segments:
                return True
        
        memory_state = memory_state.detach()
        return False

In [20]:

num_mem_tokens = 10
# device = torch.device(3)
device = 'cpu'

rmt_config = {'num_mem_tokens': 10, 
               #  'max_n_segments': 1,
               #  'segment_alignment': 'right',
               #  'tokenizer': tokenizer,
               #  'memory_layers': 'all', 
               #  'share_memory_layers': True,
               #  'reconstruction_loss_coef': 0.1,
               #  'k1': -1, 'k2': 3,
               #  'segment_ordering': 'regular',
               #  'input_size': 1024, 
               #  'bptt_depth': -1, 
               #  'sum_loss': False,
             }

base_model = AutoModelForCausalLM.from_pretrained(model_name)
cell = MemoryCell(base_model, num_mem_tokens=2)



In [21]:
rmt = RecurrentWrapper(cell, max_n_segments=5, segment_size=124, segment_alignment='center')

In [22]:
gen = iter(train_dataloader)
batch = next(gen)
batch = next(gen)
batch = next(gen)
# batch.pop('labels_mask')
# batch.pop('labels')
1

1

In [23]:
rmt_out = rmt(**batch)

In [24]:
memory_state.shape

NameError: name 'memory_state' is not defined

ModuleAttributeError: 'RecurrentWrapper' object has no attribute 'generate'

In [25]:
gen

<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x7f7bec027eb0>

In [23]:
rmt_out.keys()

odict_keys(['loss', 'logits'])

In [24]:
for k in rmt_out:
    if 'loss' in k:
        print(k, rmt_out[k] )

loss tensor(11.4472, grad_fn=<NllLossBackward>)


In [309]:
segmented = rmt.segment(input_ids=input_ids)

In [312]:
[s['input_ids'].shape for s in segmented]

[torch.Size([2, 103]),
 torch.Size([2, 103]),
 torch.Size([2, 103]),
 torch.Size([2, 103]),
 torch.Size([2, 100])]

In [269]:
# rmt_out = rmt(**batch)

self = rmt
input_ids = batch['input_ids']

memory_state = None
segmented = self.segment(input_ids=input_ids, inputs_embeds=None, attention_mask=None)

cell_outputs = []
for seg_num, segment in enumerate(segmented):
    cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)#**batch)
    cell_outputs.append(cell_out)
    self.manage_gradients(memory_state, seg_num)

# out = self.process_outputs(cell_outputs)

In [215]:
input_ids.shape

torch.Size([2, 248])

In [297]:
torch.chunk(torch.ones(10), 3)

(tensor([1., 1., 1., 1.]), tensor([1., 1., 1., 1.]), tensor([1., 1.]))

### seq2seq

In [None]:
tokenizer.add_tokens('[GEN]', special_tokens=True)
gen_token = tokenizer.encode('[GEN]')[0]

rmt.memory_cell.model.resize_token_embeddings(len(tokenizer))

# gen_token = tokenizer.eos_token_id

Embedding(50258, 768)

In [None]:
tokenizer.special_tokens_map

{'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>'}

In [None]:
global_attention_first_token = False  # should be True for LED
encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}
# generate_kwargs = {'max_length': args.targettokenizer.add_tokens('[GEN]', special_tokens=True)
# gen_token = tokenizer.encode('[GEN]')[0]

# rmt.memory_cell.model.resize_token_embeddings(len(tokenizer))_seq_len, 'min_length': args.target_seq_len}
generate_kwargs = {}

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    labels = [b['output'][:args.input_seq_len * 10] for b in batch]

    collated = {}
    inputs = tokenizer.batch_encode_plus(list(inputs), padding=False)
    labels = tokenizer.batch_encode_plus(list(labels), padding=False)

    full_inputs = [torch.tensor(i[:input_size - len(l) - 1] + [gen_token] + l) for i, l in zip(inputs['input_ids'], labels['input_ids'])]
    full_inputs = pad_sequence(full_inputs, padding_value=tokenizer.pad_token_id).T
    
    labels_mask = torch.zeros_like(full_inputs).bool()
    for i, l in enumerate(labels['input_ids']):
        labels_mask[i, -len(l) -1:] = True

    collated['input_ids'] = collated['labels'] = full_inputs
    collated['labels_mask'] = labels_mask
    collated['attention_mask'] = collated['input_ids'] != tokenizer.pad_token_id

    return collated
    

In [None]:
seq2seq_task_name = 'quality'
dataset = datasets.load_dataset('tau/scrolls', seq2seq_task_name)
train_dataset = dataset['train']
# shuffle train data each epoch (one loop over train_dataset)
# train_sampler = DistributedSampler(train_dataset, shuffle=True, drop_last=False, seed=args.seed)
kwargs = {'pin_memory': True}
train_dataloader = DataLoader(train_dataset, batch_size=2,
                                collate_fn=collate_fn, **kwargs)



Downloading readme: 0.00B [00:00, ?B/s]

Found cached dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/quality/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
gen = iter(train_dataloader)
batch = next(gen)

Token indices sequence length is longer than the specified maximum sequence length for this model (1291 > 1024). Running this sequence through the model will result in indexing errors


In [None]:
batch.keys()

dict_keys(['input_ids', 'labels', 'labels_mask', 'attention_mask'])

In [25]:
batch['input_ids'].shape

torch.Size([2, 512])

In [26]:
batch['labels'].shape

torch.Size([2, 512])

In [42]:
tokenizer.batch_decode(batch['input_ids'])

['Why is Si retirement so significant to the Space Exploration Team? \n\n (A) There aren’t enough working people in the world. They won’t be able to find a replacement.\n (B) As one of two remaining spacemen, it would likely mean the defunding and shut down of the Space Exploration Team.\n (C) Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.\n (D) His retirement may inspire others to[GEN]Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.',
 'What makes Gubelin an outlier in the present day?\n\n (A) He is much older than the rest of the population.\n (B) He refuses new operations that could improve his health.\n (C) His mind is still active, and he values hard work.\n (D) He still wears glasses and value objects like the gold watch given to Si.\n\n\nSPACEMAN ON A SPREE\n\n\n\n\n   BY MACK REYNOLDS\n\n\n\n\n   Illustrated by Nodel\n\n\n\n\n[GEN]He still wears glasses and value objects like the g

In [43]:
tokenizer.batch_decode([c[i] for c, i in zip(batch['input_ids'], batch['labels_mask'])])

['[GEN]Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.',
 '[GEN]He still wears glasses and value objects like the gold watch given to Si.']

In [44]:
tokenizer.batch_decode([c[i] for c, i in zip(batch['input_ids'], batch['attention_mask'])])

['Why is Si retirement so significant to the Space Exploration Team? \n\n (A) There aren’t enough working people in the world. They won’t be able to find a replacement.\n (B) As one of two remaining spacemen, it would likely mean the defunding and shut down of the Space Exploration Team.\n (C) Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.\n (D) His retirement may inspire others to[GEN]Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.',
 'What makes Gubelin an outlier in the present day?\n\n (A) He is much older than the rest of the population.\n (B) He refuses new operations that could improve his health.\n (C) His mind is still active, and he values hard work.\n (D) He still wears glasses and value objects like the gold watch given to Si.\n\n\nSPACEMAN ON A SPREE\n\n\n\n\n   BY MACK REYNOLDS\n\n\n\n\n   Illustrated by Nodel\n\n\n\n\n[GEN]He still wears glasses and value objects like the g

In [48]:
self = rmt
input_ids = batch['input_ids']
labels = None
labels_mask = batch['labels_mask']
attention_mask = batch['attention_mask']
output_attentions = output_hidden_states = True
inputs_embeds = None

memory_state = None
segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask)

cell_outputs = []
for seg_num, segment in enumerate(segmented):
    cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)
    cell_outputs.append(cell_out)
    self.manage_gradients(memory_state, seg_num)

out = self.process_outputs(cell_outputs, labels=labels, 
                            labels_mask=labels_mask,
                            output_attentions=output_attentions, 
                            output_hidden_states=output_hidden_states)

# gen = rmt.generate(batch['input_ids'])


In [57]:
tokenizer.decode(input_ids[0][labels_mask[0]])

'[GEN]Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.'

In [58]:
tokenizer.decode(input_ids[0][~labels_mask[0]])

'Why is Si retirement so significant to the Space Exploration Team? \n\n (A) There aren’t enough working people in the world. They won’t be able to find a replacement.\n (B) As one of two remaining spacemen, it would likely mean the defunding and shut down of the Space Exploration Team.\n (C) Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.\n (D) His retirement may inspire others to'

In [61]:
tasks = [ids[~mask] for ids, mask in zip(input_ids, labels_mask)]
answers = [ids[mask] for ids, mask in zip(input_ids, labels_mask)]

In [63]:
[len(t) for t in tasks]

[105, 112]

In [72]:
tasks = [torch.cat((ids[~mask], torch.tensor([gen_token]))) for ids, mask in zip(input_ids, labels_mask)]

In [74]:
tokenizer.decode(tasks[0])

'Why is Si retirement so significant to the Space Exploration Team? \n\n (A) There aren’t enough working people in the world. They won’t be able to find a replacement.\n (B) As one of two remaining spacemen, it would likely mean the defunding and shut down of the Space Exploration Team.\n (C) Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.\n (D) His retirement may inspire others to[GEN]'

In [None]:
def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
        encoder_no_repeat_ngram_size: Optional[int] = None,
        num_return_sequences: Optional[int] = None,
        max_time: Optional[float] = None,
        max_new_tokens: Optional[int] = None,
        decoder_start_token_id: Optional[int] = None,
        use_cache: Optional[bool] = None,
        num_beam_groups: Optional[int] = None,
        diversity_penalty: Optional[float] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
        stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
        constraints: Optional[List[Constraint]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
        remove_invalid_values: Optional[bool] = None,
        synced_gpus: Optional[bool] = False,
        **model_kwargs,
    ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
        r"""
        Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
        multinomial sampling, beam-search decoding, and beam-search multinomial sampling.

        Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name inside
        the [`PretrainedConfig`] of the model. The default values indicated are the default values of those config.

        Most of these parameters are explained in more detail in [this blog
        post](https://huggingface.co/blog/how-to-generate).

        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length,
            feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
            max_length (`int`, *optional*, defaults to `model.config.max_length`):
                The maximum length of the sequence to be generated.
            max_new_tokens (`int`, *optional*, defaults to None):
                The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
                `max_new_tokens` or `max_length` but not both, they serve the same purpose.
            min_length (`int`, *optional*, defaults to 10):
                The minimum length of the sequence to be generated.
            do_sample (`bool`, *optional*, defaults to `False`):
                Whether or not to use sampling ; use greedy decoding otherwise.
            early_stopping (`bool`, *optional*, defaults to `False`):
                Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
            num_beams (`int`, *optional*, defaults to 1):
                Number of beams for beam search. 1 means no beam search.
            temperature (`float`, *optional*, defaults to 1.0):
                The value used to module the next token probabilities.
            top_k (`int`, *optional*, defaults to 50):
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
            top_p (`float`, *optional*, defaults to 1.0):
                If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
                are kept for generation.
            repetition_penalty (`float`, *optional*, defaults to 1.0):
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            bos_token_id (`int`, *optional*):
                The id of the *beginning-of-sequence* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            length_penalty (`float`, *optional*, defaults to 1.0):
                Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
                model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
                sequences.
            no_repeat_ngram_size (`int`, *optional*, defaults to 0):
                If set to int > 0, all ngrams of that size can only occur once.
            encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
                If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
                `decoder_input_ids`.
            bad_words_ids(`List[List[int]]`, *optional*):
                List of token ids that are not allowed to be generated. In order to get the token ids of the words that
                should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
                add_special_tokens=False).input_ids`.
            num_return_sequences(`int`, *optional*, defaults to 1):
                The number of independently computed returned sequences for each element in the batch.
            max_time(`float`, *optional*, defaults to None):
                The maximum amount of time you allow the computation to run for in seconds. generation will still
                finish the current pass after allocated time has been passed.
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
                that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
                as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
            decoder_start_token_id (`int`, *optional*):
                If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
            use_cache: (`bool`, *optional*, defaults to `True`):
                Whether or not the model should use the past last key/values attentions (if applicable to the model) to
                speed up decoding.
            num_beam_groups (`int`, *optional*, defaults to 1):
                Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
                beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
            diversity_penalty (`float`, *optional*, defaults to 0.0):
                This value is subtracted from a beam's score if it generates a token same as any beam from other group
                at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
                enabled.
            prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
            logits_processor (`LogitsProcessorList`, *optional*):
                 Custom logits processors that complement the default logits processors built from arguments and a
                 model's config. If a logit processor is passed that is already created with the arguments or a model's
                 config an error is thrown. This feature is intended for advanced users.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                 Custom stopping criteria that complement the default stopping criteria built from arguments and a
                 model's config. If a stopping criteria is passed that is already created with the arguments or a
                 model's config an error is thrown. This feature is intended for advanced users.
            constraints (`List[Constraint]`, *optional*):
                 Custom constraints that can be added to the generation to ensure that the output will contain the use
                 of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
            forced_bos_token_id (`int`, *optional*):
                The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
                for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
                the target language token.
            forced_eos_token_id (`int`, *optional*):
                The id of the token to force as the last generated token when `max_length` is reached.
            remove_invalid_values (`bool`, *optional*):
                Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
                crash. Note that using `remove_invalid_values` can slow down generation.
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
                is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
                should be prefixed with *decoder_*.

        Return:
            [`~file_utils.ModelOutput`] or `torch.LongTensor`: A [`~file_utils.ModelOutput`] (if
            `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.

                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
                [`~file_utils.ModelOutput`] types are:

                    - [`~generation_utils.GreedySearchDecoderOnlyOutput`],
                    - [`~generation_utils.SampleDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSearchDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSampleDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~file_utils.ModelOutput`] types are:

                    - [`~generation_utils.GreedySearchEncoderDecoderOutput`],
                    - [`~generation_utils.SampleEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSearchEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSampleEncoderDecoderOutput`]

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

        >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
        >>> # do greedy decoding without providing a prompt
        >>> outputs = model.generate(max_length=40)
        >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
        >>> document = (
        ...     "at least two people were killed in a suspected bomb attack on a passenger bus "
        ...     "in the strife-torn southern philippines on monday , the military said."
        ... )
        >>> # encode input context
        >>> input_ids = tokenizer(document, return_tensors="pt").input_ids
        >>> # generate 3 independent sequences using beam search decoding (5 beams)
        >>> # with T5 encoder-decoder model conditioned on short news article.
        >>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
        >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))

        >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
        >>> input_context = "The dog"
        >>> # encode input context
        >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
        >>> # generate 3 candidates using sampling
        >>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
        >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))

        >>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
        >>> model = AutoModelForCausalLM.from_pretrained("ctrl")
        >>> # "Legal" is one of the control codes for ctrl
        >>> input_context = "Legal My neighbor is"
        >>> # encode input context
        >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
        >>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
        >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False)
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
        >>> input_context = "My cute dog"
        >>> # get tokens of words that should not be generated
        >>> bad_words_ids = tokenizer(
        ...     ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
        >>> ).input_ids
        >>> # encode input context
        >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
        >>> # generate sequences without allowing bad_words to be generated
        >>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
        >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
        ```"""
        # 1. Set generation parameters if not already defined
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
        num_beams = num_beams if num_beams is not None else self.config.num_beams
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
        num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )

        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

        if eos_token_id is None and hasattr(self.config, "decoder"):
            eos_token_id = self.config.decoder.eos_token_id

        if pad_token_id is None and eos_token_id is not None:
            # special case if pad_token_id is not defined
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            pad_token_id = eos_token_id

        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # 2. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
        batch_size = inputs_tensor.shape[0]

        # 3. Define other model kwargs
        model_kwargs["output_attentions"] = output_attentions
        model_kwargs["output_hidden_states"] = output_hidden_states
        model_kwargs["use_cache"] = use_cache

        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, pad_token_id, eos_token_id
            )

        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )

        # 4. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=decoder_start_token_id,
                bos_token_id=bos_token_id,
                model_kwargs=model_kwargs,
            )
        else:
            # if decoder-only then inputs_tensor has to be `input_ids`
            input_ids = inputs_tensor

        # 5. Prepare `max_length` depending on other stopping criteria
        # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
        if max_length is None and max_new_tokens is not None:
            max_length = max_new_tokens + input_ids.shape[-1]
        elif max_length is not None and max_new_tokens is not None:
            # Both are set, this is odd, raise a warning
            warnings.warn(
                "Both `max_length` and `max_new_tokens` have been set "
                f"but they serve the same purpose. `max_length` {max_length} "
                f"will take priority over `max_new_tokens` {max_new_tokens}.",
                UserWarning,
            )
        # default to config if still None
        max_length = max_length if max_length is not None else self.config.max_length

        if input_ids.shape[-1] >= max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}. "
                "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
            )

        # 6. determine generation mode
        is_constraint_gen_mode = constraints is not None
        is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
        is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
        is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
        is_beam_sample_gen_mode = (
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
        )
        is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and constraints is None

        if num_beam_groups > num_beams:
            raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
        if is_group_beam_gen_mode and do_sample is True:
            raise ValueError(
                "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
            )

        # 7. prepare distribution pre_processing samplers
        logits_processor = self._get_logits_processor(
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
            encoder_input_ids=inputs_tensor,
            bad_words_ids=bad_words_ids,
            min_length=min_length,
            max_length=max_length,
            eos_token_id=eos_token_id,
            forced_bos_token_id=forced_bos_token_id,
            forced_eos_token_id=forced_eos_token_id,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            num_beams=num_beams,
            num_beam_groups=num_beam_groups,
            diversity_penalty=diversity_penalty,
            remove_invalid_values=remove_invalid_values,
            logits_processor=logits_processor,
        )

        # 8. prepare stopping criteria
        stopping_criteria = self._get_stopping_criteria(
            max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
        )

        # 9. go into different generation modes
        if is_greedy_gen_mode:
            if num_return_sequences > 1:
                raise ValueError(
                    f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
                )

            # 10. run greedy search
            return self.greedy_search(
                input_ids,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif is_sample_gen_mode:
            # 10. prepare logits warper
            logits_warper = self._get_logits_warper(
                top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
            )

            # 11. expand input_ids with `num_return_sequences` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # 12. run sample
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif is_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            # 10. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
            # 11. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
            # 12. run beam search
            return self.beam_search(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif is_beam_sample_gen_mode:
            # 10. prepare logits warper
            logits_warper = self._get_logits_warper(
                top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
            )

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")
            # 11. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size * num_return_sequences,
                num_beams=num_beams,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
            )

            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams * num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # 13. run beam sample
            return self.beam_sample(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif is_group_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if num_beams % num_beam_groups != 0:
                raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            # 10. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                max_length=stopping_criteria.max_length,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
                num_beam_groups=num_beam_groups,
            )
            # 11. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
            # 12. run beam search
            return self.group_beam_search(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif is_constraint_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            if num_beams <= 1:
                raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.")

            if do_sample:
                raise ValueError("`do_sample` needs to be false for constrained generation.")

            if num_beam_groups is not None and num_beam_groups > 1:
                raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

            # 10. prepare beam search scorer
            constrained_beam_scorer = ConstrainedBeamSearchScorer(
                constraints=constraints,
                batch_size=batch_size,
                num_beams=num_beams,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
            # 11. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
            # 12. run beam search
            return self.constrained_beam_search(
                input_ids,
                constrained_beam_scorer=constrained_beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

In [75]:
tasks

[tensor([ 5195,   318, 15638, 10737,   523,  2383,   284,   262,  4687, 36806,
          4816,    30,   220,   628,   357,    32,     8,  1318,  3588,   447,
           247,    83,  1576,  1762,   661,   287,   262,   995,    13,  1119,
          1839,   447,   247,    83,   307,  1498,   284,  1064,   257,  9014,
            13,   198,   357,    33,     8,  1081,   530,   286,   734,  5637,
         34752,  8952,    11,   340,   561,  1884,  1612,   262, 47613,   278,
           290,  4423,   866,   286,   262,  4687, 36806,  4816,    13,   198,
           357,    34,     8, 13614,   649, 34752,  8952,   318, 16378,   290,
           640, 18587,    13,  1119,  1839,   447,   247,    83,   423,  2687,
          2073,  3492,   706,   683,    13,   198,   357,    35,     8,  2399,
         10737,   743, 18330,  1854,   284, 50257]),
 tensor([ 2061,  1838,   402,   549, 27176,   281,   503,  2505,   287,   262,
          1944,  1110,    30,   628,   357,    32,     8,   679,   318,   881,

In [53]:
out.logits.shape

torch.Size([2, 128, 50258])

In [101]:
out.loss

tensor(60.2067, grad_fn=<NllLossBackward>)

In [105]:
out.loss

tensor(61.4377, grad_fn=<NllLossBackward>)