#Org the data 

In [1]:
import torch
import platform
import sys

# Check if running on macOS
if platform.system() == 'Darwin':
    # Check for MPS (Metal Performance Shaders) availability on Mac
    print(f"Running on macOS {platform.mac_ver()[0]}")
    print(f"MPS is built: {torch.backends.mps.is_built()}")
    print(f"MPS is available: {torch.backends.mps.is_available()}")
else:
    print(f"Running on {platform.system()} {platform.release()}")
    if platform.system() == 'Windows':
        # Check for CUDA availability on Windows
        print(f"CUDA is available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA version: {torch.version.cuda}")
            print(f"Number of CUDA devices: {torch.cuda.device_count()}")
            print(f"Current CUDA device: {torch.cuda.current_device()}")
            print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    else:
        print("MPS is only available on macOS devices")

Running on macOS 15.4
MPS is built: True
MPS is available: True


In [6]:
# 
import pandas as pd
import numpy as np
from datetime import datetime
import os
import matplotlib.pyplot as plt
import seaborn as sns

def extract_action_samples(csv_file, plot_stats=False):
    """
    Extract action samples from a CSV file according to specific rules.
    
    Args:
        csv_file (str): Full path to the CSV file
        plot_stats (bool): If True, plot statistics about the extracted samples
    
    Returns:
        pandas.DataFrame: Processed data with action groups and group indices
        list: Group start times
        datetime: Start time of the whole process (timestamp of first record)
    """
    # Load the CSV file
    df = pd.read_csv(csv_file)
    
    # Ensure timestamp is in datetime format
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    
    # Get the start time of the whole process (first record)
    process_start_time = df['timestamp'].iloc[0] if not df.empty else None
    
    # Initialize variables
    action_groups = []
    group_start_times = []
    current_idx = 0
    all_extracted_records = []
    
    # Iterate through dataframe to find action groups
    while current_idx < len(df) - 60:  # Need at least 60 records for a complete group
        # Look for a cross button press
        if df.loc[current_idx + 10, 'button_press'] == 'cross':
            # Check if we have 10 'none' records before the press
            pre_press = df.iloc[current_idx:current_idx + 10]
            if all(pre_press['button_press'] == 'none'):
                # Extract the group: 10 records before press + 50 records starting from press
                action_group = df.iloc[current_idx:current_idx + 60]
                
                # Calculate duration
                start_time = action_group.iloc[0]['timestamp']
                end_time = action_group.iloc[-1]['timestamp']
                duration = (end_time - start_time).total_seconds()
                
                # Store group and start time
                action_groups.append({
                    'group': action_group,
                    'duration': duration
                })
                group_start_times.append(start_time)
                
                # Add group_index to records
                group_index = len(action_groups) - 1
                group_with_index = action_group.copy()
                group_with_index['group_index'] = group_index
                
                # Add to collection
                all_extracted_records.append(group_with_index)
                
                # Move to position after this group
                current_idx += 60
            else:
                # Pre-press condition not met, move forward by 1
                current_idx += 1
        else:
            # No press at expected position, move forward by 1
            current_idx += 1
    
    # Combine all extracted records
    if all_extracted_records:
        extracted_df = pd.concat(all_extracted_records)
        print(f"Extracted {len(action_groups)} action groups from {csv_file}")
        
        # Print the shape of the sample and count of how many samples
        print(f"Shape of extracted data: {extracted_df.shape}")
        print(f"Total number of samples: {len(extracted_df)}")
        print(f"Number of unique groups: {extracted_df['group_index'].nunique()}")
        
        # Print the shape of each individual sample (group)
        print("\nShape of each sample (group):")
        for i, group in enumerate(action_groups):
            group_df = group['group']
            print(f"Group {i}: {group_df.shape} - Duration: {group['duration']:.2f}s")
        
        # If plot_stats is True, visualize detailed statistics about the samples
        if plot_stats:
            # Create a figure with multiple subplots
            fig, axes = plt.subplots(len(action_groups), 3, figsize=(18, 4*len(action_groups)))
            
            # If only one group, make axes indexable
            if len(action_groups) == 1:
                axes = np.array([axes])
                
            for i, group in enumerate(action_groups):
                group_df = group['group']
                
                # Count statistics
                pre_press_count = sum(group_df.iloc[:10]['button_press'] == 'none')
                press_events = group_df.iloc[10:]['button_press'].value_counts()
                press_count = sum(group_df.iloc[10:]['button_press'] != 'none')
                post_press_none_count = sum(group_df.iloc[10:]['button_press'] == 'none')
                
                # Plot 1: Button press distribution
                press_events_df = pd.DataFrame(press_events).reset_index()
                press_events_df.columns = ['Button', 'Count']
                sns.barplot(x='Button', y='Count', data=press_events_df, ax=axes[i, 0])
                axes[i, 0].set_title(f'Group {i}: Button Press Distribution')
                axes[i, 0].set_ylabel('Count')
                axes[i, 0].tick_params(axis='x', rotation=45)
                
                # Plot 2: Gyro data over time
                time_indices = range(len(group_df))
                axes[i, 1].plot(time_indices, group_df['gyro_pitch'], label='Pitch')
                axes[i, 1].plot(time_indices, group_df['gyro_yaw'], label='Yaw')
                axes[i, 1].plot(time_indices, group_df['gyro_roll'], label='Roll')
                axes[i, 1].axvline(x=10, color='r', linestyle='--', label='First Press')
                axes[i, 1].set_title(f'Group {i}: Gyro Data')
                axes[i, 1].set_xlabel('Time Index')
                axes[i, 1].set_ylabel('Gyro Values')
                axes[i, 1].legend()
                
                # Plot 3: Accelerometer data over time
                axes[i, 2].plot(time_indices, group_df['acc_x'], label='X')
                axes[i, 2].plot(time_indices, group_df['acc_y'], label='Y')
                axes[i, 2].plot(time_indices, group_df['acc_z'], label='Z')
                axes[i, 2].axvline(x=10, color='r', linestyle='--', label='First Press')
                axes[i, 2].set_title(f'Group {i}: Accelerometer Data')
                axes[i, 2].set_xlabel('Time Index')
                axes[i, 2].set_ylabel('Accel Values')
                axes[i, 2].legend()
                
                # Add text annotation with statistics
                stats_text = (f"Duration: {group['duration']:.2f}s\n"
                             f"Pre-press none: {pre_press_count}\n"
                             f"Button presses: {press_count}\n"
                             f"Post-press none: {post_press_none_count}")
                axes[i, 0].annotate(stats_text, xy=(0.5, -0.4), xycoords='axes fraction', 
                                   ha='center', va='center', fontsize=10,
                                   bbox=dict(boxstyle='round', fc='lightyellow', alpha=0.7))
            
            plt.tight_layout()
            plt.show()
            
            # Create a summary plot
            plt.figure(figsize=(12, 8))
            
            # Plot 1: Group durations
            plt.subplot(2, 2, 1)
            durations = [group['duration'] for group in action_groups]
            plt.bar(range(len(durations)), durations)
            plt.xlabel('Group Index')
            plt.ylabel('Duration (seconds)')
            plt.title('Duration of Each Action Group')
            
            # Plot 2: Button press distribution across all groups
            plt.subplot(2, 2, 2)
            sns.countplot(x='button_press', data=extracted_df)
            plt.title('Button Press Distribution')
            plt.xlabel('Button Type')
            plt.ylabel('Count')
            plt.xticks(rotation=45)
            
            # Plot 3: Gyro data distribution
            plt.subplot(2, 2, 3)
            sns.boxplot(data=extracted_df[['gyro_pitch', 'gyro_yaw', 'gyro_roll']])
            plt.title('Gyro Data Distribution')
            plt.ylabel('Values')
            
            # Plot 4: Accelerometer data distribution
            plt.subplot(2, 2, 4)
            sns.boxplot(data=extracted_df[['acc_x', 'acc_y', 'acc_z']])
            plt.title('Accelerometer Data Distribution')
            plt.ylabel('Values')
            
            plt.tight_layout()
            plt.show()

        # Print basic information about the extracted data
        print("\nExtracted Data Overview:")
        print(f"Total records: {len(extracted_df)}")
        print(f"Number of groups: {len(action_groups)}")
        print(f"Columns: {extracted_df.columns.tolist()}")
        print("\nSample data (first 5 rows):")
        print(extracted_df.head())
        print("\nButton press distribution:")
        print(extracted_df['button_press'].value_counts())
            
        return extracted_df, group_start_times, process_start_time
    else:
        print(f"No action groups found in {csv_file}")
        return pd.DataFrame(), [], process_start_time

In [7]:
import numpy as np
import soundfile as sf
from scipy import signal
import librosa
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import os

def split_audio_by_timestamps(wav_file, group_start_times, process_start_time, segment_duration=0.3, 
                            half_peak_duration=0.03, plot_spectrograms=False):
    """
    Split audio file into segments based on group start times from action samples.
    
    Args:
        wav_file (str): Path to the WAV file
        group_start_times (list): List of timestamps for each group start
        process_start_time (datetime): Start time of the whole process
        segment_duration (float): Duration of each segment in seconds (default: 0.3)
        half_peak_duration (float): Half duration of peak window in seconds (default: 0.03)
        plot_spectrograms (bool): Whether to plot spectrograms (default: False)
    
    Returns:
        dict: Contains processed audio data with following keys:
            - 'audio_segments': List of raw audio segments
            - 'peak_segments': List of peak audio segments
            - 'spectrograms': List of spectrogram data
            - 'mfccs': List of MFCC features
            - 'low_power_peaks': List of indices of low power peak segments
    """
    # Load the audio file
    audio, sr = sf.read(wav_file)
    
    # Initialize storage for results
    audio_segments = []
    peak_segments = []
    spectrograms = []
    mfccs = []
    low_power_peaks = []
    off_center_peaks = []
    
    # Process each group start time
    for i, start_time in enumerate(group_start_times):
        # Calculate the start index in audio samples
        relative_seconds = (start_time - process_start_time).total_seconds()
        start_idx = int(relative_seconds * sr)
        end_idx = start_idx + int(segment_duration * sr)
        
        # Ensure indices are within bounds
        if start_idx >= 0 and end_idx <= len(audio):
            # Extract the segment
            segment = audio[start_idx:end_idx]
            audio_segments.append(segment)
            
            # Convert to mono if stereo
            if len(segment.shape) > 1 and segment.shape[1] > 1:
                segment_mono = np.mean(segment, axis=1)
            else:
                segment_mono = segment
            
            # Apply Wiener filter for denoising
            noise_samples = int(0.03 * sr)
            noise = segment_mono[:noise_samples]
            noise_psd = 0.0005 * np.mean(np.abs(np.fft.rfft(noise))**2)
            denoised_segment = signal.wiener(segment_mono, mysize=1024, noise=noise_psd)
            
            # Find peak in the middle section
            start_exclude_idx = int(0.01 * sr)
            end_exclude_idx = len(denoised_segment) - int(0.01 * sr)
            valid_segment = denoised_segment[start_exclude_idx:end_exclude_idx]
            
            # Peak detection
            abs_segment = np.abs(valid_segment)
            max_amplitude = np.max(abs_segment)
            peaks, peak_properties = signal.find_peaks(abs_segment,
                                                     prominence=0.2*max_amplitude,  # Lowered prominence threshold
                                                     distance=int(0.01*sr))
            
            # If no peaks found, use maximum value
            if len(peaks) == 0:
                peak_index = np.argmax(abs_segment) + start_exclude_idx
                peak_properties = {'prominences': [0]}
            else:
                highest_peak_idx = np.argmax(peak_properties['prominences'])
                peak_index = peaks[highest_peak_idx] + start_exclude_idx
            
            peak_time = peak_index / sr
            
            # Calculate peak window
            peak_start_time = max(0.0, peak_time - half_peak_duration)
            peak_end_time = min(segment_duration, peak_time + half_peak_duration)
            peak_start_idx = int(peak_start_time * sr)
            peak_end_idx = int(peak_end_time * sr)
            
            # Extract peak window
            peak_segment = denoised_segment[peak_start_idx:peak_end_idx]
            
            # New low power peak detection logic
            peak_absolute_value = np.abs(denoised_segment[peak_index])
            is_low_power = peak_absolute_value < 0.05  # Mark as low power if absolute peak value < 0.05
            
            if is_low_power:
                low_power_peaks.append(i)
            
            # Check if peak is centered
            window_duration = peak_end_time - peak_start_time
            middle_time = peak_start_time + window_duration / 2
            tolerance = window_duration * 0.1
            is_off_center = abs(peak_time - middle_time) > tolerance
            if is_off_center:
                off_center_peaks.append(i)
            
            # Compute spectrogram
            n_fft = 256
            hop_length = 128
            frequencies, times, Sxx = signal.spectrogram(peak_segment, sr,
                                                       nperseg=n_fft,
                                                       noverlap=n_fft-hop_length,
                                                       scaling='density')
            
            Sxx_db = 10 * np.log10(Sxx + 1e-10)
            
            # Only save the spectrogram data
            spectrograms.append(Sxx_db)
            
            # Compute MFCC
            stft = librosa.stft(peak_segment, n_fft=n_fft, hop_length=hop_length)
            mel_spec = librosa.feature.melspectrogram(S=np.abs(stft)**2, sr=sr, n_mels=40)
            mfcc_features = librosa.feature.mfcc(S=librosa.power_to_db(mel_spec), 
                                               n_mfcc=13, fmax=12000, fmin=0)
            
            # Only save the MFCC data
            mfccs.append(mfcc_features)
            
            peak_segments.append(peak_segment)
    
    # Print statistics with more detail
    print(f"Processed {len(audio_segments)} segments")
    print(f"Groups with off-center peaks: {off_center_peaks}")
    print(f"Number of low power peaks (abs peak < 0.05): {len(low_power_peaks)}")
    print(f"Low power peak group indices: {low_power_peaks}")
    
    # Calculate percentage of low power peaks
    low_power_percentage = (len(low_power_peaks) / len(audio_segments)) * 100 if audio_segments else 0
    print(f"Percentage of low power peaks: {low_power_percentage:.2f}%")
    
    peak_values = [np.max(np.abs(peak_seg)) for peak_seg in peak_segments]
    print(f"\nPeak amplitude statistics:")
    print(f"Min peak amplitude: {min(peak_values):.3f}")
    print(f"Max peak amplitude: {max(peak_values):.3f}")
    print(f"Mean peak amplitude: {np.mean(peak_values):.3f}")
    
    # Print sample shapes
    if spectrograms and mfccs:
        print(f"\nSample spectrogram shape frequency bins , number of time slices: {spectrograms[0].shape[0]} , {spectrograms[0].shape[1 ]}")
        print(f"Sample MFCC shape number of mfcc coefficients , number of time slices: {mfccs[0].shape[0]} , {mfccs[0].shape[1 ]}")
        print('\n')
    
    if plot_spectrograms and spectrograms:
        # Plot configuration
        num_specs = len(spectrograms)
        cols = 10
        rows = (num_specs + cols - 1) // cols
        
        fig = plt.figure(figsize=(20, 2 * rows))
        gs = GridSpec(rows, cols, figure=fig)
        
        for i in range(len(spectrograms)):
            row = i // cols
            col = i % cols
            
            ax = fig.add_subplot(gs[row, col])
            denoised_segment_to_plot = audio_segments[i]
            if len(denoised_segment_to_plot.shape) > 1:
                denoised_segment_to_plot = np.mean(denoised_segment_to_plot, axis=1)
            
            time_axis = np.linspace(0, segment_duration, len(denoised_segment_to_plot))
            ax.plot(time_axis, denoised_segment_to_plot)
            
            # Simplified title
            title_color = 'red' if i in low_power_peaks else 'black'
            ax.set_title(f'Group {i}', color=title_color)
            
            if row < rows - 1:
                ax.set_xlabel('')
            if col > 0:
                ax.set_ylabel('')
            else:
                ax.set_ylabel('Amplitude')
        
        plt.suptitle('Denoised audio segments (Y-axis: normalized amplitude [-1, 1])', fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.97])
        plt.show()
    
    return {
        'audio_segments': audio_segments,
        'peak_segments': peak_segments,
        'spectrograms': spectrograms,
        'mfccs': mfccs,
        'low_power_peaks': low_power_peaks
    }

In [4]:
### read the data 

user1_offh_csv="../data/fuse_one/controller_data_user1_non-haptic_20250328_181856.csv"
user2_offh_csv="../data/fuse_one/controller_data_user2_non-haptic_20250327_152659.csv"
user3_offh_csv='../data/fuse_one/controller_data_user3_non-haptic_20250327_160855.csv'
user4_offh_csv='../data/fuse_one/controller_data_user4_non-haptic_20250327_165510.csv'

user1_offh_wav="../data/fuse_one/audio_user1_non-haptic_20250328_181856.wav"
user2_offh_wav="../data/fuse_one/audio_user2_non-haptic_20250327_152659.wav"
user3_offh_wav='../data/fuse_one/audio_user3_non-haptic_20250327_160855.wav'
user4_offh_wav='../data/fuse_one/audio_user4_non-haptic_20250327_165510.wav'

In [5]:
u1_df,u1_g_time,u1_pro_time= extract_action_samples(user1_offh_csv)
u2_df,u2_g_time,u2_pro_time=extract_action_samples(user2_offh_csv)
u3_df,u3_g_time,u3_pro_time=extract_action_samples(user3_offh_csv)
u4_df,u4_g_time,u4_pro_time=extract_action_samples(user4_offh_csv)

Extracted 126 action groups from ../data/fuse_one/controller_data_user1_non-haptic_20250328_181856.csv
Shape of extracted data: (7560, 11)
Total number of samples: 7560
Number of unique groups: 126

Shape of each sample (group):
Group 0: (60, 10) - Duration: 0.29s
Group 1: (60, 10) - Duration: 0.30s
Group 2: (60, 10) - Duration: 0.30s
Group 3: (60, 10) - Duration: 0.29s
Group 4: (60, 10) - Duration: 0.30s
Group 5: (60, 10) - Duration: 0.30s
Group 6: (60, 10) - Duration: 0.30s
Group 7: (60, 10) - Duration: 0.30s
Group 8: (60, 10) - Duration: 0.29s
Group 9: (60, 10) - Duration: 0.30s
Group 10: (60, 10) - Duration: 0.30s
Group 11: (60, 10) - Duration: 0.30s
Group 12: (60, 10) - Duration: 0.30s
Group 13: (60, 10) - Duration: 0.30s
Group 14: (60, 10) - Duration: 0.30s
Group 15: (60, 10) - Duration: 0.29s
Group 16: (60, 10) - Duration: 0.30s
Group 17: (60, 10) - Duration: 0.29s
Group 18: (60, 10) - Duration: 0.29s
Group 19: (60, 10) - Duration: 0.29s
Group 20: (60, 10) - Duration: 0.30s
Grou

In [8]:
u1_audio_splited=split_audio_by_timestamps(user1_offh_wav,u1_g_time,u1_pro_time,plot_spectrograms=False)
u2_audio_splited=split_audio_by_timestamps(user2_offh_wav,u2_g_time,u2_pro_time,plot_spectrograms=False)
u3_audio_splited=split_audio_by_timestamps(user3_offh_wav,u3_g_time,u3_pro_time,plot_spectrograms=False)
u4_audio_splited=split_audio_by_timestamps(user4_offh_wav,u4_g_time,u4_pro_time,plot_spectrograms=False)

Processed 126 segments
Groups with off-center peaks: [47, 49, 95, 117, 123]
Number of low power peaks (abs peak < 0.05): 64
Low power peak group indices: [0, 1, 3, 5, 10, 11, 13, 14, 17, 19, 21, 24, 27, 32, 34, 35, 37, 40, 42, 46, 47, 48, 49, 50, 51, 53, 54, 55, 56, 58, 61, 71, 74, 75, 76, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 94, 95, 96, 98, 99, 102, 106, 107, 108, 109, 110, 115, 116, 117, 121, 122, 123, 124, 125]
Percentage of low power peaks: 50.79%

Peak amplitude statistics:
Min peak amplitude: 0.003
Max peak amplitude: 0.475
Mean peak amplitude: 0.089

Sample spectrogram shape frequency bins , number of time slices: 129 , 21
Sample MFCC shape number of mfcc coefficients , number of time slices: 13 , 23


Processed 98 segments
Groups with off-center peaks: [49]
Number of low power peaks (abs peak < 0.05): 98
Low power peak group indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 

In [10]:
### prepare the data 
def calculate_derivative_rms_features(data, window_size=5):
    """
    Calculate RMS features for derivatives using sliding windows
    Args:
        data: numpy array of shape (N, 6) - sequence with 6 features
        window_size: size of window for RMS calculation (default=5)
    Returns:
        numpy array of RMS values for each complete window
    """
    n_windows = len(data) // window_size
    if len(data) % window_size > 0:  # Handle remaining data
        n_windows += 1
    
    rms_features = []
    for i in range(n_windows):
        start_idx = i * window_size
        end_idx = min(start_idx + window_size, len(data))
        window_data = data[start_idx:end_idx, :]
        
        # Calculate RMS for each feature in this window
        rms = np.sqrt(np.mean(np.square(window_data), axis=0))
        rms_features.append(rms)
    
    return np.array(rms_features)  # Shape: (n_windows, 6)

def prepare_derivative_features(data):
    """
    Calculate first and second derivatives and their RMS features
    Args:
        data: numpy array of shape (60, 6) - one sample/sequence with 6 features
    Returns:
        combined RMS features from both derivatives
        shape: (12, 12) - 12 time steps, 12 features (6 from each derivative)
    """
    # Apply z-score normalization to the entire sequence first
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0) + 1e-9
    normalized_data = (data - mean) / std
    
    # Calculate first derivative from normalized data
    first_derivative = np.diff(normalized_data, axis=0)  # Shape: (59, 6)
    
    # Calculate second derivative
    second_derivative = np.diff(first_derivative, axis=0)  # Shape: (58, 6)
    
    # Calculate RMS features for both derivatives
    first_der_rms = calculate_derivative_rms_features(first_derivative)  # Shape: (~12, 6)
    second_der_rms = calculate_derivative_rms_features(second_derivative)  # Shape: (~12, 6)
    
    # Ensure both have exactly 12 windows by padding or truncating
    target_length = 12
    
    def adjust_length(features, target_len):
        if len(features) > target_len:
            return features[:target_len]
        elif len(features) < target_len:
            padding = np.zeros((target_len - len(features), features.shape[1]))
            return np.vstack([features, padding])
        return features
    
    first_der_rms = adjust_length(first_der_rms, target_length)
    second_der_rms = adjust_length(second_der_rms, target_length)
    
    # Combine features
    combined_features = np.hstack([first_der_rms, second_der_rms])  # Shape: (12, 12)
    
    return combined_features


In [27]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import Counter
import pandas as pd

@dataclass
class UserData:
    """Data class to store user-specific data components"""
    df: pd.DataFrame
    audio_splited: Dict[str, List]
    meg_data: Optional[Dict] = None

class PlaySenseDataManager:
    """
    A comprehensive data manager for PlaySense data that handles:
    1. Data merging from different sources
    2. Dataset creation for PyTorch
    3. Data splitting and loader creation
    """
    def __init__(self):
        self.users_data = {}
        self.merged_data = {}
        self.dataset = None
    
    def add_user_data(self, user_id: int, user_data: UserData) -> None:
        """Add data for a specific user"""
        self.users_data[user_id] = user_data
        
    def merge_all_users_data(self) -> None:
        """Merge data for all added users"""
        for user_id, user_data in self.users_data.items():
            merged = self._merge_action_and_audio_data(
                df=user_data.df,
                audio_samples=user_data.audio_splited['audio_segments'],
                peak_samples=user_data.audio_splited['peak_segments'],
                spec_samples=user_data.audio_splited['spectrograms'],
                mfcc_samples=user_data.audio_splited['mfccs'],
                low_power_peaks=user_data.audio_splited['low_power_peaks']
            )
            self.merged_data[user_id] = merged
            print(f"\nUser {user_id} data merged successfully:")
            print(f"Number of groups: {len(merged)}")
            
            # Print example shapes for the first group
            if merged:
                first_group_idx = list(merged.keys())[0]
                first_group = merged[first_group_idx]
                
                # Debug information to check what's in first_group
                print(f"DEBUG - First group keys: {list(first_group.keys())}")
                print(f"DEBUG - Action data keys: {list(first_group['action_data'].keys())}")
                print(f"DEBUG - Inertial features keys: {list(first_group['action_data']['inertial_features'].keys())}")
                
                # Check if derivative_rms is None before accessing shape
                derivative_rms = first_group['action_data']['inertial_features']['derivative_rms']
                if derivative_rms is None:
                    print(f"ERROR: derivative_rms is None for user {user_id}, group {first_group_idx}")
                    print(f"DEBUG - Check _prepare_derivative_features implementation")
                    continue
                
                inertial_shape = derivative_rms.shape
                
                # Check if mfcc is None before accessing shape
                mfcc = first_group['audio_data']['mfcc']
                if mfcc is None:
                    print(f"ERROR: mfcc is None for user {user_id}, group {first_group_idx}")
                    continue
                
                mfcc_shape = mfcc.shape
                
                print(f"Example shapes from first group (group {first_group_idx}):")
                print(f"  - Inertial features shape: {inertial_shape}")
                print(f"  - MFCC features shape: {mfcc_shape}")

    def _merge_action_and_audio_data(self, df, audio_samples, peak_samples, 
                                   spec_samples, mfcc_samples, low_power_peaks):
        """Internal method for merging action and audio data"""
        # Verify input sample counts
        group_indices = df['group_index'].unique()
        n_groups = len(group_indices)
        
        # Initialize the merged dataset
        merged_data = {}
        
        # Inertial data columns to extract
        inertial_columns = ['gyro_pitch', 'gyro_yaw', 'gyro_roll', 
                           'acc_x', 'acc_y', 'acc_z']
        
        # Process each group
        for group_idx in group_indices:
            if int(group_idx) in low_power_peaks:
                continue
                
            group_df = df[df['group_index'] == group_idx].copy()
            inertial_data = group_df[inertial_columns].values
            
            # Debug information before calling _prepare_derivative_features
            print(f"DEBUG - Processing group {group_idx}")
            print(f"DEBUG - Inertial data shape: {inertial_data.shape}")
            
            # Check if inertial_data is valid
            if inertial_data.size == 0:
                print(f"WARNING: Empty inertial data for group {group_idx}, skipping")
                continue
                
            derivative_features = self._prepare_derivative_features(inertial_data)
            
            # Debug information after calling _prepare_derivative_features
            if derivative_features is None:
                print(f"ERROR: _prepare_derivative_features returned None for group {group_idx}")
                continue
                
            # Check if audio data exists for this group
            # Since audio_samples and mfcc_samples are arrays, we need to check differently
            group_idx_int = int(group_idx)
            audio_idx = np.where(np.array([i == group_idx_int for i in range(len(audio_samples))]))[0]
            mfcc_idx = np.where(np.array([i == group_idx_int for i in range(len(mfcc_samples))]))[0]
            
            if len(audio_idx) == 0:
                print(f"WARNING: No audio data for group {group_idx}, skipping")
                continue
                
            if len(mfcc_idx) == 0:
                print(f"WARNING: No MFCC data for group {group_idx}, skipping")
                continue
            
            merged_data[int(group_idx)] = {
                'action_data': {
                    'dataframe': group_df,
                    'inertial_features': {
                        'derivative_rms': derivative_features,
                        'feature_names': {
                            'first_derivative': [f'first_der_rms_{col}' for col in inertial_columns],
                            'second_derivative': [f'second_der_rms_{col}' for col in inertial_columns]
                        }
                    }
                },
                'audio_data': {
                    'raw_audio': audio_samples[audio_idx[0]],
                    'peak_audio': peak_samples[audio_idx[0]],
                    'spectrogram': spec_samples[audio_idx[0]],
                    'mfcc': mfcc_samples[mfcc_idx[0]],
                    'is_low_power': False
                }
            }
        
        return merged_data

    def _prepare_derivative_features(self, inertial_data):
        """Prepare derivative features from inertial data"""
        # Debug information
        print(f"DEBUG - Inside _prepare_derivative_features")
        print(f"DEBUG - Input data shape: {inertial_data.shape}")
        
        try:
            # Call the external prepare_derivative_features function
            # This is likely referring to the function defined outside this class
            return prepare_derivative_features(inertial_data)
        except Exception as e:
            print(f"ERROR in _prepare_derivative_features: {str(e)}")
            import traceback
            traceback.print_exc()
            return None

    def create_dataset(self) -> None:
        """Create PyTorch dataset from merged data"""
        self.dataset = PlaySenseDataset(
            data_dict=self.merged_data
        )

    def create_data_loaders(self, 
                          batch_size: int = 32,
                          train_ratio: float = 0.8,
                          shuffle: bool = True) -> Tuple[DataLoader, DataLoader]:
        """Create train and validation data loaders"""
        if self.dataset is None:
            raise ValueError("Dataset not created. Call create_dataset() first.")

        # Calculate split sizes
        train_size = int(train_ratio * len(self.dataset))
        val_size = len(self.dataset) - train_size

        # Split dataset
        train_dataset, val_dataset = torch.utils.data.random_split(
            self.dataset, 
            [train_size, val_size]
        )

        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=shuffle
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False
        )

        return train_loader, val_loader
    
class PlaySenseDataset(Dataset):
    """PyTorch Dataset for PlaySense data"""
    def __init__(self, data_dict: Dict):
        self.inertial_features = []
        self.mfcc_features = []
        self.labels = []
        
        # Print initial data statistics
        print("\nInitializing Dataset:")
        print(f"Number of users: {len(data_dict)}")
        
        # Create label mapping (user_id -> consecutive index)
        unique_user_ids = sorted(list(data_dict.keys()))
        self.label_mapping = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
        self.inverse_mapping = {idx: user_id for user_id, idx in self.label_mapping.items()}
        
        print("\nLabel mapping (user_id -> model_label):")
        for user_id, idx in self.label_mapping.items():
            print(f"User {user_id} -> Label {idx}")
        
        # Track samples statistics
        samples_stats = {user_id: {'kept': 0, 'discarded': 0} for user_id in data_dict.keys()}
        discarded_groups = {user_id: [] for user_id in data_dict.keys()}
        
        for user_id, user_data in data_dict.items():
            if user_data is None or len(user_data) == 0:
                print(f"Warning: Empty or None data for user {user_id}")
                continue
            
            for group_idx in user_data:
                group = user_data[group_idx]
                
                try:
                    # Get MFCC features and check shape
                    mfcc = group['audio_data']['mfcc']
                    
                    # Only keep samples with shape (13, 23)
                    if not isinstance(mfcc, np.ndarray) or mfcc.shape != (13, 23):
                        samples_stats[user_id]['discarded'] += 1
                        discarded_groups[user_id].append(group_idx)
                        continue
                    
                    # Get inertial features
                    inertial = group['action_data']['inertial_features']['derivative_rms']
                    
                    # Verify inertial feature type
                    if not isinstance(inertial, np.ndarray):
                        print(f"Warning: Invalid inertial feature type for user {user_id}, group {group_idx}")
                        continue
                    
                    self.inertial_features.append(inertial)
                    self.mfcc_features.append(mfcc)
                    # Map user_id to consecutive index
                    self.labels.append(self.label_mapping[user_id])
                    samples_stats[user_id]['kept'] += 1
                    
                except KeyError as e:
                    print(f"Warning: Missing data structure for user {user_id}, group {group_idx}: {e}")
                except Exception as e:
                    print(f"Error processing user {user_id}, group {group_idx}: {e}")
        
        # Print detailed statistics
        print("\nSampling Statistics:")
        print("-------------------")
        for user_id in samples_stats:
            stats = samples_stats[user_id]
            print(f"\nUser {user_id} (Label {self.label_mapping[user_id]}):")
            print(f"  Kept samples: {stats['kept']}")
            print(f"  Discarded samples: {stats['discarded']}")
            if discarded_groups[user_id]:
                print(f"  Discarded group indices: {discarded_groups[user_id]}")
        
        # Convert to tensors if we have data
        if self.inertial_features and self.mfcc_features:
            try:
                self.inertial_features = torch.FloatTensor(np.array(self.inertial_features))
                self.mfcc_features = torch.FloatTensor(np.array(self.mfcc_features))
                self.labels = torch.LongTensor(self.labels)
                
                # Print final dataset statistics
                print("\nFinal Dataset Statistics:")
                print("------------------------")
                print(f"Total samples: {len(self.labels)}")
                class_dist = Counter(self.labels.numpy())
                print("\nClass distribution:")
                for label, count in sorted(class_dist.items()):
                    original_id = self.inverse_mapping[label]
                    print(f"User {original_id} (Label {label}): {count} samples")
                print(f"\nFeature shapes:")
                print(f"Inertial features: {self.inertial_features.shape}")
                print(f"MFCC features: {self.mfcc_features.shape}")
                
                # Verify label range
                min_label = self.labels.min().item()
                max_label = self.labels.max().item()
                num_classes = len(self.label_mapping)
                print(f"\nLabel range: [{min_label}, {max_label}] (Number of classes: {num_classes})")
                
            except Exception as e:
                print(f"Error converting to tensors: {e}")
                raise
        else:
            raise ValueError("No valid samples found in the dataset")
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'inertial': self.inertial_features[idx],
            'mfcc': self.mfcc_features[idx],
            'label': self.labels[idx]
        }
    
    def get_num_classes(self):
        """Return the number of classes in the dataset"""
        return len(self.label_mapping)

# Model


In [14]:
import torch
import torch.nn as nn

class InertialMFCCTransformer(nn.Module):
    def __init__(self, 
                 inertial_dim=12,      # Inertial features dimension (12 for your 12x12 data)
                 mfcc_dim=13,          # MFCC features dimension (13 for your 13x23 data)
                 num_classes=3,         # Number of output classes
                 d_model=64,           # Hidden dimension
                 nhead=4,              # Number of attention heads
                 num_layers=2,         # Number of transformer layers
                 dropout=0.1):         # Dropout rate
        super().__init__()
        
        # 1. Inertial Branch
        # Convert inertial features to transformer dimension
        self.inertial_embedding = nn.Sequential(
            nn.Linear(inertial_dim, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Inertial transformer encoder
        self.inertial_transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=d_model * 4,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=num_layers
        )
        
        # 2. MFCC Branch
        # Convert MFCC features to transformer dimension
        self.mfcc_embedding = nn.Sequential(
            nn.Linear(mfcc_dim, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # MFCC transformer encoder
        self.mfcc_transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=d_model * 4,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=num_layers
        )
        
        # 3. Fusion Layer
        self.fusion = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # 4. Classification Head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.LayerNorm(d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, num_classes)
        )
        
    def forward(self, inertial_data, mfcc_data):
        """
        Forward pass of the model.
        Args:
            inertial_data: tensor of shape (batch_size, 12, 12)
            mfcc_data: tensor of shape (batch_size, 13, 23)
        Returns:
            output: tensor of shape (batch_size, num_classes)
        """
        # 1. Process Inertial Data
        # Transform and encode inertial features
        x_inertial = self.inertial_embedding(inertial_data)     # (batch_size, 12, d_model)
        x_inertial = self.inertial_transformer(x_inertial)      # (batch_size, 12, d_model)
        # Global average pooling
        x_inertial = torch.mean(x_inertial, dim=1)             # (batch_size, d_model)
        
        # 2. Process MFCC Data
        # Transpose MFCC data to handle time dimension
        x_mfcc = self.mfcc_embedding(mfcc_data.transpose(1, 2)) # (batch_size, 23, d_model)
        x_mfcc = self.mfcc_transformer(x_mfcc)                  # (batch_size, 23, d_model)
        # Global average pooling
        x_mfcc = torch.mean(x_mfcc, dim=1)                     # (batch_size, d_model)
        
        # 3. Mid-Fusion
        # Concatenate features from both modalities
        x_combined = torch.cat([x_inertial, x_mfcc], dim=1)    # (batch_size, d_model*2)
        x_fused = self.fusion(x_combined)                      # (batch_size, d_model)
        
        # 4. Classification
        output = self.classifier(x_fused)                      # (batch_size, num_classes)
        
        return output

# Example usage:
"""
# Initialize model
model = InertialMFCCTransformer(
    inertial_dim=12,     # Input dimension for inertial data
    mfcc_dim=13,         # Input dimension for MFCC data
    num_classes=3,       # Number of classes to predict
    d_model=64,          # Hidden dimension size
    nhead=4,             # Number of attention heads
    num_layers=2,        # Number of transformer layers
    dropout=0.1          # Dropout rate
)

# Forward pass example
batch_size = 32
inertial_input = torch.randn(batch_size, 12, 12)    # Batch of inertial data
mfcc_input = torch.randn(batch_size, 13, 23)        # Batch of MFCC data
output = model(inertial_input, mfcc_input)          # Shape: (batch_size, num_classes)
"""

'\n# Initialize model\nmodel = InertialMFCCTransformer(\n    inertial_dim=12,     # Input dimension for inertial data\n    mfcc_dim=13,         # Input dimension for MFCC data\n    num_classes=3,       # Number of classes to predict\n    d_model=64,          # Hidden dimension size\n    nhead=4,             # Number of attention heads\n    num_layers=2,        # Number of transformer layers\n    dropout=0.1          # Dropout rate\n)\n\n# Forward pass example\nbatch_size = 32\ninertial_input = torch.randn(batch_size, 12, 12)    # Batch of inertial data\nmfcc_input = torch.randn(batch_size, 13, 23)        # Batch of MFCC data\noutput = model(inertial_input, mfcc_input)          # Shape: (batch_size, num_classes)\n'

# Train 

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
from datetime import datetime
import os

class PlaySenseTrainer:
    """
    Trainer class for PlaySense model that handles:
    1. Training loop
    2. Validation
    3. Metrics tracking
    4. Model checkpointing
    """
    def __init__(
        self,
        model,
        data_manager,
        learning_rate=0.001,
        num_epochs=50,
        device=None,
        checkpoint_dir='checkpoints'
    ):
        self.model = model
        self.data_manager = data_manager
        self.num_epochs = num_epochs
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.checkpoint_dir = checkpoint_dir
        
        # Create checkpoint directory
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # Move model to device
        self.model = self.model.to(self.device)
        
        # Initialize criterion and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        
        # Initialize best metrics
        self.best_val_accuracy = 0.0
        
        # Metrics history
        self.history = {
            'train_loss': [], 'train_acc': [], 'train_f1': [],
            'val_loss': [], 'val_acc': [], 'val_f1': []
        }

    def train(self, batch_size=32, train_ratio=0.8):
        """Main training loop"""
        print(f"Starting training on device: {self.device}")
        print(f"Model architecture:\n{self.model}")
        
        # Create data loaders
        train_loader, val_loader = self.data_manager.create_data_loaders(
            batch_size=batch_size,
            train_ratio=train_ratio
        )
        
        # Training loop
        for epoch in range(self.num_epochs):
            # Training phase
            train_metrics = self._train_epoch(train_loader)
            
            # Validation phase
            val_metrics = self._validate(val_loader)
            
            # Update history
            self._update_history(train_metrics, val_metrics)
            
            # Print metrics
            self._print_metrics(epoch, train_metrics, val_metrics)
            
            # Save checkpoint if best model
            if val_metrics['accuracy'] > self.best_val_accuracy:
                self.best_val_accuracy = val_metrics['accuracy']
                self._save_checkpoint(epoch, val_metrics)

    def _train_epoch(self, train_loader):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        predictions = []
        true_labels = []
        
        for batch in train_loader:
            # Move data to device
            inertial = batch['inertial'].to(self.device)
            mfcc = batch['mfcc'].to(self.device)
            labels = batch['label'].to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(inertial, mfcc)
            loss = self.criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            # Track metrics
            running_loss += loss.item()
            predictions.extend(outputs.argmax(dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
        
        # Calculate metrics
        metrics = self._calculate_metrics(predictions, true_labels, running_loss, len(train_loader))
        return metrics

    def _validate(self, val_loader):
        """Validate the model"""
        self.model.eval()
        running_loss = 0.0
        predictions = []
        true_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                # Move data to device
                inertial = batch['inertial'].to(self.device)
                mfcc = batch['mfcc'].to(self.device)
                labels = batch['label'].to(self.device)
                
                # Forward pass
                outputs = self.model(inertial, mfcc)
                loss = self.criterion(outputs, labels)
                
                # Track metrics
                running_loss += loss.item()
                predictions.extend(outputs.argmax(dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        
        # Calculate metrics
        metrics = self._calculate_metrics(predictions, true_labels, running_loss, len(val_loader))
        return metrics

    def _calculate_metrics(self, predictions, true_labels, running_loss, num_batches):
        """Calculate training/validation metrics"""
        accuracy = accuracy_score(true_labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            true_labels, predictions, average='weighted'
        )
        
        return {
            'loss': running_loss / num_batches,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }

    def _update_history(self, train_metrics, val_metrics):
        """Update metrics history"""
        self.history['train_loss'].append(train_metrics['loss'])
        self.history['train_acc'].append(train_metrics['accuracy'])
        self.history['train_f1'].append(train_metrics['f1'])
        self.history['val_loss'].append(val_metrics['loss'])
        self.history['val_acc'].append(val_metrics['accuracy'])
        self.history['val_f1'].append(val_metrics['f1'])

    def _print_metrics(self, epoch, train_metrics, val_metrics):
        """Print current metrics"""
        print(f"\nEpoch {epoch+1}/{self.num_epochs}")
        print("Training Metrics:")
        print(f"Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}")
        print(f"F1: {train_metrics['f1']:.4f}, Precision: {train_metrics['precision']:.4f}")
        print("\nValidation Metrics:")
        print(f"Loss: {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}")
        print(f"F1: {val_metrics['f1']:.4f}, Precision: {val_metrics['precision']:.4f}")

    def _save_checkpoint(self, epoch, metrics):
        """Save model checkpoint"""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        checkpoint_path = os.path.join(
            self.checkpoint_dir,
            f'model_epoch{epoch}_acc{metrics["accuracy"]:.4f}_{timestamp}.pth'
        )
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'metrics': metrics,
            'history': self.history
        }
        
        torch.save(checkpoint, checkpoint_path)
        print(f"\nCheckpoint saved: {checkpoint_path}")


# Execut 

In [28]:
# 1. First, set up the data manager (as shown in previous code)
data_manager = PlaySenseDataManager()

# Add user data
user1_data = UserData(
    df=u1_df,
    audio_splited=u1_audio_splited
)
data_manager.add_user_data(user_id=1, user_data=user1_data)

user3_data = UserData(
    df=u3_df,
    audio_splited=u3_audio_splited
)
data_manager.add_user_data(user_id=3, user_data=user3_data)

user4_data = UserData(
    df=u4_df,
    audio_splited=u4_audio_splited
)
data_manager.add_user_data(user_id=4, user_data=user4_data)

# Merge data and create dataset
data_manager.merge_all_users_data()
data_manager.create_dataset()

# 2. Initialize the model
model = InertialMFCCTransformer(
    inertial_dim=12,     # Derivative features dimension
    mfcc_dim=13,         # MFCC features dimension
    num_classes=data_manager.dataset.get_num_classes(),       # 3 users (1, 3, 4)
    d_model=64,
    nhead=4,
    num_layers=2,
    dropout=0.1
)

# 3. Create trainer instance
trainer = PlaySenseTrainer(
    model=model,
    data_manager=data_manager,
    learning_rate=0.001,
    num_epochs=50,
    checkpoint_dir='checkpoints'
)

# 4. Start training
trainer.train(batch_size=32, train_ratio=0.8)

DEBUG - Processing group 2
DEBUG - Inertial data shape: (60, 6)
DEBUG - Inside _prepare_derivative_features
DEBUG - Input data shape: (60, 6)
DEBUG - Processing group 4
DEBUG - Inertial data shape: (60, 6)
DEBUG - Inside _prepare_derivative_features
DEBUG - Input data shape: (60, 6)
DEBUG - Processing group 6
DEBUG - Inertial data shape: (60, 6)
DEBUG - Inside _prepare_derivative_features
DEBUG - Input data shape: (60, 6)
DEBUG - Processing group 7
DEBUG - Inertial data shape: (60, 6)
DEBUG - Inside _prepare_derivative_features
DEBUG - Input data shape: (60, 6)
DEBUG - Processing group 8
DEBUG - Inertial data shape: (60, 6)
DEBUG - Inside _prepare_derivative_features
DEBUG - Input data shape: (60, 6)
DEBUG - Processing group 9
DEBUG - Inertial data shape: (60, 6)
DEBUG - Inside _prepare_derivative_features
DEBUG - Input data shape: (60, 6)
DEBUG - Processing group 12
DEBUG - Inertial data shape: (60, 6)
DEBUG - Inside _prepare_derivative_features
DEBUG - Input data shape: (60, 6)
DEBUG

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1/50
Training Metrics:
Loss: 1.0945, Accuracy: 0.3864
F1: 0.3700, Precision: 0.3895

Validation Metrics:
Loss: 1.0525, Accuracy: 0.2000
F1: 0.1485, Precision: 0.4413

Checkpoint saved: checkpoints/model_epoch0_acc0.2000_20250411_180520.pth

Epoch 2/50
Training Metrics:
Loss: 1.0363, Accuracy: 0.4886
F1: 0.4763, Precision: 0.5114

Validation Metrics:
Loss: 0.7499, Accuracy: 0.8667
F1: 0.8087, Precision: 0.7641

Checkpoint saved: checkpoints/model_epoch1_acc0.8667_20250411_180520.pth

Epoch 3/50
Training Metrics:
Loss: 0.7906, Accuracy: 0.7102
F1: 0.7002, Precision: 0.7137

Validation Metrics:
Loss: 0.5708, Accuracy: 0.7556
F1: 0.7676, Precision: 0.9137

Epoch 4/50
Training Metrics:
Loss: 0.6371, Accuracy: 0.7443
F1: 0.7374, Precision: 0.7415

Validation Metrics:
Loss: 0.4518, Accuracy: 0.9556
F1: 0.9556, Precision: 0.9556

Checkpoint saved: checkpoints/model_epoch3_acc0.9556_20250411_180520.pth

Epoch 5/50
Training Metrics:
Loss: 0.5703, Accuracy: 0.8182
F1: 0.8135, Precision: 0.