In [None]:
!pip install -q transformers wandb huggingface_hub trl pydantic

In [None]:
!pip install -qU datasets

In [None]:
!pip install -qU unsloth

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import unsloth
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
import os
import json
import logging
from typing import Dict, Any, Optional, List
from pydantic import BaseModel, Field, field_validator
from typing import Dict, Any, Optional, List
from pathlib import Path
import torch
from datasets import Dataset, load_dataset
from transformers import TrainingArguments
from huggingface_hub import HfApi, create_repo

In [None]:
# ================================
# Configuration Management
# ================================
class ModelConfig(BaseModel):
    """Model configuration settings"""
    base_model: str = Field(default="unsloth/llama-3-8b-bnb-4bit", description="Base model to fine-tune")
    max_seq_length: int = Field(default=2048, gt=0, description="Maximum sequence length")
    dtype: Optional[str] = Field(default=None, description="Data type for model weights")
    load_in_4bit: bool = Field(default=True, description="Load model in 4-bit quantization")

    @field_validator('max_seq_length')
    @classmethod
    def validate_seq_length(cls, v):
        if v <= 0:
            raise ValueError('max_seq_length must be positive')
        return v

class LoRAConfig(BaseModel):
    """LoRA configuration for fine-tuning"""
    r: int = Field(default=16, gt=0, description="LoRA rank")
    target_modules: List[str] = Field(
        default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        description="Target modules for LoRA"
    )
    lora_alpha: int = Field(default=16, gt=0, description="LoRA alpha parameter")
    lora_dropout: float = Field(default=0.0, ge=0.0, le=1.0, description="LoRA dropout rate")
    bias: str = Field(default="none", description="Bias type for LoRA")
    use_gradient_checkpointing: str = Field(default="unsloth", description="Gradient checkpointing method")
    random_state: int = Field(default=3407, description="Random seed")
    use_rslora: bool = Field(default=False, description="Use RSLoRA")
    loftq_config: Optional[Dict] = Field(default=None, description="LoftQ configuration")

    @field_validator('bias')
    @classmethod
    def validate_bias(cls, v):
        allowed_bias = ["none", "all", "lora_only"]
        if v not in allowed_bias:
            raise ValueError(f'bias must be one of {allowed_bias}')
        return v

class SFTConfig(BaseModel):
    """Supervised Fine-tuning configuration"""
    output_dir: str = Field(default="./sft_results", description="Output directory for SFT")
    per_device_train_batch_size: int = Field(default=2, gt=0, description="Training batch size per device")
    gradient_accumulation_steps: int = Field(default=4, gt=0, description="Gradient accumulation steps")
    warmup_steps: int = Field(default=5, ge=0, description="Warmup steps")
    max_steps: int = Field(default=100, gt=0, description="Maximum training steps")
    learning_rate: float = Field(default=2e-4, gt=0, description="Learning rate")
    fp16: bool = Field(default=True, description="Use FP16 precision")
    bf16: bool = Field(default=False, description="Use BF16 precision")
    logging_steps: int = Field(default=1, gt=0, description="Logging frequency")
    optim: str = Field(default="adamw_8bit", description="Optimizer")
    weight_decay: float = Field(default=0.01, ge=0, description="Weight decay")
    lr_scheduler_type: str = Field(default="linear", description="Learning rate scheduler")
    seed: int = Field(default=3407, description="Random seed")
    dataset_text_field: str = Field(default="text", description="Dataset text field name")

    @field_validator('optim')
    @classmethod
    def validate_optimizer(cls, v):
        allowed_optims = ["adamw_8bit", "adamw", "sgd", "adafactor"]
        if v not in allowed_optims:
            raise ValueError(f'optim must be one of {allowed_optims}')
        return v

class DPOConfig(BaseModel):
    """Direct Preference Optimization configuration"""
    output_dir: str = Field(default="./dpo_results", description="Output directory for DPO")
    per_device_train_batch_size: int = Field(default=1, gt=0, description="Training batch size per device")
    gradient_accumulation_steps: int = Field(default=8, gt=0, description="Gradient accumulation steps")
    warmup_steps: int = Field(default=5, ge=0, description="Warmup steps")
    max_steps: int = Field(default=50, gt=0, description="Maximum training steps")
    learning_rate: float = Field(default=5e-7, gt=0, description="Learning rate")
    fp16: bool = Field(default=True, description="Use FP16 precision")
    bf16: bool = Field(default=False, description="Use BF16 precision")
    logging_steps: int = Field(default=1, gt=0, description="Logging frequency")
    optim: str = Field(default="adamw_8bit", description="Optimizer")
    weight_decay: float = Field(default=0.01, ge=0, description="Weight decay")
    lr_scheduler_type: str = Field(default="linear", description="Learning rate scheduler")
    seed: int = Field(default=3407, description="Random seed")
    beta: float = Field(default=0.1, gt=0, description="DPO beta parameter")
    dataset_num_proc: int = Field(default=4, gt=0, description="Number of processes for dataset")

    @field_validator('beta')
    @classmethod
    def validate_beta(cls, v):
        if v <= 0:
            raise ValueError('beta must be positive')
        return v

class HubConfig(BaseModel):
    """HuggingFace Hub configuration"""
    repo_id: str = Field(default="rahulsamant37/voice-assistant-model", description="HuggingFace repository ID")
    private: bool = Field(default=False, description="Make repository private")
    token: Optional[str] = Field(default=None, description="HuggingFace API token")
    commit_message: str = Field(default="Upload fine-tuned voice model", description="Commit message")

    @field_validator('repo_id')
    @classmethod
    def validate_repo_id(cls, v):
        if '/' not in v:
            raise ValueError('repo_id must be in format "username/repo-name"')
        return v

class PipelineConfig(BaseModel):
    """Main pipeline configuration"""
    model: ModelConfig = Field(default_factory=ModelConfig, description="Model configuration")
    lora: LoRAConfig = Field(default_factory=LoRAConfig, description="LoRA configuration")
    sft: SFTConfig = Field(default_factory=SFTConfig, description="SFT configuration")
    dpo: DPOConfig = Field(default_factory=DPOConfig, description="DPO configuration")
    hub: HubConfig = Field(default_factory=HubConfig, description="Hub configuration")
    wandb_project: str = Field(default="voice-finetuning", description="Weights & Biases project name")
    use_wandb: bool = Field(default=False, description="Enable Weights & Biases logging") # Changed default to False
    save_intermediate: bool = Field(default=True, description="Save intermediate checkpoints")

    class Config:
        """Pydantic configuration"""
        validate_assignment = True
        extra = "forbid"
        use_enum_values = True

In [None]:
# ================================
# Logging Service
# ================================
class LoggingService:
    """Centralized logging service"""

    def __init__(self, name: str = "voice_pipeline", level: str = "INFO"):
        self.logger = logging.getLogger(name)
        self.logger.setLevel(getattr(logging, level.upper()))

        if not self.logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)

    def info(self, message: str):
        self.logger.info(message)

    def error(self, message: str):
        self.logger.error(message)

    def warning(self, message: str):
        self.logger.warning(message)

In [None]:
# ================================
# Data Service
# ================================
class DataService:
    """Service for handling dataset operations"""

    def __init__(self, logger: LoggingService):
        self.logger = logger

    def load_sft_dataset(self) -> Dataset:
        """Load dataset for supervised fine-tuning - UPDATED with working datasets"""
        self.logger.info("Loading SFT dataset...")

        try:
            # Try multiple datasets until one works
            datasets_to_try = [
                ("lmsys/chatbot_arena_conversations", "train[:1000]"),
                ("HuggingFaceH4/ultrachat_200k", "train_sft[:1000]"),
                ("microsoft/DialoGPT-medium", None),  # Keep as fallback
                ("Anthropic/hh-rlhf", "train[:1000]"),
            ]

            dataset = None
            for dataset_name, split in datasets_to_try:
                try:
                    self.logger.info(f"Trying dataset: {dataset_name}")
                    if split:
                        dataset = load_dataset(dataset_name, split=split)
                    else:
                        dataset = load_dataset(dataset_name, split="train[:1000]")
                    self.logger.info(f"Successfully loaded: {dataset_name}")
                    break
                except Exception as e:
                    self.logger.warning(f"Failed to load {dataset_name}: {str(e)}")
                    continue

            if dataset is None:
                # Create a synthetic dataset as last resort
                self.logger.info("Creating synthetic conversation dataset...")
                dataset = self._create_synthetic_dataset()

            # Format the dataset for conversation
            formatted_dataset = self._format_sft_dataset(dataset)
            self.logger.info(f"Loaded {len(formatted_dataset)} SFT examples")
            return formatted_dataset

        except Exception as e:
            self.logger.error(f"All dataset loading attempts failed: {str(e)}")
            # Create synthetic dataset as final fallback
            return self._create_synthetic_dataset()

    def _create_synthetic_dataset(self) -> Dataset:
        """Create a synthetic conversational dataset"""
        self.logger.info("Creating synthetic conversational dataset...")

        # Sample conversation data for voice assistant training
        conversations = [
            {"text": "<|im_start|>user\nWhat's the weather like today?<|im_end|>\n<|im_start|>assistant\nI'd be happy to help you with the weather, but I don't have access to real-time weather data. Could you please tell me your location so I can provide more specific guidance?<|im_end|>"},
            {"text": "<|im_start|>user\nSet a timer for 10 minutes<|im_end|>\n<|im_start|>assistant\nI'll set a timer for 10 minutes. The timer is now running and will alert you when the time is up.<|im_end|>"},
            {"text": "<|im_start|>user\nPlay some music<|im_end|>\n<|im_start|>assistant\nI'd be happy to help you play music! However, I don't have direct access to music streaming services. You can try saying 'Hey Google, play music' or 'Alexa, play music' to your smart speaker, or open your favorite music app.<|im_end|>"},
            {"text": "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\nI don't have access to real-time information, but you can check the time on your device or ask your voice assistant. Is there anything else I can help you with?<|im_end|>"},
            {"text": "<|im_start|>user\nTell me a joke<|im_end|>\n<|im_start|>assistant\nHere's a joke for you: Why don't scientists trust atoms? Because they make up everything! I hope that brought a smile to your face.<|im_end|>"},
            {"text": "<|im_start|>user\nWhat can you help me with?<|im_end|>\n<|im_start|>assistant\nI'm a voice assistant and I can help you with various tasks like answering questions, providing information, helping with calculations, setting reminders, and having conversations. What would you like assistance with today?<|im_end|>"},
            {"text": "<|im_start|>user\nCall my mom<|im_end|>\n<|im_start|>assistant\nI'm not able to make phone calls directly, but I can guide you on how to call your mom. You can say 'Call Mom' to your phone's voice assistant, or find her contact in your phone app and tap to call.<|im_end|>"},
            {"text": "<|im_start|>user\nWhat's 25 plus 37?<|im_end|>\n<|im_start|>assistant\n25 plus 37 equals 62.<|im_end|>"},
            {"text": "<|im_start|>user\nTurn off the lights<|im_end|>\n<|im_start|>assistant\nI can't directly control your lights, but if you have smart lights, you can try saying 'Turn off the lights' to your smart home assistant, or use your smart home app to control them.<|im_end|>"},
            {"text": "<|im_start|>user\nGood morning<|im_end|>\n<|im_start|>assistant\nGood morning! I hope you're having a wonderful start to your day. How can I assist you this morning?<|im_end|>"},
        ]

        # Expand the dataset by creating variations
        expanded_conversations = []
        for conv in conversations:
            expanded_conversations.append(conv)

        # Add more variations to reach 100+ examples
        for i in range(10):
            for conv in conversations:
                expanded_conversations.append(conv)

        dataset = Dataset.from_list(expanded_conversations)
        self.logger.info(f"Created synthetic dataset with {len(dataset)} examples")
        return dataset

    def _format_sft_dataset(self, dataset: Dataset) -> Dataset:
        """Format dataset for SFT training"""
        def format_conversation(examples):
            formatted_texts = []

            # Handle different dataset formats
            if 'text' in examples:
                # Already formatted
                formatted_texts = examples['text']
            elif 'conversation' in examples:
                # Chatbot arena format
                for conv in examples['conversation']:
                    if isinstance(conv, list) and len(conv) >= 2:
                        formatted_text = f"<|im_start|>user\n{conv[0]['content']}<|im_end|>\n<|im_start|>assistant\n{conv[1]['content']}<|im_end|>"
                        formatted_texts.append(formatted_text)
            elif 'chosen' in examples:
                # HH-RLHF format
                for chosen in examples['chosen']:
                    if 'Human:' in chosen and 'Assistant:' in chosen:
                        parts = chosen.split('Assistant:')
                        if len(parts) >= 2:
                            human_part = parts[0].replace('Human:', '').strip()
                            assistant_part = parts[1].strip()
                            formatted_text = f"<|im_start|>user\n{human_part}<|im_end|>\n<|im_start|>assistant\n{assistant_part}<|im_end|>"
                            formatted_texts.append(formatted_text)
            else:
                # Generic fallback - use first text field available
                text_fields = [k for k in examples.keys() if isinstance(examples[k][0], str)]
                if text_fields:
                    for text in examples[text_fields[0]]:
                        formatted_text = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n{text}<|im_end|>"
                        formatted_texts.append(formatted_text)

            return {"text": formatted_texts}

        try:
            formatted_dataset = dataset.map(format_conversation, batched=True, remove_columns=dataset.column_names)
            return formatted_dataset
        except Exception as e:
            self.logger.warning(f"Error formatting dataset: {str(e)}, using as-is")
            # If formatting fails, ensure we have a 'text' column
            if 'text' not in dataset.column_names:
                first_text_col = [col for col in dataset.column_names if dataset[col][0] and isinstance(dataset[col][0], str)]
                if first_text_col:
                    dataset = dataset.rename_column(first_text_col[0], 'text')
            return dataset

    def load_dpo_dataset(self) -> Dataset:
        """Load dataset for DPO training - UPDATED"""
        self.logger.info("Loading DPO dataset...")

        try:
            # Try Anthropic HH-RLHF first
            try:
                dataset = load_dataset("Anthropic/hh-rlhf", split="train[:500]")
                self.logger.info("Successfully loaded Anthropic/hh-rlhf dataset")
            except Exception as e:
                self.logger.warning(f"Failed to load Anthropic/hh-rlhf: {str(e)}")
                # Create synthetic DPO dataset
                dataset = self._create_synthetic_dpo_dataset()

            def format_dpo_data(examples):
                prompts = []
                chosen = []
                rejected = []

                if 'chosen' in examples and 'rejected' in examples:
                    # Standard DPO format
                    for i in range(len(examples["chosen"])):
                        try:
                            chosen_text = examples["chosen"][i]
                            rejected_text = examples["rejected"][i]

                            if "Assistant:" in chosen_text:
                                prompt = chosen_text.split("Assistant:")[0] + "Assistant:"
                                chosen_response = chosen_text.split("Assistant:")[-1].strip()
                                rejected_response = rejected_text.split("Assistant:")[-1].strip()

                                prompts.append(prompt)
                                chosen.append(chosen_response)
                                rejected.append(rejected_response)
                        except (IndexError, AttributeError):
                            continue
                else:
                    # Use synthetic data
                    return self._get_synthetic_dpo_batch()

                return {
                    "prompt": prompts,
                    "chosen": chosen,
                    "rejected": rejected
                }

            formatted_dataset = dataset.map(format_dpo_data, batched=True)
            # Filter out empty examples
            formatted_dataset = formatted_dataset.filter(lambda x: len(x["prompt"]) > 0)

            self.logger.info(f"Loaded {len(formatted_dataset)} DPO examples")
            return formatted_dataset

        except Exception as e:
            self.logger.error(f"Failed to load DPO dataset: {str(e)}")
            return self._create_synthetic_dpo_dataset()

    def _create_synthetic_dpo_dataset(self) -> Dataset:
        """Create synthetic DPO dataset"""
        self.logger.info("Creating synthetic DPO dataset...")

        dpo_examples = [
            {
                "prompt": "Human: What's the weather like today? Assistant:",
                "chosen": "I'd be happy to help you with the weather, but I don't have access to real-time weather data. Could you please tell me your location?",
                "rejected": "It's sunny and warm everywhere today!"
            },
            {
                "prompt": "Human: Tell me a joke Assistant:",
                "chosen": "Here's a joke for you: Why don't scientists trust atoms? Because they make up everything!",
                "rejected": "Jokes are stupid and waste of time."
            },
            {
                "prompt": "Human: How do I cook pasta? Assistant:",
                "chosen": "To cook pasta: 1) Boil water in a large pot, 2) Add salt, 3) Add pasta and cook according to package directions, 4) Drain and serve!",
                "rejected": "Just put pasta in cold water and hope for the best."
            },
            {
                "prompt": "Human: What's 2+2? Assistant:",
                "chosen": "2 + 2 equals 4.",
                "rejected": "2 + 2 equals 5, obviously."
            },
            {
                "prompt": "Human: Help me with my homework Assistant:",
                "chosen": "I'd be happy to help you with your homework! What subject are you working on and what specific questions do you have?",
                "rejected": "I'm not going to do your homework for you. Figure it out yourself."
            },
        ]

        # Expand the dataset
        expanded_examples = []
        for _ in range(20):  # Create 100 examples
            for example in dpo_examples:
                expanded_examples.append(example)

        dataset = Dataset.from_list(expanded_examples)
        self.logger.info(f"Created synthetic DPO dataset with {len(dataset)} examples")
        return dataset

    def _get_synthetic_dpo_batch(self):
        """Get a batch of synthetic DPO data"""
        return {
            "prompt": ["Human: Hello Assistant:", "Human: What time is it? Assistant:"],
            "chosen": ["Hello! How can I assist you today?", "I don't have access to real-time information, but you can check your device for the current time."],
            "rejected": ["Go away.", "Time is an illusion."]
        }

In [None]:
# ================================
# Model Service
# ================================
class ModelService:
    """Service for model operations"""

    def __init__(self, config: PipelineConfig, logger: LoggingService):
        self.config = config
        self.logger = logger
        self.model = None
        self.tokenizer = None

    def load_base_model(self):
        """Load the base model and tokenizer"""
        self.logger.info(f"Loading base model: {self.config.model.base_model}")

        self.model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_name=self.config.model.base_model,
            max_seq_length=self.config.model.max_seq_length,
            dtype=self.config.model.dtype,
            load_in_4bit=self.config.model.load_in_4bit,
            # Added parameters for offloading/device map
            llm_int8_enable_fp32_cpu_offload=True,
            device_map="auto",
        )

        self.logger.info("Base model loaded successfully")

    def prepare_for_training(self):
        """Prepare model for training with LoRA"""
        self.logger.info("Preparing model for training...")

        self.model = FastLanguageModel.get_peft_model(
            self.model,
            r=self.config.lora.r,
            target_modules=self.config.lora.target_modules,
            lora_alpha=self.config.lora.lora_alpha,
            lora_dropout=self.config.lora.lora_dropout,
            bias=self.config.lora.bias,
            use_gradient_checkpointing=self.config.lora.use_gradient_checkpointing,
            random_state=self.config.lora.random_state,
            use_rslora=self.config.lora.use_rslora,
            loftq_config=self.config.lora.loftq_config,
        )

        self.logger.info("Model prepared for training")

In [None]:
# ================================
# Training Services
# ================================

class SFTService:
    """Supervised Fine-tuning service"""

    def __init__(self, config: PipelineConfig, logger: LoggingService):
        self.config = config
        self.logger = logger

    def train(self, model, tokenizer, dataset: Dataset):
        """Run supervised fine-tuning"""
        self.logger.info("Starting SFT training...")

        from trl import SFTTrainer

        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            train_dataset=dataset,
            dataset_text_field=self.config.sft.dataset_text_field,
            max_seq_length=self.config.model.max_seq_length,
            dataset_num_proc=2,
            packing=False,
            args=TrainingArguments(
                per_device_train_batch_size=self.config.sft.per_device_train_batch_size,
                gradient_accumulation_steps=self.config.sft.gradient_accumulation_steps,
                warmup_steps=self.config.sft.warmup_steps,
                max_steps=self.config.sft.max_steps,
                learning_rate=self.config.sft.learning_rate,
                fp16=self.config.sft.fp16,
                bf16=self.config.sft.bf16,
                logging_steps=self.config.sft.logging_steps,
                optim=self.config.sft.optim,
                weight_decay=self.config.sft.weight_decay,
                lr_scheduler_type=self.config.sft.lr_scheduler_type,
                seed=self.config.sft.seed,
                output_dir=self.config.sft.output_dir,
                report_to="none", # Changed report_to to "none"
            ),
        )

        trainer.train()
        self.logger.info("SFT training completed")

        if self.config.save_intermediate:
            trainer.save_model()
            self.logger.info(f"SFT model saved to {self.config.sft.output_dir}")

        return trainer

In [None]:
class DPOService:
    """Direct Preference Optimization service"""

    def __init__(self, config: PipelineConfig, logger: LoggingService):
        self.config = config
        self.logger = logger

    def train(self, model, tokenizer, dataset: Dataset):
        """Run DPO training"""
        self.logger.info("Starting DPO training...")

        from trl import DPOTrainer

        trainer = DPOTrainer(
            model=model,
            ref_model=None,  # Use implicit reference model
            args=TrainingArguments(
                per_device_train_batch_size=self.config.dpo.per_device_train_batch_size,
                gradient_accumulation_steps=self.config.dpo.gradient_accumulation_steps,
                warmup_steps=self.config.dpo.warmup_steps,
                max_steps=self.config.dpo.max_steps,
                learning_rate=self.config.dpo.learning_rate,
                fp16=self.config.dpo.fp16,
                bf16=self.config.dpo.bf16,
                logging_steps=self.config.dpo.logging_steps,
                optim=self.config.dpo.optim,
                weight_decay=self.config.dpo.weight_decay,
                lr_scheduler_type=self.config.dpo.lr_scheduler_type,
                seed=self.config.dpo.seed,
                output_dir=self.config.dpo.output_dir,
                remove_unused_columns=False,
                report_to="none", # Changed report_to to "none"
            ),
            beta=self.config.dpo.beta,
            train_dataset=dataset,
            tokenizer=tokenizer,
            max_length=self.config.model.max_seq_length,
            max_prompt_length=self.config.model.max_seq_length // 2,
        )

        trainer.train()
        self.logger.info("DPO training completed")

        trainer.save_model()
        self.logger.info(f"DPO model saved to {self.config.dpo.output_dir}")

        return trainer

In [None]:
# ================================
# Hub Service
# ================================
class HubService:
    """HuggingFace Hub service"""

    def __init__(self, config: PipelineConfig, logger: LoggingService):
        self.config = config
        self.logger = logger
        self.api = HfApi(token=self.config.hub.token)

    def push_to_hub(self, model_path: str):
        """Push model to HuggingFace Hub"""
        self.logger.info(f"Pushing model to Hub: {self.config.hub.repo_id}")

        try:
            # Create repository if it doesn't exist
            create_repo(
                repo_id=self.config.hub.repo_id,
                private=self.config.hub.private,
                token=self.config.hub.token,
                exist_ok=True
            )

            # Upload model files
            self.api.upload_folder(
                folder_path=model_path,
                repo_id=self.config.hub.repo_id,
                commit_message=self.config.hub.commit_message,
                token=self.config.hub.token
            )

            self.logger.info("Model successfully pushed to Hub")

        except Exception as e:
            self.logger.error(f"Failed to push to Hub: {str(e)}")
            raise

In [None]:
# ================================
# Main Pipeline Orchestrator
# ================================

class VoiceFinetuningPipeline:
    """Main pipeline orchestrator"""

    def __init__(self, config_path: Optional[str] = None):
        # Load configuration
        if config_path and os.path.exists(config_path):
            with open(config_path, 'r') as f:
                config_dict = json.load(f)
                self.config = PipelineConfig(**config_dict)
        else:
            self.config = PipelineConfig()

        # Initialize services
        self.logger = LoggingService()
        self.data_service = DataService(self.logger)
        self.model_service = ModelService(self.config, self.logger)
        self.sft_service = SFTService(self.config, self.logger)
        self.dpo_service = DPOService(self.config, self.logger)
        self.hub_service = HubService(self.config, self.logger)

    def save_config(self, path: str = "pipeline_config.json"):
        """Save current configuration to file"""
        from dataclasses import asdict # Import asdict
        with open(path, 'w') as f:
            json.dump(asdict(self.config), f, indent=2, default=str)
        self.logger.info(f"Configuration saved to {path}")

    def run_pipeline(self):
        """Run the complete fine-tuning pipeline"""
        self.logger.info("Starting Voice Fine-tuning Pipeline")

        # Initialize W&B if enabled
        # Removed wandb.init()

        try:
            # Step 1: Load base model
            self.model_service.load_base_model()
            self.model_service.prepare_for_training()

            # Step 2: SFT Training
            self.logger.info("=" * 50)
            self.logger.info("PHASE 1: Supervised Fine-tuning")
            self.logger.info("=" * 50)

            sft_dataset = self.data_service.load_sft_dataset()
            sft_trainer = self.sft_service.train(
                self.model_service.model,
                self.model_service.tokenizer,
                sft_dataset
            )

            # Step 3: DPO Training
            self.logger.info("=" * 50)
            self.logger.info("PHASE 2: Direct Preference Optimization")
            self.logger.info("=" * 50)

            dpo_dataset = self.data_service.load_dpo_dataset()
            dpo_trainer = self.dpo_service.train(
                self.model_service.model,
                self.model_service.tokenizer,
                dpo_dataset
            )

            # Step 4: Push to Hub
            self.logger.info("=" * 50)
            self.logger.info("PHASE 3: Publishing to HuggingFace Hub")
            self.logger.info("=" * 50)

            self.hub_service.push_to_hub(self.config.dpo.output_dir)

            self.logger.info("Pipeline completed successfully!")

        except Exception as e:
            self.logger.error(f"Pipeline failed: {str(e)}")
            raise
        finally:
            # Removed wandb.finish()
            pass # Added pass to keep finally block

In [None]:
"""Main calling points"""
def create_default_config():
    """Create and save default configuration"""
    pipeline = VoiceFinetuningPipeline()
    pipeline.save_config()
    print("✅ Default configuration saved to pipeline_config.json")
    return pipeline

def run_pipeline_with_defaults():
    """Run pipeline with default configuration - Perfect for Jupyter"""
    print("🎙️ Starting Voice Fine-tuning Pipeline with Default Settings")
    print("=" * 60)

    try:
        pipeline = VoiceFinetuningPipeline()
        pipeline.run_pipeline()
        print("🎉 Pipeline completed successfully!")
        return pipeline
    except Exception as e:
        print(f"❌ Pipeline failed: {str(e)}")
        raise

def run_pipeline_with_config(config_path: str = "pipeline_config.json"):
    """Run pipeline with specified config file"""
    print(f"🎙️ Starting Voice Fine-tuning Pipeline with config: {config_path}")
    print("=" * 60)

    try:
        pipeline = VoiceFinetuningPipeline(config_path)
        pipeline.run_pipeline()
        print("🎉 Pipeline completed successfully!")
        return pipeline
    except Exception as e:
        print(f"❌ Pipeline failed: {str(e)}")
        raise

def run_quick_demo():
    """Quick demo with minimal training steps - Great for testing"""
    print("🎯 Running Quick Demo Mode")
    print("=" * 30)

    # Create quick demo config
    demo_config = PipelineConfig()
    demo_config.sft.max_steps = 10
    demo_config.dpo.max_steps = 5
    demo_config.sft.per_device_train_batch_size = 1
    demo_config.dpo.per_device_train_batch_size = 1
    demo_config.hub.repo_id = "demo/voice-assistant-test"

    try:
        pipeline = VoiceFinetuningPipeline()
        pipeline.config = demo_config
        pipeline.run_pipeline()
        print("🎉 Demo completed successfully!")
        return pipeline
    except Exception as e:
        print(f"❌ Demo failed: {str(e)}")
        raise

def setup_custom_pipeline(**kwargs):
    """Setup custom pipeline with keyword arguments"""
    print("🔧 Setting up Custom Pipeline Configuration")
    print("=" * 45)

    config = PipelineConfig()

    # Update config with provided kwargs
    for key, value in kwargs.items():
        if hasattr(config, key):
            setattr(config, key, value)
        else:
            # Handle nested config updates
            if '.' in key:
                section, param = key.split('.', 1)
                if hasattr(config, section):
                    section_config = getattr(config, section)
                    if hasattr(section_config, param):
                        setattr(section_config, param, value)

    try:
        pipeline = VoiceFinetuningPipeline()
        pipeline.config = config
        pipeline.run_pipeline()
        print("🎉 Custom pipeline completed successfully!")
        return pipeline
    except Exception as e:
        print(f"❌ Custom pipeline failed: {str(e)}")
        raise
def main():
    """Main function called"""
    return run_pipeline_with_defaults()

In [None]:
# # Quick start - runs immediately with defaults
# pipeline = main()

# # Or run specific functions:
# pipeline = run_pipeline_with_defaults()

# Quick demo for testing
pipeline = run_quick_demo()

# Custom parameters inline
# pipeline = setup_custom_pipeline(
#     hub_repo_id="rahulsamant37/my-voice-model",
#     sft_max_steps=50,
#     dpo_max_steps=25
# )

2025-06-21 18:27:13,507 - voice_pipeline - INFO - Starting Voice Fine-tuning Pipeline
INFO:voice_pipeline:Starting Voice Fine-tuning Pipeline
2025-06-21 18:27:13,509 - voice_pipeline - INFO - Loading base model: unsloth/llama-3-8b-bnb-4bit
INFO:voice_pipeline:Loading base model: unsloth/llama-3-8b-bnb-4bit


🎯 Running Quick Demo Mode
==((====))==  Unsloth 2025.6.4: Fast Llama patching. Transformers: 4.52.4.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


2025-06-21 18:27:15,572 - voice_pipeline - ERROR - Pipeline failed: LlamaForCausalLM.__init__() got an unexpected keyword argument 'llm_int8_enable_fp32_cpu_offload'
ERROR:voice_pipeline:Pipeline failed: LlamaForCausalLM.__init__() got an unexpected keyword argument 'llm_int8_enable_fp32_cpu_offload'


❌ Demo failed: LlamaForCausalLM.__init__() got an unexpected keyword argument 'llm_int8_enable_fp32_cpu_offload'


TypeError: LlamaForCausalLM.__init__() got an unexpected keyword argument 'llm_int8_enable_fp32_cpu_offload'