# üöÄ V5 Optimized Training - Qwen2.5-0.5B

## Improvements:
- ‚úÖ **3 epochs** (vs 1 epoch in v4)
- ‚úÖ **Lower LR** (5e-6 vs 2e-5) - h·ªçc ch·∫≠m, ·ªïn ƒë·ªãnh h∆°n
- ‚úÖ **Cosine scheduler** v·ªõi warmup d√†i h∆°n
- ‚úÖ **LoRA rank 64** (vs 32) - model capacity cao h∆°n
- ‚úÖ **Gradient clipping** - stability
- ‚úÖ **Smart eval** every 2000 steps

## Expected Results:
- **Target**: 60-65% accuracy (+9-14% vs v4's 51%)
- **Training time**: ~4-5 hours on T4 GPU
- **Memory**: Fits in T4 15GB VRAM

## Step 1: Check GPU

In [None]:
!nvidia-smi

## Step 2: Install Dependencies

In [None]:
!pip install -q transformers datasets peft accelerate bitsandbytes sentencepiece

## Step 3: Clone Repository

In [None]:
!git clone https://github.com/phucfix/medical-data-mining.git
%cd medical-data-mining

## Step 4: Pull Latest Code (if already cloned)

In [None]:
# Uncomment if you already cloned and need to update
# %cd medical-data-mining
# !git pull origin main

## Step 5: Verify Training Data

In [None]:
import json

# Check training data
with open('data/slm_train_style_adapted.jsonl', 'r') as f:
    train_data = [json.loads(line) for line in f if line.strip()]

print(f"üìä Training samples: {len(train_data):,}")
print(f"\nüìù Sample data:")
print(json.dumps(train_data[0], indent=2, ensure_ascii=False))

# Check validation data if exists
try:
    with open('data/slm_val.jsonl', 'r') as f:
        val_data = [json.loads(line) for line in f if line.strip()]
    print(f"\n‚úÖ Validation samples: {len(val_data):,}")
except:
    print("\n‚ö†Ô∏è  No validation data found (will train without eval)")

## Step 6: Start Training üî•

**IMPORTANT**: This will take ~4-5 hours on T4 GPU

In [None]:
!python src/train_slm_qwen_lora_v5_optimized.py

## Step 7: Check Training Metrics

In [None]:
with open('models/qwen2.5-0.5b-med-slm-lora-v5-optimized/metrics.json', 'r') as f:
    metrics = json.load(f)

print("üìä Training Metrics:")
print("=" * 50)
for key, value in metrics.items():
    print(f"{key}: {value:.4f}")
print("=" * 50)

## Step 8: Download Trained Model

In [None]:
# Zip the model
!zip -r qwen2.5-0.5b-med-slm-lora-v5-optimized.zip models/qwen2.5-0.5b-med-slm-lora-v5-optimized

# Download
from google.colab import files
files.download('qwen2.5-0.5b-med-slm-lora-v5-optimized.zip')

## Step 9: Test Model on Test_sample.v1.0.csv üß™

Test the trained model and compare with v4-chunked baseline

In [None]:
# Run test script
!python src/test_qwen_on_sample_v4.py --version v5-optimized

## Step 10: Analyze Test Results üìä

Load and visualize the test results

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Load test results
with open('data/test_sample_v5-optimized_results.json', 'r') as f:
    results = json.load(f)

# Display metrics
print("=" * 60)
print("üéØ V5 OPTIMIZED TEST RESULTS")
print("=" * 60)
print(f"\nüìä Overall Performance:")
print(f"   Accuracy: {results['accuracy']:.2%}")
print(f"   Correct: {results['correct']}/{results['total']}")
print(f"\nüìà Detailed Metrics:")
print(f"   Precision: {results['precision']:.4f}")
print(f"   Recall: {results['recall']:.4f}")
print(f"   F1-Score: {results['f1']:.4f}")

# Confusion Matrix
print(f"\nüî¢ Confusion Matrix:")
cm = results['confusion_matrix']
print(f"   True Positives:  {cm['TP']}")
print(f"   False Positives: {cm['FP']}")
print(f"   True Negatives:  {cm['TN']}")
print(f"   False Negatives: {cm['FN']}")

# Prediction Distribution
pred_dist = results['prediction_distribution']
actual_dist = results['actual_distribution']
print(f"\nüìä Prediction Distribution:")
print(f"   Predicted TRUE:  {pred_dist['TRUE']} ({pred_dist['TRUE']/results['total']*100:.1f}%)")
print(f"   Predicted FALSE: {pred_dist['FALSE']} ({pred_dist['FALSE']/results['total']*100:.1f}%)")
print(f"\nüìä Actual Distribution:")
print(f"   Actual TRUE:  {actual_dist['TRUE']} ({actual_dist['TRUE']/results['total']*100:.1f}%)")
print(f"   Actual FALSE: {actual_dist['FALSE']} ({actual_dist['FALSE']/results['total']*100:.1f}%)")
print("=" * 60)

# Visualize Confusion Matrix
plt.figure(figsize=(8, 6))
cm_matrix = [[cm['TP'], cm['FP']], [cm['FN'], cm['TN']]]
sns.heatmap(cm_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Pred TRUE', 'Pred FALSE'],
            yticklabels=['Actual TRUE', 'Actual FALSE'])
plt.title(f'V5 Optimized - Confusion Matrix\nAccuracy: {results["accuracy"]:.2%}')
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.tight_layout()
plt.savefig('v5_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()
print("\n‚úÖ Confusion matrix saved as 'v5_confusion_matrix.png'")

## Step 11: Compare v4 vs v5 Performance üìà

Compare the improvements from v4-chunked to v5-optimized

In [None]:
# Comparison data (v4-chunked baseline)
v4_accuracy = 0.5120  # 51.20%
v4_precision = 0.5308
v4_recall = 0.6899
v4_f1 = 0.6000

# v5 results
v5_accuracy = results['accuracy']
v5_precision = results['precision']
v5_recall = results['recall']
v5_f1 = results['f1']

# Calculate improvements
acc_improvement = (v5_accuracy - v4_accuracy) * 100
prec_improvement = (v5_precision - v4_precision) * 100
recall_improvement = (v5_recall - v4_recall) * 100
f1_improvement = (v5_f1 - v4_f1) * 100

# Display comparison
print("=" * 70)
print("üìä V4 vs V5 COMPARISON")
print("=" * 70)
print(f"\n{'Metric':<15} {'v4-chunked':<15} {'v5-optimized':<15} {'Improvement':<15}")
print("-" * 70)
print(f"{'Accuracy':<15} {v4_accuracy:<15.2%} {v5_accuracy:<15.2%} {acc_improvement:>+13.2f}%")
print(f"{'Precision':<15} {v4_precision:<15.4f} {v5_precision:<15.4f} {prec_improvement:>+13.2f}%")
print(f"{'Recall':<15} {v4_recall:<15.4f} {v5_recall:<15.4f} {recall_improvement:>+13.2f}%")
print(f"{'F1-Score':<15} {v4_f1:<15.4f} {v5_f1:<15.4f} {f1_improvement:>+13.2f}%")
print("=" * 70)

# Visualize comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart comparison
metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
v4_values = [v4_accuracy*100, v4_precision*100, v4_recall*100, v4_f1*100]
v5_values = [v5_accuracy*100, v5_precision*100, v5_recall*100, v5_f1*100]

x = range(len(metrics))
width = 0.35

axes[0].bar([i - width/2 for i in x], v4_values, width, label='v4-chunked', color='skyblue')
axes[0].bar([i + width/2 for i in x], v5_values, width, label='v5-optimized', color='lightcoral')
axes[0].set_ylabel('Score (%)')
axes[0].set_title('V4 vs V5 Performance Comparison')
axes[0].set_xticks(x)
axes[0].set_xticklabels(metrics, rotation=45, ha='right')
axes[0].legend()
axes[0].grid(axis='y', alpha=0.3)

# Add value labels on bars
for i, (v4, v5) in enumerate(zip(v4_values, v5_values)):
    axes[0].text(i - width/2, v4 + 1, f'{v4:.1f}', ha='center', va='bottom', fontsize=9)
    axes[0].text(i + width/2, v5 + 1, f'{v5:.1f}', ha='center', va='bottom', fontsize=9)

# Improvement chart
improvements = [acc_improvement, prec_improvement, recall_improvement, f1_improvement]
colors = ['green' if imp > 0 else 'red' for imp in improvements]
axes[1].barh(metrics, improvements, color=colors, alpha=0.7)
axes[1].set_xlabel('Improvement (%)')
axes[1].set_title('V5 Improvements over V4')
axes[1].axvline(x=0, color='black', linestyle='-', linewidth=0.5)
axes[1].grid(axis='x', alpha=0.3)

# Add value labels
for i, imp in enumerate(improvements):
    axes[1].text(imp, i, f' {imp:+.2f}%', va='center', 
                ha='left' if imp > 0 else 'right', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('v4_vs_v5_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("\n‚úÖ Comparison chart saved as 'v4_vs_v5_comparison.png'")

# Summary
print(f"\nüéØ SUMMARY:")
if v5_accuracy >= 0.60:
    print(f"   ‚úÖ SUCCESS! Achieved {v5_accuracy:.2%} accuracy (target: 60-65%)")
    print(f"   ‚úÖ Improved by {acc_improvement:+.2f}% from v4's {v4_accuracy:.2%}")
elif v5_accuracy > v4_accuracy:
    print(f"   ‚ö†Ô∏è  Partial success: {v5_accuracy:.2%} accuracy")
    print(f"   üìà Improved by {acc_improvement:+.2f}% but below 60% target")
else:
    print(f"   ‚ùå No improvement: {v5_accuracy:.2%} vs v4's {v4_accuracy:.2%}")
    print(f"   üìâ Need to investigate training issues")

## Step 12: Error Analysis üîç

Analyze prediction errors to understand model weaknesses

In [None]:
# Load test data with predictions
import csv

test_data = []
with open('Test_sample.v1.0.csv', 'r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        test_data.append(row)

# Match predictions from results file
predictions = results['predictions']

# Analyze errors
errors = []
for i, (sample, pred) in enumerate(zip(test_data, predictions)):
    actual = sample['answer']
    if pred != actual:
        errors.append({
            'index': i,
            'question': sample['question'],
            'actual': actual,
            'predicted': pred,
            'error_type': f"False {'Positive' if pred == 'ƒê√∫ng' else 'Negative'}"
        })

print("=" * 80)
print(f"üîç ERROR ANALYSIS")
print("=" * 80)
print(f"\nüìä Total Errors: {len(errors)} / {len(test_data)} ({len(errors)/len(test_data)*100:.1f}%)")

# Error type distribution
false_positives = sum(1 for e in errors if e['error_type'] == 'False Positive')
false_negatives = sum(1 for e in errors if e['error_type'] == 'False Negative')

print(f"\nüìà Error Distribution:")
print(f"   False Positives: {false_positives} ({false_positives/len(errors)*100:.1f}%)")
print(f"   False Negatives: {false_negatives} ({false_negatives/len(errors)*100:.1f}%)")

# Show sample errors
print(f"\nüî¥ Sample False Positives (predicted ƒê√∫ng, actually Sai):")
print("-" * 80)
fp_samples = [e for e in errors if e['error_type'] == 'False Positive'][:5]
for i, err in enumerate(fp_samples, 1):
    print(f"\n{i}. Question: {err['question'][:100]}...")
    print(f"   Predicted: {err['predicted']} | Actual: {err['actual']}")

print(f"\nüî¥ Sample False Negatives (predicted Sai, actually ƒê√∫ng):")
print("-" * 80)
fn_samples = [e for e in errors if e['error_type'] == 'False Negative'][:5]
for i, err in enumerate(fn_samples, 1):
    print(f"\n{i}. Question: {err['question'][:100]}...")
    print(f"   Predicted: {err['predicted']} | Actual: {err['actual']}")

print("\n" + "=" * 80)

# Save detailed error analysis
error_df = pd.DataFrame(errors)
error_df.to_csv('v5_error_analysis.csv', index=False, encoding='utf-8')
print(f"\n‚úÖ Detailed error analysis saved to 'v5_error_analysis.csv'")

## Expected Improvements

| Version | Epochs | LR | LoRA Rank | Accuracy | Notes |
|---------|--------|-----|-----------|----------|-------|
| v4-chunked | 1 | 2e-5 | 32 | 51% | Baseline |
| v5-optimized | 3 | 5e-6 | 64 | **60-65%** | +9-14% expected |

## Key Changes:
1. **More training**: 3 epochs vs 1 epoch ‚Üí better learning
2. **Stable learning**: Lower LR (5e-6) + cosine scheduler ‚Üí avoid overfitting
3. **More capacity**: LoRA rank 64 ‚Üí can learn more complex patterns
4. **Better regularization**: Gradient clipping + longer warmup