In [1]:
import torch 
import unsloth
import bitsandbytes

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [2]:
if torch.cuda.is_available():
    print(" PyTorch can access the GPU.")
    print(f"   - GPU Device: {torch.cuda.get_device_name(0)}")
else:
    print(" PyTorch cannot access the GPU. Please check your CUDA installation.")

 PyTorch can access the GPU.
   - GPU Device: NVIDIA GeForce RTX 3070 Laptop GPU


In [3]:
print(f"✅ PyTorch version:  {torch.__version__}")
print(f"✅ Unsloth version:  {unsloth.__version__}")
print(f"✅ Bitsandbytes version:  {bitsandbytes.__version__}")
print("\nYour environment is ready for GPU fine-tuning!")

✅ PyTorch version:  2.7.1+cu126
✅ Unsloth version:  2025.7.7
✅ Bitsandbytes version:  0.46.1

Your environment is ready for GPU fine-tuning!


In [None]:
import logging
from pathlib import Path
from typing import Dict, List, Any, Optional
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets
from transformers import AutoTokenizer

class GemmaDatasetBuilder:
    """
    Professional dataset builder for Gemma 3 4B fine-tuning.
    Uses verified, large-scale datasets for general chat, emotion,
    and therapeutic dialogue.
    
    Features:
    - Domain tag injection for style switching
    - Stage-wise curriculum support
    """

    SOURCES: Dict[str, str] = {
        # General instruction-following
        "openorca": "Open-Orca/OpenOrca",
        # Open-domain multi-turn chat
        "ultrachat": "HuggingFaceH4/ultrachat_200k",
        # Safety/preference alignment
        "ultrafeedback": "HuggingFaceH4/ultrafeedback_binarized",
        # Emotion understanding
        "emotion": "dair-ai/emotion",
        # Empathetic multi-turn counseling
        "empathetic_dialogues": "facebook/empathetic_dialogues",
        # Benchmark mental-health counseling
        "mentalchat16k": "ShenLab/MentalChat16K",
        # Real counselor Q&A
        "counsel_chat": "nbertagnolli/counsel-chat"
    }

    # Define training phases for curriculum learning
    TRAINING_PHASES = {
        "phase1": {
            "name": "General & Emotion",
            "sources": ["openorca", "ultrachat", "emotion"],
            "description": "Basic conversational and emotional understanding"
        },
        "phase2": {
            "name": "Therapeutic",
            "sources": ["empathetic_dialogues", "mentalchat16k", "counsel_chat"],
            "description": "Empathetic counseling and therapeutic dialogue"
        }
    }

    def __init__(
        self,
        max_tokens: int = 4096,
        train_split: float = 0.9,
        seed: int = 42,
        enable_domain_tags: bool = True
    ):
        self.max_tokens = max_tokens
        self.train_split = train_split
        self.seed = seed
        self.enable_domain_tags = enable_domain_tags
        self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it")
        self.logger = self._setup_logging()
        self._prepared_data: Optional[Dataset] = None

    def _setup_logging(self):
        logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
        return logging.getLogger(__name__)

    def _tagged(self, text: str, domain: str) -> str:
        """Prepend a domain tag to any instruction for style switching."""
        if self.enable_domain_tags:
            return f"<domain>{domain}</domain>\n{text}"
        return text

    def _load_and_prepare_sources(self) -> Dataset:
        """
        Loads all unique datasets, processes them into a unified format,
        and concatenates them. This is the efficient core of the builder.
        """
        if self._prepared_data is not None:
            return self._prepared_data

        self.logger.info("Loading and preparing all source datasets...")
        
        source_map = {
            "openorca": self._load_openorca,
            "ultrachat": self._load_ultrachat,
            "ultrafeedback": self._load_ultrafeedback,
            "emotion": self._load_emotion,
            "empathetic_dialogues": self._load_empathetic,
            "mentalchat16k": self._load_mentalchat,
            "counsel_chat": self._load_counsel
        }

        all_sources = set(s for phase in self.TRAINING_PHASES.values() for s in phase["sources"])
        
        parts: List[Dataset] = []
        for source_name in all_sources:
            if source_name in source_map:
                try:
                    self.logger.info(f"Processing source: {source_name}")
                    dataset = source_map[source_name]()
                    if dataset:
                        parts.append(dataset)
                except Exception as e:
                    self.logger.error(f"Failed to load or process source {source_name}: {e}", exc_info=True)
            else:
                self.logger.warning(f"No loader found for source: {source_name}")

        if not parts:
            raise ValueError("No datasets could be loaded. Aborting.")

        self.logger.info("Combining all datasets...")
        combined = concatenate_datasets(parts).shuffle(seed=self.seed)
        self._prepared_data = combined
        return self._prepared_data

    def _format_conversation(self, ex: Dict[str, Any]) -> Dict[str, Any]:
        ins = ex["instruction"].strip()
        rsp = ex["response"].strip()
        
        return {
            "text": f"<start_of_turn>user\n{ins}<end_of_turn>\n<start_of_turn>model\n{rsp}<end_of_turn>",
            "source": ex.get("source"), 
            "domain": ex.get("domain")
        }

    def _filter_and_format_dataset(self, dataset: Dataset) -> Dataset:
        """Filters out long sequences and formats the text."""
        
        # 1. Format the text first
        formatted_ds = dataset.map(self._format_conversation, num_proc=4)
        
        # 2. Filter out samples that are too long
        original_size = len(formatted_ds)
        self.logger.info(f"Filtering dataset of size {original_size} to keep samples under {self.max_tokens} tokens...")
        
        filtered_ds = formatted_ds.filter(
            lambda x: len(self.tokenizer.encode(x["text"])) <= self.max_tokens,
            num_proc=4
        )
        
        new_size = len(filtered_ds)
        if original_size > new_size:
            self.logger.info(f"Removed {original_size - new_size} samples exceeding max token length.")
            
        return filtered_ds

    def build_complete_dataset(self) -> DatasetDict:
        """Load, combine, format, filter, and split all verified datasets."""
        self.logger.info("Building complete dataset with domain tags...")
        
        prepared_data = self._load_and_prepare_sources()
        final_data = self._filter_and_format_dataset(prepared_data)

        total = len(final_data)
        if total == 0:
            raise ValueError("The dataset is empty after filtering. Check max_tokens or source data.")
            
        train_size = int(total * self.train_split)
        
        return DatasetDict(
            train=final_data.select(range(train_size)),
            validation=final_data.select(range(train_size, total))
        )

    def build_phase_dataset(self, phase: str) -> DatasetDict:
        """Build dataset for a specific training phase using pre-loaded data."""
        if phase not in self.TRAINING_PHASES:
            raise ValueError(f"Unknown phase: {phase}. Available: {list(self.TRAINING_PHASES.keys())}")
        
        phase_info = self.TRAINING_PHASES[phase]
        self.logger.info(f"Building dataset for {phase_info['name']} phase...")
        
        prepared_data = self._load_and_prepare_sources()
        
        phase_sources = phase_info["sources"]
        phase_data = prepared_data.filter(lambda x: x['source'] in phase_sources, num_proc=4)
        
        if len(phase_data) == 0:
            raise ValueError(f"No data found for sources in phase {phase}: {phase_sources}")

        final_data = self._filter_and_format_dataset(phase_data)

        total = len(final_data)
        if total == 0:
            raise ValueError(f"The dataset for phase '{phase}' is empty after filtering.")

        train_size = int(total * self.train_split)
        
        return DatasetDict(
            train=final_data.select(range(train_size)),
            validation=final_data.select(range(train_size, total))
        )

    def _load_openorca(self) -> Dataset:
        ds = load_dataset(self.SOURCES["openorca"], split="train[:50000]")
        ds = ds.filter(lambda ex: ex.get("question") and ex.get("response"))
        return ds.map(lambda ex: {
            "instruction": self._tagged(ex["question"].strip(), "general"),
            "response": ex["response"].strip(),
            "source": "openorca",
            "domain": "general"
        }, num_proc=4)

    def _load_ultrachat(self) -> Dataset:
        ds = load_dataset(self.SOURCES["ultrachat"], split="train_sft[:25000]")
        def extract(item):
            msgs = item.get("messages", [])
            if len(msgs) >= 2 and msgs[0].get("content") and msgs[1].get("content"):
                return {"instruction": self._tagged(msgs[0]["content"].strip(), "chat"), "response": msgs[1]["content"].strip()}
            return {"instruction": None, "response": None}
        
        ds = ds.map(extract, num_proc=4).filter(lambda x: x["instruction"] is not None)
        return ds.add_column("source", ["ultrachat"] * len(ds)).add_column("domain", ["chat"] * len(ds))

    def _load_ultrafeedback(self) -> Dataset:
        ds = load_dataset(self.SOURCES["ultrafeedback"], split="train_prefs[:10000]")
        def extract(item):
            prompt = item.get("prompt", "").strip()
            chosen = item.get("chosen", [])
            if prompt and chosen:
                response_content = (chosen[-1].get("content") if isinstance(chosen[-1], dict) else str(chosen[-1]))
                if response_content:
                    return {"instruction": self._tagged(prompt, "safety"), "response": response_content.strip()}
            return {"instruction": None, "response": None}

        ds = ds.map(extract, num_proc=4).filter(lambda x: x["instruction"] is not None)
        return ds.add_column("source", ["ultrafeedback"] * len(ds)).add_column("domain", ["safety"] * len(ds))

    def _load_emotion(self) -> Dataset:
        ds = load_dataset(self.SOURCES["emotion"], split="train[:5000]")
        ds = ds.filter(lambda ex: ex.get("text") is not None)
        labels = ds.features['label'].names
        def extract(ex):
            txt = ex["text"].strip()
            emo = labels[ex["label"]]
            ins = f"What emotion is expressed in the following text?\n'{txt}'"
            rsp = f"The primary emotion expressed is {emo}."
            return {"instruction": self._tagged(ins, "emotion"), "response": rsp}
        ds = ds.map(extract, num_proc=4, remove_columns=ds.column_names)
        return ds.add_column("source", ["emotion"] * len(ds)).add_column("domain", ["emotion"] * len(ds))

    def _load_empathetic(self) -> Dataset:
        ds = load_dataset(self.SOURCES["empathetic_dialogues"], split="train[:15000]")
        ds = ds.filter(lambda ex: ex.get("prompt") and ex.get("utterance"))
        ds = ds.map(lambda ex: {
            "instruction": self._tagged(ex["prompt"].strip(), "therapeutic"), 
            "response": ex["utterance"].strip(), 
        }, num_proc=4, remove_columns=ds.column_names)
        return ds.add_column("source", ["empathetic_dialogues"] * len(ds)).add_column("domain", ["therapeutic"] * len(ds))

    def _load_mentalchat(self) -> Dataset:
        ds = load_dataset(self.SOURCES["mentalchat16k"], split="train[:16000]")
        ds = ds.filter(lambda ex: ex.get("instruction") and ex.get("chosen"))
        ds = ds.map(lambda ex: {
            "instruction": self._tagged(ex["instruction"].strip(), "therapeutic"),
            "response": (ex["chosen"].strip() if isinstance(ex["chosen"], str) else ex["chosen"][0].strip()),
        }, num_proc=4, remove_columns=ds.column_names)
        return ds.add_column("source", ["mentalchat16k"] * len(ds)).add_column("domain", ["therapeutic"] * len(ds))

    def _load_counsel(self) -> Dataset:
        ds = load_dataset(self.SOURCES["counsel_chat"], split="train")
        ds = ds.filter(lambda ex: ex.get("questionText") and ex.get("answerText"))
        ds = ds.map(lambda ex: {
            "instruction": self._tagged(ex["questionText"].strip(), "therapeutic"),
            "response": ex["answerText"].strip(),
        }, num_proc=4, remove_columns=ds.column_names)
        return ds.add_column("source", ["counsel_chat"] * len(ds)).add_column("domain", ["therapeutic"] * len(ds))

    def save_phase_datasets(self, output_dir: str = "./phase_datasets"):
        """Build and save individual phase datasets and the complete dataset."""
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True, parents=True)
        
        self.logger.info(f"Saving all datasets to {output_dir}...")
        
        for phase_name in self.TRAINING_PHASES:
            self.logger.info(f"Creating and saving '{phase_name}'...")
            phase_dataset = self.build_phase_dataset(phase_name)
            phase_path = output_path / phase_name
            phase_dataset.save_to_disk(str(phase_path))
            self.logger.info(f"Saved '{phase_name}' with {len(phase_dataset['train'])} train and {len(phase_dataset['validation'])} validation samples.")
        
        self.logger.info("Creating and saving 'complete' dataset...")
        complete_dataset = self.build_complete_dataset()
        complete_path = output_path / "complete"
        complete_dataset.save_to_disk(str(complete_path))
        self.logger.info(f"Saved 'complete' with {len(complete_dataset['train'])} train and {len(complete_dataset['validation'])} validation samples.")
        
        self.logger.info("All datasets saved successfully!")

    def print_phase_info(self):
        """Print information about training phases."""
        print("\n" + "=" * 60)
        print("Training Phase Information")
        print("=" * 60)
        for phase_name, phase_info in self.TRAINING_PHASES.items():
            print(f"\n{phase_name.upper()} ({phase_info['name']})")
            print(f"  Purpose: {phase_info['description']}")
            print(f"  Sources: {', '.join(phase_info['sources'])}")
        print("=" * 60)

def get_training_commands(base_dir: str = "./phase_datasets", output_dir: str = "./checkpoints"):
    """Return the training commands for each phase."""
    commands = {
        "Phase 1 (General & Emotion)":
            f"python train_sft.py --data_path {base_dir}/phase1 --output_dir {output_dir}/phase1 --epochs 1",
        "Phase 2 (Therapeutic)":
            f"python train_sft.py --data_path {base_dir}/phase2 --resume_from {output_dir}/phase1 --output_dir {output_dir}/phase2 --epochs 1"
    }
    return commands

def main():
    """
    Main function to demonstrate the dataset building process.
    """
    builder = GemmaDatasetBuilder(enable_domain_tags=True)
    builder.print_phase_info()
    builder.save_phase_datasets("./gemma_finetune_datasets")
    
    print("\n" + "=" * 60)
    print("Stage-wise Training Commands")
    print("=" * 60)
    commands = get_training_commands("./gemma_finetune_datasets")
    for phase_name, cmd in commands.items():
        print(f"\n# {phase_name}")
        print(cmd)
    print("=" * 60)

if __name__ == "__main__":
    
    main()

2025-07-23 17:47:34,605 - INFO - Saving all datasets to ./gemma_finetune_datasets...
2025-07-23 17:47:34,606 - INFO - Creating and saving 'phase1'...
2025-07-23 17:47:34,606 - INFO - Building dataset for General & Emotion phase...
2025-07-23 17:47:34,607 - INFO - Loading and preparing all source datasets...
2025-07-23 17:47:34,607 - INFO - Processing source: emotion



Training Phase Information

PHASE1 (General & Emotion)
  Purpose: Basic conversational and emotional understanding
  Sources: openorca, ultrachat, emotion

PHASE2 (Therapeutic)
  Purpose: Empathetic counseling and therapeutic dialogue
  Sources: empathetic_dialogues, mentalchat16k, counsel_chat


Map (num_proc=4):   0%|          | 0/5000 [00:00<?, ? examples/s]

2025-07-23 17:47:43,823 - INFO - Processing source: ultrachat


Map (num_proc=4):   0%|          | 0/25000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/24998 [00:00<?, ? examples/s]

2025-07-23 17:47:59,686 - INFO - Processing source: openorca


Map (num_proc=4):   0%|          | 0/50000 [00:00<?, ? examples/s]

2025-07-23 17:48:08,058 - INFO - Processing source: counsel_chat
Repo card metadata block was not found. Setting CardData to empty.


Map (num_proc=4):   0%|          | 0/2612 [00:00<?, ? examples/s]

2025-07-23 17:48:15,240 - INFO - Processing source: mentalchat16k
2025-07-23 17:48:18,291 - INFO - Processing source: empathetic_dialogues


Map (num_proc=4):   0%|          | 0/15000 [00:00<?, ? examples/s]

2025-07-23 17:48:24,728 - INFO - Combining all datasets...


Filter (num_proc=4):   0%|          | 0/97610 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/79998 [00:00<?, ? examples/s]

2025-07-23 17:48:34,941 - INFO - Filtering dataset of size 79998 to keep samples under 4096 tokens...


Filter (num_proc=4):   0%|          | 0/79998 [00:00<?, ? examples/s]

2025-07-23 17:48:56,747 - INFO - Removed 115 samples exceeding max token length.


Saving the dataset (0/1 shards):   0%|          | 0/71894 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7989 [00:00<?, ? examples/s]

2025-07-23 17:48:59,012 - INFO - Saved 'phase1' with 71894 train and 7989 validation samples.
2025-07-23 17:48:59,012 - INFO - Creating and saving 'phase2'...
2025-07-23 17:48:59,012 - INFO - Building dataset for Therapeutic phase...


Filter (num_proc=4):   0%|          | 0/97610 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/17612 [00:00<?, ? examples/s]

2025-07-23 17:49:06,271 - INFO - Filtering dataset of size 17612 to keep samples under 4096 tokens...


Filter (num_proc=4):   0%|          | 0/17612 [00:00<?, ? examples/s]

2025-07-23 17:49:11,340 - INFO - Removed 7 samples exceeding max token length.


Saving the dataset (0/1 shards):   0%|          | 0/15844 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1761 [00:00<?, ? examples/s]

2025-07-23 17:49:11,735 - INFO - Saved 'phase2' with 15844 train and 1761 validation samples.
2025-07-23 17:49:11,736 - INFO - Creating and saving 'complete' dataset...
2025-07-23 17:49:11,737 - INFO - Building complete dataset with domain tags...


Map (num_proc=4):   0%|          | 0/97610 [00:00<?, ? examples/s]

2025-07-23 17:49:20,195 - INFO - Filtering dataset of size 97610 to keep samples under 4096 tokens...


Filter (num_proc=4):   0%|          | 0/97610 [00:00<?, ? examples/s]

2025-07-23 17:49:43,740 - INFO - Removed 122 samples exceeding max token length.


Saving the dataset (0/1 shards):   0%|          | 0/87739 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/9749 [00:00<?, ? examples/s]

2025-07-23 17:49:46,506 - INFO - Saved 'complete' with 87739 train and 9749 validation samples.
2025-07-23 17:49:46,507 - INFO - All datasets saved successfully!



Stage-wise Training Commands

# Phase 1 (General & Emotion)
python train_sft.py --data_path ./gemma_finetune_datasets/phase1 --output_dir ./checkpoints/phase1 --epochs 1

# Phase 2 (Therapeutic)
python train_sft.py --data_path ./gemma_finetune_datasets/phase2 --resume_from ./checkpoints/phase1 --output_dir ./checkpoints/phase2 --epochs 1


In [4]:
import argparse
import os
import torch
from datasets import load_from_disk
from unsloth import FastLanguageModel, is_bfloat16_supported
from transformers import TrainingArguments
from trl import SFTTrainer
from peft import LoraConfig, TaskType


In [9]:
!python run_sft.py \
  --model_name "google/gemma-3-4b-it" \
  --data_path ./gemma_finetune_datasets/phase1 \
  --output_dir ./checkpoints/phase1 \
  --epochs 1 \
  --batch_size 2 \
  --grad_accum 4 \
  --report_to "wandb"

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Using dtype: torch.bfloat16
==((====))==  Unsloth 2025.7.7: Fast Gemma3 patching. Transformers: 4.53.3.
   \\   /|    NVIDIA GeForce RTX 3070 Laptop GPU. Num GPUs = 1. Max memory: 7.632 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
model.safetensors:   2%|▎                  | 71.9M/4.56G [00:38<22:38, 3.31MB/s]^C
Cancellation requested; stopping current tasks.
Traceback (most recent call last):
  File "/home/sanj-ai/miniconda3/envs/unsloth_env/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 627, in xet_get
    download_files(
KeyboardInterrupt

During handling of th

usage: ipykernel_launcher.py [-h] [--model_name MODEL_NAME] --data_path
                             DATA_PATH --output_dir OUTPUT_DIR
                             [--resume_from RESUME_FROM] [--epochs EPOCHS]
                             [--batch_size BATCH_SIZE]
                             [--grad_accum GRAD_ACCUM] [--lr LR]
                             [--max_len MAX_LEN] [--seed SEED]
                             [--report_to REPORT_TO]
ipykernel_launcher.py: error: the following arguments are required: --data_path, --output_dir


SystemExit: 2