In [1]:
!pip install transformers
!pip install datasets


Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 6.5MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 25.9MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 37.3MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=5003dd2e698

In [2]:
from transformers import BartTokenizer, BartForConditionalGeneration

In [4]:
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

# INPUT = "My friends are cool but they eat too many carbs."
# input = tokenizer([INPUT], max_length=500, truncation=True, padding=True, return_tensors='pt')

# #generate summary
# summary_ids = model.generate(input['input_ids'], num_beams=4, max_length=100, early_stopping=True)
# print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=557941479.0, style=ProgressStyle(descri…

KeyboardInterrupt: ignored

In [None]:
from datasets import load_dataset

train_dataset = load_dataset("glucose", split="train")
eval_dataset = load_dataset("glucose", split="test")

In [7]:
#tokenize a concrete item in the dictionary
#tokenized_dataset = train_dataset.map(lambda batch: tokenizer(batch["story"], truncation=True, padding=True), batched=True)

In [8]:
import torch

adjusted_training_dataset = list()
adjusted_eval_dataset = list()

for item in train_dataset:
  for dim in range(1, 11):
    if item[f"{dim}_generalNL"] != 'escaped':
      '''
        we follow the template used by the Enc-Dec model in the GLUCOSE paper
        template: #d: S*[X] generated_causal_inference
          d: dim
          S: story
          X: selected_sentence
      '''
      adjusted_training_dataset.append(f'''#{dim}: {item["story"]}*[{item["selected_sentence"]}] {item[f"{dim}_generalNL"]}''')

for item in eval_dataset:
  for dim in range(1, 11):
    if item[f"{dim}_generalNL"] != 'escaped':
      '''
        we follow the template used by the Enc-Dec model in the GLUCOSE paper
        template: #d: S*[X] generated_causal_inference
          d: dim
          S: story
          X: selected_sentence
      '''
      adjusted_eval_dataset.append(f'''#{dim}: {item["story"]}*[{item["selected_sentence"]}] {item[f"{dim}_generalNL"]}''')

train_encodings = tokenizer(adjusted_training_dataset, padding=True, truncation=True)
eval_encodings = tokenizer(adjusted_eval_dataset, padding=True, truncation=True)

In [None]:
'''
  Define the customized dataset
'''
class AdjustedGLUCOSEDataset(torch.utils.data.Dataset):
  def __init__(self, encodings):
    self.encodings = encodings
  
  def __getitem__(self, idx):
    #get the idx-th value of each item: attention_mask, input_ids
    #self.encodings is a dict-like object
    return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

  def __len__(self):
        return len(self.encodings.input_ids)

train_data = AdjustedGLUCOSEDataset(train_encodings)
# train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=64, num_workers=2)

eval_data = AdjustedGLUCOSEDataset(eval_encodings)
# eval_loader = torch.utils.data.DataLoader(eval_data, shuffle=True, batch_size=64, num_workers=2)



In [None]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, batch_size=64)

In [1]:
def metrics(pred):
  '''
    A callback function calculating the perplexity of generated sentence (could be extended to incoporate more metrics)
  '''
  return {
      #perlexity is just exponential of cross entropy
      'perplexity': torch.exp(pred.loss)
  }


In [2]:
from transformers import Trainer, TrainingArguments

#finetune the model using Trainer
#no need for dataloader if using Trainer
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=1,              # total number of training epochs
    #per_device_train_batch_size=8,  # batch size per device during training
    #per_device_eval_batch_size=8,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=600,
    do_train = True,
    do_eval = True,
    data_collator = data_collator
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train_data,
    eval_dataset = eval_data,
    compute_metrics = metrics
)

trainer.train()
trainer.save_model('./results')

ModuleNotFoundError: ignored