In [None]:
!pip install bitsandbytes
!pip install git+https://github.com/huggingface/transformers.git
!pip install git+https://github.com/huggingface/peft.git
!pip install git+https://github.com/huggingface/accelerate.git
!pip install datasets

Collecting bitsandbytes
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-c

In [None]:
import torch
import pandas as pd
import numpy as np
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig, DataCollatorForLanguageModeling,
    TrainingArguments, Trainer
)
from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model
)
from huggingface_hub import login
from google.colab import drive
import os
from sklearn.model_selection import train_test_split

In [None]:
# Login and mount drive
login()
drive.mount('/content/drive')

# Model configuration
hf_model = "epfl-llm/meditron-7b"

Mounted at /content/drive


In [None]:
# Enhanced quantization config for better performance
bits_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    hf_model,
    quantization_config=bits_config,
    device_map="auto",
    trust_remote_code=True,
)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

# Enhanced LoRA configuration for medical domain
lora_config = LoraConfig(
    r=32,  # Increased rank for better capacity
    lora_alpha=64,  # Increased alpha for stronger adaptation
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],  # More comprehensive targeting
    lora_dropout=0.1,  # Slightly higher dropout for regularization
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.config.use_cache = False

def print_trainable_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")

print_trainable_parameters(model)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/4.08k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.85M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/344 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/736 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/610 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/1.92G [00:00<?, ?B/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/1.90G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/1.84G [00:00<?, ?B/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/262M [00:00<?, ?B/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/1.84G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/1.90G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Trainable: 79,953,920 / 3,580,506,112 (2.23%)


In [None]:
# Load your processed FAERS dataset
df = pd.read_csv("/content/drive/MyDrive/custom_training_data.csv")
print(f"Loaded dataset with {len(df)} examples")

# Enhanced data preprocessing
def create_enhanced_training_text(row):
    """Create enhanced training text with better structure"""
    input_text = row['input']
    output_text = row['output']

    # Parse patient info if available
    age_info = ""
    sex_info = ""

    if 'AGE_' in input_text:
        age_match = input_text.split('AGE_')[1].split()[0]
        try:
            age_years = float(age_match) * 100  # Convert normalized age back
            age_info = f"Age: {age_years:.0f} years"
        except:
            pass

    if 'SEX_M' in input_text:
        sex_info = "Sex: Male"
    elif 'SEX_F' in input_text:
        sex_info = "Sex: Female"

    # Clean medication list
    medications = []
    med_parts = input_text.split()
    for part in med_parts:
        if '_PS' in part:
            med_name = part.replace('_PS', '')
            medications.append(f"{med_name} (Primary Suspect)")
        elif '_SS' in part:
            med_name = part.replace('_SS', '')
            medications.append(f"{med_name} (Secondary Suspect)")
        elif '_C' in part and not part.startswith('AGE_') and not part.startswith('SEX_'):
            med_name = part.replace('_C', '')
            medications.append(f"{med_name} (Concomitant)")
        elif '_I' in part:
            medications.append("Drug Interaction Potential")

    patient_info = " | ".join(filter(None, [age_info, sex_info]))
    medications_text = ", ".join(medications) if medications else "No medications specified"

    return f"""### Instruction:
You are a clinical pharmacovigilance AI. Predict adverse drug reactions based on patient demographics and medications with their clinical significance levels.

Medication significance:
- Primary Suspect (PS): Most likely causative drug
- Secondary Suspect (SS): Possibly causative drug
- Concomitant (C): Concurrent medication
- Drug Interaction: Potential interaction between drugs

Classify reactions by severity:
- Severe: Life-threatening or serious conditions
- Moderate: Significant symptoms requiring medical attention
- Mild: Minor symptoms or laboratory abnormalities

### Input:
Patient Information: {patient_info if patient_info else "Not specified"}
Medications: {medications_text}

### Output:
{output_text}"""

# Apply enhanced preprocessing
df['text'] = df.apply(create_enhanced_training_text, axis=1)

# ADVANCED SAMPLING STRATEGIES - Choose based on your priorities

def stratified_severity_sample(df, sample_size=25000):
    """Strategy 1: Balanced severity sampling (recommended for medical AI)"""
    severe_examples = df[df['output'].str.contains('Severe:', case=False, na=False)]
    moderate_examples = df[df['output'].str.contains('Moderate:', case=False, na=False)]
    mild_examples = df[df['output'].str.contains('Mild:', case=False, na=False)]

    print(f"üìä Original distribution:")
    print(f"   Severe: {len(severe_examples):,} ({len(severe_examples)/len(df)*100:.1f}%)")
    print(f"   Moderate: {len(moderate_examples):,} ({len(moderate_examples)/len(df)*100:.1f}%)")
    print(f"   Mild: {len(mild_examples):,} ({len(mild_examples)/len(df)*100:.1f}%)")

    # Clinical priority sampling: Over-represent severe cases
    severe_sample = min(len(severe_examples), max(int(sample_size * 0.25), 2000))
    moderate_sample = min(len(moderate_examples), max(int(sample_size * 0.35), 5000))
    mild_sample = min(len(mild_examples), sample_size - severe_sample - moderate_sample)

    print(f"üéØ Sampling strategy (Clinical Priority):")
    print(f"   Severe: {severe_sample:,} ({severe_sample/sample_size*100:.1f}%)")
    print(f"   Moderate: {moderate_sample:,} ({moderate_sample/sample_size*100:.1f}%)")
    print(f"   Mild: {mild_sample:,} ({mild_sample/sample_size*100:.1f}%)")

    return pd.concat([
        severe_examples.sample(n=severe_sample, random_state=42) if len(severe_examples) >= severe_sample else severe_examples,
        moderate_examples.sample(n=moderate_sample, random_state=42) if len(moderate_examples) >= moderate_sample else moderate_examples,
        mild_examples.sample(n=mild_sample, random_state=42) if len(mild_examples) >= mild_sample else mild_examples
    ]).reset_index(drop=True)

def drug_significance_sample(df, sample_size=25000):
    """Strategy 2: Sample based on drug significance (PS > SS > C)"""
    ps_examples = df[df['input'].str.contains('_PS', case=False, na=False)]
    ss_examples = df[df['input'].str.contains('_SS', case=False, na=False)]
    c_examples = df[df['input'].str.contains('_C', case=False, na=False)]

    print(f"üìä Drug significance distribution:")
    print(f"   Primary Suspect cases: {len(ps_examples):,}")
    print(f"   Secondary Suspect cases: {len(ss_examples):,}")
    print(f"   Concomitant only cases: {len(c_examples):,}")

    # Prioritize cases with Primary/Secondary Suspects
    ps_sample = min(len(ps_examples), int(sample_size * 0.4))
    ss_sample = min(len(ss_examples), int(sample_size * 0.4))
    c_sample = min(len(c_examples), sample_size - ps_sample - ss_sample)

    return pd.concat([
        ps_examples.sample(n=ps_sample, random_state=42) if len(ps_examples) >= ps_sample else ps_examples,
        ss_examples.sample(n=ss_sample, random_state=42) if len(ss_examples) >= ss_sample else ss_examples,
        c_examples.sample(n=c_sample, random_state=42) if len(c_examples) >= c_sample else c_examples
    ]).reset_index(drop=True)

def comprehensive_sample(df, sample_size=25000):
    """Strategy 3: Multi-criteria sampling (severity + drug significance + demographics)"""

    # First, separate by severity
    severe_df = df[df['output'].str.contains('Severe:', case=False, na=False)]
    moderate_df = df[df['output'].str.contains('Moderate:', case=False, na=False)]
    mild_df = df[df['output'].str.contains('Mild:', case=False, na=False)]

    # Within each severity, sample by drug significance
    def sample_by_drug_sig(subset_df, n_samples):
        if len(subset_df) <= n_samples:
            return subset_df

        ps_subset = subset_df[subset_df['input'].str.contains('_PS')]
        ss_subset = subset_df[subset_df['input'].str.contains('_SS')]
        c_subset = subset_df[subset_df['input'].str.contains('_C')]

        ps_n = min(len(ps_subset), int(n_samples * 0.4))
        ss_n = min(len(ss_subset), int(n_samples * 0.4))
        c_n = min(len(c_subset), n_samples - ps_n - ss_n)

        return pd.concat([
            ps_subset.sample(n=ps_n, random_state=42) if ps_n > 0 else pd.DataFrame(),
            ss_subset.sample(n=ss_n, random_state=42) if ss_n > 0 else pd.DataFrame(),
            c_subset.sample(n=c_n, random_state=42) if c_n > 0 else pd.DataFrame()
        ])

    # Sample each severity group
    severe_sample = sample_by_drug_sig(severe_df, int(sample_size * 0.25))
    moderate_sample = sample_by_drug_sig(moderate_df, int(sample_size * 0.35))
    mild_sample = sample_by_drug_sig(mild_df, sample_size - len(severe_sample) - len(moderate_sample))

    print(f"üéØ Comprehensive sampling:")
    print(f"   Severe (multi-criteria): {len(severe_sample):,}")
    print(f"   Moderate (multi-criteria): {len(moderate_sample):,}")
    print(f"   Mild (multi-criteria): {len(mild_sample):,}")

    return pd.concat([severe_sample, moderate_sample, mild_sample]).reset_index(drop=True)

# CHOOSE YOUR SAMPLING STRATEGY:
print("üîç Available sampling strategies:")
print("1. stratified_severity_sample() - Balances severe/moderate/mild (recommended)")
print("2. drug_significance_sample() - Prioritizes PS/SS cases")
print("3. comprehensive_sample() - Multi-criteria approach")
print()
TRAINING_SIZE = 25000

# Use the recommended strategy
df_sampled = stratified_severity_sample(df, TRAINING_SIZE)

# Smart sampling strategy based on dataset size
print(f"Full dataset size: {len(df):,} examples")

# Option 1: Conservative sampling for faster training
  # Adjust this based on your compute budget

# Option 2: Progressive training sizes (uncomment to use)
# TRAINING_SIZE = 15000  # Quick prototype
# TRAINING_SIZE = 30000  # Balanced training
# TRAINING_SIZE = 50000  # Full training (longer)
df_sampled = stratified_severity_sample(df, TRAINING_SIZE)
print(f"Using {len(df_sampled):,} examples for training ({len(df_sampled)/len(df)*100:.1f}% of full dataset)")

# Create stratified train/val/test splits
train_df, temp_df = train_test_split(df_sampled, test_size=0.3, random_state=42, stratify=df_sampled['output'].str.contains('Severe:', case=False, na=False))
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")

# Convert to datasets
train_dataset = Dataset.from_pandas(train_df[["text"]])
val_dataset = Dataset.from_pandas(val_df[["text"]])
test_dataset = Dataset.from_pandas(test_df[["text"]])

# Save test set
test_df.to_csv("/content/drive/MyDrive/meditron7b_faers_test_set.csv", index=False)

# Enhanced tokenization with dynamic length
def get_optimal_max_length(texts, tokenizer, percentile=95):
    """Determine optimal max length based on data distribution"""
    lengths = [len(tokenizer.encode(text)) for text in texts[:1000]]  # Sample for efficiency
    optimal_length = int(np.percentile(lengths, percentile))
    return min(optimal_length, 1024)  # Cap at reasonable limit

max_length = get_optimal_max_length(train_df['text'].tolist(), tokenizer)
print(f"Using max_length: {max_length}")

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
    )

Loaded dataset with 215770 examples
üîç Available sampling strategies:
1. stratified_severity_sample() - Balances severe/moderate/mild (recommended)
2. drug_significance_sample() - Prioritizes PS/SS cases
3. comprehensive_sample() - Multi-criteria approach

üìä Original distribution:
   Severe: 21,839 (10.1%)
   Moderate: 61,752 (28.6%)
   Mild: 198,912 (92.2%)
üéØ Sampling strategy (Clinical Priority):
   Severe: 6,250 (25.0%)
   Moderate: 8,750 (35.0%)
   Mild: 10,000 (40.0%)
Full dataset size: 215,770 examples
üìä Original distribution:
   Severe: 21,839 (10.1%)
   Moderate: 61,752 (28.6%)
   Mild: 198,912 (92.2%)
üéØ Sampling strategy (Clinical Priority):
   Severe: 6,250 (25.0%)
   Moderate: 8,750 (35.0%)
   Mild: 10,000 (40.0%)
Using 25,000 examples for training (11.6% of full dataset)
Train: 17500, Validation: 3750, Test: 3750
Using max_length: 458


In [None]:
# Tokenize datasets
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_val = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Enhanced data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Optimized training arguments for resuming from checkpoint
output_dir = "/content/drive/MyDrive/meditron7b_faers_efficient"
checkpoint_path = os.path.join(output_dir, "checkpoint-350")
os.makedirs(output_dir, exist_ok=True)

# FAST RESUME CONFIGURATION - Optimized for speed
print(f"üöÄ Fast Resume Training Configuration:")
print(f"   Resuming from: checkpoint-250")
print(f"   Dataset size: {len(df_sampled):,} examples")
print(f"   Remaining epochs: Optimized for quick completion")
print(f"   Focus: Speed over extensive training")

training_args = TrainingArguments(
    output_dir=output_dir,

    # SPEED OPTIMIZATIONS
    per_device_train_batch_size=8,  # Increased for faster processing
    gradient_accumulation_steps=2,  # Reduced for faster updates
    learning_rate=1e-4,  # Slightly higher for faster convergence
    weight_decay=0.01,
    warmup_steps=50,  # Minimal warmup since resuming
    lr_scheduler_type="linear",  # Faster than cosine

    # MINIMAL EPOCHS for quick completion
    num_train_epochs=1,  # Complete remaining training quickly
    max_steps=500,  # Hard limit - will stop at 500 steps from resume point

    # AGGRESSIVE CHECKPOINTING for safety
    logging_steps=25,  # More frequent logging
    save_steps=500,  # More frequent saves
    eval_steps=500,  # More frequent evaluation
    eval_strategy="steps",
    save_strategy="steps",
    save_total_limit=3,  # Keep more checkpoints for safety

    # Model selection
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    # MAXIMUM SPEED OPTIMIZATIONS
    fp16=True,  # Mixed precision for speed
    dataloader_num_workers=4,  # More workers for data loading
    dataloader_pin_memory=True,
    group_by_length=True,  # Group similar lengths for efficiency

    # Hub settings - disabled for speed
    push_to_hub=False,  # Disable to save time
    report_to=[],  # No logging to external services
    seed=42,

    # Memory and compute optimizations
    remove_unused_columns=True,
    ddp_find_unused_parameters=False,
    prediction_loss_only=True,  # Faster evaluation

    # Skip time-consuming operations
    skip_memory_metrics=True,
)

# Minimal trainer setup for speed
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    # Removed early stopping for faster completion
)

# Check if checkpoint exists
if os.path.exists(checkpoint_path):
    print(f"‚úÖ Found checkpoint at {checkpoint_path}")
    print("üèÉ‚Äç‚ôÇÔ∏è Starting FAST resume training...")
    trainer.train(resume_from_checkpoint=checkpoint_path)
else:
    print(f"‚ùå Checkpoint not found at {checkpoint_path}")
    print("üèÉ‚Äç‚ôÇÔ∏è Starting fresh training...")
    trainer.train()

print("üéâ Training completed!")

# Quick save of final model
final_model_path = os.path.join(output_dir, "final_model")
trainer.save_model(final_model_path)
print(f"üíæ Model saved to {final_model_path}")

Map:   0%|          | 0/17500 [00:00<?, ? examples/s]

Map:   0%|          | 0/3750 [00:00<?, ? examples/s]

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


üöÄ Fast Resume Training Configuration:
   Resuming from: checkpoint-250
   Dataset size: 25,000 examples
   Remaining epochs: Optimized for quick completion
   Focus: Speed over extensive training
‚úÖ Found checkpoint at /content/drive/MyDrive/meditron7b_faers_efficient/checkpoint-350
üèÉ‚Äç‚ôÇÔ∏è Starting FAST resume training...


	logging_steps: 25 (from args) != 50 (from trainer_state.json)
	eval_steps: 500 (from args) != 50 (from trainer_state.json)
	save_steps: 500 (from args) != 50 (from trainer_state.json)
	per_device_train_batch_size: 8 (from args) != 4 (from trainer_state.json)
  return fn(*args, **kwargs)


Step,Training Loss,Validation Loss
400,0.2208,0.211918
450,0.2089,0.2081


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


Step,Training Loss,Validation Loss
400,0.2208,0.211918
450,0.2089,0.2081
500,0.2069,0.205595




üéâ Training completed!
üíæ Model saved to /content/drive/MyDrive/meditron7b_faers_efficient/final_model


In [None]:
# Save final model
trainer.save_model(f"{output_dir}/final_model")
tokenizer.save_pretrained(f"{output_dir}/final_model")

# Enhanced inference function
def generate_adverse_reactions(patient_info, medications, model=model, tokenizer=tokenizer):
    """Generate adverse reaction predictions with enhanced formatting"""

    prompt = f"""### Instruction:
You are a clinical pharmacovigilance AI. Predict adverse drug reactions based on patient demographics and medications with their clinical significance levels.

Medication significance:
- Primary Suspect (PS): Most likely causative drug
- Secondary Suspect (SS): Possibly causative drug
- Concomitant (C): Concurrent medication
- Drug Interaction: Potential interaction between drugs

Classify reactions by severity:
- Severe: Life-threatening or serious conditions
- Moderate: Significant symptoms requiring medical attention
- Mild: Minor symptoms or laboratory abnormalities

### Input:
Patient Information: {patient_info}
Medications: {medications}

### Output:
"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract only the output portion
    if "### Output:" in response:
        response = response.split("### Output:")[-1].strip()

    return response

# Test the enhanced model
test_example = generate_adverse_reactions(
    patient_info="Age: 65 years | Sex: Male",
    medications="WARFARIN (Primary Suspect), ASPIRIN (Concomitant), METFORMIN (Concomitant)"
)

print("Test Example Output:")
print(test_example)

print("\nüéâ Enhanced training complete!")
print(f"üìä Training Statistics:")
print(f"   - Total examples trained: {len(train_df):,}")
print(f"   - Validation examples: {len(val_df):,}")
print(f"   - Test examples: {len(test_df):,}")
print(f"   - Max sequence length: {max_length}")
print(f"   - Model saved to: {output_dir}")