# GraphRec: A Graph Neural Network Framework for Athletic Recovery Optimization

## 1. Imports and Setup

In [None]:
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, GCNConv
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from pathlib import Path
import warnings
from typing import Dict, List, Tuple
import math
import random

warnings.filterwarnings("ignore")

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA GPU")
else:
    device = torch.device("cpu")
    print("Using CPU")

  from .autonotebook import tqdm as notebook_tqdm


Using MPS


## 2. Data Processing Pipeline

In [3]:
class DataProcessor:
    """Processes multi-modal athletic recovery data into graph representations"""
    
    def __init__(self, data_path: str = "pmdata"):
        self.data_path = Path(data_path)
        self.scalers = {}
        self.athlete_mapping = {}
        
    def calculate_hrv_metrics(self, hr_data: List[Dict]) -> Dict:
        """Calculate Heart Rate Variability metrics from minute-by-minute HR data"""
        try:
            bpm_values = [entry['value']['bpm'] for entry in hr_data if 'value' in entry and 'bpm' in entry['value']]
            
            if len(bpm_values) < 50:
                return {'hrv_rmssd': 0, 'hrv_pnn50': 0, 'hrv_mean': np.mean(bpm_values) if bpm_values else 60}
            
            # Calculate RR intervals from BPM
            rr_intervals = [60000 / bpm for bpm in bpm_values if bpm > 0]
            
            if len(rr_intervals) < 2:
                return {'hrv_rmssd': 0, 'hrv_pnn50': 0, 'hrv_mean': np.mean(bpm_values)}
            
            # RMSSD calculation
            successive_diffs = [abs(rr_intervals[i+1] - rr_intervals[i]) for i in range(len(rr_intervals)-1)]
            rmssd = np.sqrt(np.mean([diff**2 for diff in successive_diffs])) if successive_diffs else 0
            
            # pNN50 calculation
            pnn50 = len([diff for diff in successive_diffs if diff > 50]) / len(successive_diffs) * 100 if successive_diffs else 0
            
            return {
                'hrv_rmssd': rmssd,
                'hrv_pnn50': pnn50,
                'hrv_mean': np.mean(bpm_values)
            }
        except Exception:
            return {'hrv_rmssd': 0, 'hrv_pnn50': 0, 'hrv_mean': 60}
    
    def extract_sleep_features(self, sleep_data: List[Dict]) -> Dict:
        """Extract sleep architecture features as described in the paper"""
        if not sleep_data:
            return self._default_sleep_features()
        
        try:
            latest_sleep = sleep_data[-1]
            levels = latest_sleep.get('levels', {})
            summary = levels.get('summary', {})
            
            features = {
                'sleep_duration': latest_sleep.get('timeInBed', 0) / 60,
                'sleep_efficiency': latest_sleep.get('efficiency', 0),
                'minutes_to_sleep': latest_sleep.get('minutesToFallAsleep', 0),
                'sleep_deep_min': summary.get('deep', {}).get('minutes', 0),
                'sleep_rem_min': summary.get('rem', {}).get('minutes', 0),
                'sleep_light_min': summary.get('light', {}).get('minutes', 0),
                'sleep_wake_min': summary.get('wake', {}).get('minutes', 0),
                'sleep_deep_pct': summary.get('deep', {}).get('minutes', 0) / max(latest_sleep.get('minutesAsleep', 1), 1) * 100,
                'sleep_rem_pct': summary.get('rem', {}).get('minutes', 0) / max(latest_sleep.get('minutesAsleep', 1), 1) * 100,
                'sleep_restlessness': len(levels.get('data', [])),
            }
            
            return features
        except Exception:
            return self._default_sleep_features()
    
    def _default_sleep_features(self) -> Dict:
        return {
            'sleep_duration': 7.0,
            'sleep_efficiency': 85,
            'minutes_to_sleep': 15,
            'sleep_deep_min': 90,
            'sleep_rem_min': 90,
            'sleep_light_min': 240,
            'sleep_wake_min': 20,
            'sleep_deep_pct': 20,
            'sleep_rem_pct': 20,
            'sleep_restlessness': 15
        }
    
    def extract_activity_features(self, activity_data: List[Dict], hr_zones_data: List[Dict], calories_data: List[Dict]) -> Dict:
        """Extract physical activity features including heart rate zones"""
        try:
            hr_zones = {'zone_1': 0, 'zone_2': 0, 'zone_3': 0, 'zone_4': 0}
            if hr_zones_data:
                latest_zones = hr_zones_data[-1].get('value', {}).get('valuesInZones', {})
                hr_zones = {
                    'zone_1': latest_zones.get('IN_DEFAULT_ZONE_1', 0),
                    'zone_2': latest_zones.get('IN_DEFAULT_ZONE_2', 0), 
                    'zone_3': latest_zones.get('IN_DEFAULT_ZONE_3', 0),
                    'zone_4': latest_zones.get('ABOVE_DEFAULT_ZONE_3', 0)
                }
            
            recent_exercises = [ex for ex in activity_data if 'startTime' in ex][-7:] if activity_data else []
            total_exercise_min = sum(ex.get('duration', 0) for ex in recent_exercises) / 60000
            avg_hr = np.mean([ex.get('averageHeartRate', 0) for ex in recent_exercises if ex.get('averageHeartRate', 0) > 0]) if recent_exercises else 0
            total_calories = sum(ex.get('calories', 0) for ex in recent_exercises)
            
            daily_calories = 0
            if calories_data:
                daily_calories = sum(float(entry.get('value', 0)) for entry in calories_data[-1440:])
            
            return {
                'exercise_duration_week': total_exercise_min,
                'avg_exercise_hr': avg_hr,
                'exercise_calories_week': total_calories,
                'daily_caloric_expenditure': daily_calories,
                'hr_zone_1_min': hr_zones['zone_1'],
                'hr_zone_2_min': hr_zones['zone_2'],
                'hr_zone_3_min': hr_zones['zone_3'],
                'hr_zone_4_min': hr_zones['zone_4'],
                'training_stress_score': (hr_zones['zone_3'] * 2 + hr_zones['zone_4'] * 3) / 60
            }
        except Exception:
            return {
                'exercise_duration_week': 0,
                'avg_exercise_hr': 0,
                'exercise_calories_week': 0,
                'daily_caloric_expenditure': 2000,
                'hr_zone_1_min': 0,
                'hr_zone_2_min': 0,
                'hr_zone_3_min': 0,
                'hr_zone_4_min': 0,
                'training_stress_score': 0
            }
    
    def load_participant_data(self, participant_id: str) -> List[Dict]:
        """Load and process all data for a single participant"""
        p_path = self.data_path / participant_id
        if not p_path.exists():
            return []
        
        # Load all data files
        data_files = {
            'wellness': p_path / 'pmsys' / 'wellness.csv',
            'srpe': p_path / 'pmsys' / 'srpe.csv', 
            'injury': p_path / 'pmsys' / 'injury.csv',
            'heart_rate': p_path / 'fitbit' / 'heart_rate.json',
            'sleep': p_path / 'fitbit' / 'sleep.json',
            'sleep_score': p_path / 'fitbit' / 'sleep_score.csv',
            'exercise': p_path / 'fitbit' / 'exercise.json',
            'hr_zones': p_path / 'fitbit' / 'time_in_heart_rate_zones.json',
            'calories': p_path / 'fitbit' / 'calories.json',
            'resting_hr': p_path / 'fitbit' / 'resting_heart_rate.json'
        }
        
        # Load JSON files
        json_data = {}
        for key in ['heart_rate', 'sleep', 'exercise', 'hr_zones', 'calories', 'resting_hr']:
            try:
                if data_files[key].exists():
                    with open(data_files[key], 'r') as f:
                        json_data[key] = json.load(f)
                else:
                    json_data[key] = []
            except Exception:
                json_data[key] = []
        
        # Load CSV files
        csv_data = {}
        for key in ['wellness', 'srpe', 'injury', 'sleep_score']:
            try:
                if data_files[key].exists():
                    csv_data[key] = pd.read_csv(data_files[key])
                else:
                    csv_data[key] = pd.DataFrame()
            except Exception:
                csv_data[key] = pd.DataFrame()
        
        # Process wellness data (primary temporal anchor)
        samples = []
        if not csv_data['wellness'].empty:
            for _, wellness_row in csv_data['wellness'].iterrows():
                try:
                    sample_date = pd.to_datetime(wellness_row['effective_time_frame']).date()
                    
                    # Extract wellness features (subjective)
                    wellness_features = {
                        'fatigue': int(wellness_row['fatigue']),
                        'mood': int(wellness_row['mood']),
                        'readiness': int(wellness_row['readiness']),
                        'sleep_quality_subj': int(wellness_row['sleep_quality']),
                        'soreness': int(wellness_row['soreness']),
                        'stress': int(wellness_row['stress']),
                        'sleep_duration_subj': int(wellness_row['sleep_duration_h'])
                    }
                    
                    # Extract physiological features
                    hrv_features = self.calculate_hrv_metrics(json_data['heart_rate'])
                    sleep_features = self.extract_sleep_features(json_data['sleep'])
                    activity_features = self.extract_activity_features(
                        json_data['exercise'], 
                        json_data['hr_zones'],
                        json_data['calories']
                    )
                    
                    # Get resting heart rate
                    resting_hr = 60
                    if json_data['resting_hr']:
                        resting_hr = json_data['resting_hr'][-1].get('value', {}).get('restingHeartRate', 60)
                    
                    # Combine all features
                    all_features = {
                        **wellness_features,
                        **hrv_features,
                        **sleep_features,
                        **activity_features,
                        'resting_heart_rate': resting_hr,
                        'participant_id': participant_id,
                        'date': sample_date
                    }
                    
                    samples.append(all_features)
                    
                except Exception:
                    continue
        
        return samples
    
    def augment_sample(self, sample: Dict, augmentation_strength: float = 0.1) -> Dict:
        """Apply data augmentation as described in Section 3.6"""
        augmented = sample.copy()
        
        # Physiological noise augmentation (Gaussian noise)
        physio_keys = ['hrv_rmssd', 'hrv_pnn50', 'hrv_mean', 'resting_heart_rate',
                      'sleep_duration', 'sleep_efficiency', 'exercise_duration_week']
        
        for key in physio_keys:
            if key in augmented:
                noise = np.random.normal(0, augmentation_strength * abs(augmented[key]))
                augmented[key] = max(0, augmented[key] + noise)
        
        # Subjective data augmentation (discrete modifications)
        subjective_keys = ['fatigue', 'mood', 'readiness', 'sleep_quality_subj', 'soreness', 'stress']
        
        for key in subjective_keys:
            if key in augmented and np.random.random() < 0.1:
                current_val = augmented[key]
                change = np.random.choice([-1, 0, 1])
                if key in ['fatigue', 'soreness', 'stress']:
                    new_val = np.clip(current_val + change, 1, 4)
                else:
                    new_val = np.clip(current_val + change, 1, 8)
                augmented[key] = int(new_val)
        
        return augmented
    
    def create_augmented_samples(self, samples: List[Dict], num_augmentations: int = 2) -> List[Dict]:
        """Create augmented versions of samples for training"""
        augmented_samples = samples.copy()
        
        for _ in range(num_augmentations):
            for sample in samples:
                augmented = self.augment_sample(sample, augmentation_strength=0.1)
                augmented_samples.append(augmented)
        
        return augmented_samples
    
    def create_recovery_graphs(self, samples: List[Dict], participant_id: str, athlete_idx: int) -> List[Data]:
        """Create graph representations for recovery optimization"""
        if len(samples) < 7:
            return []
        
        graphs = []
        samples = sorted(samples, key=lambda x: x['date'])
        
        # Create temporal windows (7-day sliding window)
        for i in range(6, len(samples)):
            try:
                current_sample = samples[i]
                history_samples = samples[i-6:i]
                
                # Create multi-modal graph
                node_features, edge_index, edge_attr = self._build_recovery_graph(
                    current_sample, history_samples, athlete_idx
                )
                
                # Define recovery targets (Eq. 1-4 in paper)
                targets = self._create_recovery_targets(current_sample, samples[i:i+3] if i+3 < len(samples) else [current_sample])
                
                graph = Data(
                    x=torch.tensor(node_features, dtype=torch.float32),
                    edge_index=torch.tensor(edge_index, dtype=torch.long),
                    edge_attr=torch.tensor(edge_attr, dtype=torch.float32),
                    y=torch.tensor(targets, dtype=torch.float32),
                    participant_id=participant_id,
                    athlete_idx=athlete_idx,
                    date=current_sample['date']
                )
                
                graphs.append(graph)
                
            except Exception:
                continue
        
        return graphs
    
    def _build_recovery_graph(self, current: Dict, history: List[Dict], athlete_idx: int) -> Tuple:
        """Build multi-modal recovery graph with specialized node types (Section 3.2)"""
        
        nodes = []
        
        # Current physiological nodes (0-2)
        physio_node_1 = [  # HRV & Heart Rate
            current['hrv_rmssd'], current['hrv_pnn50'], current['hrv_mean'],
            current['resting_heart_rate'], current['avg_exercise_hr']
        ]
        
        physio_node_2 = [  # Sleep architecture
            current['sleep_duration'], current['sleep_efficiency'], current['minutes_to_sleep'],
            current['sleep_deep_min'], current['sleep_rem_min'], current['sleep_light_min'],
            current['sleep_deep_pct'], current['sleep_rem_pct'], current['sleep_restlessness']
        ]
        
        physio_node_3 = [  # Activity & training
            current['exercise_duration_week'], current['exercise_calories_week'],
            current['daily_caloric_expenditure'], current['training_stress_score'],
            current['hr_zone_1_min'], current['hr_zone_2_min'], 
            current['hr_zone_3_min'], current['hr_zone_4_min']
        ]
        
        # Current subjective nodes (3-5)
        subj_node_1 = [  # Wellness perception
            current['fatigue'], current['mood'], current['stress'],
            current['sleep_quality_subj'], current['sleep_duration_subj']
        ]
        
        subj_node_2 = [  # Performance readiness
            current['readiness'], current['soreness']
        ]
        
        subj_node_3 = [  # Recovery status (derived)
            (current['readiness'] - current['fatigue']) / 8,
            current['mood'] / 4,
            (4 - current['stress']) / 4
        ]
        
        # Historical temporal nodes (6-11)
        history_nodes = []
        for day_data in history:
            trend_features = [
                day_data['readiness'] / 8,
                day_data['fatigue'] / 4,
                day_data['sleep_quality_subj'] / 4,
                day_data['training_stress_score'] / 10
            ]
            history_nodes.append(trend_features)
        
        # Pad if insufficient history
        while len(history_nodes) < 6:
            history_nodes.insert(0, [0.5, 0.5, 0.5, 0.0])
        
        # Athlete identity node (12)
        athlete_features = [
            athlete_idx / 16,
            np.mean([h[0] for h in history_nodes]),
            np.mean([h[1] for h in history_nodes]),
            np.std([h[0] for h in history_nodes]) if len(history_nodes) > 1 else 0
        ]
        
        # Combine all nodes
        nodes = [physio_node_1, physio_node_2, physio_node_3,
                 subj_node_1, subj_node_2, subj_node_3] + history_nodes + [athlete_features]
        
        # Pad nodes to consistent size
        max_node_size = max(len(node) for node in nodes)
        for i, node in enumerate(nodes):
            if len(node) < max_node_size:
                nodes[i] = node + [0.0] * (max_node_size - len(node))
        
        edge_index, edge_attr = self._create_recovery_edges(current, history)
        
        return nodes, edge_index, edge_attr
    
    def _create_recovery_edges(self, current: Dict, history: List[Dict]) -> Tuple:
        """Create meaningful edge connections with relationship weights"""
        
        edges = []
        edge_weights = []
        
        # Physiological interconnections
        physio_connections = [
            (0, 1, 0.8),  # HRV ↔ Sleep
            (0, 2, 0.6),  # HRV ↔ Activity
            (1, 2, 0.7),  # Sleep ↔ Activity
        ]
        
        # Subjective interconnections  
        subj_connections = [
            (3, 4, 0.9),  # Wellness ↔ Readiness
            (3, 5, 0.8),  # Wellness ↔ Recovery status
            (4, 5, 0.9),  # Readiness ↔ Recovery status
        ]
        
        # Cross-modal connections (physio ↔ subjective)
        cross_connections = [
            (0, 3, 0.7),  # HRV → Wellness
            (1, 3, 0.8),  # Sleep → Wellness  
            (1, 4, 0.9),  # Sleep → Readiness
            (2, 4, 0.6),  # Activity → Readiness
            (2, 5, 0.7),  # Activity → Recovery status
        ]
        
        # Temporal connections (history → current)
        temporal_connections = []
        for i, hist_idx in enumerate(range(6, 12)):
            temporal_connections.extend([
                (hist_idx, 4, 0.5 + i * 0.05),
                (hist_idx, 5, 0.4 + i * 0.05),
            ])
        
        # Athlete identity connections
        identity_connections = [
            (12, 3, 0.6),  # Athlete → Wellness  
            (12, 4, 0.7),  # Athlete → Readiness
            (12, 5, 0.8),  # Athlete → Recovery status
        ]
        
        all_connections = physio_connections + subj_connections + cross_connections + temporal_connections + identity_connections
        
        # Create bidirectional edges
        for src, dst, weight in all_connections:
            edges.extend([(src, dst), (dst, src)])
            edge_weights.extend([weight, weight * 0.9])
        
        edge_index = [[src for src, dst in edges], [dst for src, dst in edges]]
        
        return edge_index, edge_weights
    
    def _create_recovery_targets(self, current: Dict, future: List[Dict]) -> List[float]:
        """Create multi-task recovery targets (Section 3.1)"""
        
        # Primary target: Next-day readiness prediction
        next_day_readiness = current['readiness'] / 8.0
        
        # Secondary target: Recovery quality score
        recovery_score = (
            (8 - current['fatigue']) / 8 * 0.3 +
            current['mood'] / 4 * 0.2 +
            current['readiness'] / 8 * 0.4 +
            (4 - current['stress']) / 4 * 0.1
        )
        
        # Tertiary target: Training readiness
        if current['readiness'] >= 7 and current['fatigue'] <= 2:
            training_readiness = 3.0
        elif current['readiness'] >= 5 and current['fatigue'] <= 3:
            training_readiness = 2.0
        elif current['readiness'] >= 3:
            training_readiness = 1.0
        else:
            training_readiness = 0.0
        
        # Quaternary target: Overreaching risk
        overreach_risk = 0.0
        if current['fatigue'] >= 3 and current['readiness'] <= 4:
            overreach_risk = 0.8
        elif current['fatigue'] >= 3 or current['readiness'] <= 5:
            overreach_risk = 0.4
        
        return [next_day_readiness, recovery_score, training_readiness / 3.0, overreach_risk]
    
    def process_all_participants(self) -> Tuple[List[Data], Dict]:
        """Process all participants and create recovery graphs with caching"""
        import pickle
        import os
        
        cache_file = 'processed_graphs_cache.pkl'
        
        if os.path.exists(cache_file):
            try:
                cache_time = os.path.getmtime(cache_file)
                data_dir_time = os.path.getmtime(self.data_path)
                
                newest_data_time = data_dir_time
                for p in self.data_path.iterdir():
                    if p.is_dir() and p.name.startswith('p'):
                        newest_data_time = max(newest_data_time, os.path.getmtime(p))
                
                if cache_time > newest_data_time:
                    with open(cache_file, 'rb') as f:
                        cached_data = pickle.load(f)
                        all_graphs = cached_data['graphs']
                        participant_stats = cached_data['stats']
                        self.athlete_mapping = cached_data['athlete_mapping']
                    
                    return all_graphs, participant_stats
            except Exception:
                pass
        
        all_graphs = []
        participant_stats = {}
        
        participants = [p for p in self.data_path.iterdir() if p.is_dir() and p.name.startswith('p')]
        participants = sorted(participants, key=lambda x: int(x.name[1:]))
        
        self.athlete_mapping = {p.name: i for i, p in enumerate(participants)}
        
        for participant_path in participants:
            participant_id = participant_path.name
            athlete_idx = self.athlete_mapping[participant_id]
            
            samples = self.load_participant_data(participant_id)
            
            if not samples:
                continue
            
            graphs = self.create_recovery_graphs(samples, participant_id, athlete_idx)
            all_graphs.extend(graphs)
            
            participant_stats[participant_id] = {
                'samples': len(samples),
                'graphs': len(graphs),
                'date_range': (min(s['date'] for s in samples), max(s['date'] for s in samples)) if samples else None
            }
        
        # Save processed data to cache
        try:
            cache_data = {
                'graphs': all_graphs,
                'stats': participant_stats,
                'athlete_mapping': self.athlete_mapping
            }
            with open(cache_file, 'wb') as f:
                pickle.dump(cache_data, f)
        except Exception:
            pass
        
        return all_graphs, participant_stats
    
    def create_contrastive_pairs(self, graphs: List[Data]) -> List[Tuple[Data, Data]]:
        """Create positive pairs for contrastive learning (Section 3.5)"""
        pairs = []
        
        # Group graphs by participant
        participant_graphs = {}
        for graph in graphs:
            pid = graph.participant_id
            if pid not in participant_graphs:
                participant_graphs[pid] = []
            participant_graphs[pid].append(graph)
        
        # Create positive pairs within same participant
        for pid, p_graphs in participant_graphs.items():
            if len(p_graphs) < 2:
                continue
                
            p_graphs.sort(key=lambda x: x.date)
            
            # Pairs of temporally adjacent samples
            for i in range(len(p_graphs) - 1):
                pairs.append((p_graphs[i], p_graphs[i + 1]))
                
            # Pairs of samples within a week
            for i in range(len(p_graphs)):
                for j in range(i + 1, min(i + 8, len(p_graphs))):
                    if abs((p_graphs[i].date - p_graphs[j].date).days) <= 7:
                        pairs.append((p_graphs[i], p_graphs[j]))
        
        return pairs

## 3. Multi-Scale Temporal Processing

In [4]:
class MultiScaleTemporalConv(nn.Module):
    """Multi-scale temporal convolution for capturing different time patterns"""
    
    def __init__(self, in_channels: int, out_channels: int, scales: List[int] = [1, 3, 5, 7]):
        super().__init__()
        self.scales = scales
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels, out_channels // len(scales), 
                     kernel_size=min(k, in_channels), padding=min(k, in_channels)//2)
            for k in scales
        ])
        self.norm = nn.LayerNorm(out_channels)
        self.activation = nn.GELU()
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, in_channels)
        x = x.transpose(1, 2)  # (batch_size, in_channels, seq_len)
        
        outputs = []
        for conv in self.convs:
            conv_out = conv(x)
            outputs.append(conv_out)
        
        x = torch.cat(outputs, dim=1)  # (batch_size, out_channels, seq_len)
        x = x.transpose(1, 2)  # (batch_size, seq_len, out_channels)
        
        return self.activation(self.norm(x))

## 4. Adaptive Graph Structure Learning

In [5]:
class AdaptiveGraphLearner(nn.Module):
    """Learn personalized graph structures for each athlete"""
    
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.feature_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Bilinear edge predictor (Equation 1)
        self.edge_predictor = nn.Bilinear(hidden_dim, hidden_dim, 1)
        self.sparsity_threshold = 0.3
        
    def forward(self, node_features, athlete_idx=None):
        # Encode node features
        encoded_features = self.feature_encoder(node_features)
        
        # Compute pairwise edge probabilities
        num_nodes = encoded_features.size(0)
        edge_probs = torch.zeros(num_nodes, num_nodes, device=node_features.device)
        
        for i in range(num_nodes):
            for j in range(i + 1, num_nodes):
                prob = torch.sigmoid(self.edge_predictor(
                    encoded_features[i].unsqueeze(0),
                    encoded_features[j].unsqueeze(0)
                ).squeeze())
                edge_probs[i, j] = prob
                edge_probs[j, i] = prob  # Symmetric
        
        # Apply sparsity threshold
        adj_matrix = (edge_probs > self.sparsity_threshold).float()
        
        # Convert to edge_index format
        edge_indices = torch.nonzero(adj_matrix, as_tuple=True)
        edge_index = torch.stack(edge_indices, dim=0)
        edge_weights = edge_probs[edge_indices]
        
        return edge_index, edge_weights, encoded_features

## 5. Enhanced Multi-Modal Recovery GNN

In [6]:
class MultiModalRecoveryGNN(nn.Module):
    """Enhanced Multi-Modal Recovery GNN with Temporal and Adaptive Components"""
    
    def __init__(self, hidden_dim: int = 64, num_heads: int = 8):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        
        # Multi-scale temporal processing
        self.temporal_conv = MultiScaleTemporalConv(
            in_channels=hidden_dim, 
            out_channels=hidden_dim,
            scales=[1, 3, 5, 7]
        )
        
        # Adaptive graph learning
        self.graph_learner = AdaptiveGraphLearner(
            input_dim=hidden_dim,
            hidden_dim=hidden_dim
        )
        
        # Enhanced multi-head attention for cross-modal relationships
        self.cross_modal_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=0.1,
            batch_first=True
        )
        
        # Specialized processors for different modalities
        self.physio_processor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1)
        )
        
        self.subjective_processor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1)
        )
        
        self.temporal_processor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1)
        )
        
        # Graph convolution layers
        self.gcn_layers = nn.ModuleList([
            GCNConv(hidden_dim, hidden_dim) for _ in range(2)
        ])
        
        # Residual connections
        self.residual_weights = nn.Parameter(torch.ones(2))
        
    def forward(self, x, edge_index=None, edge_attr=None, batch=None, athlete_idx=None):
        batch_size = x.size(0)
        
        # Learn adaptive graph structure if not provided
        if edge_index is None:
            edge_index, edge_weights, x = self.graph_learner(x, athlete_idx)
        else:
            edge_weights = edge_attr if edge_attr is not None else torch.ones(edge_index.size(1), device=x.device)
        
        # Apply temporal convolutions to capture multi-scale patterns
        if x.dim() == 2:
            x_temporal = x.unsqueeze(1)  # (nodes, 1, features)
        else:
            x_temporal = x
            
        x_temporal = self.temporal_conv(x_temporal)
        x = x_temporal.squeeze(1) if x_temporal.size(1) == 1 else x_temporal.mean(dim=1)
        
        # Node type-specific processing
        node_features = []
        for i in range(batch_size):
            if i < 3:  # Physiological nodes
                processed = self.physio_processor(x[i].unsqueeze(0))
            elif i < 6:  # Subjective nodes
                processed = self.subjective_processor(x[i].unsqueeze(0))
            elif i < 12:  # Temporal nodes
                processed = self.temporal_processor(x[i].unsqueeze(0))
            else:  # Identity node
                processed = x[i].unsqueeze(0)
            node_features.append(processed)
        
        if node_features:
            x_processed = torch.cat(node_features, dim=0)
        else:
            x_processed = x
            
        # Multi-layer graph convolution with residual connections (Equation 2)
        h = x_processed
        for i, gcn in enumerate(self.gcn_layers):
            h_new = gcn(h, edge_index, edge_weights)
            h = self.residual_weights[i] * h + (1 - self.residual_weights[i]) * h_new
            h = F.gelu(h)
            h = F.dropout(h, p=0.1, training=self.training)
        
        # Cross-modal attention for final refinement
        if h.size(0) > 1:
            h_attended, attention_weights = self.cross_modal_attention(
                h.unsqueeze(0), h.unsqueeze(0), h.unsqueeze(0)
            )
            h = h_attended.squeeze(0)
        
        return h

## 6. Temporal Attention Recovery Network (TARN)

In [7]:
class TemporalAttentionRecoveryNetwork(nn.Module):
    """Temporal Attention Recovery Network (TARN) - Core Innovation"""
    
    def __init__(self, input_dim: int, hidden_dim: int = 128, num_layers: int = 3, num_heads: int = 8):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Input projection
        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Multi-layer recovery-specific GNN
        self.recovery_layers = nn.ModuleList([
            MultiModalRecoveryGNN(hidden_dim, num_heads) for _ in range(num_layers)
        ])
        
        # Temporal attention for time-aware processing
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=0.1,
            batch_first=True
        )
        
        # Graph-level representation
        self.graph_pooling = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),  # Mean + max pooling
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Individual recovery signature learning
        self.signature_learning = nn.Sequential(
            nn.Linear(hidden_dim + 4, hidden_dim),  # Graph + athlete features
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2)
        )
        
    def forward(self, x, edge_index, edge_attr, batch, athlete_idx):
        """Forward pass through TARN"""
        
        # Project input features
        h = self.input_projection(x)
        
        # Multi-layer recovery-specific message passing
        for layer in self.recovery_layers:
            h_new = layer(h, edge_index, edge_attr, batch)
            h = h + h_new  # Residual connection
            h = F.layer_norm(h, h.shape[-1:])  # Normalize
        
        # Temporal attention (if batch contains multiple time steps)
        if batch is not None:
            unique_batches = torch.unique(batch)
            attended_features = []
            
            for b in unique_batches:
                mask = (batch == b)
                graph_nodes = h[mask].unsqueeze(0)  # Add batch dimension
                
                # Self-attention within graph
                attended, attention_weights = self.temporal_attention(
                    graph_nodes, graph_nodes, graph_nodes
                )
                attended_features.append(attended.squeeze(0))
            
            h = torch.cat(attended_features, dim=0)
        
        # Graph-level pooling
        if batch is not None:
            graph_features = []
            for b in torch.unique(batch):
                mask = (batch == b)
                graph_nodes = h[mask]
                
                # Mean and max pooling
                mean_pool = torch.mean(graph_nodes, dim=0)
                max_pool = torch.max(graph_nodes, dim=0)[0]
                
                graph_repr = torch.cat([mean_pool, max_pool])
                graph_features.append(graph_repr)
            
            graph_features = torch.stack(graph_features)
        else:
            # Single graph case
            mean_pool = torch.mean(h, dim=0)
            max_pool = torch.max(h, dim=0)[0]
            graph_features = torch.cat([mean_pool, max_pool]).unsqueeze(0)
        
        # Apply graph pooling transformation
        graph_features = self.graph_pooling(graph_features)
        
        # Individual recovery signature learning
        if athlete_idx is not None:
            # Add athlete-specific features
            athlete_features = torch.zeros(graph_features.size(0), 4, device=graph_features.device)
            if isinstance(athlete_idx, torch.Tensor):
                athlete_features[:, 0] = athlete_idx.float() / 16.0
            else:
                athlete_features[:, 0] = athlete_idx / 16.0
            
            # Combine with graph features
            combined_features = torch.cat([graph_features, athlete_features], dim=-1)
            recovery_signature = self.signature_learning(combined_features)
            
            # Final representation
            final_features = torch.cat([graph_features, recovery_signature], dim=-1)
        else:
            final_features = graph_features
        
        return final_features

## 7. Contrastive Learning Components

In [8]:
class ContrastiveEncoder(nn.Module):
    """Encoder for contrastive self-supervised pre-training"""
    
    def __init__(self, input_dim: int, hidden_dim: int = 128, projection_dim: int = 64):
        super().__init__()
        self.encoder = TemporalAttentionRecoveryNetwork(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_layers=2,
            num_heads=4
        )
        
        self.projector = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim // 2, projection_dim),
            nn.LayerNorm(projection_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim, projection_dim)
        )
    
    def forward(self, x, edge_index, edge_attr, batch=None, athlete_idx=None):
        # Get representations from encoder
        features = self.encoder(x, edge_index, edge_attr, batch, athlete_idx)
        
        # Project to contrastive space
        projections = self.projector(features)
        
        return features, projections

## 8. Main GraphRec Model

In [9]:
class GraphRecModel(nn.Module):
    """Enhanced GraphRec Model with Contrastive Pre-training Support"""
    
    def __init__(self, input_dim: int, hidden_dim: int = 128, num_athletes: int = 16, 
                 use_pretrained: bool = False, pretrained_encoder=None):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_athletes = num_athletes
        self.use_pretrained = use_pretrained
        
        # Core TARN network (potentially pre-trained)
        if use_pretrained and pretrained_encoder is not None:
            self.tarn = pretrained_encoder.encoder
            # Freeze pre-trained weights initially
            for param in self.tarn.parameters():
                param.requires_grad = False
        else:
            self.tarn = TemporalAttentionRecoveryNetwork(
                input_dim=input_dim,
                hidden_dim=hidden_dim,
                num_layers=3,
                num_heads=8
            )
        
        # Multi-task prediction heads
        final_dim = hidden_dim + hidden_dim // 2
        
        # Primary task: Next-day readiness prediction
        self.readiness_predictor = nn.Sequential(
            nn.Linear(final_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # Secondary task: Recovery quality score
        self.recovery_quality_predictor = nn.Sequential(
            nn.Linear(final_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # Tertiary task: Training readiness classification
        self.training_readiness_classifier = nn.Sequential(
            nn.Linear(final_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # Quaternary task: Overreaching risk detection
        self.overreach_detector = nn.Sequential(
            nn.Linear(final_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # Task importance weighting
        self.task_attention = nn.Sequential(
            nn.Linear(final_dim, 4),
            nn.Softmax(dim=-1)
        )
    
    def unfreeze_encoder(self):
        """Unfreeze pre-trained encoder for fine-tuning"""
        if self.use_pretrained:
            for param in self.tarn.parameters():
                param.requires_grad = True
        
    def forward(self, x, edge_index, edge_attr, batch=None, athlete_idx=None):
        """Forward pass through GraphRec"""
        
        # Core TARN processing
        recovery_features = self.tarn(x, edge_index, edge_attr, batch, athlete_idx)
        
        # Multi-task predictions
        readiness_pred = self.readiness_predictor(recovery_features)
        quality_pred = self.recovery_quality_predictor(recovery_features)
        training_pred = self.training_readiness_classifier(recovery_features)
        overreach_pred = self.overreach_detector(recovery_features)
        
        # Task importance weighting
        task_weights = self.task_attention(recovery_features)
        
        return {
            'readiness': readiness_pred,
            'quality': quality_pred,
            'training': training_pred,
            'overreach': overreach_pred,
            'task_weights': task_weights,
            'features': recovery_features
        }

## 9. Physics-Informed Loss Functions

In [10]:
class PhysicsInformedLoss(nn.Module):
    """Physics-informed loss with physiological constraints"""
    
    def __init__(self):
        super().__init__()
        
        # Physiological bounds (normalized values)
        self.hrv_bounds = (0.0, 1.0)
        self.sleep_efficiency_bounds = (0.3, 1.0)
        self.readiness_bounds = (0.0, 1.0)
        self.fatigue_bounds = (0.0, 1.0)
        
    def compute_physics_violations(self, predictions: Dict, node_features: torch.Tensor = None):
        """Compute violations of physiological constraints (Equations 5-7)"""
        violations = []
        
        # Readiness-quality correlation constraint
        readiness = predictions['readiness'].squeeze()
        quality = predictions['quality'].squeeze()
        
        readiness_quality_violation = torch.relu(-torch.corrcoef(torch.stack([readiness, quality]))[0, 1])
        violations.append(readiness_quality_violation)
        
        # Overreach risk constraint
        overreach = predictions['overreach'].squeeze()
        fatigue_proxy = 1.0 - readiness
        expected_correlation = torch.corrcoef(torch.stack([fatigue_proxy, overreach]))[0, 1]
        overreach_violation = torch.relu(-expected_correlation)
        violations.append(overreach_violation)
        
        # Bounds violations
        readiness_bounds_violation = (
            torch.relu(self.readiness_bounds[0] - readiness).mean() +
            torch.relu(readiness - self.readiness_bounds[1]).mean()
        )
        violations.append(readiness_bounds_violation)
        
        return torch.stack(violations).mean()
    
    def forward(self, predictions: Dict, node_features: torch.Tensor = None, physics_weight: float = 0.1):
        """Compute physics-informed loss"""
        physics_loss = self.compute_physics_violations(predictions, node_features)
        return physics_weight * physics_loss


class ContrastiveLoss(nn.Module):
    """Contrastive learning loss for self-supervised pre-training"""
    
    def __init__(self, temperature: float = 0.5):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, z1: torch.Tensor, z2: torch.Tensor):
        """InfoNCE loss for contrastive learning (Equation 8)"""
        # Normalize embeddings
        z1 = F.normalize(z1, dim=-1)
        z2 = F.normalize(z2, dim=-1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(z1, z2.T) / self.temperature
        
        # Labels are diagonal (positive pairs)
        batch_size = z1.size(0)
        labels = torch.arange(batch_size, device=z1.device)
        
        loss = F.cross_entropy(similarity_matrix, labels)
        return loss


class CorrelationLoss(nn.Module):
    """Loss function that directly optimizes correlation"""
    
    def __init__(self):
        super().__init__()
        
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor):
        """Compute negative correlation as loss (maximize correlation)"""
        if predictions.dim() > 1:
            predictions = predictions.flatten()
        if targets.dim() > 1:
            targets = targets.flatten()
            
        # Compute Pearson correlation
        vx = predictions - torch.mean(predictions)
        vy = targets - torch.mean(targets)
        
        correlation = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)) + 1e-8)
        
        return -correlation

## 10. Enhanced Multi-Objective Loss Function

In [11]:
class GraphRecLoss(nn.Module):
    """Enhanced multi-objective loss function for recovery optimization"""
    
    def __init__(self, task_weights: List[float] = [0.4, 0.3, 0.2, 0.1], use_correlation_loss: bool = True):
        super().__init__()
        self.task_weights = task_weights
        self.use_correlation_loss = use_correlation_loss
        
        # Standard losses
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCELoss()
        self.l1_loss = nn.L1Loss()
        
        # Enhanced losses
        self.correlation_loss = CorrelationLoss()
        self.physics_loss = PhysicsInformedLoss()
        self.contrastive_loss = ContrastiveLoss()
        
        # Focal loss parameters
        self.focal_alpha = 0.25
        self.focal_gamma = 2.0
        
    def focal_loss(self, predictions: torch.Tensor, targets: torch.Tensor):
        """Focal loss for handling imbalanced data"""
        ce_loss = F.binary_cross_entropy_with_logits(predictions, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.focal_alpha * (1-pt)**self.focal_gamma * ce_loss
        return focal_loss.mean()
    
    def forward(self, predictions: Dict, targets: torch.Tensor, task_attention_weights: torch.Tensor = None, 
                node_features: torch.Tensor = None, epoch: int = 0, max_epochs: int = 150):
        """Calculate enhanced multi-objective loss"""
        
        # Handle different target tensor shapes
        if targets.dim() == 0:
            targets = torch.tensor([[0.5, 0.5, 0.5, 0.5]], device=targets.device)
        elif targets.dim() == 1:
            if targets.size(0) == 4:
                targets = targets.unsqueeze(0)
            else:
                targets = targets.view(-1, 4)
        
        batch_size = predictions['readiness'].size(0)
        
        if targets.size(0) != batch_size:
            if targets.size(0) == 1:
                targets = targets.expand(batch_size, -1)
            else:
                raise ValueError(f"Target batch size {targets.size(0)} doesn't match prediction batch size {batch_size}")
        
        # Extract predictions and targets
        readiness_pred = predictions['readiness'].squeeze(-1)
        quality_pred = predictions['quality'].squeeze(-1)
        training_pred = predictions['training'].squeeze(-1)
        overreach_pred = predictions['overreach'].squeeze(-1)
        
        # Extract targets with proper dimensions
        if targets.size(1) >= 4:
            readiness_target = targets[:, 0]
            quality_target = targets[:, 1]
            training_target = targets[:, 2]
            overreach_target = targets[:, 3]
        else:
            # Fallback for malformed targets
            readiness_target = torch.zeros_like(readiness_pred)
            quality_target = torch.zeros_like(quality_pred) 
            training_target = torch.zeros_like(training_pred)
            overreach_target = torch.zeros_like(overreach_pred)
        
        # Standard losses with correlation enhancement
        if self.use_correlation_loss and batch_size > 1:
            readiness_loss = 0.5 * self.mse_loss(readiness_pred, readiness_target) + 0.5 * self.correlation_loss(readiness_pred, readiness_target)
            quality_loss = 0.5 * self.mse_loss(quality_pred, quality_target) + 0.5 * self.correlation_loss(quality_pred, quality_target)
            training_loss = 0.5 * self.mse_loss(training_pred, training_target) + 0.5 * self.correlation_loss(training_pred, training_target)
        else:
            readiness_loss = self.mse_loss(readiness_pred, readiness_target)
            quality_loss = self.mse_loss(quality_pred, quality_target)
            training_loss = self.mse_loss(training_pred, training_target)
        
        # Focal loss for overreach detection
        overreach_loss = self.focal_loss(overreach_pred, overreach_target)
        
        # Base weighted loss
        base_loss = (
            self.task_weights[0] * readiness_loss +
            self.task_weights[1] * quality_loss +
            self.task_weights[2] * training_loss +
            self.task_weights[3] * overreach_loss
        )
        
        # Physics-informed loss with dynamic weighting
        physics_loss_val = torch.tensor(0.0, device=readiness_pred.device)
        if batch_size > 1:
            physics_weight = max(0.1, 0.5 * (1 - epoch / max_epochs))
            try:
                physics_loss_val = self.physics_loss(predictions, node_features, physics_weight)
            except Exception:
                physics_loss_val = torch.tensor(0.0, device=readiness_pred.device)
        
        # Adaptive weighting based on attention
        if task_attention_weights is not None:
            attention_weights = torch.mean(task_attention_weights, dim=0)
            adaptive_loss = (
                attention_weights[0] * readiness_loss +
                attention_weights[1] * quality_loss +
                attention_weights[2] * training_loss +
                attention_weights[3] * overreach_loss
            )
            total_loss = 0.6 * base_loss + 0.3 * adaptive_loss + 0.1 * physics_loss_val
        else:
            total_loss = 0.9 * base_loss + 0.1 * physics_loss_val
        
        return {
            'total_loss': total_loss,
            'readiness_loss': readiness_loss,
            'quality_loss': quality_loss,
            'training_loss': training_loss,
            'overreach_loss': overreach_loss,
            'physics_loss': physics_loss_val
        }

## 11. Contrastive Pre-training Function

In [12]:
def pretrain_contrastive_encoder(graphs: List[Data], input_dim: int, hidden_dim: int = 128, 
                                epochs: int = 50, batch_size: int = 64):
    """Pre-train encoder using contrastive learning"""
    
    # Create contrastive pairs
    processor = DataProcessor()
    contrastive_pairs = processor.create_contrastive_pairs(graphs)
    
    if not contrastive_pairs:
        return None
    
    # Initialize encoder
    encoder = ContrastiveEncoder(input_dim, hidden_dim).to(device)
    contrastive_loss = ContrastiveLoss(temperature=0.5)
    optimizer = torch.optim.AdamW(encoder.parameters(), lr=0.001, weight_decay=1e-4)
    
    # Training loop
    encoder.train()
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        
        random.shuffle(contrastive_pairs)
        
        for i in range(0, len(contrastive_pairs), batch_size):
            batch_pairs = contrastive_pairs[i:i + batch_size]
            
            if len(batch_pairs) < 2:
                continue
                
            optimizer.zero_grad()
            
            z1_list, z2_list = [], []
            
            for graph1, graph2 in batch_pairs:
                try:
                    # Move graphs to device
                    graph1.x = graph1.x.to(device)
                    graph1.edge_index = graph1.edge_index.to(device)
                    graph1.edge_attr = graph1.edge_attr.to(device)
                    graph2.x = graph2.x.to(device)
                    graph2.edge_index = graph2.edge_index.to(device)
                    graph2.edge_attr = graph2.edge_attr.to(device)
                    
                    _, z1 = encoder(graph1.x, graph1.edge_index, graph1.edge_attr, 
                                  None, getattr(graph1, 'athlete_idx', None))
                    _, z2 = encoder(graph2.x, graph2.edge_index, graph2.edge_attr,
                                  None, getattr(graph2, 'athlete_idx', None))
                    
                    z1_list.append(z1.squeeze() if z1.dim() > 1 else z1)
                    z2_list.append(z2.squeeze() if z2.dim() > 1 else z2)
                except Exception:
                    continue
            
            if len(z1_list) > 1:
                try:
                    z1_batch = torch.stack(z1_list)
                    z2_batch = torch.stack(z2_list)
                    
                    loss = contrastive_loss(z1_batch, z2_batch)
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    num_batches += 1
                except Exception:
                    continue
        
        if num_batches > 0:
            avg_loss = total_loss / num_batches
            if (epoch + 1) % 10 == 0:
                print(f"Pre-training Epoch {epoch+1:3d} | Contrastive Loss: {avg_loss:.4f}")
    
    return encoder

## 12. Training Pipeline

In [13]:
def train_graphrec(use_contrastive_pretraining: bool = True, use_data_augmentation: bool = True):
    """Train the GraphRec system"""
    
    # Data processing
    processor = DataProcessor()
    graphs, stats = processor.process_all_participants()
    
    # Apply data augmentation if requested
    if use_data_augmentation:
        original_count = len(graphs)
    
    if not graphs:
        return
    
    print(f"Generated {len(graphs)} recovery graphs")
    
    # Dataset statistics
    total_samples = sum(s['samples'] for s in stats.values())
    total_graphs = sum(s['graphs'] for s in stats.values())
    print(f"Total samples: {total_samples}, Total graphs: {total_graphs}, Participants: {len(stats)}")
    
    # Data splitting
    train_graphs, temp_graphs = train_test_split(graphs, test_size=0.4, random_state=42)
    val_graphs, test_graphs = train_test_split(temp_graphs, test_size=0.5, random_state=42)
    
    print(f"Train/Val/Test: {len(train_graphs)}/{len(val_graphs)}/{len(test_graphs)}")
    
    # Create data loaders with device transfer
    train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_graphs, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_graphs, batch_size=64, shuffle=False)
    
    def move_batch_to_device(batch, device):
        batch.x = batch.x.to(device)
        batch.edge_index = batch.edge_index.to(device)
        batch.edge_attr = batch.edge_attr.to(device)
        batch.y = batch.y.to(device)
        if hasattr(batch, 'batch'):
            batch.batch = batch.batch.to(device)
        return batch
    
    # Model setup with optional contrastive pre-training
    input_dim = train_graphs[0].x.size(1)
    
    pretrained_encoder = None
    if use_contrastive_pretraining:
        pretrained_encoder = pretrain_contrastive_encoder(
            train_graphs, input_dim, hidden_dim=128, epochs=30
        )
    
    model = GraphRecModel(
        input_dim=input_dim, 
        hidden_dim=128, 
        num_athletes=16,
        use_pretrained=use_contrastive_pretraining,
        pretrained_encoder=pretrained_encoder
    ).to(device)
    
    print(f"Model Architecture: Input dim: {input_dim}, Hidden dim: 128, Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Training setup with enhanced loss function
    criterion = GraphRecLoss(use_correlation_loss=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    
    # Unfreeze pre-trained encoder after warmup
    warmup_epochs = 20
    
    # Training loop
    best_val_loss = float('inf')
    best_model_state = None
    patience = 0
    max_patience = 25
    
    for epoch in range(150):
        # Unfreeze encoder after warmup
        if epoch == warmup_epochs and model.use_pretrained:
            model.unfreeze_encoder()
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        
        # Training phase
        model.train()
        total_loss = 0
        all_predictions = {'readiness': [], 'quality': [], 'training': [], 'overreach': []}
        all_targets = {'readiness': [], 'quality': [], 'training': [], 'overreach': []}
        
        for batch in train_loader:
            batch = move_batch_to_device(batch, device)
            
            optimizer.zero_grad()
            
            # Forward pass
            predictions = model(
                batch.x, 
                batch.edge_index, 
                batch.edge_attr,
                batch.batch,
                getattr(batch, 'athlete_idx', None)
            )
            
            # Loss calculation with enhanced features
            loss_dict = criterion(
                predictions, 
                batch.y, 
                predictions['task_weights'],
                batch.x,
                epoch,
                150
            )
            loss = loss_dict['total_loss']
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            
            # Collect predictions for metrics
            all_predictions['readiness'].extend(predictions['readiness'].detach().cpu().numpy().flatten())
            all_predictions['quality'].extend(predictions['quality'].detach().cpu().numpy().flatten())
            all_predictions['training'].extend(predictions['training'].detach().cpu().numpy().flatten())
            all_predictions['overreach'].extend(predictions['overreach'].detach().cpu().numpy().flatten())
            
            # Handle PyG batching of y tensors
            if batch.y.dim() == 2:
                all_targets['readiness'].extend(batch.y[:, 0].cpu().numpy())
                all_targets['quality'].extend(batch.y[:, 1].cpu().numpy())
                all_targets['training'].extend(batch.y[:, 2].cpu().numpy())
                all_targets['overreach'].extend(batch.y[:, 3].cpu().numpy())
            elif batch.y.dim() == 1:
                batch_size = len(batch.ptr) - 1
                y_reshaped = batch.y.view(batch_size, 4)
                all_targets['readiness'].extend(y_reshaped[:, 0].cpu().numpy())
                all_targets['quality'].extend(y_reshaped[:, 1].cpu().numpy())
                all_targets['training'].extend(y_reshaped[:, 2].cpu().numpy())
                all_targets['overreach'].extend(y_reshaped[:, 3].cpu().numpy())
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_predictions = {'readiness': [], 'quality': [], 'training': [], 'overreach': []}
        val_targets = {'readiness': [], 'quality': [], 'training': [], 'overreach': []}
        
        with torch.no_grad():
            for batch in val_loader:
                batch = move_batch_to_device(batch, device)
                
                predictions = model(
                    batch.x,
                    batch.edge_index,
                    batch.edge_attr,
                    batch.batch,
                    getattr(batch, 'athlete_idx', None)
                )
                
                loss_dict = criterion(
                    predictions, 
                    batch.y, 
                    predictions['task_weights'],
                    batch.x,
                    epoch,
                    150
                )
                val_loss += loss_dict['total_loss'].item()
                
                # Collect validation predictions
                val_predictions['readiness'].extend(predictions['readiness'].detach().cpu().numpy().flatten())
                val_predictions['quality'].extend(predictions['quality'].detach().cpu().numpy().flatten())
                val_predictions['training'].extend(predictions['training'].detach().cpu().numpy().flatten())
                val_predictions['overreach'].extend(predictions['overreach'].detach().cpu().numpy().flatten())
                
                # Handle PyG batching of y tensors
                if batch.y.dim() == 2:
                    val_targets['readiness'].extend(batch.y[:, 0].cpu().numpy())
                    val_targets['quality'].extend(batch.y[:, 1].cpu().numpy())
                    val_targets['training'].extend(batch.y[:, 2].cpu().numpy())
                    val_targets['overreach'].extend(batch.y[:, 3].cpu().numpy())
                elif batch.y.dim() == 1:
                    batch_size = len(batch.ptr) - 1
                    y_reshaped = batch.y.view(batch_size, 4)
                    val_targets['readiness'].extend(y_reshaped[:, 0].cpu().numpy())
                    val_targets['quality'].extend(y_reshaped[:, 1].cpu().numpy())
                    val_targets['training'].extend(y_reshaped[:, 2].cpu().numpy())
                    val_targets['overreach'].extend(y_reshaped[:, 3].cpu().numpy())
        
        # Calculate metrics
        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        # Readiness prediction accuracy (main metric)
        readiness_mae = mean_absolute_error(val_targets['readiness'], val_predictions['readiness'])
        readiness_rmse = np.sqrt(mean_squared_error(val_targets['readiness'], val_predictions['readiness']))
        
        scheduler.step()
        
        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            patience = 0
        else:
            patience += 1
        
        # Progress reporting
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d} | Loss: {avg_train_loss:.4f}/{avg_val_loss:.4f} | "
                  f"Readiness MAE: {readiness_mae:.3f} | RMSE: {readiness_rmse:.3f}")
        
        if patience >= max_patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break
    
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
    
    # Final evaluation
    model.eval()
    test_predictions = {'readiness': [], 'quality': [], 'training': [], 'overreach': []}
    test_targets = {'readiness': [], 'quality': [], 'training': [], 'overreach': []}
    
    with torch.no_grad():
        for batch in test_loader:
            batch = move_batch_to_device(batch, device)
            
            predictions = model(
                batch.x,
                batch.edge_index,
                batch.edge_attr,
                batch.batch,
                getattr(batch, 'athlete_idx', None)
            )
            
            test_predictions['readiness'].extend(predictions['readiness'].detach().cpu().numpy().flatten())
            test_predictions['quality'].extend(predictions['quality'].detach().cpu().numpy().flatten())
            test_predictions['training'].extend(predictions['training'].detach().cpu().numpy().flatten())
            test_predictions['overreach'].extend(predictions['overreach'].detach().cpu().numpy().flatten())
            
            if batch.y.dim() == 2:
                test_targets['readiness'].extend(batch.y[:, 0].cpu().numpy())
                test_targets['quality'].extend(batch.y[:, 1].cpu().numpy())
                test_targets['training'].extend(batch.y[:, 2].cpu().numpy())
                test_targets['overreach'].extend(batch.y[:, 3].cpu().numpy())
            elif batch.y.dim() == 1:
                batch_size = len(batch.ptr) - 1
                y_reshaped = batch.y.view(batch_size, 4)
                test_targets['readiness'].extend(y_reshaped[:, 0].cpu().numpy())
                test_targets['quality'].extend(y_reshaped[:, 1].cpu().numpy())
                test_targets['training'].extend(y_reshaped[:, 2].cpu().numpy())
                test_targets['overreach'].extend(y_reshaped[:, 3].cpu().numpy())
    
    # Calculate final metrics
    results = {}
    for task in ['readiness', 'quality', 'training', 'overreach']:
        mae = mean_absolute_error(test_targets[task], test_predictions[task])
        rmse = np.sqrt(mean_squared_error(test_targets[task], test_predictions[task]))
        correlation = np.corrcoef(test_targets[task], test_predictions[task])[0, 1]
        
        results[task] = {
            'mae': mae,
            'rmse': rmse,
            'correlation': correlation
        }
    
    # Print results
    print("\nGRAPHREC RESULTS:")
    print("=" * 50)
    for task, metrics in results.items():
        print(f"{task.upper()} Prediction:")
        print(f"   MAE: {metrics['mae']:.4f}")
        print(f"   RMSE: {metrics['rmse']:.4f}")
        print(f"   Correlation: {metrics['correlation']:.4f}")
        print()
    
    # Overall performance
    avg_correlation = np.mean([r['correlation'] for r in results.values()])
    avg_mae = np.mean([r['mae'] for r in results.values()])
    
    print(f"OVERALL PERFORMANCE:")
    print(f"   Average Correlation: {avg_correlation:.4f}")
    print(f"   Average MAE: {avg_mae:.4f}")
    
    # Save model
    torch.save({
        'model_state_dict': model.state_dict(),
        'results': results,
        'processor': processor,
        'model_config': {
            'input_dim': input_dim,
            'hidden_dim': 128,
            'num_athletes': 16
        }
    }, 'graphrec_model.pth')
    
    return model, results

## 13. Model Training and Evaluation

In [15]:
model, results = train_graphrec(
    use_contrastive_pretraining=False,
    use_data_augmentation=True
)

Generated 1651 recovery graphs
Total samples: 1747, Total graphs: 1651, Participants: 16
Train/Val/Test: 990/330/331
Model Architecture: Input dim: 9, Hidden dim: 128, Parameters: 922,577
Epoch   1 | Loss: -0.2183/-0.3610 | Readiness MAE: 0.149 | RMSE: 0.182
Epoch  10 | Loss: -0.3829/-0.3939 | Readiness MAE: 0.043 | RMSE: 0.062
Epoch  20 | Loss: -0.3911/-0.3993 | Readiness MAE: 0.068 | RMSE: 0.086
Epoch  30 | Loss: -0.3884/-0.3950 | Readiness MAE: 0.045 | RMSE: 0.062
Epoch  40 | Loss: -0.3937/-0.3986 | Readiness MAE: 0.056 | RMSE: 0.072
Epoch  50 | Loss: -0.3946/-0.4031 | Readiness MAE: 0.025 | RMSE: 0.044
Epoch  60 | Loss: -0.3982/-0.4057 | Readiness MAE: 0.046 | RMSE: 0.059
Epoch  70 | Loss: -0.4004/-0.4067 | Readiness MAE: 0.036 | RMSE: 0.048
Epoch  80 | Loss: -0.4017/-0.4070 | Readiness MAE: 0.031 | RMSE: 0.045
Epoch  90 | Loss: -0.4022/-0.4075 | Readiness MAE: 0.022 | RMSE: 0.039
Epoch 100 | Loss: -0.4027/-0.4075 | Readiness MAE: 0.025 | RMSE: 0.041
Epoch 110 | Loss: -0.4036/-0.40