# CMI - Detect Behavior with Sensor Data

## IMU_LSTM_cross_attention_exp_272

In [19]:
import os, warnings, numpy as np, pandas as pd
from pathlib import Path
from copy import deepcopy

from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, ChainedScheduler

from scipy.spatial.transform import Rotation as R
import math

from cmi_2025_metric_copy_for_import import CompetitionMetric
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Configuration

In [None]:
RAW_DIR = Path("data")
EXPORT_DIR = "models/exp_272"
os.makedirs(EXPORT_DIR, exist_ok = True)
EXPORT_DIR = Path(EXPORT_DIR)                                    
BATCH_SIZE = 64
PAD_PERCENTILE = 94 
LR_INIT = 5e-4 
WD = 3e-4 
MIXUP_ALPHA = 0.4 
EPOCHS = 220
PATIENCE = 30
T_0 = 10

ema_decay = 0.99

ACC_COLS = ["acc_x", "acc_y", "acc_z"]
ROT_COLS = ["rot_w", "rot_x", "rot_y", "rot_z"]

FOCAL_LOSS = False
CLASS_WEIGHTS = False
DEL_SUBJ = False

SEED = 42

In [21]:
ids2del = pd.read_csv(RAW_DIR / "ids2del_prop_0.2.csv")

### Utility Functions

In [22]:
class SmoothFocalLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1, gamma=2):
        super().__init__()
        self.smoothing = smoothing
        self.gamma = gamma
        self.classes = classes
        
    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        
        # Create one-hot encoding of actual labels
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.classes - 1))
            true_dist.scatter_(1, target.max(1)[1].unsqueeze(1), 1.0 - self.smoothing)
        
        # Calculate focal weight
        pt = (true_dist * pred).sum(1)
        focal_weight = (1 - pt) ** self.gamma
        
        # Calculate loss
        loss = -torch.sum(true_dist * torch.log(pred + 1e-6), dim=1)
        loss = focal_weight * loss
        
        return loss.mean()

In [23]:
class SignalTransform:
    def __init__(self, always_apply: bool = False, p: float = 0.5):
        self.always_apply = always_apply
        self.p = p

    def __call__(self, y: np.ndarray):
        if self.always_apply:
            return self.apply(y)
        else:
            if np.random.rand() < self.p:
                return self.apply(y)
            else:
                return y

    def apply(self, y: np.ndarray):
        raise NotImplementedError


class Compose:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        for trns in self.transforms:
            y = trns(y)
        return y


class OneOf:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        n_trns = len(self.transforms)
        trns_idx = np.random.choice(n_trns)
        trns = self.transforms[trns_idx]
        return trns(y)
    
class TimeStretch(SignalTransform):
    def __init__(self, max_rate=1.5, min_rate=0.5, always_apply=False, p=0.5):
        super().__init__(always_apply, p)
        self.max_rate = max_rate
        self.min_rate = min_rate
        self.always_apply = always_apply
        self.p = p

    def apply(self, x: np.ndarray):
        """
        Stretch a 1D or 2D array in time using linear interpolation.
        After stretching, pad or crop at the beginning to match the original length.
        Padding value is 0.
        - x: np.ndarray of shape (L,) or (L, N)
        """
        rate = np.random.uniform(self.min_rate, self.max_rate)
        L = x.shape[0]
        L_new = int(L / rate)
        orig_idx = np.linspace(0, L - 1, num=L)
        new_idx = np.linspace(0, L - 1, num=L_new)

        if x.ndim == 1:
            stretched = np.interp(new_idx, orig_idx, x)
            # Pad or crop at the beginning
            if L_new < L:
                # Pad at the beginning
                padded = np.zeros(L, dtype=stretched.dtype)
                padded[-L_new:] = stretched
                return padded
            elif L_new > L:
                # Crop at the beginning
                return stretched[-L:]
            else:
                return stretched
        elif x.ndim == 2:
            stretched = np.stack([
                np.interp(new_idx, orig_idx, x[:, i]) for i in range(x.shape[1])
            ], axis=1)
            if L_new < L:
                padded = np.zeros((L, x.shape[1]), dtype=stretched.dtype)
                padded[-L_new:, :] = stretched
                return padded
            elif L_new > L:
                return stretched[-L:, :]
            else:
                return stretched
        else:
            raise ValueError("Only 1D or 2D arrays are supported.")


class TimeShift(SignalTransform):
    def __init__(self, always_apply=False, p=0.5, max_shift_pct=0.25, padding_mode="replace"):
        super().__init__(always_apply, p)
        
        assert 0 <= max_shift_pct <= 1.0, "`max_shift_pct` must be between 0 and 1"
        assert padding_mode in ["replace", "zero"], "`padding_mode` must be either 'replace' or 'zero'"
        
        self.max_shift_pct = max_shift_pct
        self.padding_mode = padding_mode

    def apply(self, x: np.ndarray, **params):
        assert x.ndim == 2, "`x` must be a 2D array with shape (L, N)"
        
        L = x.shape[0]
        max_shift = int(L * self.max_shift_pct)
        shift = np.random.randint(-max_shift, max_shift + 1)

        # Roll along time axis (axis=0)
        augmented = np.roll(x, shift, axis=0)

        if self.padding_mode == "zero":
            if shift > 0:
                augmented[:shift, :] = 0
            elif shift < 0:
                augmented[shift:, :] = 0

        return augmented

In [24]:
def handle_quaternion_missing_values(rot_data: np.ndarray) -> np.ndarray:
    """
    Handle missing values in quaternion data intelligently
    
    Key insight: Quaternions must have unit length |q| = 1
    If one component is missing, we can reconstruct it from the others
    """
    rot_cleaned = rot_data.copy()
    
    for i in range(len(rot_data)):
        row = rot_data[i]
        missing_count = np.isnan(row).sum()
        
        if missing_count == 0:
            # No missing values, normalize to unit quaternion
            norm = np.linalg.norm(row)
            if norm > 1e-8:
                rot_cleaned[i] = row / norm
            else:
                rot_cleaned[i] = [1.0, 0.0, 0.0, 0.0]  # Identity quaternion
                
        elif missing_count == 1:
            # One missing value, reconstruct using unit quaternion constraint
            # |w|² + |x|² + |y|² + |z|² = 1
            missing_idx = np.where(np.isnan(row))[0][0]
            valid_values = row[~np.isnan(row)]
            
            sum_squares = np.sum(valid_values**2)
            if sum_squares <= 1.0:
                missing_value = np.sqrt(max(0, 1.0 - sum_squares))
                # Choose sign for continuity with previous quaternion
                if i > 0 and not np.isnan(rot_cleaned[i-1, missing_idx]):
                    if rot_cleaned[i-1, missing_idx] < 0:
                        missing_value = -missing_value
                rot_cleaned[i, missing_idx] = missing_value
                rot_cleaned[i, ~np.isnan(row)] = valid_values
            else:
                rot_cleaned[i] = [1.0, 0.0, 0.0, 0.0]
        else:
            # More than one missing value, use identity quaternion
            rot_cleaned[i] = [1.0, 0.0, 0.0, 0.0]
    
    return rot_cleaned

def compute_world_acceleration(acc: np.ndarray, rot: np.ndarray) -> np.ndarray:
    """
    Convert acceleration from device coordinates to world coordinates
    
    This is the key innovation: normalizing for device orientation
    
    Args:
        acc: acceleration in device coordinates, shape (time_steps, 3) [x, y, z]
        rot: rotation quaternion, shape (time_steps, 4) [w, x, y, z] (normalized)
    
    Returns:
        acc_world: acceleration in world coordinates, shape (time_steps, 3)
        
    Why this matters:
    - Device acceleration depends on how the watch is oriented on the wrist
    - World acceleration is independent of device orientation
    - This helps the model focus on actual hand motion rather than wrist rotation
    """
    try:
        # Convert quaternion format from [w, x, y, z] to [x, y, z, w] for scipy
        rot_scipy = rot[:, [1, 2, 3, 0]]
        
        # Verify quaternions are valid (non-zero norm)
        norms = np.linalg.norm(rot_scipy, axis=1)
        if np.any(norms < 1e-8):
            # Replace problematic quaternions with identity
            mask = norms < 1e-8
            rot_scipy[mask] = [0.0, 0.0, 0.0, 1.0]  # Identity quaternion in scipy format
        
        # Create rotation object and apply transformation
        r = R.from_quat(rot_scipy)
        acc_world = r.apply(acc)
        
    except Exception:
        # Fallback to original acceleration if transformation fails
        print("Warning: World coordinate transformation failed, using device coordinates")
        acc_world = acc.copy()
    
    return acc_world

In [25]:
def calculate_angular_velocity_from_quat(rot_data, time_delta=1/200): # Assuming 200Hz sampling rate
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[["rot_x", "rot_y", "rot_z", "rot_w"]].values
    else:
        quat_values = rot_data

    num_samples = quat_values.shape[0]
    angular_vel = np.zeros((num_samples, 3))

    for i in range(num_samples - 1):
        q_t = quat_values[i]
        q_t_plus_dt = quat_values[i+1]

        if np.all(np.isnan(q_t)) or np.all(np.isclose(q_t, 0)) or \
           np.all(np.isnan(q_t_plus_dt)) or np.all(np.isclose(q_t_plus_dt, 0)):
            continue

        try:
            rot_t = R.from_quat(q_t)
            rot_t_plus_dt = R.from_quat(q_t_plus_dt)

            # Calculate the relative rotation
            delta_rot = rot_t.inv() * rot_t_plus_dt
            
            # Convert delta rotation to angular velocity vector
            # The rotation vector (Euler axis * angle) scaled by 1/dt
            # is a good approximation for small delta_rot
            angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
        except ValueError:
            # If quaternion is invalid, angular velocity remains zero
            pass
            
    return angular_vel

def calculate_angular_distance(rot_data):
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[["rot_x", "rot_y", "rot_z", "rot_w"]].values
    else:
        quat_values = rot_data

    num_samples = quat_values.shape[0]
    angular_dist = np.zeros(num_samples)

    for i in range(num_samples - 1):
        q1 = quat_values[i]
        q2 = quat_values[i+1]

        if np.all(np.isnan(q1)) or np.all(np.isclose(q1, 0)) or \
           np.all(np.isnan(q2)) or np.all(np.isclose(q2, 0)):
            angular_dist[i] = 0 # Или np.nan, в зависимости от желаемого поведения
            continue
        try:
            # Преобразование кватернионов в объекты Rotation
            r1 = R.from_quat(q1)
            r2 = R.from_quat(q2)

            # Вычисление углового расстояния: 2 * arccos(|real(p * q*)|)
            # где p* - сопряженный кватернион q
            # В scipy.spatial.transform.Rotation, r1.inv() * r2 дает относительное вращение.
            # Угол этого относительного вращения - это и есть угловое расстояние.
            relative_rotation = r1.inv() * r2
            
            # Угол rotation vector соответствует угловому расстоянию
            # Норма rotation vector - это угол в радианах
            angle = np.linalg.norm(relative_rotation.as_rotvec())
            angular_dist[i] = angle
        except ValueError:
            angular_dist[i] = 0 # В случае недействительных кватернионов
            pass
            
    return angular_dist

In [26]:
def remove_gravity_from_acc(acc_data, rot_data):

    if isinstance(acc_data, pd.DataFrame):
        acc_values = acc_data[["acc_x", "acc_y", "acc_z"]].values
    else:
        acc_values = acc_data

    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[["rot_x", "rot_y", "rot_z", "rot_w"]].values
    else:
        quat_values = rot_data

    num_samples = acc_values.shape[0]
    linear_accel = np.zeros_like(acc_values)
    
    gravity_world = np.array([0, 0, 9.81])

    for i in range(num_samples):
        if np.all(np.isnan(quat_values[i])) or np.all(np.isclose(quat_values[i], 0)):
            linear_accel[i, :] = acc_values[i, :] 
            continue

        try:
            rotation = R.from_quat(quat_values[i])
            gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
            linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
        except ValueError:
             linear_accel[i, :] = acc_values[i, :]
             
    return linear_accel

In [27]:
def add_features(df: pd.DataFrame) -> pd.DataFrame:

    """
    Add world-frame acceleration features to the dataframe.

    This function computes acceleration in world coordinates using quaternion orientation,
    and appends these as new columns to the dataframe.

    Args:
        df (pd.DataFrame): Input dataframe containing sensor data and metadata.

    Returns:
        pd.DataFrame: Dataframe with added world-frame acceleration columns.
    """

    meta_cols = [
        "gesture", "gesture_int", "sequence_type", 
        "behavior", "orientation", "row_id", "subject", 
        "phase", "sequence_id", "sequence_counter"
    ]
    feature_cols_all = [c for c in df.columns if c not in meta_cols]
    other_cols = [c for c in feature_cols_all if (c.startswith("thm_") or c.startswith("tof_"))]
    
    # Add features https://www.kaggle.com/code/rktqwe/lb-0-77-linear-accel-tf-bilstm-gru-attention
    df["acc_mag"] = np.sqrt(df["acc_x"]**2 + df["acc_y"]**2 + df["acc_z"]**2)
    df["acc_mag_jerk"] = df["acc_mag"].diff().fillna(0)

    # Jerk magnitude (rate of change of acceleration)
    df["jerk_x"] = np.gradient(df["acc_x"])
    df["jerk_y"] = np.gradient(df["acc_y"]) 
    df["jerk_z"] = np.gradient(df["acc_z"])
    df["jerk_magnitude"] = np.sqrt(df["jerk_x"]**2 + df["jerk_y"]**2 + df["jerk_z"]**2)
      
    # Correlation between axes (rolling correlation)
    cor_list = []
    window = 20
    for _, group in df.groupby("sequence_id"):
        group["acc_xy_corr"] = group["acc_x"].rolling(window).corr(group["acc_y"]).fillna(0)
        group["acc_xz_corr"] = group["acc_x"].rolling(window).corr(group["acc_z"]).fillna(0)
        group["acc_yz_corr"] = group["acc_y"].rolling(window).corr(group["acc_z"]).fillna(0)
        cor_list.append(group[["acc_xy_corr", "acc_xz_corr", "acc_yz_corr"]].to_numpy())
    cor_list = np.concatenate(cor_list, axis = 0)
    df[["acc_xy_corr", "acc_xz_corr", "acc_yz_corr"]] = cor_list

    df["rot_angle"] = 2 * np.arccos(df["rot_w"].clip(-1, 1))
    df["rot_angle_vel"] = df["rot_angle"].diff().fillna(0)   

    print("Calculating angular velocity from quaternion derivatives...")
    angular_vel_list = []
    for _, group in df.groupby("sequence_id"):
        rot_data_group = group[["rot_x", "rot_y", "rot_z", "rot_w"]]
        angular_vel_group = calculate_angular_velocity_from_quat(rot_data_group)
        angular_vel_list.append(angular_vel_group)
    angular_vel_list = np.concatenate(angular_vel_list, axis = 0)
    df[["angular_vel_x", "angular_vel_y", "angular_vel_z"]] = angular_vel_list
    
    df["angular_vel_magnitude"] = np.sqrt(
        df["angular_vel_x"]**2 + df["angular_vel_y"]**2 + df["angular_vel_x"]**2
    )

    print("Calculating angular distance between successive quaternions...")
    angular_distance_list = []
    for _, group in df.groupby("sequence_id"):
        rot_data_group = group[["rot_x", "rot_y", "rot_z", "rot_w"]]
        angular_dist_group = calculate_angular_distance(rot_data_group)
        angular_distance_list.append(angular_dist_group)
    angular_distance_list = np.concatenate(angular_distance_list, axis = 0)
    df["angular_distance"] = angular_distance_list
    
    ############################################

    ACC_COLS2 = ["acc_x2", "acc_y2", "acc_z2"]

    linear_accel_list = []
    for _, group in df.groupby("sequence_id"):
        acc_data_group = group[ACC_COLS]
        rot_data_group = group[ROT_COLS]
        linear_accel_group = remove_gravity_from_acc(acc_data_group, rot_data_group)
        linear_accel_list.append(linear_accel_group)
    linear_accel_list  = np.concatenate(linear_accel_list , axis = 0)
    df[ACC_COLS2] = linear_accel_list

    # Add features https://www.kaggle.com/code/rktqwe/lb-0-77-linear-accel-tf-bilstm-gru-attention
    df["acc_mag2"] = np.sqrt(df["acc_x2"]**2 + df["acc_y2"]**2 + df["acc_z2"]**2)
    df["acc_mag_jerk2"] = df["acc_mag2"].diff().fillna(0)

    # Jerk magnitude (rate of change of acceleration)
    df["jerk_x2"] = np.gradient(df["acc_x2"])
    df["jerk_y2"] = np.gradient(df["acc_y2"]) 
    df["jerk_z2"] = np.gradient(df["acc_z2"])
    df["jerk_magnitude2"] = np.sqrt(df["jerk_x2"]**2 + df["jerk_y2"]**2 + df["jerk_z2"]**2)
        
    # Correlation between axes (rolling correlation)
    cor_list = []
    window = 20
    for _, group in df.groupby("sequence_id"):
        group["acc_xy_corr2"] = group["acc_x2"].rolling(window).corr(group["acc_y2"]).fillna(0)
        group["acc_xz_corr2"] = group["acc_x2"].rolling(window).corr(group["acc_z2"]).fillna(0)
        group["acc_yz_corr2"] = group["acc_y2"].rolling(window).corr(group["acc_z2"]).fillna(0)
        cor_list.append(group[["acc_xy_corr2", "acc_xz_corr2", "acc_yz_corr2"]].to_numpy())
    cor_list = np.concatenate(cor_list, axis = 0)
    df[["acc_xy_corr2", "acc_xz_corr2", "acc_yz_corr2"]] = cor_list

    df = df[meta_cols 
            + ACC_COLS + ["acc_mag", "acc_mag_jerk"] 
            + ["jerk_x", "jerk_y", "jerk_z", "jerk_magnitude"]
            + ["acc_xy_corr", "acc_xz_corr", "acc_yz_corr"]
            + ROT_COLS + ["rot_angle", "rot_angle_vel"] 
            + ["angular_vel_x", "angular_vel_y", "angular_vel_z"] 
            + ["angular_vel_magnitude", "angular_distance"] 
            + ACC_COLS2 + ["acc_mag2", "acc_mag_jerk2"] 
            + ["jerk_x2", "jerk_y2", "jerk_z2", "jerk_magnitude2"]
            + ["acc_xy_corr2", "acc_xz_corr2", "acc_yz_corr2"] + other_cols]

    df.replace(-np.inf, -1, inplace=True)
    df.replace(np.inf, 1, inplace=True)

    return df


def pad_sequences(
    sequences: list[np.ndarray], 
    maxlen: int, 
    padding: str = "pre", 
    truncating: str = "pre", 
    dtype: str = "float32"
) -> np.ndarray:
    
    """
    Pad sequences to the same length.

    Parameters:
    -----------
    sequences : list of numpy.ndarray
        List of sequences (each sequence is a numpy array) to pad
    maxlen : int
        Maximum length of all sequences. Sequences longer than maxlen will be truncated
    padding : str, optional (default="pre")
        "pre" or "post", pad either before or after each sequence
    truncating : str, optional (default="pre")
        "pre" or "post", remove values from sequences larger than maxlen either at the
        beginning or at the end
    dtype : str, optional (default="float32")
        Type of the output array

    Returns:
    --------
    numpy.ndarray
        Padded sequences array of shape (len(sequences), maxlen, ...)
    """
    
    if padding not in ["pre", "post"]: 
        raise NotImplementedError("Invalid padding")
    if truncating not in ["pre", "post"]: 
        raise NotImplementedError("Invalid truncating")
    
    n_samples = len(sequences)
    if maxlen is None:
        maxlen = max(len(s) for s in sequences)
    
    # Sample shape from first non empty sequence
    # If no non-empty sequences, return array of shape (0, maxlen)
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break
    
    x = np.zeros((n_samples, maxlen) + sample_shape, dtype = dtype)
    
    for idx, s in enumerate(sequences):
        if len(s) == 0:
            continue
        
        if truncating == "pre":
            s = s[-maxlen:] 
        elif truncating == "post":
            s[:maxlen]

        trunc = s.shape[0]
        
        if padding == "pre":
            x[idx, -trunc:] = s
        elif padding == "post":
            x[idx, :trunc] = s
            
    return x

def to_categorical(y, num_classes = None):
    y = np.array(y, dtype = "int")
    if not num_classes:
        num_classes = np.max(y) + 1
    return np.eye(num_classes)[y]

In [28]:
def mirror_quaternion(quat):
    """
    Mirror a single quaternion through the YZ plane.

    Args:
        quat (array of shape (N, 4)): [w, x, y, z]

    Returns:
        mirrored (np.ndarray of shape (N, 4)): mirrored quaternion [w, x, y, z]
    """

    P = np.diag([-1, 1, 1])  # reflection through YZ
    rot = R.from_quat(quat[:, [1, 2, 3, 0]])  # SciPy uses [x, y, z, w]
    R_mat = rot.as_matrix()
    R_flipped = P @ R_mat @ P
    flipped = R.from_matrix(R_flipped).as_quat()
    return flipped[:, [3, 0, 1, 2]]  # back to [w, x, y, z]


def mirror_data(data):
    """
    Mirror left-handed samples to match right-handed frame.

    Args:
        data (np.ndarray of shape (N, 7)): sensor data
    
    Returns:
        A new array with mirrored left-handed samples.

    """
    
    data[:, 0] = -data[:, 0]
    data[:, 3:] = mirror_quaternion(data[:, 3:]) # [w, x, y, z]

    return data

def process_left_handed(df, dem):
    left_handed = dem[dem["handedness"] == 0]
    left_handed = df.loc[df["subject"].isin(left_handed["subject"])]
    cols_to_transform = ["acc_x", "acc_y", "acc_z", "rot_w", "rot_x", "rot_y", "rot_z"]
    left_handed_arr = left_handed[cols_to_transform].to_numpy()
    df.loc[df["subject"].isin(left_handed["subject"]), cols_to_transform] = mirror_data(left_handed_arr)
    return df

In [29]:
class ModelEMA(nn.Module):

    """
    Model Exponential Moving Average (EMA).
    
    Maintains an exponential moving average of model parameters during training.
    This helps improve model stability and performance by creating a temporal
    ensemble of model parameters.

    Parameters:
    -----------
    model : torch.nn.Module
        The model whose parameters will be averaged
    decay : float, optional (default=0.99)
        The decay rate for the exponential moving average.
        Higher values (closer to 1) give more weight to past parameters.
    device : torch.device, optional (default=None)
        The device to store the EMA model on.
        If None, will use the same device as the input model.

    Methods:
    --------
    update(model):
        Updates the EMA parameters using the current model parameters
    set(model):
        Sets the EMA parameters directly from the current model parameters
    """

    def __init__(self, model, decay = 0.99, device = None):
        super().__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device = device)

    def _update(self, model, update_fn):

        """
        Internal method to update EMA parameters.

        Parameters:
        -----------
        model : torch.nn.Module
            The current model
        update_fn : callable
            Function that defines how to update the EMA parameters
        """

        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), 
                                      model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device = self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):

        """
        Update the EMA parameters using the decay rate.
        
        new_parameter = decay * old_parameter + (1 - decay) * current_parameter

        Parameters:
        -----------
        model : torch.nn.Module
            The current model whose parameters will be used to update the EMA
        """

        self._update(model, update_fn = lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):

        """
        Directly set the EMA parameters to the current model parameters.
        
        Parameters:
        -----------
        model : torch.nn.Module
            The model whose parameters will be copied to the EMA
        """

        self._update(model, update_fn = lambda e, m: m)

### Preprocess data

In [30]:
df = pd.read_csv(RAW_DIR / "train.csv")
dem = pd.read_csv(RAW_DIR / "train_demographics.csv")

le = LabelEncoder()
df["gesture_int"] = le.fit_transform(df["gesture"])
np.save(EXPORT_DIR / "gesture_classes.npy", le.classes_)

df[ROT_COLS] = handle_quaternion_missing_values(df[ROT_COLS].to_numpy())
seq_list = []
for _, seq in df.groupby("sequence_id"):
    seq = seq.ffill().bfill().fillna(0)
    seq_list.append(seq)
df = pd.concat(seq_list, axis = 0)

# Mirror IMU data for left-handed
df = process_left_handed(df, dem)

# Add features
df = add_features(df)

Calculating angular velocity from quaternion derivatives...
Calculating angular distance between successive quaternions...


In [31]:
meta_cols = {
    "gesture", "gesture_int", "sequence_type", 
    "behavior", "orientation", "row_id", "subject", 
    "phase", "sequence_id", "sequence_counter"
}
feature_cols_all = [c for c in df.columns if c not in meta_cols]
imu_cols  = [c for c in feature_cols_all if not (c.startswith("thm_") or c.startswith("tof_"))]
np.save(EXPORT_DIR / "feature_cols_imu.npy", np.array(imu_cols))
print(f"{len(imu_cols)} IMU features")

35 IMU features


In [None]:
transforms = Compose([
    TimeShift(p = 0.35, padding_mode = "zero", max_shift_pct = 0.25),
    TimeStretch(p = 0.35, max_rate = 1.5, min_rate = 0.5),
])

class MixupDataset(Dataset):
    def __init__(self, X, y, transforms, alpha = 0.2, mode = "train"):
        self.X = X
        self.y = y
        self.transforms = transforms
        self.alpha = alpha
        self.mode = mode

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):

        X, y = self.X[idx], self.y[idx]

        if self.mode == "train":
            if self.alpha > 0:
                lam = np.random.beta(self.alpha, self.alpha)
                idx2 = np.random.randint(len(self.X))
                X = lam * X + (1 - lam) * self.X[idx2]
                y = lam * y + (1 - lam) * self.y[idx2]
            
            X = self.transforms(X)
        return torch.FloatTensor(X), torch.FloatTensor(y)

### Model training

In [33]:
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y.expand_as(x)

class ResidualSEBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, drop=0.3):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding='same', bias = False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding='same', bias = False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.se = SEBlock(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(drop)
        self.pool = nn.MaxPool1d(2)
        
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, 1, padding='same', bias = False),
                nn.BatchNorm1d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)
        
        out += identity
        out = self.relu(out)
        out = self.pool(out)
        out = self.dropout(out)
        return out

class AttentionLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        weights = self.attention(x).squeeze(-1)
        weights = F.softmax(weights, dim=1).unsqueeze(-1)
        context = torch.sum(x * weights, dim=1)
        return context
    
class CrossAttention(nn.Module):
    """
    Cross-attention module for multi-branch feature fusion.
    Allows each branch to attend to features from other branches.
    """
    def __init__(self, feature_dim, num_heads=8, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.feature_dim = feature_dim
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads
        
        assert feature_dim % num_heads == 0, "feature_dim must be divisible by num_heads"
        
        # Linear projections for queries, keys, and values for each branch
        self.q_linear = nn.Linear(feature_dim, feature_dim, bias=False)
        self.k_linear = nn.Linear(feature_dim, feature_dim, bias=False)
        self.v_linear = nn.Linear(feature_dim, feature_dim, bias=False)
        
        # Output projection
        self.out_linear = nn.Linear(feature_dim, feature_dim)
        self.dropout = nn.Dropout(dropout)
        
        # Layer normalization for residual connection
        self.layer_norm = nn.LayerNorm(feature_dim)
        
    def forward(self, query_branch, key_value_branches):
        """
        Args:
            query_branch: (B, T, C) - the branch that will be updated
            key_value_branches: List of (B, T, C) tensors - branches to attend to
        Returns:
            Updated query_branch with cross-attention: (B, T, C)
        """
        B, T, C = query_branch.shape
        
        # Create queries from the query branch
        Q = self.q_linear(query_branch)  # (B, T, C)
        
        # Concatenate all key-value branches for attention
        all_kv = torch.stack(key_value_branches, dim=1)  # (B, num_branches, T, C)
        num_branches = all_kv.shape[1]
        
        # Reshape for multi-head attention
        all_kv = all_kv.reshape(B, num_branches * T, C)  # (B, num_branches*T, C)
        
        K = self.k_linear(all_kv)  # (B, num_branches*T, C)
        V = self.v_linear(all_kv)  # (B, num_branches*T, C)
        
        # Reshape for multi-head attention
        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, num_heads, T, head_dim)
        K = K.view(B, num_branches * T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, num_heads, num_branches*T, head_dim)
        V = V.view(B, num_branches * T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, num_heads, num_branches*T, head_dim)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, num_heads, T, num_branches*T)
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)  # (B, num_heads, T, head_dim)
        
        # Concatenate heads and put through final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, C)
        output = self.out_linear(attn_output)
        
        # Residual connection and layer normalization
        output = self.layer_norm(query_branch + output)
        
        return output

class IMUCrossAttentionFusion(nn.Module):
    """
    Cross-attention fusion module for three IMU branches.
    Each branch attends to the other two branches.
    """
    def __init__(self, feature_dim=256, num_heads=8, dropout=0.1):
        super(IMUCrossAttentionFusion, self).__init__()
        
        # Cross-attention modules for each branch
        self.cross_attn1 = CrossAttention(feature_dim, num_heads, dropout)
        self.cross_attn2 = CrossAttention(feature_dim, num_heads, dropout)
        self.cross_attn3 = CrossAttention(feature_dim, num_heads, dropout)
        
    def forward(self, imu1, imu2, imu3):
        """
        Args:
            imu1, imu2, imu3: (B, T, C) tensors from each IMU branch
        Returns:
            Tuple of enhanced features: (enhanced_imu1, enhanced_imu2, enhanced_imu3)
        """
        # Each branch attends to the other two branches
        enhanced_imu1 = self.cross_attn1(imu1, [imu2, imu3])
        enhanced_imu2 = self.cross_attn2(imu2, [imu1, imu3])
        enhanced_imu3 = self.cross_attn3(imu3, [imu1, imu2])
        
        return enhanced_imu1, enhanced_imu2, enhanced_imu3
    
class OneBranchModel(nn.Module):
    def __init__(self, imu_dim, n_classes):
        super().__init__()
        self.imu_dim = imu_dim
        
        # IMU branchs
        self.imu_branch1 = nn.Sequential(
            ResidualSEBlock(12, 128, 3, drop=0.3),
            ResidualSEBlock(128, 256, 5, drop=0.3)
        )
        self.imu_branch2 = nn.Sequential(
            ResidualSEBlock(11, 128, 3, drop=0.3),
            ResidualSEBlock(128, 256, 5, drop=0.3)
        )
        self.imu_branch3 = nn.Sequential(
            ResidualSEBlock(12, 128, 3, drop=0.3),
            ResidualSEBlock(128, 256, 5, drop=0.3)
        )
        
        self.cross_attention_fusion = IMUCrossAttentionFusion(
            feature_dim=256, num_heads=8, dropout=0.1
        )

        # BiLSTM
        self.bilstm = nn.LSTM(256*3, 512, num_layers=1, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(0.4)   

        # Attention
        self.attention = AttentionLayer(1024)
        
        # Dense layers
        self.fc_layers = nn.Sequential(
            nn.Linear(1024, 512, bias = False),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256, bias = False),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, n_classes)
        )

    def forward(self, x):
        imu = x[:, :, :self.imu_dim]

        imu1 = self.imu_branch1(imu[:, :, :12].transpose(1, 2)).transpose(1, 2)
        imu2 = self.imu_branch2(imu[:, :, 12:23].transpose(1, 2)).transpose(1, 2)
        imu3 = self.imu_branch3(imu[:, :, 23:].transpose(1, 2)).transpose(1, 2)

        # Apply cross-attention fusion between branches
        imu1, imu2, imu3 = self.cross_attention_fusion(imu1, imu2, imu3)

        merged = torch.cat((imu1, imu2, imu3), dim=2)

        lstm_out, _ = self.bilstm(merged)
        lstm_out = self.dropout(lstm_out)

        attended = self.attention(lstm_out)
 
        out = self.fc_layers(attended)
        return out 

In [None]:
def train_single_fold(fold, train_inds, val_inds):

    print(f"Train fold №{fold}")

    # Split
    df_train, df_val = df.iloc[train_inds], df.iloc[val_inds]
    # Pad/crop length by percentile per fold
    seq_lengths = [len(group) for _, group in df_train.groupby("sequence_id")]
    pad_len = int(np.percentile(seq_lengths, PAD_PERCENTILE))
    np.save(EXPORT_DIR / f"sequence_maxlen_fold{fold}.npy", pad_len)
    
    # Remove problematic sequence if it exists
    df_train = df_train[df_train["sequence_id"] != "SEQ_011975"]
    df_train = df_train[~df_train["sequence_id"].isin(ids2del["sequence_id"])]
    if DEL_SUBJ:
        df_train = df_train[~df_train["subject"].isin(["SUBJ_045235", "SUBJ_019262"])]
    
    df_val = df_val[df_val["sequence_id"] != "SEQ_011975"]
    if DEL_SUBJ:
        df_val = df_val[~df_val["subject"].isin(["SUBJ_045235", "SUBJ_019262"])]

    # For OOF preds
    val_to_save = pd.DataFrame({"sequence_id": df_val["sequence_id"].unique()})

    # Create sequences train
    seq_gp = df_train.groupby("sequence_id")
    X_list, y_list = [], []
    for _, seq in seq_gp:
        y_list.append(seq["gesture_int"].iloc[0])
        seq = seq[imu_cols]
        mat = seq.to_numpy()
        X_list.append(mat)
        

    X_tr = pad_sequences(X_list, maxlen = pad_len, padding = "pre", truncating = "pre")
    y_tr = to_categorical(y_list, num_classes = len(le.classes_))

    # Class weights
    cw_vals = compute_class_weight("balanced", classes = np.arange(len(le.classes_)), y = y_list)

    # Create sequences val
    seq_gp = df_val.groupby("sequence_id")
    X_list, y_list = [], []
    for _, seq in seq_gp:
        y_list.append(seq["gesture_int"].iloc[0])
        seq = seq[imu_cols]
        mat = seq.to_numpy()
        X_list.append(mat)
        

    X_val = pad_sequences(X_list, maxlen = pad_len, padding = "pre", truncating = "pre")
    y_val = to_categorical(y_list, num_classes = len(le.classes_))

    # Model
    model = OneBranchModel(len(imu_cols), len(le.classes_)).to(device)
    ema_model = ModelEMA(model, decay = ema_decay, device = device)
    print(f"Model has {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M parameters")

    optimizer = torch.optim.Adam(model.parameters(), lr = LR_INIT, weight_decay = WD)

    if FOCAL_LOSS:
        criterion = SmoothFocalLoss(classes=len(le.classes_))
    else:
        criterion = nn.CrossEntropyLoss(
            weight = torch.FloatTensor(cw_vals).to(device) if CLASS_WEIGHTS else None,
            label_smoothing = 0.1
        )

    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0 = T_0)
    
    train_dataset = MixupDataset(X_tr, y_tr, alpha = MIXUP_ALPHA, mode = "train", transforms = transforms)
    train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)
    val_dataset = MixupDataset(X_val, y_val, mode = "valid", transforms = transforms)
    val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE)

    best_val_F1 = 0
    patience_counter = 0
    
    print(X_tr.shape)

    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        for batch_X, batch_y in train_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            ema_model.update(model)
        scheduler.step()
        train_loss /= len(train_loader)
            
        # Validation
        val_loss = 0
        outputs = []
        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                out = ema_model.module(batch_X)
                val_loss += criterion(out, batch_y).item()
                out = F.softmax(out, dim=1)
                outputs.append(out.cpu())
            outputs = np.concatenate(outputs)

        val_loss /= len(val_loader)   

        val_F1 = CompetitionMetric().calculate_hierarchical_f1(
            pd.DataFrame({"gesture": le.classes_[y_val.argmax(1)]}),
            pd.DataFrame({"gesture": le.classes_[outputs.argmax(1)]})
        )
        
        log_message = f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f} \
            Val Loss: {val_loss:.4f} Val F1: {val_F1:.3f}"
        print(log_message)
        with open(EXPORT_DIR / f"log_1branch_fold{fold}.txt", "a") as f:
            f.write(log_message + "\n")
    
        if val_F1 > best_val_F1:
            best_val_F1 = val_F1
            patience_counter = 0
            torch.save(
                ema_model.module.state_dict(), 
                EXPORT_DIR / f"one_branch_mixup_best_f1_fold{fold}.pt"
            )
            print(f"Model saved with F1 score: {best_val_F1:.4f}")
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"Early stopping after {epoch+1} epochs")
                break

    val_to_save[le.classes_] = outputs
    val_to_save.to_csv(EXPORT_DIR / f"1branch_val_sequences_fold{fold}.csv", index = False)
    print("Training done - artefacts saved in", EXPORT_DIR)

    return best_val_F1, val_to_save

In [None]:
group_kfold = GroupKFold(n_splits = 5, shuffle = True, random_state = SEED)

best_scores = []
oof_preds = []

for fold, (train_inds, val_inds) in enumerate(group_kfold.split(df, groups = df["subject"])):
    best_score, oof_pred = train_single_fold(fold, train_inds, val_inds)
    best_scores.append(best_score)
    oof_preds.append(oof_pred)

oof_preds = pd.concat(oof_preds, axis = 0)
oof_preds.to_csv(EXPORT_DIR / f"1branch_val_sequences.csv", index = False)

In [None]:
np.array(best_scores).mean()