# Using Flan T5 to determine how accurate different parts of articles can predict keywords

In [1]:
import sys
!{sys.executable} -m pip install torch torchvision torchaudio datasets scikit-learn transformers rapidfuzz --quiet
import json
import os
import torch
import re
from rapidfuzz import fuzz
from datasets import Dataset
from sklearn.metrics import precision_score, recall_score, f1_score
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    TrainingArguments,
    Trainer,
    T5Tokenizer,
    T5ForConditionalGeneration,
    EarlyStoppingCallback,
)

distutils: /home/rpuranda/.local/lib/python3.9/site-packages
sysconfig: /home/rpuranda/.local/lib64/python3.9/site-packages[0m
user = True
home = None
root = None
prefix = None[0m


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_data(file_paths):
    data = []
    for file_path in file_paths:
        with open(file_path, "r") as f:
            for line in f:
                try:
                    entry = json.loads(line)
                    data.append(entry)
                except json.JSONDecodeError as e:
                    print(f"Skipping bad line in {file_path}: {e}")
    return data

In [3]:
def prepare_dataset(data, input_field):
    return [
        {
            "input": f"Extract keywords: {item[input_field]}",
            "target": ", ".join(item["keywords"])
        }
        for item in data
        if input_field in item and "keywords" in item and isinstance(item["keywords"], list)
    ]

In [4]:
def tokenize_data(dataset, tokenizer, max_input_length=512, max_target_length=32):
    inputs = [item["input"] for item in dataset]
    targets = [item["target"] for item in dataset]

    model_inputs = tokenizer(inputs, max_length=max_input_length, padding="max_length", truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, padding="max_length", truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return Dataset.from_dict(model_inputs)

In [6]:
def preprocess_function(examples, tokenizer, source_col, target_col, max_input_length=512, max_target_length=128):
    model_inputs = tokenizer(
        examples[source_col],
        max_length=max_input_length,
        padding="max_length",
        truncation=True,
    )

    labels = tokenizer(
        examples[target_col],
        max_length=max_target_length,
        padding="max_length",
        truncation=True,
    )["input_ids"]

    model_inputs["labels"] = labels
    return model_inputs

def train_model(model_name, train_data, val_data, output_dir):
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)

    # Convert raw pandas dataframes/lists to Hugging Face datasets
    train_dataset = Dataset.from_list(train_data)
    val_dataset = Dataset.from_list(val_data)

    # Replace with the actual column names of your dataset
    source_col = "input"   # e.g., "text" or "source"
    target_col = "target"  # e.g., "summary" or "label"

    tokenized_train = train_dataset.map(lambda x: preprocess_function(x, tokenizer, source_col, target_col), batched=True)
    tokenized_val = val_dataset.map(lambda x: preprocess_function(x, tokenizer, source_col, target_col), batched=True)

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="steps",
        eval_steps=500,
        logging_steps=500,
        save_steps=500,
        save_total_limit=1,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=3,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )

    trainer.train()
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    return model, tokenizer

In [7]:
def get_file_paths(prefix, start, end):
    return [f"{prefix}{i:04d}" for i in range(start, end + 1)]

In [8]:
def clean_keyword(kw):
    return re.sub(r'\W+', '', kw.lower())

In [48]:
def is_fuzzy_match(pred, true_keywords):
    for true_kw in true_keywords:
        if pred in true_kw or true_kw in pred:
            return True
    return False

In [49]:
def compute_metrics(preds, refs):
    precision_list = []
    recall_list = []
    f1_list = []
    average_precision = 0

    for pred, ref in zip(preds, refs):
        pred = set(clean_keyword(k) for k in pred)
        ref = set(clean_keyword(k) for k in ref)
        true_positives = 0
        precisions = []
        for i, pred_kw in enumerate(pred):
            if is_fuzzy_match(pred_kw, ref):
                true_positives += 1
                precision_at_i = true_positives / (i + 1)
                precisions.append(precision_at_i)

        if not precisions:
            continue
        average_precision += sum(precisions) / len(ref) if ref else 0
        precision = true_positives / len(pred) if pred else 0
        recall = true_positives / len(ref) if ref else 0
        f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0

        print("precision: ", precision, " recall: ", recall, " f1: ", f1)

        precision_list.append(precision)
        recall_list.append(recall)
        f1_list.append(f1)

    return {
        "precision": sum(precision_list) / len(precision_list),
        "recall": sum(recall_list) / len(recall_list),
        "f1": sum(f1_list) / len(f1_list),
        "map": average_precision / len(precision_list)
    }

In [12]:
def evaluate_order_agnostic(model, tokenizer, dataset):
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    predictions = []
    references = []

    for sample in dataset:
        inputs = tokenizer(sample["input"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        input_ids = inputs["input_ids"].to(device) 

        with torch.no_grad():
            outputs = model.generate(input_ids=input_ids, max_length=64)

        pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
        ref = sample["target"]

        pred = set(pred.split(", "))
        ref = set(ref.split(", "))
        
        predictions.append(pred)
        references.append(ref)

        pred = set(clean_keyword(k) for k in pred)
        ref = set(clean_keyword(k) for k in ref)
        true_positives = len(pred & ref)
        precision = true_positives / len(pred) if pred else 0
        recall = true_positives / len(ref) if ref else 0
        f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0

        print("precision: ", precision, " recall: ", recall, " f1: ", f1)
        # print("Predictions: ", set(pred.split(", ")))
        # print("References: ", set(ref.split(", ")))

    return predictions, references

In [13]:
model_name = "google/flan-t5-small"
train_files = get_file_paths("./data/training-data-chunk-", 0, 7)
test_files = get_file_paths("./data/training-data-chunk-", 8, 9)

print("Loading data...")
train_data = load_data(train_files)
test_data = load_data(test_files)

Loading data...


In [14]:
print("Preparing abstract...")
train_abstract = prepare_dataset(train_data, "abstract_content")
print(f"Training with {len(train_abstract)} abstract samples")
test_abstract = prepare_dataset(test_data, "abstract_content")
print(f"Testing with {len(test_abstract)} abstract samples")

Preparing abstract...
Training with 8000 abstract samples
Testing with 2000 abstract samples


In [15]:
print("Preparing body...")
train_body = prepare_dataset(train_data, "content")
print(f"Training with {len(train_body)} body samples")
test_body = prepare_dataset(test_data, "content")
print(f"Testing with {len(test_body)} body samples")

Preparing body...
Training with 8000 body samples
Testing with 2000 body samples


In [16]:
print("Training abstract model...")
abstract_model, abstract_tokenizer = train_model(model_name, train_abstract, test_abstract, "flan_t5_abstract")

Training abstract model...


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Map: 100%|██████████| 8000/8000 [00:10<00:00, 741.95 examples/s]
Map: 100%|██████████| 2000/2000 [00:02<00:00, 741.39 examples/s]
  trainer = Trainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
500,3.7452,0.882611
1000,0.8832,0.81525
1500,0.8425,0.806781


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


In [17]:
print("Training body model...")
body_model, body_tokenizer = train_model(model_name, train_body, test_body, "flan_t5_body")

Training body model...


Map: 100%|██████████| 8000/8000 [05:34<00:00, 23.93 examples/s]
Map: 100%|██████████| 2000/2000 [01:23<00:00, 23.81 examples/s]
  trainer = Trainer(


Step,Training Loss,Validation Loss
500,3.8403,0.91899
1000,0.9244,0.856146
1500,0.886,0.848029


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


In [23]:
print("Evaluating abstract-only model...")
abstract_predictions, abstract_references = evaluate_order_agnostic(abstract_model, abstract_tokenizer, test_abstract)

Evaluating abstract-only model...
precision:  0.0  recall:  0.0  f1:  0
precision:  0.3333333333333333  recall:  0.0625  f1:  0.10526315789473684
precision:  0.25  recall:  0.2  f1:  0.22222222222222224
precision:  0.0  recall:  0.0  f1:  0
precision:  0.25  recall:  0.25  f1:  0.25
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.5  recall:  0.25  f1:  0.3333333333333333
precision:  0.0  recall:  0.0  f1:  0
precision:  0.16666666666666666  recall:  0.06666666666666667  f1:  0.09523809523809522
precision:  0.2  recall:  0.2  f1:  0.20000000000000004
precision:  0.0  recall:  0.0  f1:  0
precision:  0.5  recall:  0.3333333333333333  f1:  0.4
precision:  0.0  recall:  0.0  f1:  0


KeyboardInterrupt: 

In [50]:
abstract_metrics = compute_metrics(abstract_predictions, abstract_references)
print(abstract_metrics)

precision:  0.2  recall:  1.0  f1:  0.33333333333333337
precision:  0.6666666666666666  recall:  0.125  f1:  0.21052631578947367
precision:  0.5  recall:  0.4  f1:  0.4444444444444445
precision:  1.0  recall:  1.0  f1:  1.0
precision:  0.25  recall:  0.25  f1:  0.25
precision:  1.0  recall:  0.25  f1:  0.4
precision:  0.3333333333333333  recall:  0.3333333333333333  f1:  0.3333333333333333
precision:  1.0  recall:  0.5  f1:  0.6666666666666666
precision:  0.6666666666666666  recall:  0.10526315789473684  f1:  0.18181818181818182
precision:  1.0  recall:  0.4  f1:  0.5714285714285715
precision:  0.4  recall:  0.4  f1:  0.4000000000000001
precision:  0.25  recall:  0.3333333333333333  f1:  0.28571428571428575
precision:  0.5  recall:  0.3333333333333333  f1:  0.4
precision:  0.6666666666666666  recall:  0.25  f1:  0.36363636363636365
precision:  0.3333333333333333  recall:  0.2  f1:  0.25
precision:  0.3333333333333333  recall:  1.0  f1:  0.5
precision:  0.3333333333333333  recall:  0.16

In [51]:
print("Evaluating body-only model...")
body_predictions, body_references = evaluate_order_agnostic(body_model, body_tokenizer, test_body)

Evaluating body-only model...
precision:  0.0  recall:  0.0  f1:  0
precision:  0.5  recall:  0.0625  f1:  0.1111111111111111
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  1.0  recall:  0.06666666666666667  f1:  0.125
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.5  recall:  0.3333333333333333  f1:  0.4
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.25  recall:  0.16666666666666666  f1:  0.2
precision:  0.0  recall:  0.0  f1:  0
precision:  0.0  recall:  0.0  f1:  0
precision:  0.3333333333333333  recall:  0.2  f1:  0.25
pr

In [52]:
body_metrics = compute_metrics(body_predictions, body_references)
print(body_metrics)

precision:  1.0  recall:  0.125  f1:  0.2222222222222222
precision:  0.25  recall:  0.2  f1:  0.22222222222222224
precision:  0.5  recall:  0.25  f1:  0.3333333333333333
precision:  0.5  recall:  0.25  f1:  0.3333333333333333
precision:  0.5  recall:  0.3333333333333333  f1:  0.4
precision:  0.6666666666666666  recall:  0.10526315789473684  f1:  0.18181818181818182
precision:  1.0  recall:  0.06666666666666667  f1:  0.125
precision:  0.5  recall:  0.2  f1:  0.28571428571428575
precision:  0.5  recall:  0.3333333333333333  f1:  0.4
precision:  0.3333333333333333  recall:  0.25  f1:  0.28571428571428575
precision:  0.3333333333333333  recall:  0.07692307692307693  f1:  0.125
precision:  0.5  recall:  0.2  f1:  0.28571428571428575
precision:  0.2  recall:  0.16666666666666666  f1:  0.1818181818181818
precision:  0.25  recall:  0.16666666666666666  f1:  0.2
precision:  0.5  recall:  1.0  f1:  0.6666666666666666
precision:  0.6666666666666666  recall:  0.4  f1:  0.5
precision:  0.5  recall: