In [88]:
def clean_time_series_data(df, cgm_cols=["Dexcom GL", "Libre GL"], activity_cols=["HR", "METs"]):
    """Clean and validate time-series data"""
    # Copy to avoid modifying original
    df = df.copy()
    
    # 1. Handle missing values
    for col in cgm_cols + activity_cols:
        if col in df.columns:
            # Linear interpolation with 5-minute window
            df[col] = df[col].interpolate(method='linear', limit=5)
            # Forward/backward fill remaining
            df[col] = df[col].ffill().bfill()
    
    # 2. Remove physiologically impossible values
    if "Dexcom GL" in df.columns:
        df = df[(df["Dexcom GL"] >= 40) & (df["Dexcom GL"] <= 400)]
    if "HR" in df.columns:
        df = df[(df["HR"] >= 40) & (df["HR"] <= 200)]
    if "METs" in df.columns:
        df = df[(df["METs"] >= 0.5) & (df["METs"] <= 20)]
    
    # 3. Smooth noisy data (Savitzky-Golay filter)
    for col in cgm_cols + activity_cols:
        if col in df.columns:
            try:
                from scipy.signal import savgol_filter
                df[col] = savgol_filter(df[col], window_length=15, polyorder=2)
            except ImportError:
                # Fallback to rolling mean if scipy not available
                df[col] = df[col].rolling(window=15, min_periods=1, center=True).mean()
    
    return df

In [160]:
import pandas as pd
import os, pdb, copy
import numpy as np
from torch.utils.data import Dataset
import torch
from sklearn.model_selection import train_test_split
import cv2
from tqdm import tqdm

def get_image(
    img_filename: str,
    subject_id: int,
    target_size: tuple,
    cgmacros_path: str = "CGMacros 2/",
) -> np.ndarray:
    subject_path = f"CGMacros-{subject_id:03d}/"
    img_path = f"{cgmacros_path}{subject_path}{img_filename}"
    if not os.path.exists(img_path):
        print(f"File {img_path} not found")
        raise FileNotFoundError
    # Loading names out
    img_data = cv2.resize(
        cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB),
        target_size,
        interpolation=cv2.INTER_LANCZOS4,
    )
    return img_data


def load_CGMacros(
    subject_id: int,
    csv_dir: str = "CGMacros 2",
) -> pd.DataFrame:
    if type(subject_id) != int:
        print("subject_id should be an integer")
        raise ValueError
    subejct_path = f"CGMacros-{subject_id:03d}/CGMacros-{subject_id:03d}.csv"
    subject_file = os.path.join(csv_dir, subejct_path)
    if not os.path.exists(subject_file):
        tqdm.write(f"File {subject_file} not found")
        raise FileNotFoundError
    dataset_df = pd.read_csv(subject_file, index_col=None)
    dataset_df["Timestamp"] = pd.to_datetime(dataset_df["Timestamp"])
    dataset_df = clean_time_series_data(dataset_df)  # Add cleaning step
    return dataset_df.set_index("Timestamp")

def load_daily_traces(
    dataset_df: pd.DataFrame,
    subject_id: int,
    cgm_cols=["Dexcom GL", "Libre GL"],
    activity_cols=["HR", "Calories (Activity)"],
    img_size=(112, 112)
):
    """
    Enhanced version with data cleaning, normalization, and NaN handling using median for time-series data (CGM and activity).
    
    Returns:
        tuple: (
            days_list: List of date strings,
            cgm_daily_data: Normalized array of shape (n_days, n_cgm_cols, 1440),
            activity_daily_data: Normalized array of shape (n_days, n_activity_cols, 1440),
            image_data_by_day: dict of image data,
            nutrition_data_by_day: dict of nutrition data,
            cgm_stats: dict of normalization parameters {'mean', 'std'},
            activity_stats: dict of normalization parameters {'mean', 'std'}
        )
    """
    # 1. Clean and preprocess the raw data
    def clean_series(series):
        """Helper function to clean a time series"""
        # Interpolation
        series = series.interpolate(method='linear', limit=5).ffill().bfill()
        
        # Smoothing (Savitzky-Golay if available, else rolling mean)
        try:
            from scipy.signal import savgol_filter
            series = savgol_filter(series, window_length=15, polyorder=2)
        except ImportError:
            series = series.rolling(window=15, min_periods=1, center=True).mean()
        
        return series

    # 2. Apply cleaning to each relevant column (CGM and activity)
    cleaned_df = dataset_df.copy()
    for col in cgm_cols + activity_cols:
        if col in cleaned_df.columns:
            cleaned_df[col] = clean_series(cleaned_df[col])
    
    # 3. Resample to 1-minute frequency and handle NaNs after resampling with median
    resampled_df = cleaned_df.resample('1min').ffill(limit=5)
    
    # Fill NaNs with the median of each column (for both CGM and activity data)
    for col in cgm_cols + activity_cols:
        resampled_df[col] = resampled_df[col].fillna(resampled_df[col].median())
    
    # 4. Calculate normalization statistics
    cgm_stats = {
        'mean': resampled_df[cgm_cols].mean().values,
        'std': resampled_df[cgm_cols].std().values
    }
    activity_stats = {
        'mean': resampled_df[activity_cols].mean().values,
        'std': resampled_df[activity_cols].std().values
    }

    # 5. Initialize arrays for CGM and activity data
    days = pd.Series(resampled_df.index.date).unique()
    days_list = [str(day) for day in days]
    cgm_daily_data = np.full((len(days), len(cgm_cols), 1440), np.nan)
    activity_daily_data = np.full((len(days), len(activity_cols), 1440), np.nan)
    image_data_by_day = {}
    nutrition_data_by_day = {}

    # 6. Process each day
    for i, day in enumerate(days):
        day_start = pd.Timestamp(day)
        day_end = day_start + pd.Timedelta(days=1) - pd.Timedelta(minutes=1)
        day_data = resampled_df.loc[day_start:day_end]
        
        if not day_data.empty:
            minutes_of_day = (day_data.index.hour * 60 + day_data.index.minute).values
            
            # Process and normalize CGM data
            for j, col in enumerate(cgm_cols):
                if col in day_data.columns:
                    vals = day_data[col].values
                    # Replace NaN values with the median of the column for that day
                    vals = np.nan_to_num(vals, nan=day_data[col].median())
                    # Normalize the values
                    vals = (vals - cgm_stats['mean'][j]) / (cgm_stats['std'][j] + 1e-8)
                    cgm_daily_data[i, j, minutes_of_day] = vals
            
            # Process and normalize activity data
            for j, col in enumerate(activity_cols):
                if col in day_data.columns:
                    vals = day_data[col].values
                    # Replace NaN values with the median of the column for that day
                    vals = np.nan_to_num(vals, nan=day_data[col].median())
                    # Normalize the values
                    vals = (vals - activity_stats['mean'][j]) / (activity_stats['std'][j] + 1e-8)
                    activity_daily_data[i, j, minutes_of_day] = vals
        
        # Process nutrition and image data (with NaN handling)
        day_str = str(day)
        original_day_data = dataset_df.loc[day_start:day_end]
        
        # Nutrition data processing...
        nutrition_rows = original_day_data.dropna(subset=['Calories', 'Carbs', 'Protein', 'Fat', 'Fiber'], how='all')
        day_nutrition = []
        for ts, row in nutrition_rows.iterrows():
            nutrition = {
                'timestamp': ts.strftime('%Y-%m-%d %H:%M:%S'),
                'MealType': row['Meal Type'],
                'calories': row['Calories'] if pd.notna(row['Calories']) else 0,
                'carbs': row['Carbs'] if pd.notna(row['Carbs']) else 0,
                'protein': row['Protein'] if pd.notna(row['Protein']) else 0,
                'fat': row['Fat'] if pd.notna(row['Fat']) else 0,
                'fiber': row['Fiber'] if pd.notna(row['Fiber']) else 0,
                'has_image': pd.notna(row['Image path'])
            }
            day_nutrition.append(nutrition)
        nutrition_data_by_day[day_str] = day_nutrition
        
        # Image data processing (with NaN handling for missing paths)
        image_rows = original_day_data.dropna(subset=['Image path'])
        day_images = []
        for ts, row in image_rows.iterrows():
            try:
                img_data = get_image(row['Image path'], subject_id, img_size)
                metadata = {
                    'timestamp': ts.strftime('%Y-%m-%d %H:%M:%S'),
                    'meal_type': row['Meal Type'] if 'Meal Type' in row else None,
                    'calories': row['Calories'] if 'Calories' in row else None,
                    'carbs': row['Carbs'] if 'Carbs' in row else None,
                    'protein': row['Protein'] if 'Protein' in row else None,
                    'fat': row['Fat'] if 'Fat' in row else None,
                    'fiber': row['Fiber'] if 'Fiber' in row else None
                }
                day_images.append({'image': img_data, 'metadata': metadata})
            except FileNotFoundError:
                continue
        image_data_by_day[day_str] = day_images if day_images else []

    return (
        days_list,
        cgm_daily_data,          # Now normalized
        activity_daily_data,     # Now normalized
        image_data_by_day,
        nutrition_data_by_day,
        cgm_stats,               # New: normalization parameters
        activity_stats           # New: normalization parameters
    )



def create_daily_dataset(
    subject_id: int,
    csv_dir: str = "CGMacros 2",
    cgm_cols=["Dexcom GL", "Libre GL"],
    activity_cols=["HR"],  # Consider alternative spellings
    img_size=(112, 112),
    verbose=False
):
    try:
        # 1. Load data with column validation
        dataset_df = load_CGMacros(subject_id, csv_dir)
        
        if verbose:
            print("Available columns:", dataset_df.columns.tolist())
        
        # 2. Handle missing columns gracefully
        available_activity_cols = [col for col in activity_cols 
                                 if col in dataset_df.columns]
        if len(available_activity_cols) < len(activity_cols):
            print(f"Warning: Missing activity columns. Using {available_activity_cols} for subject {subject_id}")
        
        # 3. Process data with validated columns
        result = load_daily_traces(
            dataset_df, subject_id, 
            cgm_cols=cgm_cols,
            activity_cols=available_activity_cols,  # Use only available columns
            img_size=img_size
        )
        
        return (subject_id,) + result[:5]  # Return first 5 elements + subject_id
    
    except FileNotFoundError:
        print(f"Data for subject {subject_id} not found.")
        return None
    except Exception as e:
        print(f"Error processing subject {subject_id}: {str(e)}")
        return None
    

def process_multiple_subjects(
    subject_ids=None,
    csv_dir="CGMacros 2",
    save_dir="processed_data/",
    cgm_cols=["Dexcom GL","Libre GL"],
    activity_cols=["HR"],
    img_size=(112, 112)
):
    """
    Process multiple subjects and save their daily trace data.
    
    Args:
        subject_ids (list): List of subject IDs to process. If None, try subjects 1-50.
        csv_dir (str): Directory containing the CGMacros data
        save_dir (str): Directory to save processed data
        cgm_cols (list): List of CGM columns to extract
        activity_cols (list): List of activity columns to extract
        img_size (tuple): Size to resize images to
    
    Returns:
        dict: Summary of processed data
    """
    # Create save directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    if subject_ids is None:
        subject_ids = range(1, 51)  # Try subjects 1-50
    
    summary = {
        'processed_subjects': [],
        'total_days': 0,
        'total_images': 0,
        'total_meals':0
    }
    
    for subject_id in tqdm(subject_ids, desc="Processing subjects"):
        result = create_daily_dataset(subject_id, csv_dir)
        if result is None:
            continue
            
        subject_id, days, cgm, activity, images, nutrition = result  # Unpack new return
        
        # Save data
        subject_data = {
            'subject_id': subject_id,
            'days': days,
            'cgm_data': cgm,
            'activity_data': activity,
            'image_data': images,
            'nutrition_data': nutrition  # NEW: Store nutrition separately
        }
        torch.save(subject_data, os.path.join(save_dir, f"subject_{subject_id:03d}_daily_data.pt"))
        
        # Update summary counts (can add nutrition-specific metrics)
        summary['processed_subjects'].append(subject_id)
        summary['total_days'] += len(days)
        summary['total_images'] += sum(len(imgs) for imgs in images.values())
        summary['total_meals'] += sum(len(meals) for meals in nutrition.values())
    return summary

class DailyTracesDataset(Dataset):
    def __init__(self, data_dir, subject_ids=None, transform=None, skip_days=[1]):
        """
        Args:
            data_dir (str): Directory containing processed .pt files
            subject_ids (list): Optional list of subject IDs to include
            transform (callable): Optional transform for images/time-series
            skip_days (list): Day numbers to exclude (e.g., [1] skips day 1 of each month)
        """
        self.data_dir = data_dir
        self.transform = transform
        self.skip_days = skip_days or [1]  # Default to no skipped days
        
        # Find relevant subject files
        if subject_ids is None:
            self.data_files = [
                f for f in os.listdir(data_dir) 
                if f.startswith("subject_") and f.endswith("_daily_data.pt")
            ]
        else:
            self.data_files = [
                f"subject_{sid:03d}_daily_data.pt" 
                for sid in subject_ids 
                if os.path.exists(os.path.join(data_dir, f"subject_{sid:03d}_daily_data.pt"))
            ]
        
        # Build indices accounting for skip_days
        self.indices = []
        self.subject_day_pairs = []
        
        for file_idx, fname in enumerate(self.data_files):
            data = torch.load(os.path.join(data_dir, fname))
            subject_id = data['subject_id']
            
            for day_idx, day_str in enumerate(data['days']):
                day_num = int(day_str.split('-')[2])
                if day_num not in self.skip_days:
                    self.indices.append((file_idx, day_idx))
                    self.subject_day_pairs.append((subject_id, day_num))

    def __len__(self):
        """Returns total number of valid (subject, day) pairs after filtering"""
        return len(self.indices)

    def __getitem__(self, idx):
        file_idx, day_idx = self.indices[idx]
        data = torch.load(os.path.join(self.data_dir, self.data_files[file_idx]))
        day = data['days'][day_idx]
        
        # Apply transforms if specified
        def _apply_transform(x):
            return self.transform(x) if self.transform else x
        
        return {
            'subject_id': data['subject_id'],
            'day': day,
            'cgm_data': _apply_transform(data['cgm_data'][day_idx]),
            'activity_data': _apply_transform(data['activity_data'][day_idx]),
            'images': [_apply_transform(img['image']) for img in data['image_data'].get(day, [])],
            'nutrition': data['nutrition_data'].get(day, []),
            'subject_day_pair': self.subject_day_pairs[idx]
        }


def split_dataset_by_subject_day(dataset, test_size=0.2, random_state=2025):
    """
    Split the dataset based on subject-day pairs to ensure all data from
    the same subject and day stays together in either training or testing set.
    
    Args:
        dataset (DailyTracesDataset): The dataset to split
        test_size (float): Proportion of data to use for testing
        random_state (int): Random seed for reproducibility
    
    Returns:
        tuple: (train_indices, test_indices)
    """
    # Get unique subject-day pairs
    subject_day_df = pd.DataFrame(dataset.subject_day_pairs, columns=['subject_id', 'day_id'])
    unique_pairs = subject_day_df.drop_duplicates()
    
    # Split the unique subject-day pairs
    train_pairs, test_pairs = train_test_split(
        unique_pairs, 
        test_size=test_size,
        random_state=random_state
    )
    
    # Convert to sets for faster lookup
    train_pairs_set = set(zip(train_pairs['subject_id'], train_pairs['day_id']))
    test_pairs_set = set(zip(test_pairs['subject_id'], test_pairs['day_id']))
    
    # Create masks for train and test indices
    train_indices = []
    test_indices = []
    
    for i, (subject_id, day_id) in enumerate(dataset.subject_day_pairs):
        if (subject_id, day_id) in train_pairs_set:
            train_indices.append(i)
        elif (subject_id, day_id) in test_pairs_set:
            test_indices.append(i)
    
    return train_indices, test_indices


class SubjectDaySubset(Dataset):
    """
    Subset of DailyTracesDataset based on indices.
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]


def get_train_test_datasets(data_dir, subject_ids=None, test_size=0.2, random_state=2025, transform=None):
    """
    Get train and test datasets split by subject-day pairs.
    
    Args:
        data_dir (str): Directory containing processed data
        subject_ids (list): List of subject IDs to include. If None, include all available.
        test_size (float): Proportion of data to use for testing
        random_state (int): Random seed for reproducibility
        transform (callable): Optional transform to apply to the data
    
    Returns:
        tuple: (train_dataset, test_dataset)
    """
    # Create the full dataset
    full_dataset = DailyTracesDataset(data_dir, subject_ids, transform)
    
    # Split by subject-day pairs
    train_indices, test_indices = split_dataset_by_subject_day(full_dataset, test_size, random_state)
    
    # Create train and test subsets
    train_dataset = SubjectDaySubset(full_dataset, train_indices)
    test_dataset = SubjectDaySubset(full_dataset, test_indices)
    
    print(f"Full dataset size: {len(full_dataset)}")
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    return train_dataset, test_dataset


def extract_image_features(images, feature_extractor):
    """
    Extract features from a list of images using a pre-trained model.
    
    Args:
        images (list): List of image arrays
        feature_extractor: Model for feature extraction
    
    Returns:
        np.ndarray: Array of image features
    """
    if not images:
        return np.array([])
    
    # Stack images into a batch
    batch = np.stack([img_data['image'] for img_data in images])
    batch_tensor = torch.tensor(batch).float().permute(0, 3, 1, 2)  # NHWC -> NCHW
    
    # Normalize if needed
    if batch_tensor.max() > 1.0:
        batch_tensor = batch_tensor / 255.0
    
    # Extract features
    with torch.no_grad():
        features = feature_extractor(batch_tensor)
    
    return features.cpu().numpy()

def custom_collate(batch):
    """Handles variable-length nutrition data and images"""
    import numpy as np

    def fix_nans(array):
        """Replace NaNs with median (per channel)"""
        median_vals = np.nanmedian(array, axis=1, keepdims=True)  # Compute median per channel
        return np.where(np.isnan(array), median_vals, array)  # Replace NaNs

    # Fix NaNs before converting to tensors
    for i, x in enumerate(batch):
        x['cgm_data'] = fix_nans(x['cgm_data'])
        x['activity_data'] = fix_nans(x['activity_data'])

    return {
        'subject_ids': torch.tensor([x['subject_id'] for x in batch]),
        'days': [x['day'] for x in batch],
        'cgm_data': torch.stack([torch.tensor(x['cgm_data'], dtype=torch.float32) for x in batch]),
        'activity_data': torch.stack([torch.tensor(x['activity_data'], dtype=torch.float32) for x in batch]),
        'images': [x['images'] for x in batch],  # List of lists (variable length)
        'nutrition': [x['nutrition'] for x in batch],  # List of lists
        'subject_day_pairs': [x['subject_day_pair'] for x in batch]
    }


In [135]:
from torch.utils.data import DataLoader

def main():
    # Step 1: Process the raw CSV data for subjects and save the daily traces
    # You can adjust the subject IDs, directories, etc. as needed
    summary = process_multiple_subjects(
        subject_ids=range(1, 50),  # Process subjects 1-10
        csv_dir="CGMacros 2",  # Path to your CSV files
        save_dir="processed_data/",    # Where to save processed data
        cgm_cols=["Dexcom GL","Libre GL"],
        activity_cols=["HR"],
        img_size=(112, 112)
    )
    
    print(f"Processing summary:")
    print(f"- Processed {len(summary['processed_subjects'])} subjects")
    print(f"- Total days: {summary['total_days']}")
    print(f"- Total images: {summary['total_images']}")
    
    # Step 2: Create train and test datasets from the processed data
    train_dataset, test_dataset = get_train_test_datasets(
        data_dir="processed_data/",
        subject_ids=None,  # Use all available subjects
        test_size=0.2,     # 80% train, 20% test
        random_state=2025, # For reproducibility
        transform=None     # Add any transforms you need
    )
    
    print("Done creating train and test datasets")

    # Step 3: Create DataLoaders for efficient batching
    train_loader = DataLoader(
        train_dataset,
        batch_size=8,
        shuffle=True,
        num_workers=0,
        collate_fn=custom_collate
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=0,
        collate_fn=custom_collate
    )
    
    return train_loader, test_loader

In [136]:
train_loader, test_loader = main()

Processing subjects:   2%|▏         | 1/49 [00:00<00:24,  1.98it/s]

File CGMacros 2/CGMacros-001/photos/00000075-PHOTO-2020-5-9-9-48-0.jpg not found


Processing subjects:   4%|▍         | 2/49 [00:01<00:23,  1.99it/s]

File CGMacros 2/CGMacros-002/photos/00000083-PHOTO-2019-11-24-12-13-0.jpg not found


Processing subjects:  47%|████▋     | 23/49 [00:10<00:12,  2.09it/s]

File CGMacros 2/CGMacros-024/CGMacros-024.csv not found
Data for subject 24 not found.
File CGMacros 2/CGMacros-025/CGMacros-025.csv not found
Data for subject 25 not found.


Processing subjects:  73%|███████▎  | 36/49 [00:16<00:08,  1.55it/s]

File CGMacros 2/CGMacros-037/CGMacros-037.csv not found
Data for subject 37 not found.


Processing subjects:  80%|███████▉  | 39/49 [00:17<00:04,  2.15it/s]

File CGMacros 2/CGMacros-040/CGMacros-040.csv not found
Data for subject 40 not found.


Processing subjects: 100%|██████████| 49/49 [00:23<00:00,  2.10it/s]
  data = torch.load(os.path.join(data_dir, fname))


Processing summary:
- Processed 45 subjects
- Total days: 532
- Total images: 2289
Full dataset size: 511
Train dataset size: 408
Test dataset size: 103
Done creating train and test datasets


In [140]:
summary = process_multiple_subjects(
    subject_ids=range(1, 50),  # Process subjects 1-10
    csv_dir="CGMacros 2",  # Path to your CSV files
    save_dir="processed_data/",    # Where to save processed data
    cgm_cols=["Dexcom GL","Libre GL"],
    activity_cols=["HR"],
    img_size=(112, 112)
)

print(f"Processing summary:")
print(f"- Processed {len(summary['processed_subjects'])} subjects")
print(f"- Total days: {summary['total_days']}")
print(f"- Total images: {summary['total_images']}")

# Step 2: Create train and test datasets from the processed data
train_dataset, test_dataset = get_train_test_datasets(
    data_dir="processed_data/",
    subject_ids=None,  # Use all available subjects
    test_size=0.2,     # 80% train, 20% test
    random_state=2025, # For reproducibility
    transform=None     # Add any transforms you need
    )

Processing subjects:   2%|▏         | 1/49 [00:00<00:24,  1.96it/s]

File CGMacros 2/CGMacros-001/photos/00000075-PHOTO-2020-5-9-9-48-0.jpg not found


Processing subjects:   4%|▍         | 2/49 [00:01<00:24,  1.96it/s]

File CGMacros 2/CGMacros-002/photos/00000083-PHOTO-2019-11-24-12-13-0.jpg not found


Processing subjects:  47%|████▋     | 23/49 [00:10<00:12,  2.12it/s]

File CGMacros 2/CGMacros-024/CGMacros-024.csv not found
Data for subject 24 not found.
File CGMacros 2/CGMacros-025/CGMacros-025.csv not found
Data for subject 25 not found.


Processing subjects:  73%|███████▎  | 36/49 [00:16<00:07,  1.67it/s]

File CGMacros 2/CGMacros-037/CGMacros-037.csv not found
Data for subject 37 not found.


Processing subjects:  80%|███████▉  | 39/49 [00:17<00:04,  2.29it/s]

File CGMacros 2/CGMacros-040/CGMacros-040.csv not found
Data for subject 40 not found.


Processing subjects: 100%|██████████| 49/49 [00:22<00:00,  2.22it/s]
  data = torch.load(os.path.join(data_dir, fname))


Processing summary:
- Processed 45 subjects
- Total days: 532
- Total images: 2289
Full dataset size: 511
Train dataset size: 408
Test dataset size: 103


In [159]:
import numpy as np

for sample in test_dataset:
    # Fix NaNs in CGM Data
    if 'cgm_data' in sample and sample['cgm_data'] is not None:
        cgm_data = sample['cgm_data']
        for i in range(cgm_data.shape[0]):  # Iterate over each row (sensor type)
            nan_mask = np.isnan(cgm_data[i])  # Identify NaNs
            if np.any(nan_mask):  # If NaNs exist, replace with median
                median_val = np.nanmedian(cgm_data[i])  # Compute median ignoring NaNs
                cgm_data[i, nan_mask] = median_val

    # Fix NaNs in Activity Data
    if 'activity_data' in sample and sample['activity_data'] is not None:
        activity_data = sample['activity_data']
        for i in range(activity_data.shape[0]):  # Iterate over each row (sensor type)
            nan_mask = np.isnan(activity_data[i])  # Identify NaNs
            if np.any(nan_mask):  # If NaNs exist, replace with median
                median_val = np.nanmedian(activity_data[i])  # Compute median ignoring NaNs
                activity_data[i, nan_mask] = median_val

# Confirm all NaNs have been replaced
assert not np.any(np.isnan(train_dataset[0]['cgm_data'])), "NaNs still exist in CGM data!"
assert not np.any(np.isnan(train_dataset[0]['activity_data'])), "NaNs still exist in Activity data!"

print("NaN handling completed. No NaNs remain in CGM and Activity data.")


  data = torch.load(os.path.join(self.data_dir, self.data_files[file_idx]))


NaN handling completed. No NaNs remain in CGM and Activity data.


In [162]:
def custom_collate2(batch):
    """Handles variable-length nutrition data and images"""
    import numpy as np

    def fix_nans(array):
        """Replace NaNs with median (per channel)"""
        median_vals = np.nanmedian(array, axis=1, keepdims=True)  # Compute median per channel
        return np.where(np.isnan(array), median_vals, array)  # Replace NaNs

    # Fix NaNs before converting to tensors
    for i, x in enumerate(batch):
        x['cgm_data'] = fix_nans(x['cgm_data'])
        x['activity_data'] = fix_nans(x['activity_data'])

    return {
        'subject_ids': torch.tensor([x['subject_id'] for x in batch]),
        'days': [x['day'] for x in batch],
        'cgm_data': torch.stack([torch.tensor(x['cgm_data'], dtype=torch.float32) for x in batch]),
        'activity_data': torch.stack([torch.tensor(x['activity_data'], dtype=torch.float32) for x in batch]),
        'images': [x['images'] for x in batch],  # List of lists (variable length)
        'nutrition': [x['nutrition'] for x in batch],  # List of lists
        'subject_day_pairs': [x['subject_day_pair'] for x in batch]
    }

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=custom_collate2
)

test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0,
    collate_fn=custom_collate2
    )

In [155]:
train_dataset[0]['activity_data'][0]

  data = torch.load(os.path.join(self.data_dir, self.data_files[file_idx]))


array([-0.03312733, -0.03312733, -0.03312733, ...,  0.33720096,
        0.33720096,  0.33720096])

In [157]:
train_dataset[3]['activity_data']
np.isnan(train_dataset[0]['cgm_data'][0]).any()

  data = torch.load(os.path.join(self.data_dir, self.data_files[file_idx]))


np.False_

In [124]:
for batch_dict in train_loader:
    break


print(batch_dict['cgm_data'][0])
print(torch.isnan(batch_dict['cgm_data'][0]).any())


tensor([[-0.6221, -0.6057, -0.5854,  ...,  1.0794,  1.0635,  1.0475],
        [-1.4413, -1.4321, -1.4227,  ...,  0.2647,  0.2417,  0.2183]],
       dtype=torch.float64)
tensor(False)


  data = torch.load(os.path.join(self.data_dir, self.data_files[file_idx]))


In [169]:
import torch
import torch.nn as nn  # Import torch.nn
import torch.nn.functional as F 
class RMSRELoss(nn.Module):
    def __init__(self, epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon  # Small value to avoid division by zero
    
    def forward(self, pred, target):
        relative_error = (pred - target) / (target + self.epsilon)
        squared_rel_error = relative_error ** 2
        mean_squared_rel_error = torch.mean(squared_rel_error)
        return torch.sqrt(mean_squared_rel_error)

In [164]:
import torch

def check_data(loader):
    for batch in loader:
        # Check CGM data
        if torch.isnan(batch['cgm_data']).any():
            print("NaN in CGM data!")
            nan_cgm_indices = torch.isnan(batch['cgm_data']).nonzero(as_tuple=True)[0]
            for idx in nan_cgm_indices:
                nan_value = batch['cgm_data'][idx]
                print(f"NaN in CGM data at index {idx.item()}: {nan_value}")
                break
            return False
        
        # Check activity data
        if torch.isnan(batch['activity_data']).any():
            print("NaN in activity data!")
            nan_activity_indices = torch.isnan(batch['activity_data']).nonzero(as_tuple=True)[0]
            for idx in nan_activity_indices:
                nan_value = batch['activity_data'][idx]
                print(f"NaN in activity data at index {idx.item()}: {nan_value}")
                break
            return False
        
        # Check nutrition values
        for day_meals in batch['nutrition']:
            for meal in day_meals:
                if not isinstance(meal['calories'], (int, float)) or meal['calories'] < 0:
                    print(f"Invalid calorie value: {meal['calories']}")
                    return False
    
    print("Data validation passed!")
    return True

# Example usage before training
assert check_data(test_loader), "Invalid training data detected"


  data = torch.load(os.path.join(self.data_dir, self.data_files[file_idx]))


Data validation passed!


In [None]:
from transformer import MultiheadAttention as TransformerEncoder
import torch



# Set device to MPS for Apple Silicon (M3 Mac)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Initialize the CGM and Activity Encoders
cgm_model = TransformerEncoder(
    n_features=180,
    embed_dim=96,
    num_heads=2,
    num_classes=64,
    dropout=0.2,
    num_layers=3,
).to(device)

activity_model = TransformerEncoder(
    n_features=180,
    embed_dim=96,
    num_heads=2,
    num_classes=64,
    dropout=0.2,
    num_layers=3,
).to(device)

# Example forward pass
x = torch.randn(32, 1, 180).to(device)
output = cgm_model(x)
print(output.shape)  # Should output: torch.Size([32, 64])
