<a href="https://colab.research.google.com/github/pankaj-juneja/dask/blob/master/train_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets
import datasets
import torch
import pandas as pd
import random
from transformers import BartForConditionalGeneration, BartTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset
import os

def generate_sample_data(output_file="sample_data.csv", num_samples=100):
    texts = [
        "The Federal Reserve announced an interest rate hike to combat inflation.",
        "Apple Inc. reported a record quarterly revenue of $123.9 billion.",
        "Tesla's stock price surged after strong vehicle delivery numbers.",
        "Crude oil prices fell as global supply concerns eased.",
        "Bitcoin dropped below $40,000 amid market uncertainty."
    ]
    summaries = [
        "Fed raises interest rates to curb inflation.",
        "Apple posts record revenue of $123.9B.",
        "Tesla stock jumps on strong deliveries.",
        "Oil prices decline as supply stabilizes.",
        "Bitcoin dips below $40K amid volatility."
    ]

    data = {"text": [], "summary": []}
    for _ in range(num_samples):
        idx = random.randint(0, len(texts) - 1)
        data["text"].append(texts[idx])
        data["summary"].append(summaries[idx])

    df = pd.DataFrame(data)
    df.to_csv(output_file, index=False)
    print(f"Sample data generated and saved to {output_file}")

def load_data(dataset_path):
    dataset = load_dataset("csv", data_files=dataset_path)
    return dataset["train"].train_test_split(test_size=0.1)

def preprocess_data(examples, tokenizer, max_length=1024):
    model_inputs = tokenizer(examples["text"], max_length=max_length, truncation=True, padding="max_length")
    labels = tokenizer(examples["summary"], max_length=200, truncation=True, padding="max_length").input_ids
    model_inputs["labels"] = labels
    return model_inputs

def train_model(dataset_path, model_name="facebook/bart-large-cnn", output_dir="finetuned_model"):
    dataset = load_data(dataset_path)
    tokenizer = BartTokenizer.from_pretrained(model_name)
    dataset = dataset.map(lambda x: preprocess_data(x, tokenizer), batched=True)

    model = BartForConditionalGeneration.from_pretrained(model_name)
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=5e-5,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=3,
        weight_decay=0.01,
        save_total_limit=2,
        fp16=torch.cuda.is_available()
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        tokenizer=tokenizer,
        data_collator=DataCollatorForSeq2Seq(tokenizer, model=model)
    )

    trainer.train()
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

def summarize_text(text, model_path="finetuned_model"):
    tokenizer = BartTokenizer.from_pretrained(model_path)
    model = BartForConditionalGeneration.from_pretrained(model_path)
    inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
    summary_ids = model.generate(inputs.input_ids, max_length=200, min_length=50, do_sample=False)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

def main():
    sample_data_path = "sample_data.csv"
    generate_sample_data(sample_data_path)
    train_model(sample_data_path)

    test_text = "The Federal Reserve increased interest rates to manage inflation risks."
    summary = summarize_text(test_text)
    print(f"Generated Summary: {summary}")

if __name__ == "__main__":
    main()


Collecting datasets
  Downloading datasets-3.4.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.4.0-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

Generating train split: 0 examples [00:00, ? examples/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

Map:   0%|          | 0/90 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

  trainer = Trainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

# New Section