### 0. Initial Setting

In [1]:
# %%capture
# !pip install datasets==1.0.2
# !pip install transformers==4.2.1

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

cache_dir = "/data4/yoomcache"
model_cache_dir = os.path.join(cache_dir, 'huggingface')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

import logging
logging.getLogger().setLevel(logging.CRITICAL)
logging.basicConfig(level=logging.INFO)


import torch
from datasets import load_dataset, load_metric, load_from_disk
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
from transformers import AutoConfig, EncoderDecoderConfig, EncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

import wandb
wandb.init(project="testing-roberta2gpt", entity="yoom618")

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myoom618[0m (use `wandb login --relogin` to force relogin)


### 1. Initialize Model

In [3]:
config_encoder = AutoConfig.from_pretrained("roberta-base", cache_dir=model_cache_dir)
config_decoder = AutoConfig.from_pretrained("gpt2", cache_dir=model_cache_dir)
config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder, cache_dir=model_cache_dir)
model = EncoderDecoderModel(config=config)
# model.save_pretrained("roberta2gpt", cache_dir=model_cache_dir)
# model = EncoderDecoderModel.from_pretrained("roberta2gpt", cache_dir=model_cache_dir)

model.encoder.encoder.layer = model.encoder.encoder.layer[:6]
model.decoder.transformer.h = model.decoder.transformer.h[-6:]

In [4]:
encoder_tokenizer = RobertaTokenizer.from_pretrained("roberta-base", cache_dir=model_cache_dir)
encoder_tokenizer.bos_token = encoder_tokenizer.cls_token  # CLS token will work as BOS token
encoder_tokenizer.eos_token = encoder_tokenizer.sep_token  # SEP token will work as EOS token

# make sure GPT2 appends EOS in begin and end
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs

GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
decoder_tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir=model_cache_dir)
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
decoder_tokenizer.pad_token = decoder_tokenizer.unk_token


model.config.decoder_start_token_id = encoder_tokenizer.cls_token_id
model.config.eos_token_id = encoder_tokenizer.sep_token_id
model.config.pad_token_id = encoder_tokenizer.pad_token_id
model.config.vocab_size = model.config.encoder.vocab_size


# set decoding params
model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

In [5]:
# Freeze decoder parameters
for param in model.decoder.parameters():
    param.requires_grad = False


### 2. Preparing Dataset

In [6]:
# map data correctly
def map_to_encoder_decoder_inputs(batch):    # Tokenizer will automatically set [BOS] <text> [EOS] 
    encoder_length, decoder_length = 512, 128
    inputs = encoder_tokenizer(batch["article"], 
                               padding="max_length", 
                               truncation=True, 
                               max_length=encoder_length)
    outputs = decoder_tokenizer(batch["highlights"], 
                                padding="max_length", 
                                truncation=True, 
                                max_length=decoder_length)
    
    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["labels"] = outputs.input_ids.copy()
    batch["decoder_attention_mask"] = outputs.attention_mask

    # complicated list comprehension here because pad_token_id alone is not good enough to know whether label should be excluded or not
    batch["labels"] = -100 if batch["decoder_attention_mask"] == 0 else batch["labels"]

    assert len(inputs.input_ids) == encoder_length
    assert len(outputs.input_ids) == decoder_length

    return batch

In [7]:
if os.path.exists(os.path.join(cache_dir, 'preprocessed/train')):
    train_dataset = load_from_disk(os.path.join(cache_dir, 'preprocessed/train'))
else:
    train_dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0", split="train", cache_dir=data_cache_dir)
    train_dataset = train_dataset.map(
        map_to_encoder_decoder_inputs, 
        # batched=True, 
        # batch_size=batch_size, 
        remove_columns=['id', 'article', 'highlights'],
    )
    train_dataset.set_format(
        type="torch", 
        columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
    )
    
    train_dataset.save_to_disk(os.path.join(cache_dir, 'preprocessed/train'))


if os.path.exists(os.path.join(cache_dir, 'preprocessed/val')):
    val_dataset = load_from_disk(os.path.join(cache_dir, 'preprocessed/val'))
else:
    val_dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0", split="validation", cache_dir=data_cache_dir)
    val_dataset = val_dataset.map(
        map_to_encoder_decoder_inputs, 
        # batched=True, 
        # batch_size=batch_size, 
        remove_columns=['id', 'article', 'highlights'],
    )
    val_dataset.set_format(
        type="torch", 
        columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
    )
    val_dataset.save_to_disk(os.path.join(cache_dir, 'preprocessed/val'))

### Training Model

In [8]:
# load rouge for validation
rouge = load_metric("rouge")
# rouge = load_metric("rouge", experiment_id=1)

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = decoder_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
    label_str = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [9]:
batch_size = 16

# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    output_dir=os.path.join(checkpoint_dir, "roberta2gpt"),
    # do_train=True,
    # do_eval=True,
    # do_predict=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=4,
#     learning_rate=1e-4, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0,
    num_train_epochs=100,
    max_steps=-1,
    # lr_scheduler_type='linear', warmup_ratio=0.0, 
    
    logging_strategy='steps',
    save_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=1000,
    save_steps=2000,
    eval_steps=1000,
    warmup_steps=10000,
    save_total_limit=3,
    overwrite_output_dir=True,
)

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# start training
trainer.train()

***** Running training *****
  Num examples = 28711
  Num Epochs = 100
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 179500
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Rouge2 Precision,Rouge2 Recall,Rouge2 Fmeasure
1000,5.5206,3.66548,0.0,0.0,0.0
2000,4.7231,3.675316,0.0,0.0,0.0
3000,4.7173,3.67219,0.0,0.0,0.0
4000,4.7192,3.659882,0.0,0.0,0.0
5000,4.713,3.661578,0.0,0.0,0.0
6000,4.7066,3.668557,0.0,0.0,0.0
7000,4.5457,3.296617,0.0,0.0,0.0
8000,4.3451,3.272192,0.0,0.0,0.0
9000,4.2349,3.183884,0.0,0.0,0.0
10000,4.138,3.083094,0.0,0.0,0.0


***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-2000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-2000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-2000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-20000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-4000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-4000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-4000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2g

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-18000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-18000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-18000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-12000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-20000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-20000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-20000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/rob

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-34000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-34000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-34000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-28000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-36000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-36000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-36000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/rob

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-50000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-50000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-50000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-44000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-52000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-52000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-52000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/rob

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-66000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-66000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-66000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-60000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-68000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-68000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-68000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/rob

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-82000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-82000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-82000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-76000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-84000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-84000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-84000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/rob

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-98000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-98000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-98000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-92000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-100000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-100000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-100000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-114000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-114000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-114000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-108000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-116000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-116000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-116000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpo

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-128000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-128000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-128000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-122000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-130000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-130000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-130000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpo

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-144000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-144000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-144000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-138000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-146000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-146000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-146000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpo

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-158000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-158000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-158000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-152000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-160000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-160000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-160000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpo

***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-174000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-174000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-174000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/roberta2gpt/checkpoint-168000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
***** Running Evaluation *****
  Num examples = 1337
  Batch size = 4
Saving model checkpoint to /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-176000
Configuration saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-176000/config.json
Model weights saved in /data4/yoomcache/checkpoint/roberta2gpt/checkpoint-176000/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpo

TrainOutput(global_step=179500, training_loss=2.7927890720208044, metrics={'train_runtime': 210481.0711, 'train_samples_per_second': 13.641, 'train_steps_per_second': 0.853, 'total_flos': 8.805162054057984e+17, 'train_loss': 2.7927890720208044, 'epoch': 100.0})

https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16

In [10]:
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")

bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

# CLS token will work as BOS token
bert_tokenizer.bos_token = bert_tokenizer.cls_token

# SEP token will work as EOS token
bert_tokenizer.eos_token = bert_tokenizer.sep_token


# make sure GPT2 appends EOS in begin and end
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs


GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token


# set decoding params
model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id
model.config.eos_token_id = gpt2_tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

test_dataset = load_dataset("cnn_dailymail", "3.0.0", split="test")
batch_size = 4


# map data correctly
def generate_summary(batch):
    # Tokenizer will automatically set [BOS] <text> [EOS]
    # cut off at BERT max length 512
    inputs = bert_tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    outputs = model.generate(input_ids, attention_mask=attention_mask)

    # all special tokens including will be removed
    output_str = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred"] = output_str

    return batch


results = test_dataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])

loading configuration file https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16/resolve/main/config.json from cache at /home/yoomin/.cache/huggingface/transformers/7944e7bc294d39091dcbf5fc5c6fc0d3d32cda1fc3e6208912c82482c489a888.2f0b414aab7a259e6c2cb673dc14da52f6b46ac536982d5c10f29c8fec3fadef
You are using a model of type encoder_decoder to instantiate a model of type encoder-decoder. This is not supported for all configurations of models and can yield errors.
Model config EncoderDecoderConfig {
  "architectures": [
    "EncoderDecoderModel"
  ],
  "decoder": {
    "_name_or_path": "",
    "activation_function": "gelu_new",
    "add_cross_attention": true,
    "architectures": [
      "GPT2LMHeadModel"
    ],
    "attn_pdrop": 0.1,
    "bad_words_ids": null,
    "bos_token_id": 50256,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": f

NameError: name 'device' is not defined

In [None]:
# load rouge for validation
rouge = load_metric("rouge")

pred_str = results["pred"]
label_str = results["highlights"]

rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

print(rouge_output)