In [1]:
!pip install unsloth

Collecting unsloth
  Downloading unsloth-2025.3.9-py3-none-any.whl.metadata (59 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.3/59.3 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting unsloth_zoo>=2025.3.8 (from unsloth)
  Downloading unsloth_zoo-2025.3.8-py3-none-any.whl.metadata (16 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting bitsandbytes (from unsloth)
  Downloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting triton>=3.0.0 (from unsloth)
  Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting tyro (from unsloth)
  Downloading tyro-0.9.16-py3-none-any.whl.metadata (9.4 kB)
Collecting transformers!=4.47.0,>=4.46.1 (from unsloth)
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━

In [2]:
%%writefile my_script.py

import os
import torch
import pandas as pd
import numpy as np
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix

# Import Unsloth and required trainer components
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import DataCollatorForSeq2Seq

# Set seed for reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Set up distributed training
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

# Initialize the distributed environment
def setup_ddp():
    # Check if we're running under torchrun/torch.distributed.launch
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ['LOCAL_RANK'])
    else:
        # For Kaggle with 2 GPUs, we can set these manually
        world_size = torch.cuda.device_count()
        rank = 0  # Master process
        local_rank = 0
    
    if world_size > 1:
        dist.init_process_group(backend='nccl', init_method='env://')
        torch.cuda.set_device(local_rank)
        print(f"Initialized process {rank}/{world_size} with local_rank {local_rank}")
    
    return rank, world_size, local_rank

# Configuration
model_name = "unsloth/gemma-2b-bnb-4bit"  # Pre-quantized Gemma model
output_dir = "./gemma-9b-survey-finetuned"
max_seq_length = 1400  # Context window size
load_in_4bit = True

# Setup distributed training
rank, world_size, local_rank = setup_ddp()
is_main_process = (rank == 0)

# Load your dataset (only on main process to avoid duplicating work)
if is_main_process:
    print("Loading data...")
    data_path = "/kaggle/input/updateddatasethackathon/trainDataUpdatedFlag.csv" 
    df = pd.read_csv(data_path)
    print(f"Loaded {len(df)} rows of data")

    # Check for NaN values in the Quality Flag column and drop them
    nan_count = df["OE_Quality_Flag"].isna().sum()
    print(f"Found {nan_count} NaN values in 'OE_Quality_Flag' column")
    if nan_count > 0:
        original_count = len(df)
        df = df.dropna(subset=["OE_Quality_Flag"])
        print(f"Dropped {original_count - len(df)} rows with NaN values, {len(df)} rows remaining")

    # OPTIMIZATION 4: Sample a subset of the data for faster training
    # Limit to 2000 examples while preserving class distribution
    if len(df) > 2000:
        # Get indices for each class
        positive_indices = df[df["OE_Quality_Flag"] == 1].index.tolist()
        negative_indices = df[df["OE_Quality_Flag"] == 0].index.tolist()
        
        # Calculate how many samples to take from each class (maintain proportion)
        positive_ratio = len(positive_indices) / len(df)
        positive_count = int(2000 * positive_ratio)
        negative_count = 2000 - positive_count
        
        # Sample indices
        sampled_positive = np.random.choice(positive_indices, min(positive_count, len(positive_indices)), replace=False)
        sampled_negative = np.random.choice(negative_indices, min(negative_count, len(negative_indices)), replace=False)
        
        # Combine and filter the dataframe
        sampled_indices = np.concatenate([sampled_positive, sampled_negative])
        df = df.loc[sampled_indices].reset_index(drop=True)
        print(f"Sampled down to {len(df)} examples for faster training")

    # Function to convert a row to text (excluding specified columns)
    def row_to_text(row, exclude_columns=["OE_Quality_Flag", "Unique ID", "Start Date", "End Date"]):
        details = []
        for col in df.columns:
            if col in exclude_columns:
                continue
            details.append(f"{col}: {row[col]}")
        return "\n".join(details)

    # Prepare dataset texts and labels
    print("Preparing dataset...")
    texts = df.apply(row_to_text, axis=1).tolist()
    labels = df["OE_Quality_Flag"].astype(int).tolist()

    # Class distribution
    positive_samples = sum(labels)
    negative_samples = len(labels) - positive_samples
    print(f"Class distribution: Positive={positive_samples} ({positive_samples/len(labels):.2%}), Negative={negative_samples} ({negative_samples/len(labels):.2%})")

    # Split data into training and validation sets
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=0.2, random_state=seed, stratify=labels
    )
    print(f"Data split: {len(train_texts)} training samples, {len(val_texts)} validation samples")
else:
    train_texts, val_texts, train_labels, val_labels = None, None, None, None

# Broadcast data to all processes if using multi-GPU
if world_size > 1:
    if rank == 0:
        # Package data for broadcasting
        data_package = {
            'train_texts': train_texts,
            'val_texts': val_texts,
            'train_labels': train_labels,
            'val_labels': val_labels
        }
    else:
        data_package = None
    
    # Broadcast from rank 0 to all other processes
    data_package = [data_package]
    torch.distributed.broadcast_object_list(data_package, src=0)
    data_package = data_package[0]
    
    # Unpack data
    train_texts = data_package['train_texts']
    val_texts = data_package['val_texts']
    train_labels = data_package['train_labels']
    val_labels = data_package['val_labels']

# Load model and tokenizer using Unsloth (set device appropriately)
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    load_in_4bit=load_in_4bit,
    device_map={"": local_rank} if world_size > 1 else "auto",
)

# OPTIMIZATION 1: Set up LoRA fine-tuning with reduced parameters
model = FastLanguageModel.get_peft_model(
    model,
    r=8,  # Reduced from 16 to 8
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Reduced target modules
    lora_alpha=16,  # Reduced from 32 to 16
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing="unsloth",  # Uses less VRAM
    random_state=seed,
)

# Apply the Gemma chat template
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template="gemma",  # Use Gemma's chat template
)

# Prepare the conversations for Gemma format
dataset_list = []
for text, label in zip(train_texts, train_labels):
    conversation = [
        {"role": "user", "content": f"Analyze the following survey response and determine if it is a quality response.\n\n{text}"},
        {"role": "assistant", "content": f"The quality flag for this survey is: {label}"}
    ]
    dataset_list.append({"conversations": conversation})

val_dataset_list = []
for text, label in zip(train_texts, train_labels):
    quality_text = "not flagged (good quality)" if label == 0 else "flagged (poor quality)"
    conversation = [
        {"role": "user", "content": f"Analyze the following survey response and determine if it is a quality response.\n\n{text}"},
        {"role": "assistant", "content": f"The quality flag for this survey is: {label} which means this response is {quality_text}."}
    ]
    val_dataset_list.append({"conversations": conversation})

# Create datasets
train_dataset = Dataset.from_list(dataset_list)
val_dataset = Dataset.from_list(val_dataset_list)

# Standardize format
from unsloth.chat_templates import standardize_sharegpt
train_dataset = standardize_sharegpt(train_dataset)
val_dataset = standardize_sharegpt(val_dataset)

# Create formatting function
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [
        tokenizer.apply_chat_template(
            convo, tokenize=False, add_generation_prompt=False
        )
        for convo in convos
    ]
    return {"text": texts}

# Apply formatting
train_dataset = train_dataset.map(
    formatting_prompts_func,
    batched=True,
    num_proc=4,  # OPTIMIZATION 4: Use multiprocessing for data prep
)

val_dataset = val_dataset.map(
    formatting_prompts_func,
    batched=True,
    num_proc=4,  # OPTIMIZATION 4: Use multiprocessing for data prep
)

# Print sample to verify formatting (only on main process)
if is_main_process:
    print("\nSample formatted instruction:")
    print(train_dataset[0]["text"][:500] + "...")

# Create a custom callback to print losses and evaluation metrics
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from typing import Dict

class MetricsLoggingCallback(TrainerCallback):
    def __init__(self, eval_dataset, tokenizer, model, val_texts, val_labels):
        self.eval_dataset = eval_dataset
        self.tokenizer = tokenizer
        self.model = model
        self.val_texts = val_texts
        self.val_labels = val_labels
        self.best_accuracy = 0.0
        self.best_f1 = 0.0
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        # Only log on the main process
        if not dist.is_initialized() or dist.get_rank() == 0:
            if logs is not None:
                if "loss" in logs:
                    print(f"Step {state.global_step}: Training Loss = {logs['loss']:.4f}")
                if "eval_loss" in logs:
                    print(f"Step {state.global_step}: Validation Loss = {logs['eval_loss']:.4f}")
                
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        # Only run evaluation on the main process
        if not dist.is_initialized() or dist.get_rank() == 0:
            # Run evaluation on a sample of validation data
            print(f"\n===== Evaluation at step {state.global_step} =====")
            eval_subset = min(len(self.val_texts), 20)  # Start with a smaller subset
            
            predictions = []
            for i in range(eval_subset):
                text = self.val_texts[i]
                pred = self.predict_sample(text)
                predictions.append(pred)
            
            # Calculate metrics
            eval_labels = self.val_labels[:eval_subset]
            accuracy = accuracy_score(eval_labels, predictions)
            f1 = f1_score(eval_labels, predictions)
            precision = precision_score(eval_labels, predictions, zero_division=0)
            recall = recall_score(eval_labels, predictions, zero_division=0)
            conf_matrix = confusion_matrix(eval_labels, predictions)
            
            # Update best scores
            if accuracy > self.best_accuracy:
                self.best_accuracy = accuracy
            if f1 > self.best_f1:
                self.best_f1 = f1
            
            # Print metrics
            print(f"Evaluating on {eval_subset} samples:")
            print(f"Accuracy: {accuracy:.4f} (Best: {self.best_accuracy:.4f})")
            print(f"F1 Score: {f1:.4f} (Best: {self.best_f1:.4f})")
            print(f"Precision: {precision:.4f}")
            print(f"Recall: {recall:.4f}")
            print(f"Confusion Matrix:")
            print(conf_matrix)
            
            # Store metrics in logs for potential saving
            if metrics is not None:
                metrics["eval_accuracy"] = accuracy
                metrics["eval_f1"] = f1
                metrics["eval_precision"] = precision
                metrics["eval_recall"] = recall
            
            print("=" * 50)
        
    def predict_sample(self, text):
        messages = [
            {"role": "user", "content": f"Analyze the following survey response and determine if it is a quality response.\n\n{text}"}
        ]
        inputs = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(self.model.device)
        
        # Generate prediction
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs,
                max_new_tokens=20,
                temperature=0.1,
                use_cache=True
            )
        
        # Decode and extract prediction
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Try different ways to extract the prediction (0 or 1)
        try:
            if "quality flag" in generated_text.lower() and ":" in generated_text:
                # Extract the number after "quality flag:"
                prediction_text = generated_text.lower().split("quality flag")[-1].split(":")[1].strip()
                prediction = int(prediction_text[0])  # Take first digit
            elif "0" in generated_text:
                prediction = 0
            elif "1" in generated_text:
                prediction = 1
            else:
                prediction = 0  # Default
        except:
            prediction = 0  # Default if parsing fails
        
        return prediction

# OPTIMIZATION 6: Add early stopping callback
from transformers import EarlyStoppingCallback

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

# Define helper function to check bfloat16 support
def is_bfloat16_supported():
    return hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()

# Set up training arguments with distributed training options
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=False,  # Explicitly disable fp16
    bf16=is_bfloat16_supported(),  # Use bf16 if supported
    logging_steps=10,  # Log less frequently (changed from 1)
    logging_first_step=True,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    save_steps=100,
    eval_strategy="steps",  # Using new parameter name instead of deprecated evaluation_strategy
    eval_steps=20,  # Evaluate less frequently (changed from 5)
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    do_eval=True,
    report_to="none",
    seed=seed,
    # Distributed training specific args
    local_rank=local_rank,
    ddp_find_unused_parameters=False,
    ddp_bucket_cap_mb=25,
    dataloader_pin_memory=False,
    # OPTIMIZATION 4: Add data loading optimizations
    dataloader_num_workers=4,  # Use multiple workers for data loading
)

# Initialize our custom callback - only used on main process
if is_main_process:
    custom_callback = [
        MetricsLoggingCallback(
            eval_dataset=val_dataset,
            tokenizer=tokenizer,
            model=model,
            val_texts=val_texts,
            val_labels=val_labels
        ),
        early_stopping_callback  # OPTIMIZATION 6: Add early stopping
    ]
else:
    custom_callback = [early_stopping_callback]  # OPTIMIZATION 6: Add early stopping on all processes

# Create trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
    dataset_num_proc=4,  # OPTIMIZATION 4: Use multiprocessing
    packing=False,
    args=training_args,
    callbacks=custom_callback,
)

# Set up to train on assistant responses only
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    # Adjust these based on actual template inspection
    instruction_part="<start_of_turn>user",
    response_part="<start_of_turn>model",
)

# Train the model
if is_main_process:
    print("\nStarting training...")
trainer_stats = trainer.train()

# Print final training metrics (only on main process)
if is_main_process:
    print("\nTraining completed!")
    print(f"Final training loss: {trainer_stats.training_loss}")

    # Convert model to inference mode
    FastLanguageModel.for_inference(model)

    # Save the fine-tuned model
    print("\nSaving model...")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")

    # Comprehensive evaluation function to evaluate on the full validation set
    def evaluate_model_full():
        print("\n===== FINAL MODEL EVALUATION =====")
        
        # Set model to evaluation mode
        model.eval()
        
        # Function to predict
        def predict(text):
            messages = [
                {"role": "user", "content": f"Analyze the following survey response and determine if it is a quality response.\n\n{text}"}
            ]
            inputs = tokenizer.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(model.device)
            
            # Generate prediction
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=inputs,
                    max_new_tokens=20,
                    temperature=0.1,
                    use_cache=True
                )
            
            # Decode and extract prediction
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # More robust prediction extraction
            try:
                if "quality flag" in generated_text.lower() and ":" in generated_text:
                    prediction_text = generated_text.lower().split("quality flag")[-1].split(":")[1].strip()
                    prediction = int(prediction_text[0])
                elif "0" in generated_text:
                    prediction = 0
                elif "1" in generated_text:
                    prediction = 1
                else:
                    prediction = 0
            except:
                prediction = 0
                
            return prediction, generated_text
        
        # Evaluate on full validation set
        print(f"Evaluating on all {len(val_texts)} validation samples...")
        val_predictions = []
        all_generated_texts = []
        
        for idx, text in enumerate(val_texts):
            pred, gen_text = predict(text)
            val_predictions.append(pred)
            all_generated_texts.append(gen_text)
            
            # Print progress every 5 samples
            if (idx + 1) % 5 == 0 or idx == len(val_texts) - 1:
                print(f"Processed {idx + 1}/{len(val_texts)} samples...")
        
        # Calculate metrics
        accuracy = accuracy_score(val_labels, val_predictions)
        f1 = f1_score(val_labels, val_predictions)
        precision = precision_score(val_labels, val_predictions, zero_division=0)
        recall = recall_score(val_labels, val_predictions, zero_division=0)
        conf_matrix = confusion_matrix(val_labels, val_predictions)
        
        # Print detailed metrics
        print("\n===== FINAL EVALUATION RESULTS =====")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"F1 Score: {f1:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"Confusion Matrix:")
        print(conf_matrix)
        
        # Calculate class-specific metrics
        tn, fp, fn, tp = conf_matrix.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        
        print(f"\nTrue Positives: {tp}")
        print(f"True Negatives: {tn}")
        print(f"False Positives: {fp}")
        print(f"False Negatives: {fn}")
        print(f"Specificity (True Negative Rate): {specificity:.4f}")
        
        # Print some example predictions
        print("\n===== EXAMPLE PREDICTIONS =====")
        for i in range(min(5, len(val_texts))):
            print(f"\nSample {i+1}:")
            print(f"True label: {val_labels[i]}")
            print(f"Predicted: {val_predictions[i]}")
            print(f"Generated text: {all_generated_texts[i][:100]}...")
        
        # Save predictions to CSV for further analysis
        results_df = pd.DataFrame({
            "true_label": val_labels,
            "predicted_label": val_predictions,
            "generated_text": all_generated_texts
        })
        
        results_path = os.path.join(output_dir, "validation_results.csv")
        results_df.to_csv(results_path, index=False)
        print(f"\nDetailed validation results saved to {results_path}")
        
        return {
            "accuracy": accuracy,
            "f1": f1,
            "precision": precision,
            "recall": recall,
            "specificity": specificity,
            "confusion_matrix": conf_matrix
        }

    # Run the comprehensive evaluation
    print("\nRunning final model evaluation...")
    final_metrics = evaluate_model_full()

    print("\nFine-tuning and evaluation complete!")

# Cleanup distributed process group if initialized
if world_size > 1:
    dist.destroy_process_group()


Writing my_script.py


In [3]:
!export CUDA_VISIBLE_DEVICES=0,1

In [4]:
!python -m torch.distributed.launch --nproc_per_node=2 --master_port=29500 my_script.py  # Replace with your actual Python script name

and will be removed in future. Use torchrun.
Note that --use-env is set by default in torchrun.
If your script expects `--local-rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See 
https://pytorch.org/docs/stable/distributed.html#launch-utility for 
further instructions

  main()
W0310 09:39:55.905000 62 torch/distributed/run.py:792] 
W0310 09:39:55.905000 62 torch/distributed/run.py:792] *****************************************
W0310 09:39:55.905000 62 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0310 09:39:55.905000 62 torch/distributed/run.py:792] *****************************************
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
