In [1]:
!pip install -q \
    "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" \
    datasets \
    evaluate \
    rouge-score \
    peft \
    accelerate \
    bitsandbytes \
    transformers \
    trl \
    gradio \
    xformers \
    scikit-learn

import os
import platform
import torch
import gc
import warnings
warnings.filterwarnings('ignore')

if torch.cuda.is_available():
    device = "cuda"
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"  GPU: {gpu_name}")
    print(f"  VRAM: {gpu_memory_gb:.1f} GB")
    print(f"  CUDA: {torch.version.cuda}")
    
    torch.backends.cudnn.benchmark = True
    gc.collect()
    torch.cuda.empty_cache()
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
    torch.cuda.memory.set_per_process_memory_fraction(0.95)
    
    print(f"  Available VRAM: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB")
else:
    device = "cpu"
    print(" No GPU detected - check Runtime > Change runtime type")

print(f"PyTorch: {torch.__version__}\n")

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0mm
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/122.9 MB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m899.7/899.7 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m594.3/594.3 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m154.1 MB/s[0m eta 

In [3]:
import os

hf_token = os.getenv("HF_TOKEN")
wandb_key = os.getenv("WANDB_API_KEY")

print("HF_TOKEN set:", bool(hf_token))
print("WANDB_API_KEY set:", bool(wandb_key))

if hf_token:
    from huggingface_hub import login
    login(token=hf_token)

if wandb_key:
    import wandb
    wandb.login(key=wandb_key)


HF_TOKEN set: True
WANDB_API_KEY set: True


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmanivarshithpc[0m ([33mmanivarshithpc-vignan-institute-of-technology-and-science[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
import random
import numpy as np
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
print(f"✓ Seed set to {SEED}")

✓ Seed set to 42


In [5]:
MODEL_NAME_BASE = "unsloth/DeepSeek-R1-Distill-Llama-8B-bnb-4bit"  

HF_REPO_ID = os.getenv("HF_REPO_ID", "varshith7/deepseek-medical-cot")
WANDB_PROJECT = os.getenv("WANDB_PROJECT", "deepseek-medical-cot")


hyperparameters = {
    # Training schedule
    "num_train_epochs": 3,  
    "train_batch_size": 1,  
    "gradient_accumulation_steps": 8,  
    
    
    "learning_rate": 2e-5,  
    "lr_scheduler_type": "cosine", 
    "warmup_ratio": 0.1,  
    
    # Sequence and optimization
    "max_length": 1024, 
    "weight_decay": 0.01,
    "max_grad_norm": 0.3,  
    
    # Logging and evaluation
    "logging_steps": 5,
    "eval_steps": 50,  
    "save_steps": 100,
    "evaluation_strategy": "steps",
    "save_total_limit": 2,
}

print("="*60)
print(f"Model: {MODEL_NAME_BASE}")
print(f"HF Repo: {HF_REPO_ID}")
print(f"W&B Project: {WANDB_PROJECT}")
print(f"Effective Batch Size: {hyperparameters['train_batch_size'] * hyperparameters['gradient_accumulation_steps']}")
print("="*60)

Model: unsloth/DeepSeek-R1-Distill-Llama-8B-bnb-4bit
HF Repo: varshith7/deepseek-medical-cot
W&B Project: deepseek-medical-cot
Effective Batch Size: 8


In [6]:
from datasets import load_dataset, concatenate_datasets

print("Loading PubMedQA dataset...")
raw_dataset = load_dataset("pubmed_qa", "pqa_labeled")

# Create larger validation set for better evaluation
split_dataset = raw_dataset["train"].train_test_split(test_size=200, seed=SEED)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]

# Enhanced system prompt for medical reasoning
system_prompt = (
    "You are an expert medical AI assistant. Analyze the clinical information carefully, "
    "apply evidence-based reasoning, and provide a clear step-by-step explanation before "
    "giving your final answer."
)

def format_sample_enhanced(example):
    """Enhanced formatting with better structure"""
    context = example.get("context", {})
    question = example.get("question", "")
    long_answer = example.get("long_answer", "")
    
    # Join context if it's a dict with multiple sections
    if isinstance(context, dict):
        context_text = " ".join([f"{k}: {v}" for k, v in context.items()])
    else:
        context_text = str(context)
    
    # DeepSeek-R1 style formatting with chain-of-thought
    formatted_text = f"""### Instruction:
{system_prompt}

### Input:
Context: {context_text}
Question: {question}

### Response:
<think>
Let me analyze this medical question step by step:
1. Understanding the context and key information
2. Applying clinical reasoning
3. Formulating the answer based on evidence
</think>

<answer>
{long_answer}
</answer>"""
    
    return {
        "text": formatted_text,
        "question": question,
        "answer": long_answer
    }

print("Formatting dataset...")
train_dataset = train_dataset.map(
    format_sample_enhanced, 
    remove_columns=train_dataset.column_names,
    num_proc=2  # Parallel processing
)
val_dataset = val_dataset.map(
    format_sample_enhanced, 
    remove_columns=val_dataset.column_names,
    num_proc=2
)

print(f"✓ Train: {len(train_dataset)} | Validation: {len(val_dataset)}")
print("\nSample formatted text:")
print(train_dataset[0]["text"][:500] + "...")

Loading PubMedQA dataset...


README.md: 0.00B [00:00, ?B/s]

pqa_labeled/train-00000-of-00001.parquet:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Formatting dataset...


Map (num_proc=2):   0%|          | 0/800 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/200 [00:00<?, ? examples/s]

✓ Train: 800 | Validation: 200

Sample formatted text:
### Instruction:
You are an expert medical AI assistant. Analyze the clinical information carefully, apply evidence-based reasoning, and provide a clear step-by-step explanation before giving your final answer.

### Input:
Context: contexts: ['Figures from the British Defence Dental Services reveal that serving personnel in the British Army have a persistently lower level of dental fitness than those in the Royal Navy or the Royal Air Force. No research had been undertaken to ascertain if this r...


In [7]:
from transformers import AutoTokenizer

print(f"\nLoading tokenizer from {MODEL_NAME_BASE}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_BASE, use_fast=True)


if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"  Tokenizer loaded")
print(f"  Vocab size: {len(tokenizer)}")
print(f"  PAD token: {tokenizer.pad_token}")


Loading tokenizer from unsloth/DeepSeek-R1-Distill-Llama-8B-bnb-4bit...


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

  Tokenizer loaded
  Vocab size: 128256
  PAD token: <|finetune_right_pad_id|>
