# Import

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


from tqdm import tqdm

In [None]:
import pandas as pd

# Setting Model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("/kaggle/input/non-model/best_1_model")
model.to(device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-3")

In [None]:
new_token = ['<extra>','</extra>']
tokenizer.additional_special_tokens= new_token
tokenizer.add_tokens(new_token)
model.resize_token_embeddings(len(tokenizer))

In [None]:
# string = '<extra> hello </extra>'

In [None]:
# tokenizer(string)

In [None]:
encoder_max_length = 512
decoder_max_length = 64

In [None]:
class TopicDataset(Dataset):
    def __init__(self, data):
        
        self.input = data 
        self.input_ids = []
        self.attention_mask = []
        for i in range(len(self.input ['input_ids'])):
            self.input_ids.append(self.input ['input_ids'][i])
            self.attention_mask.append(self.input['attention_mask'][i])
    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self, idx):
        #return torch tensor
        return self.input_ids[idx], self.attention_mask[idx]
    
        
    

# Prepare data

In [None]:
df_train = pd.read_csv('/kaggle/input/pol-dataset/preprocessed_train.csv',index_col=0)
df_test = pd.read_csv('/kaggle/input/pol-dataset/preprocessed_test (1).csv',index_col=0)
df_valid = pd.read_csv('/kaggle/input/pol-dataset/preprocessed_validation.csv',index_col=0)

data_train = tokenizer(df_train['document'].tolist(), padding='max_length', max_length = encoder_max_length, truncation=True,return_tensors="pt")
data_test = tokenizer(df_test['document'].tolist(), padding='max_length', max_length = encoder_max_length, truncation=True,return_tensors="pt") 
data_val = tokenizer(df_valid['document'].tolist(), padding='max_length', max_length = encoder_max_length, truncation=True,return_tensors="pt") 
dataset_train = TopicDataset(data_train)
dataset_test = TopicDataset(data_test)
dataset_val = TopicDataset(data_val)
train_dataloader = DataLoader(dataset_train, batch_size=4, shuffle=False)
test_dataloader = DataLoader(dataset_test, batch_size=4, shuffle=False)
val_dataloader = DataLoader(dataset_val, batch_size=4, shuffle=False)

In [None]:
df_train

# Test model

In [None]:
x = next(iter(train_dataloader))

In [None]:
inputs = x[0].to(device)

In [None]:
model.generate(inputs)

# Predict

In [None]:
def pred(dataloader):
    model.eval()
    results = []
    with torch.no_grad():
        for batch in tqdm(dataloader):

            input_ids = batch[0].to(device)
            attention_mask = batch[1].to(device)


            outputs = model.generate(input_ids)
            for output in outputs:
                results.append(tokenizer.decode(output, skip_special_tokens=True))
    return results
        

In [None]:
results_train = pred(train_dataloader)
results_test = pred(test_dataloader)
results_val = pred(val_dataloader)


In [None]:
df_train['text'] = results_train
df_test['text'] = results_test
df_valid['text'] = results_val
df_train.to_csv('after_bart_train.csv')
df_test.to_csv('after_bart_test.csv')
df_valid.to_csv('after_bart_val.csv')


In [None]:
df_train

# Train BART2

## Import

In [None]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

from tabulate import tabulate
import nltk
from datetime import datetime

In [None]:
! pip install transformers
! pip install datasets
! pip install sentencepiece
! pip install rouge_score

## Prepare data

In [None]:
def format_data(df_input):
  df_input = df_input[['text', 'summary']]
  df_input.columns = ["document", "summary"]
  df_input['document'] = df_input['document'].str.lower()
  df_input['summary'] = df_input['summary'].str.lower()
  return df_input

In [None]:
train = format_data(df_train)
validation = format_data(df_valid)
test = format_data(df_test)

## Model and tokenizer

In [None]:
model_name = "sshleifer/distilbart-xsum-12-3"


model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set model parameters or use the default
# print(model.config)

# tokenization
encoder_max_length = 512  # demo
decoder_max_length = 64

# Add extra token

In [None]:
new_token = ['<extra>','</extra>']
tokenizer.additional_special_tokens= new_token
tokenizer.add_tokens(new_token)
model.resize_token_embeddings(len(tokenizer))


In [None]:
# prtiger
df_train = pd.read_csv('/kaggle/input/prtiger/train.csv',index_col=0)
df_test = pd.read_csv('/kaggle/input/prtiger/test.csv',index_col=0)
df_val = pd.read_csv('/kaggle/input/prtiger/valid.csv',index_col=0)

In [None]:
train['document']='<extra> '+train['document']+' </extra> '+df_train['text']
test['document']='<extra> '+test['document']+' </extra> '+df_test['text']
validation['document']='<extra> '+validation['document']+' </extra> '+df_val['text']

In [None]:
import datasets
from datasets import Dataset, DatasetDict

In [None]:
train_data_custom = Dataset.from_pandas(train)
validation_data_custom = Dataset.from_pandas(validation)
test_data_custom = Dataset.from_pandas(test)

In [None]:
def flatten(example):
    return {
        "document": example["document"],
        "summary": example["summary"],
    }

In [None]:
train_data_txt = train_data_custom.map(flatten)
validation_data_txt = validation_data_custom.map(flatten)
test_data_txt = test_data_custom.map(flatten)

# Tokenize

In [None]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source, target = batch["document"], batch["summary"]
    source_tokenized = tokenizer(
        source, padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        target, padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch


train_data = train_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=train_data_txt.column_names,
)

validation_data = validation_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=validation_data_txt.column_names,
)


test_data = test_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=validation_data_txt.column_names,
)

# Train Setting

## Metrics

In [None]:
# Borrowed from https://github.com/huggingface/transformers/blob/master/examples/seq2seq/run_summarization.py

nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

## Training arguments

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="/bart_2_output",
    seed = 42,
    data_seed = 42,
    num_train_epochs=1,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,  # demo
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=6000,
    evaluation_strategy="steps",
    save_strategy = "steps",
    eval_steps = 6000,
    save_steps = 6000,
    save_total_limit = 5,
    load_best_model_at_end=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=validation_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
!wandb login wandbid

# Train

In [None]:
trainer.train()

# Evaluate Valid

In [None]:
trainer.evaluate()

# Save Model

In [None]:
trainer.save_model("/kaggle/working/bart_2_4v5")

# Evaluate Test

In [None]:
tester = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
tester.evaluate()

# Evaluation

In [None]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["document"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


# model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)
validation_samples = validation_data_txt.select(range(16))

# summaries_before_tuning = generate_summary(validation_samples, model_before_tuning)[1]
summaries_after_tuning = generate_summary(validation_samples, model)[1]

In [None]:
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            validation_samples["summary"],
        ),
        headers=["Id", "Summary after", "Summary before"],
    )
)
print("\nTarget summaries:\n")
print(
    tabulate(list(enumerate(validation_samples["summary"])), headers=["Id", "Target summary"])
)
print("\nSource documents:\n")
print(tabulate(list(enumerate(validation_samples["document"])), headers=["Id", "Document"]))

In [None]:
test_samples = test_data_txt.select(range(16))
test_summaries_after_tuning = generate_summary(test_samples, model)[1]

In [None]:
print(
    tabulate(
        zip(
            range(len(test_summaries_after_tuning)),
            test_summaries_after_tuning,
            test_samples["summary"],
        ),
        headers=["Id", "Summary predict", "Summary target"],
    )
)
# print("\nTarget summaries:\n")
# print(
#     tabulate(list(enumerate(test_samples["summary"])), headers=["Id", "Target summary"])
# )
print("\nSource documents:\n")
print(tabulate(list(enumerate(test_samples["document"])), headers=["Id", "Document"]))