In [85]:
import torch
from transformers import BartForConditionalGeneration, BartConfig, BartTokenizerFast

In [86]:
model_class = BartForConditionalGeneration
config_class = BartConfig
model_fp = 'facebook/bart-base'

config = config_class.from_pretrained(model_fp)  

model = model_class.from_pretrained(model_fp)
tokenizer = BartTokenizerFast.from_pretrained(model_fp)

In [87]:
example_english_phrase = "The fifth beaker has 2 orange " 
batch = tokenizer(example_english_phrase, return_tensors='pt')
generated_ids = model.generate(batch['input_ids'], 
                               pad_token_id=tokenizer.pad_token_id, 
                               eos_token_id=tokenizer.eos_token_id, #len(tokenizer) - 1, 
                               bos_token_id=tokenizer.bos_token_id,  
                              )

In [88]:
tokenizer.batch_decode(generated_ids, skip_special_tokens=False)

['</s><s>The fifth beaker has 2 orange </s>']

In [89]:
prompt = ["Hello world"]
output_text = ["This is the output of my first program in C<pad>"]

In [90]:
model = model.train()

In [91]:
encoded = tokenizer.batch_encode_plus(prompt, return_tensors='pt', padding=True)
decoded = tokenizer.batch_encode_plus(output_text, return_tensors='pt', padding=True, add_special_tokens=False)
# model(encoded)

In [97]:
from transformers.models.bart.modeling_bart import shift_tokens_right

In [98]:
def shift_tokens_right(input_ids, pad_token_id):
  """ Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
      This is taken directly from modeling_bart.py
  """
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

In [99]:
print(decoded['input_ids'])

tensor([[ 713,   16,    5, 4195,    9,  127,   78,  586,   11,  230,    1]])


In [100]:
output = shift_tokens_right(decoded['input_ids'], pad_token_id=tokenizer.pad_token_id)

In [101]:
tokenizer.batch_decode(output)

[' CThis is the output of my first program in C']