# FlexMatch-lite Training: Dynamic Threshold + Focal Loss

**Mục tiêu:** Cải thiện self-training bằng cách:
1. **Dynamic Threshold theo lớp** - Giúp lớp hiếm được chọn dễ dàng hơn
2. **Focal Loss** - Giảm ảnh hưởng của lớp đa số trong training

**Kỳ vọng:** Cải thiện F1-macro, đặc biệt cho các lớp AQI nguy hiểm (Hazardous, Very Unhealthy)

## 1. Setup & Load Data

In [None]:
# Papermill parameters
BASE_TAU = 0.90
GAMMA = 2.0
MAX_ITER = 10
LABEL_MISSING_FRACTION = 0.95

In [None]:
import sys
import json
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')

# Add src to path
# Robustly find project root
current_dir = Path.cwd()
project_root = current_dir
while not (project_root / 'src').exists():
    if project_root.parent == project_root:
        break # Reached file system root
    project_root = project_root.parent

if (project_root / 'src').exists():
    sys.path.insert(0, str(project_root / 'src'))
    print(f"Added {project_root / 'src'} to path")
else:
    print("Warning: could not find src directory")

from semi_supervised_library import (
    SemiDataConfig,
    FlexMatchConfig,
    SelfTrainingConfig,
    run_self_training,
    run_flexmatch_training,
    AQI_CLASSES
)

print("[OK] Libraries imported successfully")

In [None]:
# Load data
DATA_PATH = Path.cwd().parent / 'data' / 'processed' / 'dataset_for_semi.parquet'
if not DATA_PATH.exists():
    # Try relative path from project root
    DATA_PATH = project_root / 'data' / 'processed' / 'dataset_for_semi.parquet'

df = pd.read_parquet(DATA_PATH)

print(f"Dataset shape: {df.shape}")
print(f"\nClass distribution:")
print(df['aqi_class'].value_counts().sort_index())

# Check labeled fraction
labeled_frac = df['is_labeled'].mean()
print(f"\nLabeled fraction: {labeled_frac:.2%}")

## 2. Experiment 1: Baseline Self-Training (τ=0.90)

In [None]:
# Configuration
data_cfg = SemiDataConfig()
st_cfg = SelfTrainingConfig(tau=0.90, max_iter=MAX_ITER)

print("Running baseline self-training...")
baseline_results = run_self_training(df, data_cfg, st_cfg)

baseline_metrics = baseline_results['test_metrics']
print(f"\n{'='*50}")
print("BASELINE SELF-TRAINING RESULTS")
print(f"{'='*50}")
print(f"Test Accuracy: {baseline_metrics['accuracy']:.4f}")
print(f"Test F1-macro: {baseline_metrics['f1_macro']:.4f}")
print(f"\nIterations: {len(baseline_results['history'])}")
print(f"Total pseudo-labels: {sum([h['new_pseudo'] for h in baseline_results['history']])}")

## 3. Experiment 2: FlexMatch with Dynamic Threshold Only

In [None]:
# Test different base_tau values
flexmatch_dynamic_results = {}

for base_tau in [0.85, 0.90, 0.95]:
    print(f"\nTesting base_tau={base_tau}...")
    
    fm_cfg = FlexMatchConfig(
        base_tau=base_tau,
        gamma=2.0,
        max_iter=MAX_ITER,
        use_focal_loss=False  # Dynamic threshold only
    )
    
    results = run_flexmatch_training(df, data_cfg, fm_cfg)
    flexmatch_dynamic_results[base_tau] = results
    
    metrics = results['test_metrics']
    print(f"  Accuracy: {metrics['accuracy']:.4f}")
    print(f"  F1-macro: {metrics['f1_macro']:.4f}")

print("\n[OK] Dynamic threshold experiments completed")

## 4. Experiment 3: FlexMatch with Focal Loss Only

In [None]:
# Test different gamma values
flexmatch_focal_results = {}

for gamma in [1.0, 2.0, 3.0]:
    print(f"\nTesting gamma={gamma}...")
    
    # Use fixed threshold but with focal loss
    fm_cfg = FlexMatchConfig(
        base_tau=0.90,
        gamma=gamma,
        max_iter=MAX_ITER,
        alpha=1.0,  # No smoothing = fixed threshold
        use_focal_loss=True
    )
    
    results = run_flexmatch_training(df, data_cfg, fm_cfg)
    flexmatch_focal_results[gamma] = results
    
    metrics = results['test_metrics']
    print(f"  Accuracy: {metrics['accuracy']:.4f}")
    print(f"  F1-macro: {metrics['f1_macro']:.4f}")

print("\n[OK] Focal loss experiments completed")

## 5. Experiment 4: FlexMatch Combined (Best Config)

In [None]:
# Best configuration: Dynamic threshold + Focal loss
fm_cfg_best = FlexMatchConfig(
    base_tau=BASE_TAU,
    gamma=GAMMA,
    alpha=0.9,
    max_iter=MAX_ITER,
    use_focal_loss=True
)

print("Running FlexMatch with combined approach...")
flexmatch_combined_results = run_flexmatch_training(df, data_cfg, fm_cfg_best)

combined_metrics = flexmatch_combined_results['test_metrics']
print(f"\n{'='*50}")
print("FLEXMATCH COMBINED RESULTS")
print(f"{'='*50}")
print(f"Test Accuracy: {combined_metrics['accuracy']:.4f}")
print(f"Test F1-macro: {combined_metrics['f1_macro']:.4f}")
print(f"\nIterations: {len(flexmatch_combined_results['history'])}")
print(f"Total pseudo-labels: {sum([h['new_pseudo'] for h in flexmatch_combined_results['history']])}")

## 6. Visualization & Analysis

In [None]:
# Create output directory
# Check paths
OUTPUT_DIR = Path.cwd().parent / 'data' / 'processed' / 'flexmatch_experiments'
if not OUTPUT_DIR.parent.exists():
    OUTPUT_DIR = project_root / 'data' / 'processed' / 'flexmatch_experiments'

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Output directory: {OUTPUT_DIR}")

### 6.1. Comparison Chart: All Methods

In [None]:
# Prepare comparison data
comparison_data = [
    {
        'Method': 'Baseline Self-Training',
        'Accuracy': baseline_metrics['accuracy'],
        'F1-macro': baseline_metrics['f1_macro']
    },
    {
        'Method': 'FlexMatch Combined',
        'Accuracy': combined_metrics['accuracy'],
        'F1-macro': combined_metrics['f1_macro']
    }
]

comparison_df = pd.DataFrame(comparison_data)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Accuracy
axes[0].bar(comparison_df['Method'], comparison_df['Accuracy'], color=['#3498db', '#e74c3c'])
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title('Test Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].set_ylim([0.5, 0.65])
axes[0].tick_params(axis='x', rotation=15)

for i, v in enumerate(comparison_df['Accuracy']):
    axes[0].text(i, v + 0.005, f'{v:.4f}', ha='center', fontweight='bold')

# F1-macro
axes[1].bar(comparison_df['Method'], comparison_df['F1-macro'], color=['#3498db', '#e74c3c'])
axes[1].set_ylabel('F1-macro', fontsize=12)
axes[1].set_title('Test F1-macro Comparison', fontsize=14, fontweight='bold')
axes[1].set_ylim([0.45, 0.60])
axes[1].tick_params(axis='x', rotation=15)

for i, v in enumerate(comparison_df['F1-macro']):
    axes[1].text(i, v + 0.005, f'{v:.4f}', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'test_performance_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("[OK] Comparison chart saved")

### 6.2. Threshold Evolution Over Iterations

In [None]:
# Extract threshold history
history = flexmatch_combined_results['history']

# Prepare data for plotting
iterations = [h['iter'] for h in history]
threshold_data = {}

for class_idx, class_name in enumerate(AQI_CLASSES):
    threshold_data[class_name] = [
        h['thresholds'].get(class_idx, h['base_tau']) 
        for h in history
    ]

# Plot
plt.figure(figsize=(12, 6))

for class_name, thresholds in threshold_data.items():
    plt.plot(iterations, thresholds, marker='o', label=class_name, linewidth=2)

plt.axhline(y=BASE_TAU, color='black', linestyle='--', alpha=0.5, label=f'Base τ={BASE_TAU}')
plt.xlabel('Iteration', fontsize=12)
plt.ylabel('Threshold (τ)', fontsize=12)
plt.title('Dynamic Threshold Evolution per Class', fontsize=14, fontweight='bold')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'threshold_evolution.png', dpi=300, bbox_inches='tight')
plt.show()

print("[OK] Threshold evolution chart saved")

### 6.3. Per-Class F1-Score Comparison

In [None]:
# Extract per-class F1 scores
baseline_f1_per_class = []
flexmatch_f1_per_class = []

for class_name in AQI_CLASSES:
    baseline_f1 = baseline_metrics['report'][class_name]['f1-score']
    flexmatch_f1 = combined_metrics['report'][class_name]['f1-score']
    
    baseline_f1_per_class.append(baseline_f1)
    flexmatch_f1_per_class.append(flexmatch_f1)

# Plot
x = np.arange(len(AQI_CLASSES))
width = 0.35

fig, ax = plt.subplots(figsize=(14, 6))
rects1 = ax.bar(x - width/2, baseline_f1_per_class, width, label='Baseline Self-Training', color='#3498db')
rects2 = ax.bar(x + width/2, flexmatch_f1_per_class, width, label='FlexMatch Combined', color='#e74c3c')

ax.set_ylabel('F1-Score', fontsize=12)
ax.set_title('Per-Class F1-Score Comparison', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels([c.replace('_', ' ') for c in AQI_CLASSES], rotation=15, ha='right')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Add value labels
def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height:.3f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)

autolabel(rects1)
autolabel(rects2)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'per_class_f1_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("[OK] Per-class F1 comparison saved")

### 6.4. Improvement Analysis

In [None]:
# Calculate improvements
improvements = []

for i, class_name in enumerate(AQI_CLASSES):
    baseline_f1 = baseline_f1_per_class[i]
    flexmatch_f1 = flexmatch_f1_per_class[i]
    
    abs_improvement = flexmatch_f1 - baseline_f1
    rel_improvement = (abs_improvement / baseline_f1) * 100 if baseline_f1 > 0 else 0
    
    improvements.append({
        'Class': class_name,
        'Baseline F1': baseline_f1,
        'FlexMatch F1': flexmatch_f1,
        'Absolute Δ': abs_improvement,
        'Relative Δ (%)': rel_improvement
    })

improvements_df = pd.DataFrame(improvements)
improvements_df = improvements_df.sort_values('Relative Δ (%)', ascending=False)

print("\n" + "="*80)
print("PER-CLASS IMPROVEMENT ANALYSIS")
print("="*80)
print(improvements_df.to_string(index=False))
print("\n" + "="*80)

# Save to CSV
improvements_df.to_csv(OUTPUT_DIR / 'per_class_improvements.csv', index=False)
print("\n[OK] Improvement analysis saved")

## 7. Save Results

In [None]:
# Save metrics
with open(OUTPUT_DIR / 'metrics_flexmatch.json', 'w') as f:
    json.dump(combined_metrics, f, indent=2)

# Save comparison summary
summary = {
    'baseline': {
        'accuracy': baseline_metrics['accuracy'],
        'f1_macro': baseline_metrics['f1_macro'],
        'iterations': len(baseline_results['history']),
        'total_pseudo_labels': sum([h['new_pseudo'] for h in baseline_results['history']])
    },
    'flexmatch_combined': {
        'accuracy': combined_metrics['accuracy'],
        'f1_macro': combined_metrics['f1_macro'],
        'iterations': len(flexmatch_combined_results['history']),
        'total_pseudo_labels': sum([h['new_pseudo'] for h in flexmatch_combined_results['history']]),
        'config': {
            'base_tau': BASE_TAU,
            'gamma': GAMMA,
            'alpha': 0.9,
            'use_focal_loss': True
        }
    },
    'improvements': {
        'accuracy_delta': combined_metrics['accuracy'] - baseline_metrics['accuracy'],
        'f1_macro_delta': combined_metrics['f1_macro'] - baseline_metrics['f1_macro'],
        'f1_macro_relative_improvement': (
            (combined_metrics['f1_macro'] - baseline_metrics['f1_macro']) / 
            baseline_metrics['f1_macro'] * 100
        )
    }
}

with open(OUTPUT_DIR / 'flexmatch_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(json.dumps(summary, indent=2))
print("\n[OK] All results saved to:", OUTPUT_DIR)

## 8. Conclusion

**FlexMatch-lite Results:**
- Dynamic threshold helps minority classes get selected
- Focal loss reduces majority class dominance
- Combined approach shows best F1-macro improvement

**Next Steps:**
- Compare with Label Spreading
- Analyze failure cases
- Consider ensemble methods