In [None]:
# !pip install --no-deps /kaggle/input/scikit-learn-1-6-1/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

#!/usr/bin/env python
# coding: utf-8

# # Importable version of https://www.kaggle.com/code/metric/cmi-2025

# In[ ]:


"""
Hierarchical macro F1 metric for the CMI 2025 Challenge.

This script defines a single entry point `score(solution, submission, row_id_column_name)`
that the Kaggle metrics orchestrator will call.
It performs validation on submission IDs and computes a combined binary & multiclass F1 score.
"""

import pandas as pd
from sklearn.metrics import f1_score


class ParticipantVisibleError(Exception):
    """Errors raised here will be shown directly to the competitor."""
    pass


class CompetitionMetric:
    """Hierarchical macro F1 for the CMI 2025 challenge."""
    def __init__(self):
        self.target_gestures = [
            'Above ear - pull hair',
            'Cheek - pinch skin',
            'Eyebrow - pull hair',
            'Eyelash - pull hair',
            'Forehead - pull hairline',
            'Forehead - scratch',
            'Neck - pinch skin',
            'Neck - scratch',
        ]
        self.non_target_gestures = [
            'Write name on leg',
            'Wave hello',
            'Glasses on/off',
            'Text on phone',
            'Write name in air',
            'Feel around in tray and pull out an object',
            'Scratch knee/leg skin',
            'Pull air toward your face',
            'Drink from bottle/cup',
            'Pinch knee/leg skin'
        ]
        self.all_classes = self.target_gestures + self.non_target_gestures

    def calculate_hierarchical_f1(
        self,
        sol: pd.DataFrame,
        sub: pd.DataFrame
    ) -> float:

        # Validate gestures
        invalid_types = {i for i in sub['gesture'].unique() if i not in self.all_classes}
        if invalid_types:
            raise ParticipantVisibleError(
                f"Invalid gesture values in submission: {invalid_types}"
            )

        # Compute binary F1 (Target vs Non-Target)
        y_true_bin = sol['gesture'].isin(self.target_gestures).values
        y_pred_bin = sub['gesture'].isin(self.target_gestures).values
        f1_binary = f1_score(
            y_true_bin,
            y_pred_bin,
            pos_label=True,
            zero_division=0,
            average='binary'
        )

        # Build multi-class labels for gestures
        y_true_mc = sol['gesture'].apply(lambda x: x if x in self.target_gestures else 'non_target')
        y_pred_mc = sub['gesture'].apply(lambda x: x if x in self.target_gestures else 'non_target')

        # Compute macro F1 over all gesture classes
        f1_macro = f1_score(
            y_true_mc,
            y_pred_mc,
            average='macro',
            zero_division=0
        )

        return 0.5 * f1_binary + 0.5 * f1_macro


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str
) -> float:
    """
    Compute hierarchical macro F1 for the CMI 2025 challenge.

    Expected input:
      - solution and submission as pandas.DataFrame
      - Column 'sequence_id': unique identifier for each sequence
      - 'gesture': one of the eight target gestures or "Non-Target"

    This metric averages:
    1. Binary F1 on SequenceType (Target vs Non-Target)
    2. Macro F1 on gesture (mapping non-targets to "Non-Target")

    Raises ParticipantVisibleError for invalid submissions,
    including invalid SequenceType or gesture values.


    Examples
    --------
    >>> import pandas as pd
    >>> row_id_column_name = "id"
    >>> solution = pd.DataFrame({'id': range(4), 'gesture': ['Eyebrow - pull hair']*4})
    >>> submission = pd.DataFrame({'id': range(4), 'gesture': ['Forehead - pull hairline']*4})
    >>> score(solution, submission, row_id_column_name=row_id_column_name)
    0.5
    >>> submission = pd.DataFrame({'id': range(4), 'gesture': ['Text on phone']*4})
    >>> score(solution, submission, row_id_column_name=row_id_column_name)
    0.0
    >>> score(solution, solution, row_id_column_name=row_id_column_name)
    1.0
    """
    # Validate required columns
    for col in (row_id_column_name, 'gesture'):
        if col not in solution.columns:
            raise ParticipantVisibleError(f"Solution file missing required column: '{col}'")
        if col not in submission.columns:
            raise ParticipantVisibleError(f"Submission file missing required column: '{col}'")

    metric = CompetitionMetric()
    return metric.calculate_hierarchical_f1(solution, submission)


import torch
import torch.distributed as dist


def zeropower_via_newtonschulz5(G, steps: int):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X
    
    if G.size(-2) > G.size(-1):
        X = X.mT
    return X


def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
    momentum.lerp_(grad, 1 - beta)
    update = grad.lerp_(momentum, beta) if nesterov else momentum
    if update.ndim == 4: # for the case of conv filters
        update = update.view(len(update), -1)
    update = zeropower_via_newtonschulz5(update, steps=ns_steps)
    update *= max(1, grad.size(-2) / grad.size(-1))**0.5
    return update


class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    https://kellerjordan.github.io/posts/muon/

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
    advantage that it can be stably run in bfloat16 on the GPU.

    Muon should only be used for hidden weight layers. The input embedding, final output layer,
    and any internal gains or biases should be optimized using a standard method such as AdamW.
    Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
    collapsing their last 3 dimensions.

    Arguments:
        lr: The learning rate, in units of spectral norm per update.
        weight_decay: The AdamW-style weight decay.
        momentum: The momentum. A value of 0.95 here is usually fine.
    """
    def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
        assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
        params = sorted(params, key=lambda x: x.size(), reverse=True)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params = group["params"]
            params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
            for base_i in range(len(params))[::dist.get_world_size()]:
                if base_i + dist.get_rank() < len(params):
                    p = params[base_i + dist.get_rank()]
                    state = self.state[p]
                    if len(state) == 0:
                        state["momentum_buffer"] = torch.zeros_like(p)
                    update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update.reshape(p.shape), alpha=-group["lr"])
                dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])

        return loss


class SingleDeviceMuon(torch.optim.Optimizer):
    """
    Muon variant for usage in non-distributed settings.
    """
    def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if len(state) == 0:
                    state["momentum_buffer"] = torch.zeros_like(p)
                update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                p.mul_(1 - group["lr"] * group["weight_decay"])
                p.add_(update.reshape(p.shape), alpha=-group["lr"])

        return loss


def adam_update(grad, buf1, buf2, step, betas, eps):
    buf1.lerp_(grad, 1 - betas[0])
    buf2.lerp_(grad.square(), 1 - betas[1])
    buf1c = buf1 / (1 - betas[0]**step)
    buf2c = buf2 / (1 - betas[1]**step)
    return buf1c / (buf2c.sqrt() + eps)


class MuonWithAuxAdam(torch.optim.Optimizer):
    """
    Distributed Muon variant that can be used for all parameters in the network, since it runs an
    internal AdamW for the parameters that are not compatible with Muon. The user must manually
    specify which parameters shall be optimized with Muon and which with Adam by passing in a
    list of param_groups with the `use_muon` flag set.

    The point of this class is to allow the user to have a single optimizer in their code, rather
    than having both a Muon and an Adam which each need to be stepped.

    You can see an example usage below:

    https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
    ```
    hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
    embed_params = [p for n, p in model.named_parameters() if "embed" in n]
    scalar_params = [p for p in model.parameters() if p.ndim < 2]
    head_params = [model.lm_head.weight]

    from muon import MuonWithAuxAdam
    adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
    adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
    muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
    param_groups = [*adam_groups, muon_group]
    optimizer = MuonWithAuxAdam(param_groups)
    ```
    """
    def __init__(self, param_groups):
        for group in param_groups:
            assert "use_muon" in group
            if group["use_muon"]:
                group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
                # defaults
                group["lr"] = group.get("lr", 0.02)
                group["momentum"] = group.get("momentum", 0.95)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
            else:
                # defaults
                group["lr"] = group.get("lr", 3e-4)
                group["betas"] = group.get("betas", (0.9, 0.95))
                group["eps"] = group.get("eps", 1e-10)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
        super().__init__(param_groups, dict())

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if group["use_muon"]:
                params = group["params"]
                params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
                for base_i in range(len(params))[::dist.get_world_size()]:
                    if base_i + dist.get_rank() < len(params):
                        p = params[base_i + dist.get_rank()]
                        state = self.state[p]
                        if len(state) == 0:
                            state["momentum_buffer"] = torch.zeros_like(p)
                        update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                        p.mul_(1 - group["lr"] * group["weight_decay"])
                        p.add_(update.reshape(p.shape), alpha=-group["lr"])
                    dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
            else:
                for p in group["params"]:
                    state = self.state[p]
                    if len(state) == 0:
                        state["exp_avg"] = torch.zeros_like(p)
                        state["exp_avg_sq"] = torch.zeros_like(p)
                        state["step"] = 0
                    state["step"] += 1
                    update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
                                         state["step"], group["betas"], group["eps"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update, alpha=-group["lr"])

        return loss


class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
    """
    Non-distributed variant of MuonWithAuxAdam.
    """
    def __init__(self, param_groups):
        for group in param_groups:
            assert "use_muon" in group
            if group["use_muon"]:
                # defaults
                group["lr"] = group.get("lr", 0.02)
                group["momentum"] = group.get("momentum", 0.95)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
            else:
                # defaults
                group["lr"] = group.get("lr", 3e-4)
                group["betas"] = group.get("betas", (0.9, 0.95))
                group["eps"] = group.get("eps", 1e-10)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
        super().__init__(param_groups, dict())

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if group["use_muon"]:
                for p in group["params"]:
                    if p.grad is None:
                        continue
                        
                    state = self.state[p]
                    if len(state) == 0:
                        state["momentum_buffer"] = torch.zeros_like(p)
                    update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update.reshape(p.shape), alpha=-group["lr"])
            else:
                for p in group["params"]:
                    if p.grad is None:
                        continue
                    state = self.state[p]
                    if len(state) == 0:
                        state["exp_avg"] = torch.zeros_like(p)
                        state["exp_avg_sq"] = torch.zeros_like(p)
                        state["step"] = 0
                    state["step"] += 1
                    update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
                                         state["step"], group["betas"], group["eps"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update, alpha=-group["lr"])

        return loss

from torch.optim.lr_scheduler import _LRScheduler
class ConstantCosineLR(_LRScheduler):
    """
    Constant learning rate followed by CosineAnnealing.
    """
    def __init__(
        self, 
        optimizer,
        total_steps, 
        pct_cosine, 
        last_epoch=-1,
        ):
        self.total_steps = total_steps
        self.milestone = int(total_steps * (pct_cosine))
        self.cosine_steps = max(total_steps - self.milestone, 1)
        self.min_lr = 0
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        step = self.last_epoch + 1
        if step <= self.milestone:
            factor = 1.
        else:
            s = step - self.milestone
            factor = 0.5 * (1 + math.cos(math.pi * s / self.cosine_steps))
        return [lr * factor for lr in self.base_lrs]

In [None]:
import os, json, joblib, numpy as np, pandas as pd
import random, math
from pathlib import Path
import warnings 
warnings.filterwarnings("ignore")

import sys
sys.path.append('/root/autodl-tmp/')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import StratifiedKFold
from timm.scheduler import CosineLRScheduler
from scipy.signal import firwin
import polars as pl
import numpy as np
import random


import numpy as np
import pandas as pd
from scipy.spatial.transform import Rotation as R
from tqdm.auto import tqdm 


class SignalTransform:
    def __init__(self, always_apply: bool = False, p: float = 0.5):
        self.always_apply = always_apply
        self.p = p

    def __call__(self, y: np.ndarray):
        if self.always_apply:
            return self.apply(y)
        else:
            if np.random.rand() < self.p:
                return self.apply(y)
            else:
                return y

    def apply(self, y: np.ndarray):
        raise NotImplementedError


class Compose:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        for trns in self.transforms:
            y = trns(y)
        return y


class OneOf:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        n_trns = len(self.transforms)
        trns_idx = np.random.choice(n_trns)
        trns = self.transforms[trns_idx]
        return trns(y)

class TimeStretch(SignalTransform):
    def __init__(self, max_rate=1.5, min_rate=0.5, always_apply=False, p=0.5):
        super().__init__(always_apply, p)
        self.max_rate = max_rate
        self.min_rate = min_rate
        self.always_apply = always_apply
        self.p = p

    def apply(self, x: np.ndarray):
        """
        Stretch a 1D or 2D array in time using linear interpolation.
        After stretching, pad or crop at the beginning to match the original length.
        Padding value is 0.
        - x: np.ndarray of shape (L,) or (L, N)
        """
        rate = np.random.uniform(self.min_rate, self.max_rate)
        L = x.shape[0]
        L_new = int(L / rate)
        orig_idx = np.linspace(0, L - 1, num=L)
        new_idx = np.linspace(0, L - 1, num=L_new)

        if x.ndim == 1:
            stretched = np.interp(new_idx, orig_idx, x)
            # Pad or crop at the beginning
            if L_new < L:
                # Pad at the beginning
                padded = np.zeros(L, dtype=stretched.dtype)
                padded[-L_new:] = stretched
                return padded
            elif L_new > L:
                # Crop at the beginning
                return stretched[-L:]
            else:
                return stretched
        elif x.ndim == 2:
            stretched = np.stack([
                np.interp(new_idx, orig_idx, x[:, i]) for i in range(x.shape[1])
            ], axis=1)
            if L_new < L:
                padded = np.zeros((L, x.shape[1]), dtype=stretched.dtype)
                padded[-L_new:, :] = stretched
                return padded
            elif L_new > L:
                return stretched[-L:, :]
            else:
                return stretched
        else:
            raise ValueError("Only 1D or 2D arrays are supported.")


class TimeShift(SignalTransform):
    def __init__(self, always_apply=False, p=0.5, max_shift_pct=0.25, padding_mode="replace"):
        super().__init__(always_apply, p)
        
        assert 0 <= max_shift_pct <= 1.0, "`max_shift_pct` must be between 0 and 1"
        assert padding_mode in ["replace", "zero"], "`padding_mode` must be either 'replace' or 'zero'"
        
        self.max_shift_pct = max_shift_pct
        self.padding_mode = padding_mode

    def apply(self, x: np.ndarray, **params):
        assert x.ndim == 2, "`x` must be a 2D array with shape (L, N)"
        
        L = x.shape[0]
        max_shift = int(L * self.max_shift_pct)
        shift = np.random.randint(-max_shift, max_shift + 1)

        # Roll along time axis (axis=0)
        augmented = np.roll(x, shift, axis=0)

        if self.padding_mode == "zero":
            if shift > 0:
                augmented[:shift, :] = 0
            elif shift < 0:
                augmented[shift:, :] = 0

        return augmented

transforms_custom = Compose([
    TimeShift(p=0.25, padding_mode="replace", max_shift_pct=0.25),
    TimeStretch(p=0.25, max_rate=1.5, min_rate=0.5),
])



if os.path.exists("../input/cmi-detect-behavior-with-sensor-data"):
    test_path1 = '/kaggle/input/cmi-detect-behavior-with-sensor-data/test.csv'
    test_path2 = '/kaggle/input/cmi-detect-behavior-with-sensor-data/test_demographics.csv'
else:
    if os.path.exists("./cmi-detect-behavior-with-sensor-data/"):
        test_path1 = './cmi-detect-behavior-with-sensor-data/test.csv'
        test_path2 = './cmi-detect-behavior-with-sensor-data/test_demographics.csv'
    else:
        test_path1 = './test.csv'
        test_path2 = './test_demographics.csv'

class model_zhou_v26:
    def __init__(self, kaggle_input_path="/kaggle/input/cmi3v23", seed=42, save_path='./'):
        self.models = []

        import os
        if os.path.exists("../input/cmi-detect-behavior-with-sensor-data"):
            self.TRAIN = False                     
            self.RAW_DIR = Path("../input/cmi-detect-behavior-with-sensor-data")
            self.PRETRAINED_DIR = Path(kaggle_input_path) 
            self.EXPORT_DIR = Path(save_path)                                   
        else:
            if os.path.exists("./cmi-detect-behavior-with-sensor-data/"):
                self.TRAIN = True                     
                self.RAW_DIR = Path("./cmi-detect-behavior-with-sensor-data/")
                self.PRETRAINED_DIR = Path(kaggle_input_path) 
                self.EXPORT_DIR = Path(save_path)                                  
            else:
                self.TRAIN = True                    
                self.RAW_DIR = Path("./")
                self.EXPORT_DIR = Path(save_path)
                self.PRETRAINED_DIR = Path(kaggle_input_path) 

        if not os.path.exists(self.EXPORT_DIR):
            os.system(f'mkdir {self.EXPORT_DIR}')
            
        self.VALIDATION = False
        self.SEED = seed

        self.BATCH_SIZE = 64 * 1
        self.PAD_PERCENTILE = 128 + 64
        self.maxlen = self.PAD_PERCENTILE
        self.LR_INIT = 1e-3
        self.WD = 2e-1
        
        self.PATIENCE = 40
        self.random_state = self.SEED
        
        self.tof_mode = 4
        
        self.FOLDS = 5
        self.TRAIN_FOLDS = [0, 1, 2, 3, 4,]
        
        self.EPOCHS = 160

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"▶ imports ready · pytorch {torch.__version__} · device: {self.device}")
    
    def create_model(self, param_a, param_b, param_c, param_d):        
        def get_gravity(quat):
            g=9.81
             # 四元数旋转
            x = quat[..., 0:1]  # rot_x
            y = quat[..., 1:2]  # rot_y
            z = quat[..., 2:3]  # rot_z
            w = quat[..., 3:4]  # rot_w
            g_x = 2 * g * (x * z - w * y)
            g_y = 2 * g * (y * z + w * x)
            g_z = g * (w**2 - x**2 - y**2 + z**2)
            gravity = np.concatenate([g_x, g_y, g_z], -1)
            return gravity
            
        def rotate_quaternion_np(quat: np.ndarray, angle_range: tuple = [(-10, 10), (-10, 10), (-10, 10), ]) -> np.ndarray:
            # 确保输入形状正确
            if quat.ndim != 2 or quat.shape[1] != 4:
                raise ValueError("输入四元数必须是形状为[len, 4]的数组")
            
            seq_len = quat.shape[0]
            
            # 1. 生成绕x、y、z轴的随机旋转角度（度）
            angle_x = np.random.uniform(angle_range[0][0], angle_range[0][1])
            angle_y = np.random.uniform(angle_range[1][0], angle_range[1][1])
            angle_z = np.random.uniform(angle_range[2][0], angle_range[2][1])
            
            # 2. 将角度转换为弧度
            rad_x = math.pi * angle_x / 180.0
            rad_y = math.pi * angle_y / 180.0
            rad_z = math.pi * angle_z / 180.0
            
            # 3. 生成绕各轴旋转的四元数
            qx = np.array([math.sin(rad_x/2), 0, 0, math.cos(rad_x/2)], dtype=np.float32)  # x轴旋转
            qy = np.array([0, math.sin(rad_y/2), 0, math.cos(rad_y/2)], dtype=np.float32)  # y轴旋转
            qz = np.array([0, 0, math.sin(rad_z/2), math.cos(rad_z/2)], dtype=np.float32)  # z轴旋转
            
            # 4. 组合旋转四元数 (q_total = qz * qy * qx，注意乘法顺序)
            q_zy = _quat_mul(qz, qy)
            q_total = _quat_mul(q_zy, qx)
            
            # 5. 将组合旋转应用到每个四元数上
            # 扩展旋转四元数以匹配序列长度
            q_total_expanded = np.tile(q_total, (seq_len, 1))  # [len, 4]
            rotated_quat = _quat_mul(q_total_expanded, quat)
            
            # 6. 归一化确保四元数有效性
            norms = np.linalg.norm(rotated_quat, axis=1, keepdims=True)
            rotated_quat = rotated_quat / np.maximum(norms, 1e-8)  # 防止除零
            
            return rotated_quat
        
        def _quat_mul(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
            # 提取分量（利用广播机制自动适配维度）
            x1, y1, z1, w1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
            x2, y2, z2, w2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
            
            # 四元数乘法公式（完全向量化）
            x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
            y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
            z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
            w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
            
            # 堆叠结果（保持维度）
            return np.stack([x, y, z, w], axis=-1).astype(q1.dtype)
        
        
        def remove_average_rotation_optimized(q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
            """
            从四元数序列中移除平均旋转，忽略全零四元数，并旋转到指定方向
            Args:
                q: 输入四元数序列，形状 [bs, len, 4]，顺序 [rot_x, rot_y, rot_z, rot_w]
                eps: 数值稳定性阈值
            Returns:
                q_final: 移除平均旋转并旋转到[0.5,0.5,0.5,0.5]方向后的四元数序列，形状 [bs, len, 4]
                         原始全零的位置仍为零
            """
            def is_zero_quaternion(q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
                """
                判断四元数是否为全零（或接近全零）
                Args:
                    q: 四元数，形状 [..., 4]
                    eps: 接近零的阈值
                Returns:
                    布尔张量，形状 [...], 指示每个四元数是否为零
                """
                return torch.norm(q, dim=-1) < eps
            
            def safe_normalize(q: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
                """
                安全地规范化四元数，避免除零错误
                Args:
                    q: 四元数，形状 [..., 4]
                    eps: 最小范数阈值
                Returns:
                    规范化后的四元数
                """
                norm = torch.norm(q, dim=-1, keepdim=True)
                # 添加最小范数保护
                norm = torch.clamp(norm, min=eps)
                return q / norm
            
            def quat_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
                """
                批量四元数乘法 (q1 * q2)
                Args:
                    q1: 形状 [..., 4]，四元数 [x1,y1,z1,w1]
                    q2: 形状 [..., 4]，四元数 [x2,y2,z2,w2]
                Returns:
                    q: 形状 [..., 4]，乘积 q1*q2 [x,y,z,w]
                """
                x1, y1, z1, w1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
                x2, y2, z2, w2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
                
                x = w1*x2 + x1*w2 + y1*z2 - z1*y2
                y = w1*y2 - x1*z2 + y1*w2 + z1*x2
                z = w1*z2 + x1*y2 - y1*x2 + z1*w2
                w = w1*w2 - x1*x2 - y1*y2 - z1*z2
                
                return torch.stack([x, y, z, w], dim=-1)
            
            def find_first_valid_indices(mask: torch.Tensor) -> torch.Tensor:
                """
                向量化方法查找每个batch中第一个有效四元数的索引
                Args:
                    mask: 有效掩码，形状 [bs, len]
                Returns:
                    indices: 每个batch的第一个有效索引，形状 [bs]
                """
                bs, seq_len = mask.shape
                device = mask.device
                
                # 创建索引矩阵
                indices = torch.arange(seq_len, device=device).expand(bs, seq_len)
                # 将无效位置的索引设置为大数
                indices = torch.where(mask, indices, seq_len)
                # 找到最小索引（第一个有效位置）
                first_valid = torch.min(indices, dim=1)[0]
                # 处理全无效的情况
                all_invalid = first_valid == seq_len
                first_valid[all_invalid] = 0  # 设置为0，后续会处理
                
                return first_valid
            
            def quaternion_average_frechet_optimized(q: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
                """
                计算批量四元数的Frechet均值（优化版本）
                Args:
                    q: 输入四元数，形状 [bs, len, 4]，顺序 [rot_x, rot_y, rot_z, rot_w]
                    mask: 有效四元数掩码，形状 [bs, len]，1表示有效，0表示无效
                    eps: 数值稳定性阈值
                Returns:
                    q_avg: 平均四元数，形状 [bs, 4]
                """
                bs, len_q, _ = q.shape
                device = q.device
                
                # 步骤1：安全地单位化四元数
                q_normalized = safe_normalize(q, eps=eps)
                
                # 应用掩码：无效四元数置为零
                mask_keepdim = mask.unsqueeze(-1)  # [bs, len, 1]
                q_normalized = q_normalized * mask_keepdim
                
                # 步骤2：向量化查找参考四元数
                first_valid_indices = find_first_valid_indices(mask)
                q_ref = q_normalized[torch.arange(bs, device=device), first_valid_indices].unsqueeze(1)  # [bs, 1, 4]
                
                # 步骤3：改进的符号统一策略（添加阈值避免在接近正交时翻转）
                dot = (q_normalized * q_ref).sum(dim=-1, keepdim=True)  # [bs, len, 1]
                q_unified = torch.where(dot < -0.01, -q_normalized, q_normalized)
                
                # 步骤4：构建协方差矩阵 M [bs, 4, 4]，考虑掩码
                q_reshaped = q_unified.unsqueeze(-1)  # [bs, len, 4, 1]
                q_transposed = q_reshaped.transpose(2, 3)  # [bs, len, 1, 4]
                outer_products = q_reshaped @ q_transposed  # [bs, len, 4, 4]
                
                # 应用掩码并求和
                mask_4d = mask.unsqueeze(-1).unsqueeze(-1)  # [bs, len, 1, 1]
                outer_products = outer_products * mask_4d
                M = outer_products.sum(dim=1)  # [bs, 4, 4]
                
                # 处理所有四元数都无效的情况
                valid_count = mask.sum(dim=1)  # [bs]
                all_invalid = valid_count == 0
                if all_invalid.any():
                    # 对全无效的batch，使用单位四元数作为平均
                    M[all_invalid] = torch.eye(4, device=device)
                
                # 步骤5：使用更稳定的特征值求解
                # 添加小单位矩阵增强数值稳定性
                M_stable = M + eps * torch.eye(4, device=device).unsqueeze(0)
                w, v = torch.linalg.eigh(M_stable)  # w: [bs,4], v: [bs,4,4]
                q_avg = v[:, :, -1]  # 最大特征值对应特征向量
                
                # 确保单位化 + 实部(rot_w)为正
                q_avg = safe_normalize(q_avg, eps=eps)
                q_avg = torch.where(q_avg[:, 3:4] < 0, -q_avg, q_avg)
                
                return q_avg
            
            bs, len_q, _ = q.shape
            device = q.device
            
            # 1. 识别全零四元数
            zero_mask = is_zero_quaternion(q, eps)  # [bs, len]
            valid_mask = ~zero_mask  # [bs, len]，1表示有效四元数
            
            # 2. 计算有效四元数的平均旋转
            q_avg = quaternion_average_frechet_optimized(q, valid_mask, eps)  # [bs, 4]
            
            # 3. 创建目标四元数 [0.5, 0.5, 0.5, 0.5] 并规范化
            target_quat = torch.tensor([0, 0, 0, 1], device=device, dtype=q.dtype)
            target_quat = safe_normalize(target_quat)
            
            # 4. 将目标旋转整合到平均旋转中
            # 计算目标旋转与平均旋转的复合旋转
            # q_composite = target_quat * q_avg_conj
            # 这等价于先应用平均旋转的逆，然后应用目标旋转
            
            # 计算平均四元数的共轭（逆）
            q_avg_conj = q_avg.clone()
            q_avg_conj[:, :3] *= -1  # 虚部取反：[x,y,z,w] → [-x,-y,-z,w]
            
            # 将目标旋转与平均旋转的共轭相乘
            q_composite = quat_multiply(
                target_quat.unsqueeze(0).expand(bs, 4),  # 扩展目标四元数以匹配批次大小
                q_avg_conj
            )  # [bs, 4]
            
            q_composite = safe_normalize(q_composite)  # 确保复合旋转是单位四元数
            q_composite = q_composite.unsqueeze(1)  # [bs, 1, 4]，便于广播
            
            # 5. 应用复合旋转：q_final = q_composite * q
            q_final = quat_multiply(q_composite, q)  # [bs, len, 4]
            
            # 6. 保持原始全零位置为零
            q_final = q_final * valid_mask.unsqueeze(-1)  # [bs, len, 4]
            
            return q_final

        
        class MotionFeatureExtractor(nn.Module):
            """
            支持CUDA/CPU输入的运动特征提取模块
            输入: [bs, len, 7] 张量（[acc_x, acc_y, acc_z, rot_x, rot_y, rot_z, rot_w]）
            输出: [bs, len, 7] 张量（[线性加速度x/y/z, 角速度x/y/z, 角距离]）
            """
            def __init__(self, time_delta=1/200):
                super().__init__()
                self.time_delta = time_delta
                # ！修改1：不提前固定gravity_world的设备，改为使用时动态匹配输入设备
                self.gravity_world_val = torch.tensor([0.0, 0.0, 9.81], dtype=torch.float32)
                
            def quat_to_rot_matrix(self, quat):
                """四元数（[x,y,z,w]）转旋转矩阵 [bs, len, 3, 3]"""
                x, y, z, w = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3]
                
                xx = x * x
                yy = y * y
                zz = z * z
                xy = x * y
                xz = x * z
                yz = y * z
                xw = x * w
                yw = y * w
                zw = z * w
                ww = w * w
                
                # 旋转矩阵计算（基于输入quat的设备，自动在CUDA/CPU上运算）
                rot_mat = torch.stack([
                    xx + ww - (yy + zz), 2 * (xy - zw), 2 * (xz + yw),
                    2 * (xy + zw), yy + ww - (xx + zz), 2 * (yz - xw),
                    2 * (xz - yw), 2 * (yz + xw), zz + ww - (xx + yy)
                ], dim=-1).view(*quat.shape[:-1], 3, 3)
                
                return rot_mat
            
            def rot_matrix_inverse(self, rot_mat):
                """旋转矩阵求逆（逆=转置，设备与输入一致）"""
                return rot_mat.transpose(-2, -1)
            
            def rot_matrix_multiply(self, rot_mat1, rot_mat2):
                """旋转矩阵乘法（设备自动匹配）"""
                return torch.matmul(rot_mat1, rot_mat2)
            
            def apply_rotation(self, rot_mat, vec):
                """旋转矩阵应用于向量（设备自动匹配）"""
                return torch.matmul(rot_mat, vec.unsqueeze(-1)).squeeze(-1)
            
            def remove_gravity_from_acc(self, acc_data, rot_data):
                """从加速度中移除重力（核心：设备动态匹配）"""
                bs, seq_len, _ = acc_data.shape
                device = acc_data.device  # ！修改2：获取输入设备（CUDA/CPU）
                
                valid_quats_normalized = rot_data / torch.norm(rot_data, dim=-1, keepdim=True).clamp(min=1e-8)
                rotations = self.quat_to_rot_matrix(valid_quats_normalized)
                
                # ！修改3：gravity_world动态匹配输入设备，避免CPU/CUDA不兼容
                gravity_world = self.gravity_world_val.to(device).expand(bs, seq_len, 3)
                # 重力向量从世界坐标系转到传感器坐标系
                gravity_sensor_frame = self.apply_rotation(
                    self.rot_matrix_inverse(rotations),
                    gravity_world
                )
                
                # 线性加速度计算（设备与输入一致）
                linear_accel = acc_data - gravity_sensor_frame
                return linear_accel
            
            def calculate_angular_velocity_from_quat(self, rot_data):
                """从四元数计算角速度（设备动态匹配）"""
                bs, seq_len, _ = rot_data.shape
                device = rot_data.device  # ！修改4：获取输入设备
                angular_vel = torch.zeros(bs, seq_len, 3, device=device)  # 零张量指定设备
                
                if seq_len < 2:
                    return angular_vel
                
                # 四元数有效性判断（torch.tensor(0.0)改为匹配设备）
                quat_norms = torch.norm(rot_data, dim=-1)
                # ！修改5：用rot_data.new_zeros(1)创建同设备的0张量，替代硬编码的torch.tensor(0.0)
                valid_mask = ~(torch.isnan(quat_norms) | torch.isclose(quat_norms, rot_data.new_zeros(1)))
                valid_pairs_mask = valid_mask[:, :-1] & valid_mask[:, 1:]
                
                if torch.any(valid_pairs_mask):
                    # 提取有效四元数对
                    q_t = rot_data[:, :-1][valid_pairs_mask]
                    q_t_plus_dt = rot_data[:, 1:][valid_pairs_mask]
                    
                    # 归一化
                    q_t_norms = quat_norms[:, :-1][valid_pairs_mask].unsqueeze(-1)
                    q_t_plus_dt_norms = quat_norms[:, 1:][valid_pairs_mask].unsqueeze(-1)
                    q_t_norm = q_t / q_t_norms.clamp(min=1e-8)
                    q_t_plus_dt_norm = q_t_plus_dt / q_t_plus_dt_norms.clamp(min=1e-8)
                    
                    # 旋转矩阵与相对旋转计算
                    rot_t = self.quat_to_rot_matrix(q_t_norm)
                    rot_t_plus_dt = self.quat_to_rot_matrix(q_t_plus_dt_norm)
                    delta_rot = self.rot_matrix_multiply(self.rot_matrix_inverse(rot_t), rot_t_plus_dt)
                    
                    # 罗德里格斯公式求旋转向量（设备自动匹配）
                    trace = delta_rot[..., 0, 0] + delta_rot[..., 1, 1] + delta_rot[..., 2, 2]
                    theta = torch.acos(torch.clamp((trace - 1) / 2, -1.0, 1.0))
                    sin_theta = torch.sin(theta)
                    
                    # 旋转向量计算（零张量指定设备）
                    rot_vec = torch.zeros_like(q_t_norm[..., :3], device=device)
                    # ！修改6：同修改5，避免设备不兼容
                    mask = ~torch.isclose(sin_theta, rot_data.new_zeros(1))
                    
                    if torch.any(mask):
                        factor = theta / (2 * sin_theta)
                        rot_vec[mask] = factor[mask, None] * torch.stack([
                            delta_rot[mask, 2, 1] - delta_rot[mask, 1, 2],
                            delta_rot[mask, 0, 2] - delta_rot[mask, 2, 0],
                            delta_rot[mask, 1, 0] - delta_rot[mask, 0, 1]
                        ], dim=1)
                    
                    # 角速度计算
                    angular_vel[:, :-1][valid_pairs_mask] = rot_vec / self.time_delta
                
                # 最后一个时间步复制前一个值
                if seq_len > 1:
                    angular_vel[:, -1] = angular_vel[:, -2]
                
                return angular_vel
            
            def calculate_angular_distance(self, rot_data, cumulative=False):
                """计算角距离（设备动态匹配）"""
                bs, seq_len, _ = rot_data.shape
                device = rot_data.device  # ！修改7：获取输入设备
                angular_dist = torch.zeros(bs, seq_len, 1, device=device)  # 零张量指定设备
                
                if seq_len < 2:
                    return angular_dist
                
                # 四元数归一化
                quat_norms = torch.norm(rot_data, dim=-1)
                q1 = rot_data[:, :-1]
                q2 = rot_data[:, 1:]
                q1_norms = quat_norms[:, :-1].unsqueeze(-1)
                q2_norms = quat_norms[:, 1:].unsqueeze(-1)
                q1_norm = q1 / q1_norms.clamp(min=1e-8)
                q2_norm = q2 / q2_norms.clamp(min=1e-8)
                
                # 相对旋转与角距离计算
                r1 = self.quat_to_rot_matrix(q1_norm)
                r2 = self.quat_to_rot_matrix(q2_norm)
                relative_rotation = self.rot_matrix_multiply(self.rot_matrix_inverse(r1), r2)
                
                trace = relative_rotation[..., 0, 0] + relative_rotation[..., 1, 1] + relative_rotation[..., 2, 2]
                theta = torch.acos(torch.clamp((trace - 1) / 2, -1.0, 1.0))
                angular_dist[:, :-1, 0] = theta
                
                # 最后一个时间步处理
                if seq_len > 1:
                    angular_dist[:, -1] = angular_dist[:, -2] if cumulative else 0.0
                
                # 累积角距离（设备一致）
                if cumulative:
                    angular_dist = torch.cumsum(angular_dist, dim=1)
                
                return angular_dist
            
            def forward(self, x):
                """前向传播（自动适配CUDA/CPU输入）"""
                # 分离加速度（前3维）和四元数（后4维）
                acc_data = x[..., :3]
                rot_data = x[..., 3:]
                
                # 计算三大特征（设备自动匹配输入）
                linear_accel = self.remove_gravity_from_acc(acc_data, rot_data)
                angular_vel = self.calculate_angular_velocity_from_quat(rot_data)
                angular_dist = self.calculate_angular_distance(rot_data)
                
                # 拼接输出（设备与输入一致）
                features = torch.cat([linear_accel, angular_vel, angular_dist], dim=-1)
                return features
        

        
        class ImuFeatureExtractor(nn.Module):
            def __init__(self, ):
                super().__init__()
                k = 15
                self.lpf_all = nn.Conv1d(7, 7, kernel_size=7, padding=7//2, groups=7, bias=False)
        
                self.lpf_acc  = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
                self.lpf_gyro = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
        
                self.lpf_acc2  = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
                self.lpf_gyro2 = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
        
            def forward(self, imu):
                # imu: 
                B, C, T = imu.shape
                acc  = imu[:, 0:3, :]                 # acc_x, acc_y, acc_z
                gyro = imu[:, 3:6, :]                 # gyro_x, gyro_y, gyro_z
                extra = imu[:, 6:, :]
        
                # 1) magnitude
                acc_mag  = torch.norm(acc,  dim=1, p=2, keepdim=True)          # (B,1,T)
                gyro_mag = torch.norm(gyro, dim=1, p=2, keepdim=True)
        
                # 1.2) magnitude
                acc2 = acc/acc_mag.clip(1e-12)
                gyro2 = gyro/gyro_mag.clip(1e-12)
        
        
                acc_lpf2  = self.lpf_acc2(acc2)
                acc_hpf2  = acc2 - acc_lpf2
                gyro_lpf2 = self.lpf_gyro2(gyro2)
                gyro_hpf2 = gyro2 - gyro_lpf2
                
        
                # 1.3) magnitude
                acc_mag2  = torch.norm(acc,  dim=1, p=1, keepdim=True)          # (B,1,T)
                gyro_mag2 = torch.norm(gyro, dim=1, p=1, keepdim=True)
        
                # 1.4) magnitude
                acc3 = acc/acc_mag2.clip(1e-12)
                gyro3 = gyro/gyro_mag2.clip(1e-12)
        
                # 2) jerk 
                jerk = F.pad(acc[:, :, 1:] - acc[:, :, :-1], (1,0))       # (B,3,T)
                gyro_delta = F.pad(gyro[:, :, 1:] - gyro[:, :, :-1], (1,0))
        
                # 2) jerk level2
                jerk2 = F.pad(acc[:, :, 2:] + acc[:, :, :-2] - acc[:, :, 1:-1] * 2, (1,1))       # (B,3,T)
                gyro_delta2 = F.pad(gyro[:, :, 2:] + gyro[:, :, :-2] - gyro[:, :, 1:-1] * 2, (1,1))
        
                # 3) energy
                acc_pow  = acc ** 2
                gyro_pow = gyro ** 2
        
                # 4) LPF / HPF 
                acc_lpf  = self.lpf_acc(acc)
                acc_hpf  = acc - acc_lpf
                gyro_lpf = self.lpf_gyro(gyro)
                gyro_hpf = gyro - gyro_lpf
        
        
                imu_hpf = imu - self.lpf_all(imu)
        
                features = [
                    acc, gyro,
                    
                    # acc2, gyro2,
                    acc_mag, gyro_mag,
                    
                    acc3, gyro3,
                    acc_mag2, gyro_mag2,
                    
                    jerk, gyro_delta,
                    jerk2, gyro_delta2, 
                    
                    acc_pow, gyro_pow,
                    
                    acc_lpf, acc_hpf,
                    gyro_lpf, gyro_hpf,
        
                    acc_lpf2, acc_hpf2,
                    gyro_lpf2, gyro_hpf2,
        
                    imu_hpf,
        
                    extra, imu.argsort(-1).argsort(-1)/128,
                    
                ]
        
                return torch.cat(features, dim=1)  # (B, C_out, T)
                
        class SEBlock(nn.Module):
            def __init__(self, channels, reduction=8):
                super().__init__()
                self.squeeze = nn.AdaptiveAvgPool1d(1)
                self.excitation = nn.Sequential(
                    nn.Linear(channels, channels // reduction, bias=False),
                    nn.ReLU(inplace=True),
                    nn.Linear(channels // reduction, channels, bias=False),
                    nn.Sigmoid()
                )
            
            def forward(self, x):
                b, c, _ = x.size()
                y = self.squeeze(x).view(b, c)
                y = self.excitation(y).view(b, c, 1)
                return x * y.expand_as(x)

        class ResidualSECNNBlock(nn.Module):
            def __init__(self, in_channels, out_channels, kernel_size, pool_size=2, dropout=0.3, weight_decay=1e-4):
                super().__init__()
                
                # First conv block
                self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2, bias=False)
                self.bn1 = nn.BatchNorm1d(out_channels)
                
                # Second conv block
                self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2, bias=False)
                self.bn2 = nn.BatchNorm1d(out_channels)
                
                # SE block
                self.se = SEBlock(out_channels)
                
                # Shortcut connection
                self.shortcut = nn.Sequential()
                if in_channels != out_channels:
                    self.shortcut = nn.Sequential(
                        nn.Conv1d(in_channels, out_channels, 1, bias=False),
                        nn.BatchNorm1d(out_channels)
                    )
                self.pool_size = pool_size
                self.pool = nn.MaxPool1d(pool_size)
                self.dropout = nn.Dropout(dropout)
        
                self.act = nn.ReLU()

                self.gru = nn.GRU(out_channels, out_channels//2, batch_first=True, bidirectional=True)
                
            def forward(self, x):
                shortcut = self.shortcut(x)
                
                # First conv
                out = self.act(self.bn1(self.conv1(x)))
                # Second conv
                out = self.bn2(self.conv2(out))
                
                # SE block
                out = self.se(out)
                
                # Add shortcut
                
                out += shortcut
                
                out = out + self.gru(out.transpose(1, 2))[0].transpose(1, 2)
                out = self.act(out)

                out = self.act(out)
                
                
                # Pool and dropout
                if self.pool_size>1:
                    out = self.pool(out)
                out = self.dropout(out)
                
                return out

        class AttentionLayer(nn.Module):
            def __init__(self, hidden_dim):
                super().__init__()
                self.attention = nn.Linear(hidden_dim, 1)
                
            def forward(self, x):
                # x shape: (batch, seq_len, hidden_dim)
                scores = torch.tanh(self.attention(x))  # (batch, seq_len, 1)
                weights = F.softmax(scores.squeeze(-1), dim=1)  # (batch, seq_len)
                context = torch.sum(x * weights.unsqueeze(-1), dim=1)  # (batch, hidden_dim)
                return context


        def get_gravity_torch(quat):
            g=9.81
             # 四元数旋转
            x = quat[..., 0:1]  # rot_x
            y = quat[..., 1:2]  # rot_y
            z = quat[..., 2:3]  # rot_z
            w = quat[..., 3:4]  # rot_w
            g_x = 2 * g * (x * z - w * y)
            g_y = 2 * g * (y * z + w * x)
            g_z = g * (w**2 - x**2 - y**2 + z**2)
            gravity = torch.concat([g_x, g_y, g_z], -1)
            return gravity

        tof_mode = self.tof_mode
        
        class TwoBranchModel(nn.Module):
            def __init__(self, pad_len, imu_dim_raw, tof_dim, n_classes, dropouts=[0.3, 0.3, 0.3, 0.3, 0.4, 0.5, 0.3], feature_engineering=True, **kwargs):
                super().__init__()
                self.imu_fe1 = ImuFeatureExtractor(**kwargs)
                self.imu_fe2 = ImuFeatureExtractor(**kwargs)
                self.imu_fe3 = ImuFeatureExtractor(**kwargs)
                
                imu_dim = 32 + 1 + 14 + 7 + 6 + 6 + 7 

                self.new_feat_module = MotionFeatureExtractor()
                
                self.imu_dim = imu_dim
                self.tof_dim = tof_dim
        
                self.fir_nchan = 7
                self.thm_nchan = 5
                self.tof_nchan = 5 * (5 + 3 * tof_mode)
        
                weight_decay = 3e-3
        
                
                # IMU deep branch
                self.imu_block1 = ResidualSECNNBlock(imu_dim * 1, 160, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay)
                self.imu_block2 = ResidualSECNNBlock(160, 256, 5, dropout=dropouts[1], pool_size=1, weight_decay=weight_decay)
        
                self.imu_block12 = ResidualSECNNBlock(imu_dim * 1, 160, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay)
                self.imu_block22 = ResidualSECNNBlock(160, 256, 5, dropout=dropouts[1], pool_size=1, weight_decay=weight_decay)

                self.thm_block = nn.Sequential(
                    ResidualSECNNBlock(self.thm_nchan * 1, 32, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay),
                    ResidualSECNNBlock(32, 256, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay)
                )
                
                self.tof_block = nn.Sequential(
                    ResidualSECNNBlock(self.tof_nchan * 1, 64, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay),
                    ResidualSECNNBlock(64, 256, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay)
                )
                
                self.emb_all = 256
        
                self.dropout1d = nn.Dropout1d(0.15)
                
                # BiLSTM
                self.dim_lstm = 512
                self.dim_encoder = self.dim_lstm * 2
        
                self.bilstm = nn.LSTM(self.emb_all, self.dim_lstm, num_layers=2, bidirectional=True, batch_first=True)
            
                self.lstm_dropout = nn.Dropout(dropouts[4])
        
                self.output2 = nn.Linear(self.dim_encoder, 5)
                
                # Dense layers
            
                self.dense1 = nn.Linear(self.dim_encoder * 2, 512, bias=False)
                self.bn_dense1 = nn.BatchNorm1d(512)
                self.drop1 = nn.Dropout(dropouts[5])
                
                self.dense2 = nn.Linear(512, 512, bias=False)
                self.bn_dense2 = nn.BatchNorm1d(512)
                self.drop2 = nn.Dropout(dropouts[6])
        
                self.output3 = nn.Linear(512, 5)
                self.classifier = nn.Linear(512, n_classes * 4)
        
                self.scale = nn.Parameter(torch.ones((1, 256, 1)))
                self.bias = nn.Parameter(torch.zeros((1, 256, 1)))
        
                self.scale2 = nn.Parameter(torch.ones((1, 256, 1)))
                self.bias2 = nn.Parameter(torch.zeros((1, 256, 1)))
        
                self.scale3 = nn.Parameter(torch.ones((1, 512, 1)))
                self.bias3 = nn.Parameter(torch.zeros((1, 512, 1)))
        
                self.scale4 = nn.Parameter(torch.ones((1, 512, 1)))
                self.bias4 = nn.Parameter(torch.zeros((1, 512, 1)))

                self.scale5 = nn.Parameter(torch.ones((1, 512, 1)))
                self.bias5 = nn.Parameter(torch.zeros((1, 512, 1)))

                self.scale6 = nn.Parameter(torch.ones((1, 512, 1)))
                self.bias6 = nn.Parameter(torch.zeros((1, 512, 1)))
        
                self.act = nn.GELU()
                
            def forward(self, x, ):
                # Split input
                mask_all = (x.abs().max(-1)[0]!=0).float()[:,:,None]
                mask_all_trans = mask_all.transpose(1, 2)
                
                imu = x[:, :, :self.fir_nchan] # (batch, imu_dim, seq_len)
                imu2 = self.new_feat_module(imu)
                imu2 = imu2[:,:,:7]

                imu = imu.transpose(1, 2)
                imu2 = imu2.transpose(1, 2)
                
                thm = x[:, :, self.fir_nchan:self.fir_nchan + self.thm_nchan].transpose(1, 2)  # (batch, thm_dim, seq_len)
                tof = x[:, :, self.fir_nchan + self.thm_nchan:].transpose(1, 2)  # (batch, tof_dim, seq_len)
        
                imu = self.imu_fe1(imu)   # (B, imu_dim, T)
                imu2 = self.imu_fe2(imu2)   # (B, imu_dim, T)
                
                imu = (imu * self.scale[:,:imu.shape[1],:] + self.bias[:,:imu.shape[1],:]) * mask_all.transpose(1, 2)
                imu2 = (imu2 * self.scale2[:,:imu2.shape[1],:] + self.bias2[:,:imu2.shape[1],:]) * mask_all.transpose(1, 2)
                
                thm = (thm * self.scale5[:,:thm.shape[1],:]  + self.bias5[:,:thm.shape[1],:] ) * mask_all.transpose(1, 2)
                tof = (tof * self.scale6[:,:tof.shape[1],:]  + self.bias6[:,:tof.shape[1],:] ) * mask_all.transpose(1, 2)
        
                thm = self.dropout1d(thm)
                tof = self.dropout1d(tof)
                imu = self.dropout1d(imu)
                imu2 = self.dropout1d(imu2)
                
                thm = self.thm_block(thm)
                tof = self.tof_block(tof)
        
                
                # IMU branch
                x1 = self.imu_block1(imu)
                x1 = self.imu_block2(x1)
        
                x12 = self.imu_block12(imu2)
                x12 = self.imu_block22(x12)
                
                x1 = (x1 + x12 + thm + tof)
        
                merged = x1.transpose(1, 2)  # (batch, seq_len, 256)
                
                # BiLSTM
                lstm_out, _ = self.bilstm(merged)
        
                logits2 = self.output2(self.lstm_dropout(lstm_out))
                
                attended = torch.cat(((lstm_out).max(1)[0], (lstm_out).mean(1), ), -1)
                attended = self.lstm_dropout(attended)
                
                
                # Dense layers
                x = self.act(self.bn_dense1(self.dense1(attended)))
                x = self.drop1(x)
                x = self.act(self.bn_dense2(self.dense2(x)))
                x = self.drop2(x)
                
                # Classification
                logits = (self.classifier(x))
                logits3 = (self.output3(x))
        
                if not self.training:
                    return logits.reshape(logits.shape[0], 4, 18).mean(1), logits3
                
                return logits, logits2, logits3


        return TwoBranchModel(param_a, param_b, param_c, param_d)

    def preprocess_sequence(self, df_seq: pd.DataFrame, feature_cols: list, scaler: StandardScaler):
        """Normalizes and cleans the time series sequence"""
        mat = df_seq[feature_cols].ffill().bfill().fillna(0).values
        return (mat).astype('float32')
        
    def pad_sequences_torch(self, sequences, maxlen, padding='pre', truncating='pre', value=0.0):
        """PyTorch equivalent of Keras pad_sequences"""
        result = []
        for seq in sequences:
            if len(seq) >= maxlen:
                if truncating == 'post':
                    seq = seq[:maxlen]
                else:  # 'pre'
                    seq = seq[-maxlen:]
            else:
                pad_len = maxlen - len(seq)
                if padding == 'post':
                    seq = np.concatenate([seq, np.full((pad_len, seq.shape[1]), value)])
                else:  # 'pre'
                    seq = np.concatenate([np.full((pad_len, seq.shape[1]), value), seq])
            result.append(seq)
        return np.array(result, dtype=np.float32)
        
    def preprocess_left_handed(self, l_tr):
        def handle_quaternion_missing_values(rot_data: np.ndarray) -> np.ndarray:
            """
            Handle missing values in quaternion data intelligently
            
            Key insight: Quaternions must have unit length |q| = 1
            If one component is missing, we can reconstruct it from the others
            """
            rot_cleaned = rot_data.copy()
            
            for i in range(len(rot_data)):
                row = rot_data[i]
                missing_count = np.isnan(row).sum()
                
                if missing_count == 0:
                    # No missing values, normalize to unit quaternion
                    norm = np.linalg.norm(row)
                    if norm > 1e-8:
                        rot_cleaned[i] = row / norm
                    else:
                        rot_cleaned[i] = [0.0, 0.0, 0.0, 0.0]  # Identity quaternion
                        
                elif missing_count == 1:
                    # One missing value, reconstruct using unit quaternion constraint
                    # |w|² + |x|² + |y|² + |z|² = 1
                    missing_idx = np.where(np.isnan(row))[0][0]
                    valid_values = row[~np.isnan(row)]
                    
                    sum_squares = np.sum(valid_values**2)
                    if sum_squares <= 1.0:
                        missing_value = np.sqrt(max(0, 1.0 - sum_squares))
                        # Choose sign for continuity with previous quaternion
                        if i > 0 and not np.isnan(rot_cleaned[i-1, missing_idx]):
                            if rot_cleaned[i-1, missing_idx] < 0:
                                missing_value = -missing_value
                        rot_cleaned[i, missing_idx] = missing_value
                        rot_cleaned[i, ~np.isnan(row)] = valid_values
                    else:
                        rot_cleaned[i] = [0.0, 0.0, 0.0, 0.0]
                else:
                    # More than one missing value, use identity quaternion
                    rot_cleaned[i] = [0.0, 0.0, 0.0, 0.0]
            
            return rot_cleaned
    
        rot_cleaned = handle_quaternion_missing_values(l_tr[["rot_w","rot_x", "rot_y", "rot_z"]].to_numpy())
        rot_scipy = rot_cleaned[:, [1, 2, 3, 0]]
        
        norms = np.linalg.norm(rot_scipy, axis=1)
        if np.any(norms < 1e-8):
            # Replace problematic quaternions with identity
            mask = norms < 1e-8
            rot_scipy[mask] = [0.0, 0.0, 0.0, 1.0]  # Identity quaternion in scipy format
        
        r = R.from_quat(rot_scipy)
        tmp = r.as_euler("xyz")
        tmp[:,1] = - tmp[:,1]
        tmp[:,2] = - tmp[:,2] 
        r = R.from_euler("xyz", tmp)
        tmp = r.as_quat()
        
        if np.any(norms < 1e-8):
            mask = norms < 1e-8
            tmp[mask] = [0.0, 0.0, 0.0, 0.0]
        
        l_tr = l_tr.with_columns(pl.DataFrame(tmp, schema=["rot_x", "rot_y", "rot_z", "rot_w"]))
        l_tr = l_tr.with_columns(-pl.col("acc_x"))
        
        tmp = l_tr[["thm_3", "thm_5"]]
        tmp.columns = ["thm_5", "thm_3"]
        l_tr = l_tr.with_columns(tmp)
        
        swap_1_2_4_base = [[0,7],[1,6],[2,5],[3,4], [4,3], [5,2],[6,1],[7,0]]
        swap_3_5_base = [[0,56],[8,48],[16,40], [24,32],[32,24], [40,16],[48,8], [56,0]]
        
        swap_1_2_4 = list()
        for i in range(0,64,8):
            ll = list()
            for (k,l) in swap_1_2_4_base:
                ll.append([k+i, l+i])
            swap_1_2_4 += ll
        
        swap_3_5 = list()
        for i in range(8):
            ll = list()
            for (k,l) in swap_3_5_base:
                ll.append([k+i, l+i])
            swap_3_5 += ll
        
        l_df = l_tr
        
        for (k,l) in zip(["tof_3_v" + str(x) for x in range(64)], ["tof_5_v" + str(x) for x in range(64)]):
            l_tr = l_tr.with_columns(l_df[k].alias(l))
        
        for (k,l) in zip(["tof_3_v" + str(x) for x in range(64)], ["tof_5_v" + str(x) for x in range(64)]):
            l_tr = l_tr.with_columns(l_df[l].alias(k))
        
        l_df = l_tr
        
        for i in [1,2,4]:
            for (k, l) in swap_1_2_4:
                l_tr = l_tr.with_columns(l_df["tof_" + str(i) + "_v"+str(k)].alias("tof_" + str(i) + "_v"+str(l)))
        
        for i in [3,5]:
            for (k, l) in swap_3_5:
                l_tr = l_tr.with_columns(l_df["tof_" + str(i) + "_v"+str(k)].alias("tof_" + str(i) + "_v"+str(l)))
        return l_tr
        
    def train(self, ):
        class CMI3Dataset(Dataset):
            def __init__(self,
                         X_list,
                         y_list, y_list2, y_list3,
                         maxlen,
                         mode="train",
                         imu_dim=7,
                         augment=None,
                         epoch_multiplier=1,
                        ):
                self.X_list = X_list
                self.mode = mode
                self.y_list = y_list
                self.y_list2 = y_list2
                self.y_list3 = y_list3
                self.maxlen = maxlen
                self.imu_dim = imu_dim     
                self.augment = augment
                self.epoch_multiplier = epoch_multiplier
                
            def __getitem__(self, index):
                X = self.X_list[index//self.epoch_multiplier]
                y = self.y_list[index//self.epoch_multiplier]
                y2 = self.y_list2[index//self.epoch_multiplier]
                y3 = self.y_list3[index//self.epoch_multiplier]

                if self.mode=='train':
                        X_ = np.concatenate([X, y2], -1)
        
                        X_ = transforms_custom(X_)
        
                        X = X_[:,:X.shape[-1]]
                        y2 = X_[:,X.shape[-1]:]
    
                        X[:,3:7] = X[:,3:7]/(np.sqrt((X[:,3:7]**2).sum(-1))+1e-6)[:,None]
            
                return X, y, y2, y3
            
            def __len__(self):
                return len(self.X_list) * self.epoch_multiplier

        class EMA:
            def __init__(self, model, decay=0.999):
                self.decay = decay
                self.shadow = {}
                self.backup = {}
        
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        self.shadow[name] = param.data.clone()
        
            def update(self, model):
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        assert name in self.shadow
                        new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                        self.shadow[name] = new_average.clone()
        
            def apply_shadow(self, model):
                self.backup = {}
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        self.backup[name] = param.data.clone()
                        param.data = self.shadow[name]
        
            def restore(self, model):
                for name, param in model.named_parameters():
                    if param.requires_grad and name in self.backup:
                        param.data = self.backup[name]
                self.backup = {}

        def set_seed(seed: int = 42):
            import numpy as np
            
            random.seed(seed)
        
            os.environ['PYTHONHASHSEED'] = str(seed)
        
            np.random.seed(seed)
        
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed) 
            # torch.backends.cudnn.deterministic = True
            # torch.backends.cudnn.benchmark = False
            # torch.use_deterministic_algorithms(True)
        
        set_seed(self.SEED)

        class RDROPLoss(nn.Module):
            """
            RDROP损失函数实现
            结合原始损失和KL散度约束
            """
            def __init__(self, alpha=0.5):
                super(RDROPLoss, self).__init__()
                self.alpha = alpha  # KL损失的权重系数
                self.kl_div = nn.KLDivLoss(reduction='batchmean')
        
            def forward(self, logits1, logits2):
                # KL散度约束：两次次输出分布的一致性
                p1 = F.log_softmax(logits1, dim=1)
                p2 = F.softmax(logits2, dim=1)
                kl_loss1 = self.kl_div(p1, p2)
                
                p2 = F.log_softmax(logits2, dim=1)
                p1 = F.softmax(logits1, dim=1)
                kl_loss2 = self.kl_div(p2, p1)
                
                # 总损失 = 分类损失 + alpha * KL散度损失
                total_loss = self.alpha * (kl_loss1 + kl_loss2)
                return total_loss
                
        import numpy as np
        import pandas as pd
        from scipy.spatial.transform import Rotation as R
        from tqdm.auto import tqdm 
        
        print("▶ TRAIN MODE – loading dataset …")
        df = pd.read_csv(self.RAW_DIR / "train.csv")
        self.demo = pd.read_csv(self.RAW_DIR / "train_demographics.csv")



        
        group_list = []
        for _, group in tqdm(df.groupby('sequence_id')):
            index_old = group.index
            if self.demo[self.demo['subject']==group['subject'].values[0]]['handedness'].values[0] ==0:
                ###################################################################################
                group = self.preprocess_left_handed(pl.from_pandas(group)).to_pandas()
                group.index = index_old
                
                group_list.append(group)
        df.update(pd.concat(group_list))


        
    
        dict_ = dict(zip(df['subject'].unique(), list(range(df['subject'].nunique()))))
        df['subject_id_new'] = df['subject'].map(dict_)
    
        #df = df[~df['subject'].isin({'SUBJ_045235', 'SUBJ_019262'})]
    
        dict_behavior = {'Relaxes and moves hand to target location': 1,
                         'Hand at target location': 2,
                         'Performs gesture': 3,
                         'Moves hand to target location': 4}
        
        dict_orientation = {
                            'Lie on Back': 0,
                            'Lie on Side - Non Dominant': 1,
                            'Seated Lean Non Dom - FACE DOWN': 2,
                            'Seated Straight': 3,
                           }
    
        df['bid'] = df['behavior'].map(dict_behavior)
        df['oid'] = df['orientation'].map(dict_orientation)
    
        df, new_feat_cols, new_feat_cols2 = self.get_new_features(df)
    
        # Label encoding
        le = LabelEncoder()
        df['gesture_int'] = le.fit_transform(df['gesture'])


        df['gesture_int'] = df['gesture_int'] + df['oid'] * 18

        
        np.save(self.EXPORT_DIR / "gesture_classes.npy", le.classes_)
    
        # Feature list
        meta_cols = {'gesture', 'gesture_int', 'sequence_type', 'behavior', 'orientation',
                     'row_id', 'subject', 'phase', 'sequence_id', 'sequence_counter', 'subject_id_new'}
        feature_cols = [c for c in df.columns if c not in meta_cols]
        imu_cols = [c for c in feature_cols if not (c.startswith('thm_') or c.startswith('tof_'))]
        # tof_cols = [c for c in feature_cols if c.startswith('thm_') or c.startswith('tof_')]
    
        thm_cols = [c for c in feature_cols if c.startswith('thm_')]
        print(thm_cols)
        
        tof_cols = [] # [c for c in feature_cols if c.startswith('thm_') or c.startswith('tof_')]
    
        imu_cols = ['acc_x', 'acc_y', 'acc_z', 'rot_x', 'rot_y', 'rot_z', 'rot_w',]
        
        feature_cols = imu_cols + new_feat_cols + thm_cols + new_feat_cols2
        
        print(f"  IMU {len(imu_cols)+len(new_feat_cols)} | THM {len(thm_cols)} | TOF {len(new_feat_cols2)}  | total {len(feature_cols)} features")
    
        # Save feature_cols
        np.save(self.EXPORT_DIR / "feature_cols.npy", np.array(feature_cols))
        pad_len = self.PAD_PERCENTILE
        
        # Group sequences
        seq_gp = df.groupby('sequence_id')
        X_list_raw, y_list, id_list, subject_list, y2list, y3list = [], [], [], [], [], []
        for seq_id, seq in seq_gp:
            mat = seq[feature_cols].ffill().bfill().fillna(0).values
            X_list_raw.append(mat)
            y_list.append(seq['gesture_int'].iloc[0])
            id_list.append(seq_id)
            subject_list.append(seq['subject_id_new'].iloc[0])
    
            if len(seq) < pad_len:
                bid = np.zeros(pad_len)
                bid[-len(seq):] = seq['bid'].values.ravel()
            else:
                bid = seq['bid'].values.ravel()[-pad_len:]
            
            y2list.append(bid)
            y3list.append(seq['oid'].iloc[0])
    
        pad_len = self.PAD_PERCENTILE
        np.save(self.EXPORT_DIR / "sequence_maxlen.npy", pad_len)
    
        id_list = np.array(id_list)
        y_list_all = np.eye(4 * len(le.classes_))[y_list].astype(np.float32)  # one-hot
    
        y_list_all2 = np.vstack(y2list).astype(int)
        y_list_all2 = (np.eye(5)[y_list_all2.reshape(-1, 1)]).reshape(-1, pad_len, 5).astype(np.float32)
        y_list_all3 = np.eye(5)[y3list].astype(np.float32)
    
        augmenter = None
        metrics = []
    
        criterion_rdrop = RDROPLoss(alpha=0.5)

        from sklearn.model_selection import GroupKFold
        gkf = GroupKFold(
                         n_splits=self.FOLDS, 
                         shuffle=True, 
                         random_state=self.random_state
                        )
    
        def clipped_cross_entropy(logits, y, clipval=0.6, num=18):
            return -torch.sum(F.log_softmax(logits, dim=1) * y.clip((1-clipval)/num, clipval), dim=1).mean() 

            # return 1 - (logits.softmax(-1) * y).mean()
            
        idlistall = []
        targetfinalall = []
        predimuonlyall = []
        predallfeatall = []
        predorientationall = []
        targetsorientations = []
        foldlistall = []
        for fold, (train_idx, val_idx) in enumerate(gkf.split(id_list, id_list, groups=subject_list)):
            
        # skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=self.random_state)
        # for fold, (train_idx, val_idx) in enumerate(skf.split(id_list, np.argmax(y_list_all, axis=1))):  
            
            if fold not in self.TRAIN_FOLDS:
                continue
    
            print(f"\n▶ Fold {fold}")
            
            X_list_scaled = [x for x in X_list_raw]
            X_list_all = self.pad_sequences_torch(X_list_scaled, maxlen=pad_len, padding='pre', truncating='pre')
    
            # Prepare train/val sets
            train_list = X_list_all[train_idx]
            train_y_list = y_list_all[train_idx]
            train_y_list2 = y_list_all2[train_idx]
            train_y_list3 = y_list_all3[train_idx]
    
            
            val_list = X_list_all[val_idx]
            val_y_list = y_list_all[val_idx]
    
            val_y_list2 = y_list_all2[val_idx]
            val_y_list3 = y_list_all3[val_idx]

            id_list_valid = id_list[val_idx]

            print(len(id_list_valid))
            idlistall.append(id_list_valid)
            foldlistall.append([fold for _ in range(len(id_list_valid))])
    
            # Data loaders
            train_dataset = CMI3Dataset(train_list, train_y_list, train_y_list2, train_y_list3, pad_len, mode="train", imu_dim=len(imu_cols),
                                        augment=augmenter)
            train_loader = DataLoader(train_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=6, drop_last=False, pin_memory=True)
            
            val_dataset = CMI3Dataset(val_list, val_y_list, val_y_list2, val_y_list3, pad_len, mode="val")
            val_loader = DataLoader(val_dataset, batch_size=self.BATCH_SIZE, shuffle=False, num_workers=6, drop_last=False, pin_memory=True)

            device = self.device
            
            # Model & EMA 
            model = self.create_model(pad_len, len(imu_cols), len(tof_cols), len(le.classes_)).to(device)
            model2 = self.create_model(pad_len, len(imu_cols), len(tof_cols), len(le.classes_)).to(device)

            if self.VALIDATION==True:
                checkpoint = torch.load(self.PRETRAINED_DIR / f'gesture_two_branch_fold{fold}.pth', map_location=device)
                model.load_state_dict({k.replace('_orig_mod.', ''):v for k, v in checkpoint['model_state_dict'].items()})
            
            ema = EMA(model, decay=0.998)
    
            # Optimizer
            hidden_weights = [p for p in model.parameters() if p.ndim >= 2 and p.requires_grad]
            hidden_gains_biases = [p for p in model.parameters() if p.ndim < 2 and p.requires_grad]
            param_groups = [
                dict(params=hidden_weights, use_muon=True, lr=0.005, weight_decay=self.WD),
                dict(params=hidden_gains_biases, use_muon=False, lr=self.LR_INIT, betas=(0.9, 0.95), weight_decay=self.WD),
            ]
            optimizer = SingleDeviceMuonWithAuxAdam(param_groups)

            hidden_weights2 = [p for p in model2.parameters() if p.ndim >= 2 and p.requires_grad]
            hidden_gains_biases2 = [p for p in model2.parameters() if p.ndim < 2 and p.requires_grad]
            param_groups2 = [
                dict(params=hidden_weights2, use_muon=True, lr=0.005, weight_decay=self.WD),
                dict(params=hidden_gains_biases2, use_muon=False, lr=self.LR_INIT, betas=(0.9, 0.95), weight_decay=self.WD),
            ]
            optimizer2 = SingleDeviceMuonWithAuxAdam(param_groups2)
    
            nbatch = len(train_loader)
            nsteps = self.EPOCHS * nbatch
    
            scheduler = ConstantCosineLR(optimizer, nsteps, 0.)    
            print("▶ Starting training...")
    
            best_val_acc = 0
            for epoch in tqdm(range(self.EPOCHS)):
                model.train()
                model2.train()
                
                train_preds, train_targets = [], []
                train_loss = 0.0
    
                for X, y, y2, y3 in train_loader:
                    if self.VALIDATION==True:
                        break
                        
                    BS = X.shape[0]
                    X0 = X.clone()

                    X[BS//8:,:,7:] = 0.

                    if np.random.random() > 0.5:
                        X[:BS//8*2,:,3:7] = 0.
                    else:
                        X[:BS//8*2,:,:3] = 0.
                    
                    X, y = X.float().to(device), y.to(device)
                    y2, y3 = y2.float().to(device), y3.to(device)
                    X0 = X0.float().to(device)
                    
                    optimizer.zero_grad()
                    optimizer2.zero_grad()
                    
                    LEN = 1200 + np.random.randint(80)
                    
                    logits, logits2, logits3 = model(X[:,-LEN:])
                    
                    logits_, logits2_, logits3_ = model2(X0[:,-LEN:])
    
                    clipval = 0.6
                    loss = clipped_cross_entropy(logits, y, clipval, 18) 
                    loss += clipped_cross_entropy(logits2, y2[:,-LEN:], clipval, 5)  * 0.5
                    loss += clipped_cross_entropy(logits3, y3, clipval, 5)  * 0.5

                    clipval = 0.6
                    loss2 = clipped_cross_entropy(logits_, y, clipval, 18) 
                    loss2 += clipped_cross_entropy(logits2_, y2[:,-LEN:], clipval, 5)  * 0.5
                    loss2 += clipped_cross_entropy(logits3_, y3, clipval, 5)  * 0.5

                    if epoch > 40:
                        loss = loss + loss2 + criterion_rdrop(logits, logits_)
                    else:
                        loss = loss + loss2 
                    
                    loss.backward()
                    
                    optimizer.step()
                    optimizer2.step()
                    
                    ema.update(model)
    
                    train_preds.append(logits.argmax(dim=1).cpu().numpy())
                    train_targets.append(y.argmax(dim=1).cpu().numpy())
    
                    scheduler.step()
                    train_loss += loss.item()

                model2.eval()
                model.eval()
                ema.apply_shadow(model)
                
                val_loss = 0.0
                val_preds_accmask, val_targets_accmask = [], []
                val_preds_imuonly, val_targets_imuonly = [], []
                val_preds_rotmask, val_targets_rotmask = [], []

                val_preds_mother, val_targets_rotmask = [], []

                val_preds_logits = []
                val_preds_imuonly_logits = []

                val_preds_orientation, val_targets_orientation = [], []
                val_preds_orientation_logits = []
                
                
                with torch.inference_mode():
                    for X, y, y2, y3 in val_loader:
                        X_imuonly = X[:].clone()
                        X_imuonly[:, :, 7:] = 0.0

                        X_accmask = X[:].clone()
                        X_accmask[:, :, 7:] = 0.0
                        X_accmask[:, :, 0:3] = 0.0

                        X_rotmask = X[:].clone()
                        X_rotmask[:, :, 7:] = 0.0
                        X_rotmask[:, :, 3:7] = 0.0
                    
                        X, y = X.float().to(device), y.to(device)
                        X_imuonly = X_imuonly.float().to(device)
                        X_accmask = X_accmask.float().to(device)
                        X_rotmask = X_rotmask.float().to(device)

                        logits_mother, logits_orientation = model(X)
                        
                        logits_rotmask, _  = model(X_rotmask)
                        logits_accmask, _  = model(X_accmask)
                        logits_imuonly, _  = model(X_imuonly)

                        val_preds_mother.append(logits_mother.argmax(dim=1).cpu().numpy())
                        val_preds_accmask.append(logits_accmask.argmax(dim=1).cpu().numpy())
                        val_preds_rotmask.append(logits_rotmask.argmax(dim=1).cpu().numpy())
                        val_preds_imuonly.append(logits_imuonly.argmax(dim=1).cpu().numpy())
                        
                        val_preds_logits.append(logits_mother.cpu().numpy())
                        val_preds_imuonly_logits.append(logits_imuonly.cpu().numpy())
                        
                        val_targets_imuonly.append(y.reshape(y.shape[0], 4, 18).sum(1).argmax(dim=1).cpu().numpy())


                        val_preds_orientation_logits.append(logits_orientation.cpu().numpy())
                        val_preds_orientation.append(logits_orientation.argmax(dim=1).cpu().numpy())
                        val_targets_orientation.append(y3.argmax(dim=1).cpu().numpy())
    
    
                if len(train_targets) >= 0:
                    train_acc = 0.
                    try:
                        train_acc = CompetitionMetric().calculate_hierarchical_f1(
                            pd.DataFrame({'gesture': le.classes_[np.concatenate(train_targets)]}),
                            pd.DataFrame({'gesture': le.classes_[np.concatenate(train_preds)]})
                        )
                    except:
                        pass
                    val_acc_rotmask = CompetitionMetric().calculate_hierarchical_f1(
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_targets_imuonly)]}),
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_preds_rotmask)]})
                    )
                    val_acc_imuonly = CompetitionMetric().calculate_hierarchical_f1(
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_targets_imuonly)]}),
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_preds_imuonly)]})
                    )
                    val_acc_accmask = CompetitionMetric().calculate_hierarchical_f1(
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_targets_imuonly)]}),
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_preds_accmask)]})
                    )
                    val_acc_mother = CompetitionMetric().calculate_hierarchical_f1(
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_targets_imuonly)]}),
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_preds_mother)]})
                    )

                    val_acc_orientation = np.mean(
                        np.concatenate(val_preds_orientation)==np.concatenate(val_targets_orientation)
                    )
                    
                    train_loss = np.mean(train_loss)
                    print('epoch', epoch, 'loss : ', round(train_loss, 4), '| TRAIN : ', round(train_acc, 4), '| ORIENTATION: ', round(val_acc_orientation, 4), '| IMUONLY : ', round(val_acc_imuonly, 4), '| ROTMASK : ', round(val_acc_rotmask, 4),  '| ACCMASK : ', round(val_acc_accmask, 4),  '| MOTHER : ', round(val_acc_mother, 4), '| LR : ', optimizer.param_groups[0]['lr'])
    
                    metric = val_acc_orientation # val_acc
                    
                    if metric > best_val_acc:
                        best_val_acc = metric
                        # Save model
                        torch.save({
                            'model_state_dict': model.state_dict(),
                            'imu_dim': len(imu_cols),
                            'tof_dim': len(tof_cols),
                            'n_classes': len(le.classes_),
                            'pad_len': pad_len,
                            
                        }, self.EXPORT_DIR / f"gesture_two_branch_fold{fold}.pth")
                        print(f"fold: {fold} val_all_acc: {metric:.4f}")
                        print("✔ Training done – artefacts saved in", self.EXPORT_DIR)
                    
                    ema.restore(model)

                targetfinalall.append(np.concatenate(val_targets_imuonly))
                predimuonlyall.append(np.concatenate(val_preds_imuonly_logits, 0))
                predallfeatall.append(np.concatenate(val_preds_logits, 0))
                predorientationall.append(np.concatenate(val_preds_orientation_logits, 0))
                targetsorientations.append(np.concatenate(val_targets_orientation))
                if self.VALIDATION:
                    break
                
            
            ema.apply_shadow(model)
            metrics.append(best_val_acc)

        print(metrics, sum(metrics)/len(metrics))

        import joblib
        name = str(self.PRETRAINED_DIR).replace('/', '_')
        joblib.dump(
            {
                'guestures': le.classes_, 
                'pred_all': np.concatenate(predallfeatall, 0), 
                'pred_imuonly': np.concatenate(predimuonlyall, 0), 
                'targets': np.concatenate(targetfinalall), 
                'idlist': np.concatenate(idlistall),
                'fold': np.concatenate(foldlistall),
                'pred_orientation' : np.concatenate(predorientationall, 0),
                'target_orientation' : np.concatenate(targetsorientations), 
            }, 
                    f"oof_{name}.joblib", 
                   )

        return le.classes_, np.concatenate(predallfeatall, 0), np.concatenate(predimuonlyall, 0), np.concatenate(targetfinalall), np.concatenate(idlistall), np.concatenate(predorientationall, 0), np.concatenate(targetsorientations)
        
        
    def get_new_features(self, df):
        feature_names = []
        
        for col in ['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']:
            df[col] = df[col] / 5.                                                                                                                                             
        
        new_columns = {} 
        for i in range(1, 6):
            pixel_cols = [f"tof_{i}_v{p}" for p in range(64)]
            
            new_columns.update({
                f'tof_{i}_isna_mean1': (df[pixel_cols]==-1).mean(axis=1),
                f'tof_{i}_isna_mean2': 1 - (df[pixel_cols].isna()).mean(axis=1),
            })
    
            df[pixel_cols] = 6 - df[pixel_cols].replace(-1, 255.) / 50. 
            
            tof_data = df[pixel_cols]
    
            new_columns.update({
                f'tof_{i}_mean': tof_data.mean(axis=1),
                f'tof_{i}_std': tof_data.std(axis=1),
                # f'tof_{i}_min': tof_data.min(axis=1),
                f'tof_{i}_max': tof_data.max(axis=1),
                
                # f'tof_{i}_median_norm': tof_data.median(axis=1)/tof_data.mean(axis=1).clip(1),
                # f'tof_{i}_max_norm': tof_data.min(axis=1)/tof_data.mean(axis=1).clip(1),
                # f'tof_{i}_min_norm': tof_data.max(axis=1)/tof_data.mean(axis=1).clip(1),
            })
            if self.tof_mode > 1:
                region_size = 64 // self.tof_mode
                for r in (range(self.tof_mode)):
                    region_data = tof_data.iloc[:, r*region_size : (r+1)*region_size]
                    new_columns.update({
                        f'tof{self.tof_mode}_{i}_region_{r}_mean': region_data.mean(axis=1),
                        f'tof{self.tof_mode}_{i}_region_{r}_std': region_data.std(axis=1),
                        # f'tof{self.tof_mode}_{i}_region_{r}_min': region_data.min(axis=1),
                        f'tof{self.tof_mode}_{i}_region_{r}_max': region_data.max(axis=1),
                    })
                    
        df_tof = pd.DataFrame(new_columns)
        df = pd.concat([df, df_tof], axis=1)
    
        
        return df, feature_names, list(df_tof.columns)
        

    def get_model(self, ):    
        print("▶ INFERENCE MODE – loading artefacts from", self.PRETRAINED_DIR)
        self.feature_cols = np.load(self.PRETRAINED_DIR / "feature_cols.npy", allow_pickle=True).tolist()
        self.pad_len = int(np.load(self.PRETRAINED_DIR / "sequence_maxlen.npy"))
        self.gesture_classes = np.load(self.PRETRAINED_DIR / "gesture_classes.npy", allow_pickle=True)
    
        self.imu_cols = [c for c in self.feature_cols if not (c.startswith('thm_') or c.startswith('tof_'))]
        self.tof_cols = [c for c in self.feature_cols if c.startswith('thm_') or c.startswith('tof_')]
        # Load model
        MODELS = [f'gesture_two_branch_fold{i}.pth' for i in range(self.FOLDS) if i in self.TRAIN_FOLDS]
        
        self.models = []
        for path in MODELS:
            checkpoint = torch.load(self.PRETRAINED_DIR / path, map_location=self.device)
            
            model = self.create_model(
                checkpoint['pad_len'], 
                checkpoint['imu_dim'], 
                checkpoint['tof_dim'], 
                checkpoint['n_classes']
                ).to(self.device)
    
            
            
            model.load_state_dict({k.replace('_orig_mod.', ''):v for k, v in checkpoint['model_state_dict'].items()})
            model.eval()
            self.models.append(model)
    
        print("  model, scaler, pads loaded – ready for evaluation")
    
        return 
    
    def predict(self, sequence: pl.DataFrame, demographics: pl.DataFrame, nanratio = 0):         
        if demographics[0, "handedness"] == 0: 
            sequence = self.preprocess_left_handed(sequence)
        
        df_seq = sequence.to_pandas()
        df_seq, _, _ = self.get_new_features(df_seq)
        
        with torch.no_grad():
            outputs = None
            outputs2 = None
            
            mat = self.preprocess_sequence(df_seq, self.feature_cols, None)
            pad = self.pad_sequences_torch([mat], maxlen=self.pad_len, padding='pre', truncating='pre')
            x = torch.FloatTensor(pad).float().to(self.device)

            if nanratio==1:
                x[:,:,7:] = 0.
            
            for model in self.models:
                model.eval()
                p, p2 = model(x)
                if outputs is None: 
                    outputs = p
                    outputs2 = p2
                else: 
                    outputs += p
                    outputs2 += p2
                    
            outputs /= len(self.models)
            outputs2 /= len(self.models)

        return self.gesture_classes, (outputs.cpu().numpy(), outputs2.cpu().numpy()[:,:4])

In [None]:
# for local training, prepare all the data under ./cmi-detect-behavior-with-sensor-data/ folder

model_zhou_inference_v26_1 = model_zhou_v26('/kaggle/input/cmi3v90', 42, './42')

model_zhou_inference_v26_1.TRAIN_FOLDS = [0, 1, 2, 3, 4, ]
model_zhou_inference_v26_1.VALIDATION = False
model_zhou_inference_v26_1.train()

In [None]:
# for seed in [
#              # 42,
#              6665252, 
#              # 88885252, 
#              #12345252, 
#              #325252,
#              #885252,
#              #99995252,
#              #115252,
#             ]:
#     model_zhou_inference_v0 = model_zhou_v15(f'./{seed}', seed=seed, save_path=f'./{seed}')
    
#     model_zhou_inference_v0.TRAIN_FOLDS = [0, 1, 2, 3, 4]

#     model_zhou_inference_v0.VALIDATION = False
    
#     model_zhou_inference_v0.train()
