In [None]:
# ! pip install peft

In [None]:
import os
import json
import torch
import pandas as pd
from datasets import Dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import numpy as np
from sklearn.model_selection import train_test_split

# 1. Data Preparation


def prepare_dataset(data_path):
    """
    Prepare dataset from a CSV file containing 'text' and 'topic' columns
    """
    df = pd.read_csv(data_path)

    # Convert topics to integer labels
    unique_topics = df['topic'].unique()
    topic_to_id = {topic: idx for idx, topic in enumerate(unique_topics)}
    df['label'] = df['topic'].map(topic_to_id)

    # Split dataset
    train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

    # Convert to HuggingFace datasets
    train_dataset = Dataset.from_pandas(train_df)
    val_dataset = Dataset.from_pandas(val_df)
    test_dataset = Dataset.from_pandas(test_df)

    return train_dataset, val_dataset, test_dataset, len(unique_topics), topic_to_id

# 2. Model and Tokenizer Configuration


def setup_model_and_tokenizer(num_topics):
    """
    Setup model with QLoRA configuration
    """
    # Quantization configuration
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False
    )

    # Load base model and tokenizer
    model_name = "databricks/dolly-v2-3b"  # You can change this to other models
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        num_labels=num_topics,
        device_map="auto"
    )

    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)

    # LoRA configuration
    lora_config = LoraConfig(
        r=16,  # Rank
        lora_alpha=32,  # Alpha scaling
        target_modules=["query_key_value"],  # Target modules to apply LoRA
        lora_dropout=0.05,
        bias="none",
        task_type="SEQ_CLS"  # Sequence classification task
    )

    # Create PEFT model
    model = get_peft_model(model, lora_config)

    return model, tokenizer

# 3. Data Processing


def preprocess_function(examples, tokenizer):
    """
    Tokenize and prepare inputs
    """
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length"
    )

# 4. Training Configuration


def setup_training_args(output_dir):
    """
    Setup training arguments
    """
    return TrainingArguments(
        output_dir=output_dir,
        learning_rate=2e-4,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=3,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        push_to_hub=False,
        gradient_accumulation_steps=4,
        fp16=True,
        logging_steps=100,
    )

# 5. Training Functions


def compute_metrics(eval_pred):
    """
    Compute metrics for evaluation
    """
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    # Calculate accuracy
    accuracy = (predictions == labels).mean()

    return {
        "accuracy": accuracy,
    }

# 6. Main Training Pipeline


def train_topic_model(data_path, output_dir):
    """
    Main function to train the topic modeling model
    """
    # Prepare dataset
    train_dataset, val_dataset, test_dataset, num_topics, topic_map = prepare_dataset(
        data_path)

    # Setup model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(num_topics)

    # Preprocess datasets
    train_dataset = train_dataset.map(
        lambda x: preprocess_function(x, tokenizer),
        batched=True
    )
    val_dataset = val_dataset.map(
        lambda x: preprocess_function(x, tokenizer),
        batched=True
    )
    test_dataset = test_dataset.map(
        lambda x: preprocess_function(x, tokenizer),
        batched=True
    )

    # Setup training arguments
    training_args = setup_training_args(output_dir)

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )

    # Train model
    trainer.train()

    # Evaluate on test set
    test_results = trainer.evaluate(test_dataset)
    print(f"Test results: {test_results}")

    # Save model and tokenizer
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    # Save topic mapping
    with open(os.path.join(output_dir, "topic_map.json"), "w") as f:
        json.dump(topic_map, f)

    return model, tokenizer, topic_map

# 7. Inference Function


def predict_topic(text, model, tokenizer, topic_map):
    """
    Predict topic for new text
    """
    # Tokenize input
    inputs = tokenizer(
        text,
        truncation=True,
        max_length=512,
        padding="max_length",
        return_tensors="pt"
    ).to(model.device)

    # Get prediction
    with torch.no_grad():
        outputs = model(**inputs)
        prediction = torch.argmax(outputs.logits, dim=1).item()

    # Convert to topic label
    id_to_topic = {v: k for k, v in topic_map.items()}
    predicted_topic = id_to_topic[prediction]

    return predicted_topic


# Example usage
if __name__ == "__main__":
    # Set paths
    DATA_PATH = "path/to/your/dataset.csv"  # CSV with 'text' and 'topic' columns
    OUTPUT_DIR = "topic_model_output"

    # Train model
    model, tokenizer, topic_map = train_topic_model(DATA_PATH, OUTPUT_DIR)

    # Example prediction
    test_text = "Your test text here"
    predicted_topic = predict_topic(test_text, model, tokenizer, topic_map)
    print(f"Predicted topic: {predicted_topic}")