Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine Tuning Example for summarization #1

Closed
FatemehMashhadi opened this issue Nov 29, 2021 · 5 comments
Closed

Fine Tuning Example for summarization #1

FatemehMashhadi opened this issue Nov 29, 2021 · 5 comments

Comments

@FatemehMashhadi
Copy link

Hi,
Thanks for publishing this model.

1-Is there an example of fine-tuning it for summarization?
2-How to fine-tune it on my dataset using Masked LM in pytorch?

Thanks

@sajjjadayobi
Copy link
Owner

Hi, I think you need to use an Encoder-Decoder model for summarization whereas this model is just an Encoder like BERT. One possible option is to join two instances of this model and train them end-to-end. for more information take a look at huggingface encoder-decoder.

What makes you want to train this with MLM? Is it because you want to increase accuracy on your downstream task through intermediate-finetuning?

@FatemehMashhadi
Copy link
Author

FatemehMashhadi commented Dec 1, 2021

1- Yes, it is so. I have finetuned models:
Bigbird2Bigbirdi
Bigbird2BERT
Bigbird2Roberta
on the PN-summary (Persian) dataset for summarization task, but none of them worked well. It seems that the model can not be trained at all. While I trained BERT2BERT model on the same dataset and I got good results.
I extracted 2,000 news texts with more than 1,500 tokens and trained the above three models, but the results were still not good at all.

I know of a solution to strengthen the pretrain Bigbird model that the training should continue for a few more epochs, but I did not find the example Pytorch code to continue training.

2-One of the state-of-the-art methods for summarizing long texts is the BigBirdPegasusForConditionalGeneration architecture.
Given that the models
alireza7 / PEGASUS-persian-base
And
alireza7 / PEGASUS-persian-base-PN-summary
There is in Persian, do you have any idea how to achieve the BigBirdPegasusForConditionalGeneration model?

3- Is there any script for convert Bigbird pytorch to tf?

Thank you for your guidance

@sajjjadayobi
Copy link
Owner

  1. I don't why the model doesn't work. I need to take a look at the code in order to give my opinion about the problem. Do you use Huggingface for training?

  2. I've heard about Persian PEGASUS but this model was trained using normal attention transformers which don't work for longer sequences. If you want to train a BigBirdPegasus you have to pretrained from scratch and as you know that is expensive. 🤕

    • If you would like to train such a model, I am available to collaborate
  3. Yes, this model can be easily trained with pure Pytorch (take a look at the fine-tuning example in the readme). but unfortunately 🤗 has not provided the TF version of Bigbird yet.

@FatemehMashhadi
Copy link
Author

FatemehMashhadi commented Dec 4, 2021

1- Yes, I use Huggingface.

import logging
from transformers import BigBirdModel, BigBirdForCausalLM, AutoTokenizer, EncoderDecoderModel, Trainer, TrainingArguments
import datasets

logging.basicConfig(level=logging.INFO)
model = EncoderDecoderModel.from_encoder_decoder_pretrained("./distil-bigbird-fa-zwnj", "./roberta-fa-zwnj-base")
tokenizer = AutoTokenizer.from_pretrained("./distil-bigbird-fa-zwnj")
train_dataset = datasets.load_dataset('csv', data_files='./train_1000.csv', split="train")
val_dataset = datasets.load_dataset('csv', data_files='./dev.csv', split="train")
rouge = nlp.load_metric("rouge", experiment_id=0)

model.encoder.config.gradient_checkpointing = True
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
modelogging.basicConfig(level=logging.INFO)
model = EncoderDecoderModel.from_encoder_decoder_pretrained("./distil-bigbird-fa-zwnj", "./roberta-fa-zwnj-base")
tokenizer = AutoTokenizer.from_pretrained("./distil-bigbird-fa-zwnj")
train_dataset = datasets.load_dataset('csv', data_files='./train_1000.csv', split="train")
val_dataset = datasets.load_dataset('csv', data_files='./dev.csv', split="train")
rouge = nlp.load_metric("rouge", experiment_id=0)

model.encoder.config.gradient_checkpointing = True
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4
encoder_length = 2048
decoder_length = 128
batch_size = 4

def map_to_encoder_decoder_inputs(batch):
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_length)
    outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_length)
    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["global_attention_mask"] = [[1 if i < 128 else 0 for i in range(sequence_length)] for sequence_length in len(inputs.input_ids) * [encoder_length]]
    batch["decoder_input_ids"] = outputs.input_ids
    batch["labels"] = outputs.input_ids.copy()
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
    ]
    batch["decoder_attention_mask"] = outputs.attention_mask
    assert all([len(x) == encoder_length for x in inputs.input_ids])
    assert all([len(x) == decoder_length for x in outputs.input_ids])

    return batch


def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.eos_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }


    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }


train_dataset = train_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "highlights"],
)
train_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "global_attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

val_dataset = val_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "highlights"],
)
val_dataset.set_format(
    type="torch", columns=["input_ids", "global_attention_mask", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

training_args = TrainingArguments(
    output_dir="./",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_strategy="epoch",
    #predict_from_generate=True,
    #evaluate_during_training=True,
    do_train=True,
    do_eval=True,
    logging_steps=1000,
    save_steps=2000,
    eval_steps=2000,
    overwrite_output_dir=True,
    warmup_steps=1000,
    save_total_limit=3,
    fp16=True,
    num_train_epochs=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()

This is my code for summarization that is not well trained and does not get good results.

2- Yes I am going to train this model, how can I get guidance from you? thanks.

@sajjjadayobi
Copy link
Owner

It would be better to discuss this somewhere else. this is my telegram ID

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants