# TitleFlare: Fine-Tuned T5 Model for AI-Powered Title Generation

## Project Overview
**TitleFlare** is an advanced title generation system that transforms long-form articles into catchy, context-aware, and creative headlines.  
Built on Google’s `t5-base` model and fine-tuned on a curated dataset of Medium articles, this project demonstrates how encoder-decoder architectures can be adapted for abstractive headline generation using the `transformers` library and the `Seq2SeqTrainer` framework.

> **Note**: Model training and testing were conducted on **Google Colab**, utilizing GPU acceleration to speed up experimentation and evaluation.

This project highlights the capability of fine-tuned generative models to assist in real-world content creation tasks by automating the generation of high-quality article titles.

## Objectives
- Fine-tune the `t5-base` model for headline generation.
- Ensure contextual accuracy and creativity in generated titles.
- Evaluate model quality using **ROUGE**.
- Handle long and complex article inputs efficiently through preprocessing.

## Dataset
- **Source**: [Kaggle - Medium Articles Dataset](https://www.kaggle.com/datasets/arnabchaki/medium-articles-dataset)  
- **Content**: Medium articles with titles and full body text.  
- **Cleaning**: Basic filtering was applied to remove incomplete or low-quality entries.  
- **Splits**: Training, validation, and test sets were created with stratified sampling.

## Technical Stack
- **Model Architecture**: T5-base  
- **Training Framework**: Transformers Trainer API  
- **Evaluation Metric**: ROUGE  
- **Tools Used**: PyTorch, Google Colab, TensorBoard, KaggleHub  
- **Preprocessing**: Custom text cleaner + sentence filtering to improve input quality

#### Environment Setup & Requirements Installation

Installs dependencies and prepares the environment by downloading necessary files and packages for training and evaluation.

In [1]:
!pip install datasets transformers rouge-score nltk



In [2]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\OMEN\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\OMEN\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [9]:
pip install kagglehub

Collecting kagglehubNote: you may need to restart the kernel to use updated packages.

  Downloading kagglehub-0.3.4-py3-none-any.whl.metadata (22 kB)
Downloading kagglehub-0.3.4-py3-none-any.whl (43 kB)
Installing collected packages: kagglehub
Successfully installed kagglehub-0.3.4


#### Import Core Libraries

Imports key libraries from the `transformers` and `datasets` modules for model training and evaluation.

In [4]:
import transformers
from datasets import load_dataset, load_metric

#### List Available Kaggle Datasets

Displays a list of datasets available on the authenticated Kaggle account to verify accessibility before downloading.

In [5]:
!kaggle datasets list

ref                                                        title                                               size  lastUpdated          downloadCount  voteCount  usabilityRating  
---------------------------------------------------------  -------------------------------------------------  -----  -------------------  -------------  ---------  ---------------  
bhadramohit/customer-shopping-latest-trends-dataset        Customer Shopping (Latest Trends) Dataset           76KB  2024-11-23 15:26:12           7752        143  1.0              
malaiarasugraj/global-health-statistics                    Global Health Statistics                            44MB  2024-11-27 10:52:27           1949         27  1.0              
mujtabamatin/air-quality-and-pollution-assessment          Air Quality and Pollution Assessment                84KB  2024-12-04 15:29:51           2304         43  1.0              
hopesb/student-depression-dataset                          Student Depression Dataset.    

#### Download Medium Articles Dataset from Kaggle

Downloads the Medium Articles dataset ZIP file directly using its Kaggle identifier.

In [9]:
!kaggle datasets download -d fabiochiusano/medium-articles

Dataset URL: https://www.kaggle.com/datasets/fabiochiusano/medium-articles
License(s): CC0-1.0
Downloading medium-articles.zip to C:\Users\OMEN\TitleFlare




  0%|          | 0.00/369M [00:00<?, ?B/s]
  0%|          | 1.00M/369M [00:01<06:40, 964kB/s]
  1%|          | 2.00M/369M [00:02<07:07, 900kB/s]
  1%|          | 3.00M/369M [00:03<07:27, 857kB/s]
  1%|1         | 4.00M/369M [00:05<08:07, 785kB/s]
  1%|1         | 5.00M/369M [00:06<08:34, 742kB/s]
  2%|1         | 6.00M/369M [00:08<08:54, 711kB/s]
  2%|1         | 7.00M/369M [00:10<09:43, 650kB/s]
  2%|2         | 8.00M/369M [00:11<09:50, 641kB/s]
  2%|2         | 9.00M/369M [00:13<10:07, 621kB/s]
  3%|2         | 10.0M/369M [00:14<09:26, 665kB/s]
  3%|2         | 11.0M/369M [00:16<09:48, 638kB/s]
  3%|3         | 12.0M/369M [00:18<10:06, 617kB/s]
  4%|3         | 13.0M/369M [00:20<10:32, 590kB/s]
  4%|3         | 14.0M/369M [00:22<10:16, 603kB/s]
  4%|4         | 15.0M/369M [00:23<09:36, 644kB/s]
  4%|4         | 16.0M/369M [00:25<09:27, 652kB/s]
  5%|4         | 17.0M/369M [00:26<09:51, 624kB/s]
  5%|4         | 18.0M/369M [00:28<10:06, 606kB/s]
  5%|5         | 19.0M/369M [00:30<09:

#### Load Dataset Directory via KaggleHub

Retrieves the full path to the downloaded dataset using KaggleHub, which simplifies data access within the Colab environment.

In [10]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("fabiochiusano/medium-articles")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/fabiochiusano/medium-articles?dataset_version_number=2...


100%|████████████████████████████████████████████████████████████████████████████████| 369M/369M [12:44<00:00, 506kB/s]


Extracting files...
Path to dataset files: C:\Users\OMEN\.cache\kagglehub\datasets\fabiochiusano\medium-articles\versions\2


#### Load Dataset into Hugging Face Format

Loads the Medium Articles CSV data using Hugging Face’s `load_dataset` for easier preprocessing and training integration.

In [10]:
medium_datasets = load_dataset("csv", data_files="medium-articles.zip")

Generating train split: 0 examples [00:00, ? examples/s]

#### Filter Incomplete Records

Removes any dataset entries that are missing either the article `text` or the `title` to ensure clean training data.

In [12]:
medium_datasets_cleaned = medium_datasets.filter(lambda x: x['text'] is not None and x['title'] is not None)

Filter:   0%|          | 0/186368 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

#### Split Dataset into Train, Validation, and Test Sets

Performs stratified splitting:  
- Reserves `3,000` samples for testing.  
- Then reserves another `3,000` from the training portion for validation.  
Updates the main dataset dictionary accordingly.

In [13]:
datasets_train_test = medium_datasets_cleaned["train"].train_test_split(test_size=3000)
datasets_train_validation = datasets_train_test["train"].train_test_split(test_size=3000)

medium_datasets["train"] = datasets_train_validation["train"]
medium_datasets["validation"] = datasets_train_validation["test"]
medium_datasets["test"] = datasets_train_test["test"]

medium_datasets

DatasetDict({
    train: Dataset({
        features: ['title', 'text', 'url', 'authors', 'timestamp', 'tags'],
        num_rows: 180363
    })
    validation: Dataset({
        features: ['title', 'text', 'url', 'authors', 'timestamp', 'tags'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['title', 'text', 'url', 'authors', 'timestamp', 'tags'],
        num_rows: 3000
    })
})

#### Display Dataset Split Statistics

Calculates and prints the percentage distribution of training, validation, and test samples to verify the dataset balancing.

In [14]:
n_samples_train = len(medium_datasets["train"])
n_samples_validation = len(medium_datasets["validation"])
n_samples_test = len(medium_datasets["test"])
n_samples_total = n_samples_train + n_samples_validation + n_samples_test

print(f"- Training set: {n_samples_train*100/n_samples_total:.2f}%")
print(f"- Validation set: {n_samples_validation*100/n_samples_total:.2f}%")
print(f"- Test set: {n_samples_test*100/n_samples_total:.2f}%")

- Training set: 96.78%
- Validation set: 1.61%
- Test set: 1.61%


#### Load Tokenizer and Model Checkpoint

Imports the tokenizer associated with the `t5-base` model from Hugging Face’s Transformers library to tokenize both inputs and targets.

In [15]:
import string
from transformers import AutoTokenizer

#### Set Text Preprocessing Configuration

Defines prefix and max sequence lengths used for preparing model input and output sequences.

In [16]:
model_checkpoint = "t5-base"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

#### Text Cleaning and Preprocessing Functions

Defines functions to clean raw article content and tokenize both the input text and corresponding titles.  
Filters out malformed sentence fragments and prepares data in the format expected by the `T5` model.

In [21]:
prefix = "summarize: "

max_input_length = 512
max_target_length = 64

def clean_text(text):
  sentences = nltk.sent_tokenize(text.strip())
  sentences_cleaned = [s for sent in sentences for s in sent.split("\n")]
  sentences_cleaned_no_titles = [sent for sent in sentences_cleaned
                                 if len(sent) > 0 and
                                 sent[-1] in string.punctuation]
  text_cleaned = "\n".join(sentences_cleaned_no_titles)
  return text_cleaned

def preprocess_data(examples):
  texts_cleaned = [clean_text(text) for text in examples["text"]]
  inputs = [prefix + text for text in texts_cleaned]
  model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

  # Setup the tokenizer for targets
  with tokenizer.as_target_tokenizer():
    labels = tokenizer(examples["title"], max_length=max_target_length, 
                       truncation=True)

  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

#### Dataset Filtering and Tokenization

Applies a length-based filter to exclude incomplete entries, followed by tokenization using the defined preprocessing pipeline.

In [22]:
medium_datasets_cleaned = medium_datasets.filter(lambda example: (len(example['text']) >= 500) and (len(example['title']) >= 20))
tokenized_datasets = medium_datasets_cleaned.map(preprocess_data, batched=True)
tokenized_datasets

Map:   0%|          | 0/154380 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['title', 'text', 'url', 'authors', 'timestamp', 'tags', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 154380
    })
    validation: Dataset({
        features: ['title', 'text', 'url', 'authors', 'timestamp', 'tags', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2589
    })
    test: Dataset({
        features: ['title', 'text', 'url', 'authors', 'timestamp', 'tags', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2557
    })
})

#### Define Training Arguments and Configuration

Sets up training hyperparameters, checkpointing, logging strategies, and evaluation settings for the `Seq2SeqTrainer` using the `t5-base` model.

In [23]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer




In [26]:
batch_size = 8
model_name = "t5-base-medium-title-generation"
model_dir = "/content/Models/t5-base-medium-title-generation"
args = Seq2SeqTrainingArguments(
    model_dir,
    eval_strategy="steps",
    eval_steps=100,
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="steps",
    save_steps=200,
    learning_rate=4e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="rouge1",
    report_to="tensorboard"
)

#### Initialize Data Collator for Seq2Seq

Creates a data collator that dynamically pads inputs and labels during training, ensuring compatibility with sequence-to-sequence models.

In [27]:
data_collator = DataCollatorForSeq2Seq(tokenizer)

#### Define ROUGE-Based Evaluation Metrics

Implements the `compute_metrics` function using ROUGE scores to evaluate the quality of generated headlines. Also tracks the average length of generated sequences.

In [30]:
import numpy as np
from evaluate import load

# Load the ROUGE metric with trust_remote_code
metric = load("rouge", trust_remote_code=True)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip()))
                      for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) 
                      for label in decoded_labels]
    
    # Compute ROUGE scores
    result = metric.compute(predictions=decoded_preds, references=decoded_labels,
                            use_stemmer=True)

    # Extract ROUGE f1 scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length to metrics
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id)
                      for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

#### Load Pretrained T5 Model

Loads the base T5 model for sequence-to-sequence learning using the specified checkpoint.

In [31]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

#### Initialize Trainer Object

Configures the `Seq2SeqTrainer` with model, datasets, training arguments, tokenizer, and evaluation metric function.

In [33]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics
)

#### Launch TensorBoard for Live Monitoring

Starts TensorBoard interface to track training metrics and visualize progress during model training.

In [None]:
# Start TensorBoard before training to monitor it in progress
%load_ext tensorboard
%tensorboard --logdir '{model_dir}'/runs

#### Fine-Tune the T5 Model

Fine-tunes the pretrained `T5` model on the Medium articles dataset using custom training arguments, evaluation steps, and ROUGE-based performance monitoring.

In [None]:
trainer.train()

#### Save Trained Model Locally

Persists the fine-tuned model and tokenizer to the specified directory for later use or export.

In [None]:
trainer.save_model()

#### Archive and Download Trained Model

Compresses the saved model directory into a ZIP file and initiates download to the local machine using Colab utilities.

In [None]:
# Google Colab (Optional)
!zip -r model.zip /content/Models/t5-base-medium-title-generation
from google.colab import files
files.download("model.zip")

#### Reload Model and Tokenizer from Saved Directory

Restores the tokenizer and model from local storage, preparing them for inference or evaluation.

In [None]:
# Load the model from GDrive
model_dir = "/content/Models/t5-base-medium-title-generation"

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

#### Set Inference Parameters

Defines the maximum input length used during title generation or inference.

In [None]:
max_input_length = 512

#### Generate Title from Sample Article 

Uses the fine-tuned `T5` model to generate a headline for a sample article input. Applies beam search decoding with sampling to enhance creativity in title generation.

In [None]:
text = """
We define access to a Streamlit app in a browser tab as a session.
For each browser tab that connects to the Streamlit server, a new session is created.
Streamlit reruns your script from top to bottom every time you interact with your app.
Each reruns takes place in a blank slate: no variables are shared between runs.
Session State is a way to share variables between reruns, for each user session.
In addition to the ability to store and persist state, Streamlit also exposes the
ability to manipulate state using Callbacks. In this guide, we will illustrate the
usage of Session State and Callbacks as we build a stateful Counter app.
For details on the Session State and Callbacks API, please refer to our Session
State API Reference Guide. Also, check out this Session State basics tutorial
video by Streamlit Developer Advocate Dr. Marisa Smith to get started:
"""

inputs = ["summarize: " + text]

inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
predicted_title = nltk.sent_tokenize(decoded_output.strip())[0]

print(predicted_title)
# Session State and Callbacks in Streamlit

In [None]:
text = """
Many financial institutions started building conversational AI, prior to the Covid19
pandemic, as part of a digital transformation initiative. These initial solutions
were high profile, highly personalized virtual assistants — like the Erica chatbot
from Bank of America. As the pandemic hit, the need changed as contact centers were
under increased pressures. As Cathal McGloin of ServisBOT explains in “how it started,
and how it is going,” financial institutions were looking for ways to automate
solutions to help get back to “normal” levels of customer service. This resulted
in a change from the “future of conversational AI” to a real tactical assistant
that can help in customer service. Haritha Dev of Wells Fargo, saw a similar trend.
Banks were originally looking to conversational AI as part of digital transformation
to keep up with the times. However, with the pandemic, it has been more about
customer retention and customer satisfaction. In addition, new use cases came about
as a result of Covid-19 that accelerated adoption of conversational AI. As Vinita
Kumar of Deloitte points out, banks were dealing with an influx of calls about new
concerns, like questions around the Paycheck Protection Program (PPP) loans. This
resulted in an increase in volume, without enough agents to assist customers, and
tipped the scale to incorporate conversational AI. When choosing initial use cases
to support, financial institutions often start with high volume, low complexity
tasks. For example, password resets, checking account balances, or checking the
status of a transaction, as Vinita points out. From there, the use cases can evolve
as the banks get more mature in developing conversational AI, and as the customers
become more engaged with the solutions. Cathal indicates another good way for banks
to start is looking at use cases that are a pain point, and also do not require a
lot of IT support. Some financial institutions may have a multi-year technology
roadmap, which can make it harder to get a new service started. A simple chatbot
for document collection in an onboarding process can result in high engagement,
and a high return on investment. For example, Cathal has a banking customer that
implemented a chatbot to capture a driver’s license to be used in the verification
process of adding an additional user to an account — it has over 85% engagement
with high satisfaction. An interesting use case Haritha discovered involved
educating customers on financial matters. People feel more comfortable asking a
chatbot what might be considered a “dumb” question, as the chatbot is less judgmental.
Users can be more ambiguous with their questions as well, not knowing the right
words to use, as chatbot can help narrow things down.
"""

inputs = ["summarize: " + text]

inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
predicted_title = nltk.sent_tokenize(decoded_output.strip())[0]

print(predicted_title)
# Conversational AI: The Future of Customer Service

#### Evaluate Model Performance on Test Set

Preprocesses the test set, batches the inputs, and generates predictions using the trained `T5` model.  
Computes **ROUGE-based evaluation metrics** by comparing generated titles to reference titles in the test data.

In [None]:
import torch

# get test split
test_tokenized_dataset = tokenized_datasets["test"]

# pad texts to the same length
def preprocess_test(examples):
  inputs = [prefix + text for text in examples["text"]]
  model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True,
                           padding="max_length")
  return model_inputs

test_tokenized_dataset = test_tokenized_dataset.map(preprocess_test, batched=True)

# prepare dataloader
test_tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
dataloader = torch.utils.data.DataLoader(test_tokenized_dataset, batch_size=32)

# generate text for each batch
all_predictions = []
for i,batch in enumerate(dataloader):
  predictions = model.generate(**batch)
  all_predictions.append(predictions)

# flatten predictions
all_predictions_flattened = [pred for preds in all_predictions for pred in preds]

# tokenize and pad titles
all_titles = tokenizer(test_tokenized_dataset["title"], max_length=max_target_length,
                       truncation=True, padding="max_length")["input_ids"]

# compute metrics
predictions_labels = [all_predictions_flattened, all_titles]
compute_metrics(predictions_labels)

## Final Thoughts

This project demonstrates how powerful transformer-based models like `T5` can be customized for specialized NLP tasks beyond summarization — in this case, **headline generation**.

Despite being trained on a **limited dataset** and within **Colab constraints**, **TitleFlare** achieves strong performance and reflects the potential of fine-tuned generative models in creative applications. Its combination of **rigorous preprocessing**, **effective evaluation with ROUGE**, and **elegant model architecture** makes it a solid base for future development or production deployment.

---

**Thank you for checking out TitleFlare**.  
Feel free to explore the code and adapt it for your own **NLP content generation** tasks.
