# DiaTrend Forecasting Framework con Explainability

Framework completo per il forecasting della glicemia con meccanismi di explainability.

## Fasi del Framework:
1. **Caricamento e Preprocessing Dati**
2. **Costruzione e Training del Modello**
3. **Explainability con SHAP**
4. **Valutazione e Analisi**

In [None]:
from stacked_RNNs import make_stacked_RNNs

model = make_stacked_RNNs(input_shape=(128, 32), dropout=0.5, type_model="LSTM", num_layers=3, hidden_units=64, bidirectional=True)
model.summary()

In [5]:
from Trasformers import Transformer
transformer = Transformer(
    num_layers=3,              # Numero di layer encoder/decoder
    d_model=64,                # Dimensione embedding
    num_heads=4,               # Numero attention heads
    dff=256,                   # Dimensione feed forward
    target_vocab_size=1,       # Per output numerico (forecasting)
    use_timeseries_embedding=True,  # Attiva CNN 1D per time series
    input_features=1,          # Numero feature input (solo glucose)
    num_conv_layers=2,         # Numero layer CNN 1D
    kernel_size=3,             # Dimensione kernel CNN
    dropout_rate=0.1
)

# Costruisci il modello specificando le dimensioni di input
seq_len = 12  # Lunghezza sequenza (come WINDOW_SIZE)
target_len = 6  # Lunghezza target (come HORIZON)

# Build con le dimensioni corrette
transformer.build([
    (None, seq_len, 1),      # Input encoder: (batch, seq_len, features)
    (None, target_len)       # Input decoder: (batch, target_len)
])

# Ora puoi chiamare summary
transformer.summary()

Model: "transformer_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_2 (Encoder)         multiple                  311872    
                                                                 
 decoder_2 (Decoder)         multiple                  498688    
                                                                 
 dense_43 (Dense)            multiple                  65        
                                                                 
Total params: 810,625
Trainable params: 810,625
Non-trainable params: 0
_________________________________________________________________


In [None]:
# Import delle librerie necessarie
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from sklearn.metrics import mean_squared_error, mean_absolute_error
import warnings
warnings.filterwarnings('ignore')

# Import dei moduli custom
from funcs import *
from stacked_RNNs import *
from explainability import *

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU disponibile: {tf.config.list_physical_devices('GPU')}")

# Configurazione
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
np.random.seed(42)
tf.random.set_seed(42)

## Fase 1: Caricamento e Preprocessing Dati

In [None]:
# Parametri del framework
DATA_DIR = r'C:\Users\rosar\Desktop\UNISA\Magistrale - Informatica\II semestre\DL\timeseries\dataset'
DEMOGRAPHICS_FILE = r'C:\Users\rosar\Desktop\UNISA\Magistrale - Informatica\II semestre\DL\timeseries\dataset\SubjectDemographics_3-15-23.xlsx'

# Parametri del modello
WINDOW_SIZE = 12  # 12 letture passate (1 ora se ogni 5 min)
HORIZON = 6       # 6 timesteps nel futuro (30 min se ogni 5 min)
CATEGORICAL_COLS = ['Gender', 'Race']
NUMERICAL_COLS = ['Age', 'Hemoglobin A1C']

print("=== FASE 1: CARICAMENTO DATI ===")

# 1.1 Caricamento dati CGM
print("Caricamento dati CGM...")
cgm_data = load_cgm_data(DATA_DIR)
print(f"Caricati dati per {len(cgm_data)} pazienti")

# Verifica formato dati CGM
first_patient = list(cgm_data.keys())[0]
print(f"\nFormato dati paziente {first_patient}:")
print(cgm_data[first_patient].head())
print(f"Colonne: {cgm_data[first_patient].columns.tolist()}")

In [None]:
# 1.2 Caricamento feature statiche
print("\nCaricamento feature statiche...")
static_features = load_static_features_from_excel(DEMOGRAPHICS_FILE)
print(f"Feature statiche caricate per {len(static_features)} pazienti")
print("\nPrime righe feature statiche:")
print(static_features.head())
print(f"\nColonne: {static_features.columns.tolist()}")

# Verifica valori mancanti
print("\nValori mancanti nelle feature statiche:")
print(static_features.isnull().sum())

In [None]:
# 1.3 Preprocessing base e standardizzazione nomi colonne
print("\n=== PREPROCESSING BASE ===")

# Standardizza nomi colonne CGM se necessario
standardized_cgm_data = {}
for patient_id, df in cgm_data.items():
    df_copy = df.copy()
    
    # Standardizza nomi colonne
    if 'mg/dl' in df_copy.columns:
        df_copy.rename(columns={'mg/dl': 'glucose'}, inplace=True)
    elif 'mg/dL' in df_copy.columns:
        df_copy.rename(columns={'mg/dL': 'glucose'}, inplace=True)
    
    # Converti date in datetime
    if 'date' in df_copy.columns:
        df_copy['date'] = pd.to_datetime(df_copy['date'])
    
    standardized_cgm_data[patient_id] = df_copy

# Verifica dopo standardizzazione
first_patient = list(standardized_cgm_data.keys())[0]
print(f"Colonne standardizzate: {standardized_cgm_data[first_patient].columns.tolist()}")

# Statistiche preliminari
all_glucose_values = []
for patient_id, df in standardized_cgm_data.items():
    if 'glucose' in df.columns:
        all_glucose_values.extend(df['glucose'].dropna().tolist())

print(f"\nStatistiche globali glicemia:")
print(f"- Numero totale letture: {len(all_glucose_values):,}")
print(f"- Media: {np.mean(all_glucose_values):.2f} mg/dL")
print(f"- Std: {np.std(all_glucose_values):.2f} mg/dL")
print(f"- Min: {np.min(all_glucose_values):.2f} mg/dL")
print(f"- Max: {np.max(all_glucose_values):.2f} mg/dL")

In [None]:
# 1.4 Split pazienti in train/val/test
print("\n=== SPLIT PAZIENTI ===")

# Lista pazienti validi (con dati CGM e feature statiche)
valid_patients = []
for patient_id in standardized_cgm_data.keys():
    if patient_id in static_features['SubjectID'].values:
        df = standardized_cgm_data[patient_id]
        if 'glucose' in df.columns and len(df['glucose'].dropna()) >= WINDOW_SIZE + HORIZON:
            valid_patients.append(patient_id)

print(f"Pazienti validi per training: {len(valid_patients)}")

# Split cross-paziente
train_patients, val_patients, test_patients = split_patients(
    valid_patients, val_ratio=0.15, test_ratio=0.15, seed=42
)

print(f"Train: {len(train_patients)} pazienti")
print(f"Validation: {len(val_patients)} pazienti")
print(f"Test: {len(test_patients)} pazienti")

print(f"\nPazienti train: {train_patients[:5]}...")  # Mostra primi 5
print(f"Pazienti val: {val_patients}")
print(f"Pazienti test: {test_patients}")

In [None]:
# 1.5 Preparazione dataset finale
print("\n=== PREPARAZIONE DATASET FINALE ===")

# Prepara dataset con finestre temporali, encoding feature statiche, scaling
dataset = prepare_dataset_for_training(
    cgm_data_dict=standardized_cgm_data,
    static_features_df=static_features,
    window_size=WINDOW_SIZE,
    horizon=HORIZON,
    train_patients=train_patients,
    val_patients=val_patients,
    test_patients=test_patients,
    categorical_cols=CATEGORICAL_COLS,
    numerical_cols=NUMERICAL_COLS,
    scaler_type='zscore'
)

# Verifica dimensioni dataset
print("\nDimensioni dataset:")
for split in ['train', 'val', 'test']:
    if dataset[split]['X_seq'] is not None:
        print(f"{split.upper()}:")
        print(f"  - X_seq: {dataset[split]['X_seq'].shape}")
        print(f"  - X_static: {dataset[split]['X_static'].shape}")
        print(f"  - y: {dataset[split]['y'].shape}")
    else:
        print(f"{split.upper()}: Nessun dato")

# Nomi delle feature statiche (per explainability)
feature_names = CATEGORICAL_COLS + NUMERICAL_COLS
static_dim = len(feature_names)
print(f"\nFeature statiche ({static_dim}): {feature_names}")

## Fase 2: Costruzione e Training del Modello

In [None]:
print("\n=== FASE 2: COSTRUZIONE MODELLO ===")

# Parametri modello
SEQ_LEN = WINDOW_SIZE
STATIC_DIM = static_dim
LSTM_HIDDEN = 64
FC_HIDDEN = 32
NUM_LAYERS = 2

# Costruisci modello LSTM+Static
model = build_lstm_static_model(
    seq_len=SEQ_LEN,
    static_dim=STATIC_DIM,
    lstm_hidden=LSTM_HIDDEN,
    fc_hidden=FC_HIDDEN,
    num_layers=NUM_LAYERS
)

# Compila modello
model = compile_model(model, learning_rate=0.001, loss='mse')

# Visualizza architettura
print("\nArchitettura del modello:")
model.summary()

# Plot del modello (opzionale)
try:
    tf.keras.utils.plot_model(model, to_file='model_architecture.png', show_shapes=True, show_layer_names=True)
    print("Diagramma architettura salvato in 'model_architecture.png'")
except:
    print("Impossibile creare diagramma architettura (pydot non installato)")

In [None]:
# 2.2 Training del modello
print("\n=== TRAINING MODELLO ===")

# Parametri training
EPOCHS = 100
BATCH_SIZE = 32
PATIENCE = 15

# Training
history = train_model(
    model=model,
    X_train_seq=dataset['train']['X_seq'],
    X_train_static=dataset['train']['X_static'],
    y_train=dataset['train']['y'],
    X_val_seq=dataset['val']['X_seq'],
    X_val_static=dataset['val']['X_static'],
    y_val=dataset['val']['y'],
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    patience=PATIENCE
)

print("\nTraining completato!")

In [None]:
# 2.3 Visualizzazione learning curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss durante Training')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(history.history['mae'], label='Train MAE')
plt.plot(history.history['val_mae'], label='Validation MAE')
plt.title('MAE durante Training')
plt.xlabel('Epoch')
plt.ylabel('Mean Absolute Error')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
lr_values = []
for callback in model.history.history.get('lr', []):
    lr_values.append(callback)
if lr_values:
    plt.plot(lr_values)
    plt.title('Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')
    plt.grid(True, alpha=0.3)
else:
    plt.text(0.5, 0.5, 'Learning Rate\nnon disponibile', 
             ha='center', va='center', transform=plt.gca().transAxes)

plt.tight_layout()
plt.show()

# Statistiche finali training
final_train_loss = history.history['loss'][-1]
final_val_loss = history.history['val_loss'][-1]
best_val_loss = min(history.history['val_loss'])
best_epoch = np.argmin(history.history['val_loss']) + 1

print(f"\nStatistiche Training:")
print(f"- Epoch totali: {len(history.history['loss'])}")
print(f"- Miglior epoca: {best_epoch}")
print(f"- Miglior val loss: {best_val_loss:.6f}")
print(f"- Loss finale train: {final_train_loss:.6f}")
print(f"- Loss finale val: {final_val_loss:.6f}")

## Fase 3: Valutazione del Modello

In [None]:
print("\n=== FASE 3: VALUTAZIONE MODELLO ===")

# 3.1 Valutazione su test set
results = evaluate_model(
    model=model,
    X_test_seq=dataset['test']['X_seq'],
    X_test_static=dataset['test']['X_static'],
    y_test=dataset['test']['y']
)

print(f"\nMetriche Test Set (dati scalati):")
print(f"- MSE: {results['mse']:.6f}")
print(f"- MAE: {results['mae']:.6f}")
print(f"- RMSE: {results['rmse']:.6f}")

# 3.2 Conversione a valori originali
y_scaler = dataset['scalers']['y_scaler']

# Converti predizioni e target ai valori originali
y_test_original = inverse_transform_predictions(dataset['test']['y'], y_scaler)
y_pred_original = inverse_transform_predictions(results['predictions'], y_scaler)

# Metriche su valori originali
mse_original = mean_squared_error(y_test_original, y_pred_original)
mae_original = mean_absolute_error(y_test_original, y_pred_original)
rmse_original = np.sqrt(mse_original)

print(f"\nMetriche Test Set (valori originali mg/dL):")
print(f"- MSE: {mse_original:.2f}")
print(f"- MAE: {mae_original:.2f} mg/dL")
print(f"- RMSE: {rmse_original:.2f} mg/dL")

# 3.3 Metriche cliniche
clinical_metrics = calculate_clinical_metrics(y_test_original, y_pred_original)
print(f"\nMetriche Cliniche:")
print(f"- Accuracy range normale: {clinical_metrics['normal_accuracy']:.3f}")
print(f"- Accuracy ipoglicemia: {clinical_metrics['hypoglycemia_accuracy']:.3f}")
print(f"- Accuracy iperglicemia: {clinical_metrics['hyperglycemia_accuracy']:.3f}")
print(f"- MAPE: {clinical_metrics['mean_absolute_percentage_error']:.2f}%")

In [None]:
# 3.4 Visualizzazione risultati
print("\n=== VISUALIZZAZIONE RISULTATI ===")

# Plot predizioni vs valori reali
plot_prediction_vs_actual(y_test_original, y_pred_original, 
                          title="Predizioni vs Valori Reali - Test Set")

# Distribuzione errori
errors = y_pred_original - y_test_original

plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.hist(errors, bins=50, alpha=0.7, edgecolor='black')
plt.axvline(x=0, color='red', linestyle='--', linewidth=2)
plt.title('Distribuzione Errori Predizione')
plt.xlabel('Errore (mg/dL)')
plt.ylabel('Frequenza')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.boxplot([errors])
plt.title('Boxplot Errori')
plt.ylabel('Errore (mg/dL)')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.scatter(y_test_original, errors, alpha=0.6)
plt.axhline(y=0, color='red', linestyle='--', linewidth=2)
plt.title('Errori vs Valori Reali')
plt.xlabel('Valori Reali (mg/dL)')
plt.ylabel('Errore (mg/dL)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nStatistiche Errori:")
print(f"- Errore medio: {np.mean(errors):.2f} mg/dL")
print(f"- Std errori: {np.std(errors):.2f} mg/dL")
print(f"- Errore mediano: {np.median(errors):.2f} mg/dL")
print(f"- 95% errori entro: ±{np.percentile(np.abs(errors), 95):.2f} mg/dL")

## Fase 4: Explainability con SHAP

In [None]:
print("\n=== FASE 4: EXPLAINABILITY ===")

# 4.1 Inizializzazione framework XAI
xai_framework = XAIFramework(
    model=model,
    dataset=dataset,
    feature_names=feature_names
)

print("Framework XAI inizializzato")
print(f"Feature statiche: {feature_names}")

# 4.2 Setup SHAP explainer
print("\nConfigurazione SHAP explainer...")
xai_framework.setup_shap_explainer(
    background_size=50,  # Ridotto per velocità
    explainer_type='gradient'  # Più veloce di 'deep'
)

# 4.3 Calcolo valori SHAP
print("\nCalcolo valori SHAP...")
shap_values = xai_framework.calculate_shap_values(
    max_samples=30  # Limitato per demo
)

print("Valori SHAP calcolati!")
print(f"Shape SHAP sequenza: {shap_values[0].shape}")
print(f"Shape SHAP features statiche: {shap_values[1].shape}")

In [None]:
# 4.4 Visualizzazione importanza globale
print("\n=== ANALISI IMPORTANZA GLOBALE ===")

xai_framework.plot_global_importance()

# Summary testuale
summary = xai_framework.generate_explanation_summary()
print(summary)

In [None]:
# 4.5 Analisi campioni specifici
print("\n=== ANALISI CAMPIONI SPECIFICI ===")

# Analizza 3 campioni diversi
for i in range(min(3, len(xai_framework.test_indices))):
    print(f"\n--- Campione {i} ---")
    xai_framework.plot_sample_explanation(sample_idx=i)
    
    # Spiegazione clinica
    sample_seq = dataset['test']['X_seq'][xai_framework.test_indices[i]].reshape(1, -1, 1)
    sample_static = dataset['test']['X_static'][xai_framework.test_indices[i]].reshape(1, -1)
    
    clinical_explanation = create_clinical_explanation(
        model, sample_seq, sample_static, feature_names
    )
    print("\nSpiegazione Clinica:")
    print(clinical_explanation)

In [None]:
# 4.6 Analisi interazioni feature
print("\n=== ANALISI INTERAZIONI FEATURE ===")

xai_framework.analyze_feature_interactions()

# Analisi correlazioni SHAP
static_shap = shap_values[1]
seq_shap = shap_values[0]

# Correlazione tra importanza feature statiche e importanza sequenza
plt.figure(figsize=(12, 8))

for i, feat_name in enumerate(feature_names):
    plt.subplot(2, 2, i+1 if i < 4 else 4)
    
    # Importanza feature statica per ogni campione
    feat_importance = np.abs(static_shap[:, i])
    
    # Importanza media sequenza per ogni campione
    seq_importance = np.mean(np.abs(seq_shap), axis=(1, 2))
    
    plt.scatter(feat_importance, seq_importance, alpha=0.6)
    plt.xlabel(f'Importanza {feat_name}')
    plt.ylabel('Importanza Media Sequenza')
    plt.title(f'Correlazione {feat_name} vs Sequenza')
    
    # Calcola correlazione
    if len(feat_importance) > 1:
        corr = np.corrcoef(feat_importance, seq_importance)[0, 1]
        plt.text(0.05, 0.95, f'r = {corr:.3f}', transform=plt.gca().transAxes,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Fase 5: Analisi per Paziente e Conclusioni

In [None]:
print("\n=== ANALISI PER PAZIENTE ===")

# Esempio di analisi per paziente specifico
sample_patient = test_patients[0] if test_patients else None

if sample_patient:
    print(f"\nAnalisi dettagliata paziente: {sample_patient}")
    
    # Dati del paziente
    patient_cgm = standardized_cgm_data[sample_patient]
    patient_static = static_features[static_features['SubjectID'] == sample_patient]
    
    print(f"Dati CGM disponibili: {len(patient_cgm)} letture")
    print(f"Range temporale: {patient_cgm['date'].min()} - {patient_cgm['date'].max()}")
    
    if not patient_static.empty:
        print("\nCaratteristiche paziente:")
        for col in ['Age', 'Gender', 'Race', 'Hemoglobin A1C']:
            if col in patient_static.columns:
                value = patient_static[col].iloc[0]
                print(f"- {col}: {value}")
    
    # Visualizza serie temporale del paziente
    plt.figure(figsize=(15, 8))
    
    plt.subplot(2, 1, 1)
    plt.plot(patient_cgm['date'], patient_cgm['glucose'], 'b-', linewidth=1, alpha=0.8)
    plt.axhspan(70, 180, alpha=0.2, color='green', label='Range Normale')
    plt.axhspan(0, 70, alpha=0.2, color='red', label='Ipoglicemia')
    plt.axhspan(180, 400, alpha=0.2, color='orange', label='Iperglicemia')
    plt.title(f'Serie Temporale Completa - {sample_patient}')
    plt.ylabel('Glicemia (mg/dL)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 1, 2)
    # Ultime 24 ore (assumendo 1 lettura ogni 5 min = 288 letture/giorno)
    last_24h = patient_cgm.tail(288) if len(patient_cgm) > 288 else patient_cgm
    plt.plot(last_24h['date'], last_24h['glucose'], 'b-', linewidth=2)
    plt.axhspan(70, 180, alpha=0.2, color='green')
    plt.axhspan(0, 70, alpha=0.2, color='red')
    plt.axhspan(180, 400, alpha=0.2, color='orange')
    plt.title('Ultime 24 ore (o dati disponibili)')
    plt.xlabel('Data')
    plt.ylabel('Glicemia (mg/dL)')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Statistiche del paziente
    glucose_values = patient_cgm['glucose'].dropna()
    print(f"\nStatistiche glicemiche {sample_patient}:")
    print(f"- Media: {glucose_values.mean():.2f} mg/dL")
    print(f"- Std: {glucose_values.std():.2f} mg/dL")
    print(f"- TIR (70-180): {((glucose_values >= 70) & (glucose_values <= 180)).mean()*100:.1f}%")
    print(f"- Tempo in ipoglicemia (<70): {(glucose_values < 70).mean()*100:.1f}%")
    print(f"- Tempo in iperglicemia (>180): {(glucose_values > 180).mean()*100:.1f}%")

In [None]:
# Conclusioni e salvataggio modello
print("\n=== CONCLUSIONI ===")

print(f"""
FRAMEWORK DIATREND FORECASTING - RISULTATI FINALI
================================================

DATASET:
- Pazienti totali processati: {len(valid_patients)}
- Finestre temporali create: {len(dataset['train']['y']) + len(dataset['val']['y']) + len(dataset['test']['y']):,}
- Lunghezza finestra input: {WINDOW_SIZE} timesteps
- Orizzonte predizione: {HORIZON} timesteps

MODELLO:
- Architettura: LSTM + Feature Statiche
- Feature statiche utilizzate: {len(feature_names)}
- Parametri totali: {model.count_params():,}

PERFORMANCE:
- RMSE: {rmse_original:.2f} mg/dL
- MAE: {mae_original:.2f} mg/dL
- MAPE: {clinical_metrics['mean_absolute_percentage_error']:.2f}%
- Accuracy range normale: {clinical_metrics['normal_accuracy']:.1%}

EXPLAINABILITY:
- Framework SHAP implementato
- Analisi importanza feature statiche vs sequenze CGM
- Spiegazioni a livello di campione e globale
- Visualizzazioni clinicamente interpretabili

Il framework è pronto per deployment e ulteriori analisi!
""")

# Salva modello
model.save('diatrend_lstm_static_model.h5')
print("\nModello salvato in 'diatrend_lstm_static_model.h5'")

# Salva dataset e scalers
import pickle
with open('dataset_and_scalers.pkl', 'wb') as f:
    pickle.dump({
        'dataset': dataset,
        'feature_names': feature_names,
        'window_size': WINDOW_SIZE,
        'horizon': HORIZON
    }, f)
print("Dataset e scalers salvati in 'dataset_and_scalers.pkl'")

print("\n🎉 Framework DiaTrend completato con successo!")