In [30]:
# see https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb

In [25]:
import datasets
import transformers
import torch

In [2]:
from transformers import BertTokenizerFast

In [26]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"
elif torch.backends.mps.is_available():
	device = torch.device("mps")

In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

In [13]:
# %%capture
arxiv_summarization_train = datasets.load_dataset("ccdv/arxiv-summarization",split="train")

In [None]:
# %%capture
arxiv_summarization_validation = datasets.load_dataset("ccdv/arxiv-summarization",split="validation[:10%]")

In [8]:
# arxiv_summarization = arxiv_summarization.train_test_split(test_size=0.2)

In [15]:
batch_size=16
encoder_max_length=512
decoder_max_length=128

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["abstract"], padding="max_length", truncation=True, max_length=decoder_max_length)

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask
  batch["decoder_input_ids"] = outputs.input_ids
  batch["decoder_attention_mask"] = outputs.attention_mask
  batch["labels"] = outputs.input_ids.copy()

  # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
  # We have to make sure that the PAD token is ignored
  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch


In [17]:
arxiv_summarization_train = arxiv_summarization_train.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "abstract"]
)

Map:   0%|          | 0/203037 [00:00<?, ? examples/s]

In [18]:
arxiv_summarization_train.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

In [20]:
arxiv_summarization_validation = arxiv_summarization_validation.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "abstract"]
)

Map:   0%|          | 0/644 [00:00<?, ? examples/s]

In [21]:
arxiv_summarization_validation.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

In [22]:
from transformers import EncoderDecoderModel

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [27]:
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased").to(device=device)

  with safe_open(checkpoint_file, framework="pt") as f:
  return self.fget.__get__(instance, owner)()
  storage = cls(wrap_storage=untyped_storage)
  with safe_open(filename, framework="pt", device=device) as f:
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequence

In [28]:
# set special tokens
bert2bert.config.decoder_start_token_id = tokenizer.bos_token_id
bert2bert.config.eos_token_id = tokenizer.eos_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id

In [29]:
# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

In [32]:
from transformers import Seq2SeqTrainer
from transformers import TrainingArguments
from dataclasses import dataclass, field
from typing import Optional

In [33]:
@dataclass
class Seq2SeqTrainingArguments(TrainingArguments):
    label_smoothing: Optional[float] = field(
        default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."}
    )
    sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."})
    predict_with_generate: bool = field(
        default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
    )
    adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"})
    encoder_layerdrop: Optional[float] = field(
        default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."}
    )
    decoder_layerdrop: Optional[float] = field(
        default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."}
    )
    dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."})
    attention_dropout: Optional[float] = field(
        default=None, metadata={"help": "Attention dropout probability. Goes into model.config."}
    )
    lr_scheduler: Optional[str] = field(
        default="linear", metadata={"help": f"Which lr scheduler to use."}
    )

In [34]:
# load rouge for validation
rouge = datasets.load_metric("rouge")
# rouge = transformers.evaluate.load("rouge")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

  rouge = datasets.load_metric("rouge")


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

In [49]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    evaluation_strategy='epoch',
    do_train=True,
    do_eval=True,
    logging_steps=2,  # set to 1000 for full training
    save_steps=16,  # set to 500 for full training
    eval_steps=4,  # set to 8000 for full training
    warmup_steps=1,  # set to 2000 for full training
    # max_steps=16, # delete for full training
    max_steps=32,
    overwrite_output_dir=True,
    save_total_limit=3
    # fp16=True, 
)


PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [50]:
trainer = Seq2SeqTrainer(
    model=bert2bert,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=arxiv_summarization_train,
    eval_dataset=arxiv_summarization_validation,
)
trainer.train()

max_steps is given, it will override any value given in num_train_epochs
***** Running training *****
  Num examples = 203037
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 32
  Number of trainable parameters = 247363386


  0%|          | 0/32 [00:00<?, ?it/s]



{'loss': 0.2851, 'learning_rate': 4.8387096774193554e-05, 'epoch': 0.0}
{'loss': 0.4964, 'learning_rate': 4.516129032258064e-05, 'epoch': 0.0}
{'loss': 0.4215, 'learning_rate': 4.1935483870967746e-05, 'epoch': 0.0}
{'loss': 0.1672, 'learning_rate': 3.870967741935484e-05, 'epoch': 0.0}
{'loss': 0.1361, 'learning_rate': 3.548387096774194e-05, 'epoch': 0.0}
{'loss': 0.166, 'learning_rate': 3.2258064516129034e-05, 'epoch': 0.0}
{'loss': 0.0531, 'learning_rate': 2.9032258064516133e-05, 'epoch': 0.0}


Saving model checkpoint to ./results/checkpoint-16
Configuration saved in ./results/checkpoint-16/config.json


{'loss': 0.0563, 'learning_rate': 2.5806451612903226e-05, 'epoch': 0.0}


Model weights saved in ./results/checkpoint-16/pytorch_model.bin


{'loss': 0.1168, 'learning_rate': 2.258064516129032e-05, 'epoch': 0.0}
{'loss': 0.0435, 'learning_rate': 1.935483870967742e-05, 'epoch': 0.0}
{'loss': 0.0838, 'learning_rate': 1.6129032258064517e-05, 'epoch': 0.0}
{'loss': 0.0175, 'learning_rate': 1.2903225806451613e-05, 'epoch': 0.0}
{'loss': 0.0275, 'learning_rate': 9.67741935483871e-06, 'epoch': 0.0}
{'loss': 0.0329, 'learning_rate': 6.451612903225806e-06, 'epoch': 0.0}
{'loss': 0.0181, 'learning_rate': 3.225806451612903e-06, 'epoch': 0.0}


Saving model checkpoint to ./results/checkpoint-32
Configuration saved in ./results/checkpoint-32/config.json


{'loss': 0.0389, 'learning_rate': 0.0, 'epoch': 0.0}


Model weights saved in ./results/checkpoint-32/pytorch_model.bin


In [87]:
device = "cpu"

In [1]:
# %%capture
import datasets
from transformers import BertTokenizer, EncoderDecoderModel

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = EncoderDecoderModel.from_pretrained("./results/checkpoint-32").to(device)

In [None]:
arxiv_summarization_test = datasets.load_dataset("ccdv/arxiv-summarization",split="test")

In [90]:
arxiv_summarization_test[0]

{'article': 'for about 20 years the problem of properties of short - term changes of solar activity has been considered extensively . \n many investigators studied the short - term periodicities of the various indices of solar activity . \n several periodicities were detected , but the periodicities about 155 days and from the interval of @xmath3 $ ] days ( @xmath4 $ ] years ) are mentioned most often . \n first of them was discovered by @xcite in the occurence rate of gamma - ray flares detected by the gamma - ray spectrometer aboard the _ solar maximum mission ( smm ) . \n this periodicity was confirmed for other solar flares data and for the same time period @xcite . \n it was also found in proton flares during solar cycles 19 and 20 @xcite , but it was not found in the solar flares data during solar cycles 22 @xcite . \n _    several autors confirmed above results for the daily sunspot area data . @xcite studied the sunspot data from 18741984 . \n she found the 155-day periodicity 

In [91]:
batch_size = 16

In [92]:
miniminitestdataset = arxiv_summarization_test.select(range(16))

In [93]:
def generate_summary(batch):
    # Tokenizer will automatically set [BOS] <text> [EOS]
    # cut off at BERT max length 512
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    outputs = model.generate(input_ids, attention_mask=attention_mask)

    # all special tokens including will be removed
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred"] = output_str

    return batch

In [94]:
# results = arxiv_summarization_test.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])
results = miniminitestdataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])

Map:   0%|          | 0/16 [00:00<?, ? examples/s]



In [96]:
pred_str = results["pred"]


In [98]:
label_str = results["abstract"]

In [99]:
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid


In [100]:
print(rouge_output)

Score(precision=0.0, recall=0.0, fmeasure=0.0)


In [102]:
label_str[0]

'the short - term periodicities of the daily sunspot area fluctuations from august 1923 to october 1933 are discussed . for these data \n the correlative analysis indicates negative correlation for the periodicity of about @xmath0 days , but the power spectrum analysis indicates a statistically significant peak in this time interval . \n a new method of the diagnosis of an echo - effect in spectrum is proposed and it is stated that the 155-day periodicity is a harmonic of the periodicities from the interval of @xmath1 $ ] days .    the autocorrelation functions for the daily sunspot area fluctuations and for the fluctuations of the one rotation time interval in the northern hemisphere , separately for the whole solar cycle 16 and for the maximum activity period of this cycle do not show differences , especially in the interval of @xmath2 $ ] days . \n it proves against the thesis of the existence of strong positive fluctuations of the about @xmath0-day interval in the maximum activity 

In [103]:
pred_str[0]

'we we we our our our their their their our their your your your their their your their your our their own own own yours yours yours theirs theirs theirs yours yours ours ours ours theirs theirs ours ours yours yours mine mine mine yours yours your your our our your your my my my our our his your your yer yer yer your your his his your my your your yo yo yo your your ya ya ya your your her her her his his his our our ours ours mine mine hers hers hers yours yours hers hers ours ours hers hers theirs yours mine yours mine ours ours our our my my his his ps ps ps ins ins insinsinsinsininininsins ins ins outs outs outs ins instenstenstens'