<a href="https://colab.research.google.com/github/vishnuy/AI-ML/blob/main/TrainMistralSentiAnalysis2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import Dataset
from peft import LoraConfig, get_peft_model, TaskType
import pandas as pd
from sklearn.model_selection import train_test_split
import json
import os

class MistralFineTuner:
    def __init__(self, model_name="mistralai/Mistral-7B-v0.1", max_length=512):
        """
        Initialize the fine-tuner with Mistral model

        Args:
            model_name: Huging Face model identifier
            max_length: Maximum sequence length for tokenization
        """
        self.model_name = model_name
        self.max_length = max_length
        self.tokenizer = None
        self.model = None
        self.peft_model = None

    def setup_model_and_tokenizer(self):
        """Load and configure the model and tokenizer"""
        print(f"Loading {self.model_name}...")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load model with appropriate settings for fine-tuning
        # Load with device_map="auto" initially
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )

        # Explicitly move the base model to the target device if using device_map="auto"
        # This might help resolve meta tensor issues before applying PEFT
        # target_device = "cuda" if torch.cuda.is_available() else "cpu"
        # self.model.to(target_device)


        # Setup LoRA configuration for efficient fine-tuning
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=16,  # Rank
            lora_alpha=32,  # LoRA scaling parameter
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
        )

        # Apply LoRA to model
        self.peft_model = get_peft_model(self.model, lora_config)
        self.peft_model.print_trainable_parameters()

    def create_sample_dataset(self):
        """Create a sample dataset for demonstration (sentiment analysis)"""
        data = [
            {"text": "I love this product! It's amazing and works perfectly.", "label": "positive"},
            {"text": "This is the worst purchase I've ever made. Completely disappointed.", "label": "negative"},
            {"text": "The item is okay, nothing special but does the job.", "label": "neutral"},
            {"text": "Fantastic quality! Highly recommend to everyone.", "label": "positive"},
            {"text": "Terrible customer service and poor product quality.", "label": "negative"},
            {"text": "It's an average product, meets basic expectations.", "label": "neutral"},
            {"text": "Outstanding! Exceeded all my expectations.", "label": "positive"},
            {"text": "Waste of money. Doesn't work as advertised.", "label": "negative"},
            {"text": "Decent product for the price point.", "label": "neutral"},
            {"text": "Absolutely perfect! Will buy again.", "label": "positive"},
            {"text": "Poor build quality and arrived damaged.", "label": "negative"},
            {"text": "It's fine, does what it's supposed to do.", "label": "neutral"},
            {"text": "Incredible value for money! Love it!", "label": "positive"},
            {"text": "Horrible experience, would not recommend.", "label": "negative"},
            {"text": "Standard quality, no complaints but nothing exciting.", "label": "neutral"},
        ]

        return pd.DataFrame(data)

    def format_prompt(self, text, label=None):
        """Format input text into a prompt for the model"""
        if label:
            # Training format
            return f"Analyze the sentiment of this text: {text}\nSentiment: {label}"
        else:
            # Inference format
            return f"Analyze the sentiment of this text: {text}\nSentiment:"

    def prepare_dataset(self, df):
        """Prepare dataset for training"""
        # Format prompts
        formatted_data = []
        for _, row in df.iterrows():
            prompt = self.format_prompt(row['text'], row['label'])
            formatted_data.append({"text": prompt})

        # Create HuggingFace dataset
        dataset = Dataset.from_pandas(pd.DataFrame(formatted_data))

        # Tokenize dataset
        def tokenize_function(examples):
            tokenized = self.tokenizer(
                examples["text"],
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt"
            )
            tokenized["labels"] = tokenized["input_ids"].clone() # Use .clone() for tensors
            return tokenized

        tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=dataset.column_names
        )

        return tokenized_dataset

    def train(self, train_dataset, eval_dataset=None, output_dir="./mistral-finetuned"):
        """Fine-tune the model"""
        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=3,
            per_device_train_batch_size=1,  # Small batch size for limited resources
            per_device_eval_batch_size=1,
            gradient_accumulation_steps=4,
            warmup_steps=100,
            learning_rate=5e-5,
            fp16=True,  # Use mixed precision for efficiency
            logging_steps=10,
            save_strategy="epoch",
            eval_strategy="epoch" if eval_dataset else "no", # Changed evaluation_strategy to eval_strategy
            load_best_model_at_end=True if eval_dataset else False,
            metric_for_best_model="eval_loss" if eval_dataset else None,
            report_to=None,  # Disable wandb/tensorboard logging
            gradient_checkpointing=True, # Enable gradient checkpointing
        )

        # Data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,  # Causal LM, not masked LM
        )

        # The model should already be on the correct device after setup_model_and_tokenizer
        # but explicitly move it again just in case, though this might be redundant now.
        # self.peft_model = self.peft_model.to(training_args.device)


        # Create trainer
        trainer = Trainer(
            model=self.peft_model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
        )

        # Start training
        print("Starting training...")
        trainer.train()

        # Save the fine-tuned model
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)

        print(f"Model saved to {output_dir}")

    def inference(self, text, max_new_tokens=50):
        """Run inference on the fine-tuned model"""
        prompt = self.format_prompt(text)

        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length
        )

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract just the generated part
        generated_text = response[len(prompt):].strip()
        return generated_text