# BART Fine-tuning for Reddit Relation Triples

### Parse the Reddit relation txts

`reddit_train_article_relation.txt`, `reddit_validation_article_relation.txt`, `reddit_test_article_relation.txt`

In [None]:
# Define the training data
def parse_relations(rel_pth: str):
    rel = []
    stn = []
    idx_submission = -1
    idx_sentence = -1
    with open(rel_pth, 'r') as f:
        while True:
            line = f.readline()
            if not line: break
            if line[0] == 'S':
                if int(line.split()[1]) > idx_submission:
                    idx_submission += 1
                    idx_sentence = -1
                    stn.append([])
                    rel.append([])
                if int(line.split()[2]) > idx_sentence:
                    idx_sentence += 1
                    rel[-1].append([])
                stn[-1].append(line.split('\t')[3][:-1])
                continue
            rel[-1][-1].append(line.strip().lstrip("R\t").replace('\t', ' ')+'.')
    rel_out, stn_out = [], []
    for i in range(len(rel)):
        if len(rel[i]) == 0: continue
        for j in rel[i]:
            rel_out.append(' '.join(j))
        for j in stn[i]:
            stn_out.append(j)
    return rel_out, stn_out

In [None]:
import pandas as pd
from datasets import Dataset, DatasetDict

train_rel, train_stn = parse_relations("/content/reddit_train_article_relation.txt")
train_df = pd.DataFrame({'relations': train_rel, 'sentence': train_stn})
train_dataset = Dataset.from_pandas(train_df)

validation_rel, validation_stn = parse_relations("/content/reddit_validation_article_relation.txt")
validation_df = pd.DataFrame({'relations': validation_rel, 'sentence': validation_stn})
validation_dataset = Dataset.from_pandas(validation_df)

test_rel, test_stn = parse_relations("/content/reddit_test_article_relation.txt")
test_df = pd.DataFrame({'relations': test_rel, 'sentence': test_stn})
test_dataset = Dataset.from_pandas(test_df)

reddit_dataset = DatasetDict({'train': train_dataset, 'validation': validation_dataset, 'test': test_dataset})
reddit_dataset

### Load pretrained models

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_checkpoint = 'facebook/bart-base'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to('cuda')

In [None]:
max_input_length = 256
max_target_length = 32


def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["relations"],
        max_length=max_input_length,
        truncation=True,
    )
    labels = tokenizer(
        examples["sentence"], max_length=max_target_length, truncation=True
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
reddit_tokenized = reddit_dataset.map(preprocess_function, batched=True)
reddit_tokenized = reddit_tokenized.remove_columns(
    reddit_dataset["train"].column_names
)
reddit_tokenized

In [None]:
from transformers import Seq2SeqTrainingArguments

batch_size = 64
num_train_epochs = 10
# Show the training loss with every epoch
logging_steps = len(reddit_tokenized["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]

args = Seq2SeqTrainingArguments(
    output_dir=f"{model_name}-finetuned-reddit",
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
)

In [None]:
import numpy as np
from datasets import load_metric

rouge_score = load_metric("rouge")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    # Compute ROUGE scores
    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=reddit_tokenized["train"],
    eval_dataset=reddit_tokenized["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
import time
trainer.save_model(f"bart-base-finetuned-reddit-{time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime())}")

### Inference

In [None]:
from transformers import pipeline

summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device='cuda:0')

In [None]:
def print_summary(dataset, idx, summarizer):
    print(f"\n>>> {idx}")
    relations = dataset["test"][idx]["relations"]
    sentence = dataset["test"][idx]["sentence"]
    result = summarizer(relations)[0]["summary_text"]
    print(f"\n>>> Relations: {relations}")
    print(f"\n>>> Sentence: {sentence}")
    print(f"\n>>> Result: {result}")

In [None]:
for i in range(20): print_summary(reddit_dataset, i, summarizer)