In [None]:
# !pip install --no-deps /kaggle/input/scikit-learn-1-6-1/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

#!/usr/bin/env python
# coding: utf-8

# # Importable version of https://www.kaggle.com/code/metric/cmi-2025

# In[ ]:


"""
Hierarchical macro F1 metric for the CMI 2025 Challenge.

This script defines a single entry point `score(solution, submission, row_id_column_name)`
that the Kaggle metrics orchestrator will call.
It performs validation on submission IDs and computes a combined binary & multiclass F1 score.
"""

import pandas as pd
from sklearn.metrics import f1_score


class ParticipantVisibleError(Exception):
    """Errors raised here will be shown directly to the competitor."""
    pass


class CompetitionMetric:
    """Hierarchical macro F1 for the CMI 2025 challenge."""
    def __init__(self):
        self.target_gestures = [
            'Above ear - pull hair',
            'Cheek - pinch skin',
            'Eyebrow - pull hair',
            'Eyelash - pull hair',
            'Forehead - pull hairline',
            'Forehead - scratch',
            'Neck - pinch skin',
            'Neck - scratch',
        ]
        self.non_target_gestures = [
            'Write name on leg',
            'Wave hello',
            'Glasses on/off',
            'Text on phone',
            'Write name in air',
            'Feel around in tray and pull out an object',
            'Scratch knee/leg skin',
            'Pull air toward your face',
            'Drink from bottle/cup',
            'Pinch knee/leg skin'
        ]
        self.all_classes = self.target_gestures + self.non_target_gestures

    def calculate_hierarchical_f1(
        self,
        sol: pd.DataFrame,
        sub: pd.DataFrame
    ) -> float:

        # Validate gestures
        invalid_types = {i for i in sub['gesture'].unique() if i not in self.all_classes}
        if invalid_types:
            raise ParticipantVisibleError(
                f"Invalid gesture values in submission: {invalid_types}"
            )

        # Compute binary F1 (Target vs Non-Target)
        y_true_bin = sol['gesture'].isin(self.target_gestures).values
        y_pred_bin = sub['gesture'].isin(self.target_gestures).values
        f1_binary = f1_score(
            y_true_bin,
            y_pred_bin,
            pos_label=True,
            zero_division=0,
            average='binary'
        )

        # Build multi-class labels for gestures
        y_true_mc = sol['gesture'].apply(lambda x: x if x in self.target_gestures else 'non_target')
        y_pred_mc = sub['gesture'].apply(lambda x: x if x in self.target_gestures else 'non_target')

        # Compute macro F1 over all gesture classes
        f1_macro = f1_score(
            y_true_mc,
            y_pred_mc,
            average='macro',
            zero_division=0
        )

        return 0.5 * f1_binary + 0.5 * f1_macro


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str
) -> float:
    """
    Compute hierarchical macro F1 for the CMI 2025 challenge.

    Expected input:
      - solution and submission as pandas.DataFrame
      - Column 'sequence_id': unique identifier for each sequence
      - 'gesture': one of the eight target gestures or "Non-Target"

    This metric averages:
    1. Binary F1 on SequenceType (Target vs Non-Target)
    2. Macro F1 on gesture (mapping non-targets to "Non-Target")

    Raises ParticipantVisibleError for invalid submissions,
    including invalid SequenceType or gesture values.


    Examples
    --------
    >>> import pandas as pd
    >>> row_id_column_name = "id"
    >>> solution = pd.DataFrame({'id': range(4), 'gesture': ['Eyebrow - pull hair']*4})
    >>> submission = pd.DataFrame({'id': range(4), 'gesture': ['Forehead - pull hairline']*4})
    >>> score(solution, submission, row_id_column_name=row_id_column_name)
    0.5
    >>> submission = pd.DataFrame({'id': range(4), 'gesture': ['Text on phone']*4})
    >>> score(solution, submission, row_id_column_name=row_id_column_name)
    0.0
    >>> score(solution, solution, row_id_column_name=row_id_column_name)
    1.0
    """
    # Validate required columns
    for col in (row_id_column_name, 'gesture'):
        if col not in solution.columns:
            raise ParticipantVisibleError(f"Solution file missing required column: '{col}'")
        if col not in submission.columns:
            raise ParticipantVisibleError(f"Submission file missing required column: '{col}'")

    metric = CompetitionMetric()
    return metric.calculate_hierarchical_f1(solution, submission)


import torch
import torch.distributed as dist


def zeropower_via_newtonschulz5(G, steps: int):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X
    
    if G.size(-2) > G.size(-1):
        X = X.mT
    return X


def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
    momentum.lerp_(grad, 1 - beta)
    update = grad.lerp_(momentum, beta) if nesterov else momentum
    if update.ndim == 4: # for the case of conv filters
        update = update.view(len(update), -1)
    update = zeropower_via_newtonschulz5(update, steps=ns_steps)
    update *= max(1, grad.size(-2) / grad.size(-1))**0.5
    return update


class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    https://kellerjordan.github.io/posts/muon/

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
    advantage that it can be stably run in bfloat16 on the GPU.

    Muon should only be used for hidden weight layers. The input embedding, final output layer,
    and any internal gains or biases should be optimized using a standard method such as AdamW.
    Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
    collapsing their last 3 dimensions.

    Arguments:
        lr: The learning rate, in units of spectral norm per update.
        weight_decay: The AdamW-style weight decay.
        momentum: The momentum. A value of 0.95 here is usually fine.
    """
    def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
        assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
        params = sorted(params, key=lambda x: x.size(), reverse=True)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params = group["params"]
            params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
            for base_i in range(len(params))[::dist.get_world_size()]:
                if base_i + dist.get_rank() < len(params):
                    p = params[base_i + dist.get_rank()]
                    state = self.state[p]
                    if len(state) == 0:
                        state["momentum_buffer"] = torch.zeros_like(p)
                    update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update.reshape(p.shape), alpha=-group["lr"])
                dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])

        return loss


class SingleDeviceMuon(torch.optim.Optimizer):
    """
    Muon variant for usage in non-distributed settings.
    """
    def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if len(state) == 0:
                    state["momentum_buffer"] = torch.zeros_like(p)
                update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                p.mul_(1 - group["lr"] * group["weight_decay"])
                p.add_(update.reshape(p.shape), alpha=-group["lr"])

        return loss


def adam_update(grad, buf1, buf2, step, betas, eps):
    buf1.lerp_(grad, 1 - betas[0])
    buf2.lerp_(grad.square(), 1 - betas[1])
    buf1c = buf1 / (1 - betas[0]**step)
    buf2c = buf2 / (1 - betas[1]**step)
    return buf1c / (buf2c.sqrt() + eps)


class MuonWithAuxAdam(torch.optim.Optimizer):
    """
    Distributed Muon variant that can be used for all parameters in the network, since it runs an
    internal AdamW for the parameters that are not compatible with Muon. The user must manually
    specify which parameters shall be optimized with Muon and which with Adam by passing in a
    list of param_groups with the `use_muon` flag set.

    The point of this class is to allow the user to have a single optimizer in their code, rather
    than having both a Muon and an Adam which each need to be stepped.

    You can see an example usage below:

    https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
    ```
    hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
    embed_params = [p for n, p in model.named_parameters() if "embed" in n]
    scalar_params = [p for p in model.parameters() if p.ndim < 2]
    head_params = [model.lm_head.weight]

    from muon import MuonWithAuxAdam
    adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
    adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
    muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
    param_groups = [*adam_groups, muon_group]
    optimizer = MuonWithAuxAdam(param_groups)
    ```
    """
    def __init__(self, param_groups):
        for group in param_groups:
            assert "use_muon" in group
            if group["use_muon"]:
                group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
                # defaults
                group["lr"] = group.get("lr", 0.02)
                group["momentum"] = group.get("momentum", 0.95)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
            else:
                # defaults
                group["lr"] = group.get("lr", 3e-4)
                group["betas"] = group.get("betas", (0.9, 0.95))
                group["eps"] = group.get("eps", 1e-10)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
        super().__init__(param_groups, dict())

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if group["use_muon"]:
                params = group["params"]
                params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
                for base_i in range(len(params))[::dist.get_world_size()]:
                    if base_i + dist.get_rank() < len(params):
                        p = params[base_i + dist.get_rank()]
                        state = self.state[p]
                        if len(state) == 0:
                            state["momentum_buffer"] = torch.zeros_like(p)
                        update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                        p.mul_(1 - group["lr"] * group["weight_decay"])
                        p.add_(update.reshape(p.shape), alpha=-group["lr"])
                    dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
            else:
                for p in group["params"]:
                    state = self.state[p]
                    if len(state) == 0:
                        state["exp_avg"] = torch.zeros_like(p)
                        state["exp_avg_sq"] = torch.zeros_like(p)
                        state["step"] = 0
                    state["step"] += 1
                    update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
                                         state["step"], group["betas"], group["eps"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update, alpha=-group["lr"])

        return loss


class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
    """
    Non-distributed variant of MuonWithAuxAdam.
    """
    def __init__(self, param_groups):
        for group in param_groups:
            assert "use_muon" in group
            if group["use_muon"]:
                # defaults
                group["lr"] = group.get("lr", 0.02)
                group["momentum"] = group.get("momentum", 0.95)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
            else:
                # defaults
                group["lr"] = group.get("lr", 3e-4)
                group["betas"] = group.get("betas", (0.9, 0.95))
                group["eps"] = group.get("eps", 1e-10)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
        super().__init__(param_groups, dict())

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if group["use_muon"]:
                for p in group["params"]:
                    if p.grad is None:
                        continue
                        
                    state = self.state[p]
                    if len(state) == 0:
                        state["momentum_buffer"] = torch.zeros_like(p)
                    update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update.reshape(p.shape), alpha=-group["lr"])
            else:
                for p in group["params"]:
                    if p.grad is None:
                        continue
                    state = self.state[p]
                    if len(state) == 0:
                        state["exp_avg"] = torch.zeros_like(p)
                        state["exp_avg_sq"] = torch.zeros_like(p)
                        state["step"] = 0
                    state["step"] += 1
                    update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
                                         state["step"], group["betas"], group["eps"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update, alpha=-group["lr"])

        return loss

from torch.optim.lr_scheduler import _LRScheduler
class ConstantCosineLR(_LRScheduler):
    """
    Constant learning rate followed by CosineAnnealing.
    """
    def __init__(
        self, 
        optimizer,
        total_steps, 
        pct_cosine, 
        last_epoch=-1,
        ):
        self.total_steps = total_steps
        self.milestone = int(total_steps * (pct_cosine))
        self.cosine_steps = max(total_steps - self.milestone, 1)
        self.min_lr = 0
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        step = self.last_epoch + 1
        if step <= self.milestone:
            factor = 1.
        else:
            s = step - self.milestone
            factor = 0.5 * (1 + math.cos(math.pi * s / self.cosine_steps))
        return [lr * factor for lr in self.base_lrs]

In [None]:
import os, json, joblib, numpy as np, pandas as pd
import random, math
from pathlib import Path
import warnings 
warnings.filterwarnings("ignore")

import sys
sys.path.append('/root/autodl-tmp/')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import StratifiedKFold
from timm.scheduler import CosineLRScheduler
from scipy.signal import firwin
import polars as pl
import numpy as np
import random


import numpy as np
import pandas as pd
from scipy.spatial.transform import Rotation as R
from tqdm.auto import tqdm 




if os.path.exists("../input/cmi-detect-behavior-with-sensor-data"):
    test_path1 = '/kaggle/input/cmi-detect-behavior-with-sensor-data/test.csv'
    test_path2 = '/kaggle/input/cmi-detect-behavior-with-sensor-data/test_demographics.csv'
else:
    if os.path.exists("./cmi-detect-behavior-with-sensor-data/"):
        test_path1 = './cmi-detect-behavior-with-sensor-data/test.csv'
        test_path2 = './cmi-detect-behavior-with-sensor-data/test_demographics.csv'
    else:
        test_path1 = './test.csv'
        test_path2 = './test_demographics.csv'

class model_zhou_v1:
    def __init__(self, kaggle_input_path="/kaggle/input/cmi3v23"):
        self.models = []

        import os
        if os.path.exists("../input/cmi-detect-behavior-with-sensor-data"):
            self.TRAIN = False                     
            self.RAW_DIR = Path("../input/cmi-detect-behavior-with-sensor-data")
            self.PRETRAINED_DIR = Path(kaggle_input_path) 
            self.EXPORT_DIR = Path("./")                                   
        else:
            if os.path.exists("./cmi-detect-behavior-with-sensor-data/"):
                self.TRAIN = True                     
                self.RAW_DIR = Path("./cmi-detect-behavior-with-sensor-data/")
                self.PRETRAINED_DIR = Path(kaggle_input_path) 
                self.EXPORT_DIR = Path("./")                                  
            else:
                self.TRAIN = True                    
                self.RAW_DIR = Path("./")
                self.EXPORT_DIR = Path("./")
                self.PRETRAINED_DIR = Path(kaggle_input_path) 

        self.VALIDATION = False
        self.SEED = 42

        self.BATCH_SIZE = 64 * 1
        self.PAD_PERCENTILE = 128
        self.maxlen = self.PAD_PERCENTILE
        self.LR_INIT = 1e-3
        self.WD = 2e-1
        
        self.PATIENCE = 40
        self.random_state = self.SEED
        
        self.tof_mode = 4
        
        self.FOLDS = 5
        self.TRAIN_FOLDS = [0, 1, 2, 3, 4]
        
        self.EPOCHS = 160

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"▶ imports ready · pytorch {torch.__version__} · device: {self.device}")
    
    def create_model(self, param_a, param_b, param_c, param_d):
        class ImuFeatureExtractor(nn.Module):
            def __init__(self, ):
                super().__init__()
                k = 15
                self.lpf_all = nn.Conv1d(7, 7, kernel_size=7, padding=7//2, groups=7, bias=False)
        
                self.lpf_acc  = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
                self.lpf_gyro = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
        
                self.lpf_acc2  = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
                self.lpf_gyro2 = nn.Conv1d(3, 3, k, padding=k//2, groups=3, bias=False)
        
            def forward(self, imu):
                # imu: 
                B, C, T = imu.shape
                acc  = imu[:, 0:3, :]                 # acc_x, acc_y, acc_z
                gyro = imu[:, 3:6, :]                 # gyro_x, gyro_y, gyro_z
                extra = imu[:, 6:, :]
        
                # 1) magnitude
                acc_mag  = torch.norm(acc,  dim=1, p=2, keepdim=True)          # (B,1,T)
                gyro_mag = torch.norm(gyro, dim=1, p=2, keepdim=True)
        
                # 1.2) magnitude
                acc2 = acc/acc_mag.clip(1e-12)
                gyro2 = gyro/gyro_mag.clip(1e-12)
        
        
                acc_lpf2  = self.lpf_acc2(acc2)
                acc_hpf2  = acc2 - acc_lpf2
                gyro_lpf2 = self.lpf_gyro2(gyro2)
                gyro_hpf2 = gyro2 - gyro_lpf2
                
        
                # 1.3) magnitude
                acc_mag2  = torch.norm(acc,  dim=1, p=1, keepdim=True)          # (B,1,T)
                gyro_mag2 = torch.norm(gyro, dim=1, p=1, keepdim=True)
        
                # 1.4) magnitude
                acc3 = acc/acc_mag2.clip(1e-12)
                gyro3 = gyro/gyro_mag2.clip(1e-12)
        
                # 2) jerk 
                jerk = F.pad(acc[:, :, 1:] - acc[:, :, :-1], (1,0))       # (B,3,T)
                gyro_delta = F.pad(gyro[:, :, 1:] - gyro[:, :, :-1], (1,0))
        
                # 2) jerk level2
                jerk2 = F.pad(acc[:, :, 2:] + acc[:, :, :-2] - acc[:, :, 1:-1] * 2, (1,1))       # (B,3,T)
                gyro_delta2 = F.pad(gyro[:, :, 2:] + gyro[:, :, :-2] - gyro[:, :, 1:-1] * 2, (1,1))
        
                # 3) energy
                acc_pow  = acc ** 2
                gyro_pow = gyro ** 2
        
                # 4) LPF / HPF 
                acc_lpf  = self.lpf_acc(acc)
                acc_hpf  = acc - acc_lpf
                gyro_lpf = self.lpf_gyro(gyro)
                gyro_hpf = gyro - gyro_lpf
        
        
                imu_hpf = imu - self.lpf_all(imu)
        
                features = [
                    acc, gyro,
                    
                    # acc2, gyro2,
                    acc_mag, gyro_mag,
                    
                    acc3, gyro3,
                    acc_mag2, gyro_mag2,
                    
                    jerk, gyro_delta,
                    jerk2, gyro_delta2, 
                    
                    acc_pow, gyro_pow,
                    
                    acc_lpf, acc_hpf,
                    gyro_lpf, gyro_hpf,
        
                    acc_lpf2, acc_hpf2,
                    gyro_lpf2, gyro_hpf2,
        
                    imu_hpf,
        
                    extra, 
                    
                ]
        
                return torch.cat(features, dim=1)  # (B, C_out, T)
                
        class SEBlock(nn.Module):
            def __init__(self, channels, reduction=8):
                super().__init__()
                self.squeeze = nn.AdaptiveAvgPool1d(1)
                self.excitation = nn.Sequential(
                    nn.Linear(channels, channels // reduction, bias=False),
                    nn.ReLU(inplace=True),
                    nn.Linear(channels // reduction, channels, bias=False),
                    nn.Sigmoid()
                )
            
            def forward(self, x):
                b, c, _ = x.size()
                y = self.squeeze(x).view(b, c)
                y = self.excitation(y).view(b, c, 1)
                return x * y.expand_as(x)

        class ResidualSECNNBlock(nn.Module):
            def __init__(self, in_channels, out_channels, kernel_size, pool_size=2, dropout=0.3, weight_decay=1e-4):
                super().__init__()
                
                # First conv block
                self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2, bias=False)
                self.bn1 = nn.BatchNorm1d(out_channels)
                
                # Second conv block
                self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2, bias=False)
                self.bn2 = nn.BatchNorm1d(out_channels)
                
                # SE block
                self.se = SEBlock(out_channels)
                
                # Shortcut connection
                self.shortcut = nn.Sequential()
                if in_channels != out_channels:
                    self.shortcut = nn.Sequential(
                        nn.Conv1d(in_channels, out_channels, 1, bias=False),
                        nn.BatchNorm1d(out_channels)
                    )
                self.pool_size = pool_size
                self.pool = nn.MaxPool1d(pool_size)
                self.dropout = nn.Dropout(dropout)
        
                self.act = nn.ReLU()
                
            def forward(self, x):
                shortcut = self.shortcut(x)
                
                # First conv
                out = self.act(self.bn1(self.conv1(x)))
                # Second conv
                out = self.bn2(self.conv2(out))
                
                # SE block
                out = self.se(out)
                
                # Add shortcut
                out += shortcut
                out = self.act(out)
                
                # Pool and dropout
                if self.pool_size>1:
                    out = self.pool(out)
                out = self.dropout(out)
                
                return out

        class AttentionLayer(nn.Module):
            def __init__(self, hidden_dim):
                super().__init__()
                self.attention = nn.Linear(hidden_dim, 1)
                
            def forward(self, x):
                # x shape: (batch, seq_len, hidden_dim)
                scores = torch.tanh(self.attention(x))  # (batch, seq_len, 1)
                weights = F.softmax(scores.squeeze(-1), dim=1)  # (batch, seq_len)
                context = torch.sum(x * weights.unsqueeze(-1), dim=1)  # (batch, hidden_dim)
                return context



        tof_mode = self.tof_mode
        
        class TwoBranchModel(nn.Module):
            def __init__(self, pad_len, imu_dim_raw, tof_dim, n_classes, dropouts=[0.3, 0.3, 0.3, 0.3, 0.4, 0.5, 0.3], feature_engineering=True, **kwargs):
                super().__init__()
                self.feature_engineering = feature_engineering
                if feature_engineering:
                    self.imu_fe = ImuFeatureExtractor(**kwargs)
                    imu_dim = 32 + 1 + 14 + 7 + 6 + 6
                else:
                    self.imu_fe = nn.Identity()
                    imu_dim = imu_dim_raw   
                
                self.imu_dim = imu_dim
                self.tof_dim = tof_dim
        
                self.fir_nchan = 7 + 7
                self.thm_nchan = 5
                self.tof_nchan = 5 * (9 + 4 * tof_mode)
        
                weight_decay = 3e-3
        
                
                # IMU deep branch
                self.imu_block1 = ResidualSECNNBlock(imu_dim * 1, 160, 3, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay)
                self.imu_block2 = ResidualSECNNBlock(160, 256, 3, dropout=dropouts[1], pool_size=1, weight_decay=weight_decay)
        
                self.imu_block12 = ResidualSECNNBlock(imu_dim * 1, 160, 3, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay)
                self.imu_block22 = ResidualSECNNBlock(160, 256, 3, dropout=dropouts[1], pool_size=1, weight_decay=weight_decay)
        
                self.thm_block = nn.Sequential(
                    ResidualSECNNBlock(self.thm_nchan * 1, 32, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay),
                    ResidualSECNNBlock(32, 32, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay)
                )
                
                self.tof_block = nn.Sequential(
                    ResidualSECNNBlock(self.tof_nchan * 1, 64, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay),
                    ResidualSECNNBlock(64, 64, 5, dropout=dropouts[0], pool_size=1, weight_decay=weight_decay)
                )
        
                self.emb_all = 256 + 256 + 32 + 64
        
                self.dropout1d = nn.Dropout1d(0.15)
                
                # BiLSTM
                self.dim_lstm = 256
                self.dim_encoder = self.dim_lstm * 2
        
                self.bilstm = nn.LSTM(self.emb_all, self.dim_lstm, num_layers=1, bidirectional=True, batch_first=True)
            
                self.lstm_dropout = nn.Dropout(dropouts[4])
        
                self.output2 = nn.Linear(self.dim_encoder, 5)
                
                # Dense layers
            
                self.dense1 = nn.Linear(self.dim_encoder * 2, 512, bias=False)
                self.bn_dense1 = nn.BatchNorm1d(512)
                self.drop1 = nn.Dropout(dropouts[5])
                
                self.dense2 = nn.Linear(512, 256, bias=False)
                self.bn_dense2 = nn.BatchNorm1d(256)
                self.drop2 = nn.Dropout(dropouts[6])
        
                self.output3 = nn.Linear(256, 5)
                self.classifier = nn.Linear(256, n_classes)
        
                self.scale = nn.Parameter(torch.ones((1, 256, 1)))
                self.bias = nn.Parameter(torch.zeros((1, 256, 1)))
        
                self.scale2 = nn.Parameter(torch.ones((1, 256, 1)))
                self.bias2 = nn.Parameter(torch.zeros((1, 256, 1)))
        
                self.scale3 = nn.Parameter(torch.ones((1, 256, 1)))
                self.bias3 = nn.Parameter(torch.zeros((1, 256, 1)))
        
                self.scale4 = nn.Parameter(torch.ones((1, 512, 1)))
                self.bias4 = nn.Parameter(torch.zeros((1, 512, 1)))
        
                self.act = nn.GELU()
                
            def forward(self, x, ):
                # Split input
        
                mask_all = (x.abs().max(-1)[0]!=0).float()[:,:,None]
                mask_all_trans = mask_all.transpose(1, 2)
                
                imu = x[:, :, :self.fir_nchan//2].transpose(1, 2)  # (batch, imu_dim, seq_len)
                imu2 = x[:, :, self.fir_nchan//2:self.fir_nchan].transpose(1, 2)
                
                thm = x[:, :, self.fir_nchan:self.fir_nchan + self.thm_nchan].transpose(1, 2)  # (batch, thm_dim, seq_len)
                tof = x[:, :, self.fir_nchan + self.thm_nchan:].transpose(1, 2)  # (batch, tof_dim, seq_len)
        
                imu = self.imu_fe(imu)   # (B, imu_dim, T)
                imu2 = self.imu_fe(imu2)   # (B, imu_dim, T)
                
                imu = (imu * self.scale[:,:imu.shape[1],:] + self.bias[:,:imu.shape[1],:]) * mask_all.transpose(1, 2)
                imu2 = (imu2 * self.scale2[:,:imu2.shape[1],:] + self.bias2[:,:imu2.shape[1],:]) * mask_all.transpose(1, 2)
                
                thm = (thm * self.scale3[:,:thm.shape[1],:] + self.bias3[:,:thm.shape[1],:]) * mask_all.transpose(1, 2)
                tof = (tof * self.scale4[:,:tof.shape[1],:] + self.bias4[:,:tof.shape[1],:]) * mask_all.transpose(1, 2)
                
        
                thm = self.dropout1d(thm)
                tof = self.dropout1d(tof)
                imu = self.dropout1d(imu)
                imu2 = self.dropout1d(imu2)
                
        
                thm = self.thm_block(thm)
                tof = self.tof_block(tof)
        
                
                # IMU branch
                x1 = self.imu_block1(imu)
                x1 = self.imu_block2(x1)
        
                x12 = self.imu_block12(imu2)
                x12 = self.imu_block22(x12)
        
                x1 = torch.cat((x1, x12, thm, tof), 1)
        
                merged = x1.transpose(1, 2)  # (batch, seq_len, 256)
                
                # BiLSTM
                lstm_out, _ = self.bilstm(merged)
        
                logits2 = self.output2(self.lstm_dropout(lstm_out))
                
                attended = torch.cat(((lstm_out).max(1)[0], (lstm_out).mean(1), ), -1)
                attended = self.lstm_dropout(attended)
                
                # Dense layers
                x = self.act(self.bn_dense1(self.dense1(attended)))
                x = self.drop1(x)
                x = self.act(self.bn_dense2(self.dense2(x)))
                x = self.drop2(x)
                
                # Classification
                logits = (self.classifier(x))
                logits3 = (self.output3(x))
        
                if not self.training:
                    return logits
                
                return logits, logits2, logits3


        return TwoBranchModel(param_a, param_b, param_c, param_d)

    def preprocess_sequence(self, df_seq: pd.DataFrame, feature_cols: list, scaler: StandardScaler):
        """Normalizes and cleans the time series sequence"""
        mat = df_seq[feature_cols].ffill().bfill().fillna(0).values
        return (mat).astype('float32')
        
    def pad_sequences_torch(self, sequences, maxlen, padding='pre', truncating='pre', value=0.0):
        """PyTorch equivalent of Keras pad_sequences"""
        result = []
        for seq in sequences:
            if len(seq) >= maxlen:
                if truncating == 'post':
                    seq = seq[:maxlen]
                else:  # 'pre'
                    seq = seq[-maxlen:]
            else:
                pad_len = maxlen - len(seq)
                if padding == 'post':
                    seq = np.concatenate([seq, np.full((pad_len, seq.shape[1]), value)])
                else:  # 'pre'
                    seq = np.concatenate([np.full((pad_len, seq.shape[1]), value), seq])
            result.append(seq)
        return np.array(result, dtype=np.float32)
        
    
    def train(self, ):
        class CMI3Dataset(Dataset):
            def __init__(self,
                         X_list,
                         y_list, y_list2, y_list3,
                         maxlen,
                         mode="train",
                         imu_dim=7,
                         augment=None,
                         epoch_multiplier=1,
                        ):
                self.X_list = X_list
                self.mode = mode
                self.y_list = y_list
                self.y_list2 = y_list2
                self.y_list3 = y_list3
                self.maxlen = maxlen
                self.imu_dim = imu_dim     
                self.augment = augment
                self.epoch_multiplier = epoch_multiplier
        
            def pad_sequences_torch(self, seq, maxlen, padding='post', truncating='post', value=0.0):
        
                if seq.shape[0] >= maxlen:
                    if truncating == 'post':
                        seq = seq[:maxlen]
                    else:  # 'pre'
                        seq = seq[-maxlen:]
                else:
                    pad_len = maxlen - seq.shape[0]
                    if padding == 'post':
                        seq = np.concatenate([seq, np.full((pad_len, seq.shape[1]), value)])
                    else:  # 'pre'
                        seq = np.concatenate([np.full((pad_len, seq.shape[1]), value), seq])
                return seq  
                
            def __getitem__(self, index):
                X = self.X_list[index//self.epoch_multiplier]
                y = self.y_list[index//self.epoch_multiplier]
                y2 = self.y_list2[index//self.epoch_multiplier]
                y3 = self.y_list3[index//self.epoch_multiplier]
        
                return X, y, y2, y3
            
            def __len__(self):
                return len(self.X_list) * self.epoch_multiplier

        class EMA:
            def __init__(self, model, decay=0.999):
                self.decay = decay
                self.shadow = {}
                self.backup = {}
        
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        self.shadow[name] = param.data.clone()
        
            def update(self, model):
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        assert name in self.shadow
                        new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                        self.shadow[name] = new_average.clone()
        
            def apply_shadow(self, model):
                self.backup = {}
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        self.backup[name] = param.data.clone()
                        param.data = self.shadow[name]
        
            def restore(self, model):
                for name, param in model.named_parameters():
                    if param.requires_grad and name in self.backup:
                        param.data = self.backup[name]
                self.backup = {}

        def set_seed(seed: int = 42):
            import numpy as np
            
            random.seed(seed)
        
            os.environ['PYTHONHASHSEED'] = str(seed)
        
            np.random.seed(seed)
        
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed) 
            # torch.backends.cudnn.deterministic = True
            # torch.backends.cudnn.benchmark = False
            # torch.use_deterministic_algorithms(True)
        
        set_seed(self.SEED)

        class RDROPLoss(nn.Module):
            """
            RDROP损失函数实现
            结合原始损失和KL散度约束
            """
            def __init__(self, alpha=0.5):
                super(RDROPLoss, self).__init__()
                self.alpha = alpha  # KL损失的权重系数
                self.kl_div = nn.KLDivLoss(reduction='batchmean')
        
            def forward(self, logits1, logits2):
                # KL散度约束：两次次输出分布的一致性
                p1 = F.log_softmax(logits1, dim=1)
                p2 = F.softmax(logits2, dim=1)
                kl_loss1 = self.kl_div(p1, p2)
                
                p2 = F.log_softmax(logits2, dim=1)
                p1 = F.softmax(logits1, dim=1)
                kl_loss2 = self.kl_div(p2, p1)
                
                # 总损失 = 分类损失 + alpha * KL散度损失
                total_loss = self.alpha * (kl_loss1 + kl_loss2)
                return total_loss
                
        import numpy as np
        import pandas as pd
        from scipy.spatial.transform import Rotation as R
        from tqdm.auto import tqdm 
        
        print("▶ TRAIN MODE – loading dataset …")
        df = pd.read_csv(self.RAW_DIR / "train.csv")
    
        dict_ = dict(zip(df['subject'].unique(), list(range(df['subject'].nunique()))))
        df['subject_id_new'] = df['subject'].map(dict_)
    
        df = df[~df['subject'].isin({'SUBJ_045235', 'SUBJ_019262'})]
    
        dict_behavior = {'Relaxes and moves hand to target location': 1,
                         'Hand at target location': 2,
                         'Performs gesture': 3,
                         'Moves hand to target location': 4}
        
        dict_orientation = {'Seated Lean Non Dom - FACE DOWN': 1,
                         'Lie on Side - Non Dominant': 2,
                         'Seated Straight': 3,
                         'Lie on Back': 4}
    
        df['bid'] = df['behavior'].map(dict_behavior)
        df['oid'] = df['orientation'].map(dict_orientation)
    
        df, new_feat_cols, new_feat_cols2 = self.get_new_features(df)
    
        # Label encoding
        le = LabelEncoder()
        df['gesture_int'] = le.fit_transform(df['gesture'])
        np.save(self.EXPORT_DIR / "gesture_classes.npy", le.classes_)
    
        # Feature list
        meta_cols = {'gesture', 'gesture_int', 'sequence_type', 'behavior', 'orientation',
                     'row_id', 'subject', 'phase', 'sequence_id', 'sequence_counter', 'subject_id_new'}
        feature_cols = [c for c in df.columns if c not in meta_cols]
        imu_cols = [c for c in feature_cols if not (c.startswith('thm_') or c.startswith('tof_'))]
        # tof_cols = [c for c in feature_cols if c.startswith('thm_') or c.startswith('tof_')]
    
        thm_cols = [c for c in feature_cols if c.startswith('thm_')]
        print(thm_cols)
        
        tof_cols = [] # [c for c in feature_cols if c.startswith('thm_') or c.startswith('tof_')]
    
        imu_cols = ['acc_x', 'acc_y', 'acc_z', 'rot_x', 'rot_y', 'rot_z', 'rot_w',]
        
        feature_cols = imu_cols + new_feat_cols + thm_cols + new_feat_cols2
        
        print(f"  IMU {len(imu_cols)+len(new_feat_cols)} | THM {len(thm_cols)} | TOF {len(new_feat_cols2)}  | total {len(feature_cols)} features")
    
        # Save feature_cols
        np.save(self.EXPORT_DIR / "feature_cols.npy", np.array(feature_cols))
        pad_len = self.PAD_PERCENTILE
        
        # Group sequences
        seq_gp = df.groupby('sequence_id')
        X_list_raw, y_list, id_list, subject_list, y2list, y3list = [], [], [], [], [], []
        for seq_id, seq in seq_gp:
            mat = seq[feature_cols].ffill().bfill().fillna(0).values
            X_list_raw.append(mat)
            y_list.append(seq['gesture_int'].iloc[0])
            id_list.append(seq_id)
            subject_list.append(seq['subject_id_new'].iloc[0])
    
            if len(seq) < pad_len:
                bid = np.zeros(pad_len)
                bid[-len(seq):] = seq['bid'].values.ravel()
            else:
                bid = seq['bid'].values.ravel()[-pad_len:]
            
            y2list.append(bid)
            y3list.append(seq['oid'].iloc[0])
    
        pad_len = self.PAD_PERCENTILE
        np.save(self.EXPORT_DIR / "sequence_maxlen.npy", pad_len)
    
        id_list = np.array(id_list)
        y_list_all = np.eye(len(le.classes_))[y_list].astype(np.float32)  # one-hot
    
        y_list_all2 = np.vstack(y2list).astype(int)
        y_list_all2 = (np.eye(5)[y_list_all2.reshape(-1, 1)]).reshape(-1, pad_len, 5).astype(np.float32)
        y_list_all3 = np.eye(5)[y3list].astype(np.float32)
    
        augmenter = None
        metrics = []
    
        criterion_rdrop = RDROPLoss(alpha=0.5)

        from sklearn.model_selection import GroupKFold
        gkf = GroupKFold(
                         n_splits=self.FOLDS, 
                         shuffle=True, 
                         random_state=self.random_state
                        )
    
        def clipped_cross_entropy(logits, y, clipval=0.6, num=18):
            return -torch.sum(F.log_softmax(logits, dim=1) * y.clip((1-clipval)/num, clipval), dim=1).mean() 
            
        idlistall = []
        targetfinalall = []
        predimuonlyall = []
        predallfeatall = []
        for fold, (train_idx, val_idx) in enumerate(gkf.split(id_list, id_list, groups=subject_list)):
            
        # skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=self.random_state)
        # for fold, (train_idx, val_idx) in enumerate(skf.split(id_list, np.argmax(y_list_all, axis=1))):  
            
            if fold not in self.TRAIN_FOLDS:
                continue
    
            print(f"\n▶ Fold {fold}")
            
            X_list_scaled = [x for x in X_list_raw]
            X_list_all = self.pad_sequences_torch(X_list_scaled, maxlen=pad_len, padding='pre', truncating='pre')
    
            # Prepare train/val sets
            train_list = X_list_all[train_idx]
            train_y_list = y_list_all[train_idx]
            train_y_list2 = y_list_all2[train_idx]
            train_y_list3 = y_list_all3[train_idx]
    
            
            val_list = X_list_all[val_idx]
            val_y_list = y_list_all[val_idx]
    
            val_y_list2 = y_list_all2[val_idx]
            val_y_list3 = y_list_all3[val_idx]

            id_list_valid = id_list[val_idx]
            idlistall.append(id_list_valid)
    
            # Data loaders
            train_dataset = CMI3Dataset(train_list, train_y_list, train_y_list2, train_y_list3, pad_len, mode="train", imu_dim=len(imu_cols),
                                        augment=augmenter)
            train_loader = DataLoader(train_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=6, drop_last=False, pin_memory=True)
            
            val_dataset = CMI3Dataset(val_list, val_y_list, val_y_list2, val_y_list3, pad_len, mode="val")
            val_loader = DataLoader(val_dataset, batch_size=self.BATCH_SIZE, shuffle=False, num_workers=6, drop_last=False, pin_memory=True)

            device = self.device
            
            # Model & EMA
            model = self.create_model(pad_len, len(imu_cols), len(tof_cols), len(le.classes_)).to(device)

            if self.VALIDATION==True:
                checkpoint = torch.load(self.PRETRAINED_DIR / f'gesture_two_branch_fold{fold}.pth', map_location=device)
                model.load_state_dict({k.replace('_orig_mod.', ''):v for k, v in checkpoint['model_state_dict'].items()})
            
            ema = EMA(model, decay=0.998)
    
            # Optimizer
            hidden_weights = [p for p in model.parameters() if p.ndim >= 2 and p.requires_grad]
            hidden_gains_biases = [p for p in model.parameters() if p.ndim < 2 and p.requires_grad]
            param_groups = [
                dict(params=hidden_weights, use_muon=True, lr=0.005, weight_decay=self.WD),
                dict(params=hidden_gains_biases, use_muon=False, lr=self.LR_INIT, betas=(0.9, 0.95), weight_decay=self.WD),
            ]
            optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
    
            nbatch = len(train_loader)
            nsteps = self.EPOCHS * nbatch
    
            scheduler = ConstantCosineLR(optimizer, nsteps, 0.)    
            print("▶ Starting training...")
    
            best_val_acc = 0
            for epoch in tqdm(range(self.EPOCHS)):
                model.train()
                train_preds, train_targets = [], []
                train_loss = 0.0
    
                for X, y, y2, y3 in train_loader:
                    if self.VALIDATION==True:
                        break
                        
                    BS = X.shape[0]
    
                    X_imuonly = X[:].clone()
                    X_imuonly[:, :, 14:] = 0.0
                    
                    X, X_imuonly, y = X.float().to(device), X_imuonly.float().to(device), y.to(device)
                    y2, y3 = y2.float().to(device), y3.to(device)
                    
                    optimizer.zero_grad()
                    
                    logits, logits2, logits3 = model(X)
                    # logits_, logits2_, logits3_ = model(X_imuonly)
    
                    clipval = 0.6
                    loss = clipped_cross_entropy(logits, y, clipval, 18) 
                    loss += clipped_cross_entropy(logits2, y2, clipval, 5)  * 0.4
                    loss += clipped_cross_entropy(logits3, y3, clipval, 5)  * 0.4
    
                    # loss += clipped_cross_entropy(logits_, y, clipval, 18) * 0.1
                    # loss += clipped_cross_entropy(logits2_, y2, clipval, 5)  * 0.4 * 0.1
                    # loss += clipped_cross_entropy(logits3_, y3, clipval, 5)  * 0.4 * 0.1
    
                    # loss += 1 * (criterion_rdrop(logits, logits_) + criterion_rdrop(logits2, logits2_) + criterion_rdrop(logits3, logits3_))
                    
                    loss.backward()
                    
                    optimizer.step()
                    ema.update(model)
    
                    train_preds.append(logits.argmax(dim=1).cpu().numpy())
                    train_targets.append(y.argmax(dim=1).cpu().numpy())
    
                    scheduler.step()
                    train_loss += loss.item()
                
                model.eval()
                ema.apply_shadow(model)
                
                val_loss = 0.0
                val_preds, val_targets = [], []
                val_preds_imuonly, val_targets_imuonly = [], []
                val_preds_all, val_targets_all = [], []

                val_preds_logits = []
                val_preds_imuonly_logits = []

                
                with torch.inference_mode():
                    for X, y, y2, y3 in val_loader:
                        X_imuonly = X[:].clone()
                        X_imuonly[:, :, 14:] = 0.0
                    
                        X, y = X.float().to(device), y.to(device)
                        X_imuonly = X_imuonly.float().to(device)
                        
                        logits = model(X)
                        logits_imuonly = model(X_imuonly)
    
                        val_preds.append(logits.argmax(dim=1).cpu().numpy())
                        val_preds_imuonly.append(logits_imuonly.argmax(dim=1).cpu().numpy())

                        val_preds_logits.append(logits.cpu().numpy())
                        val_preds_imuonly_logits.append(logits_imuonly.cpu().numpy())
                        
                        val_preds_all.append(logits.argmax(dim=1).cpu().numpy())
                        val_preds_all.append(logits_imuonly.argmax(dim=1).cpu().numpy())
                        
                        val_targets.append(y.argmax(dim=1).cpu().numpy())
                        val_targets_imuonly.append(y.argmax(dim=1).cpu().numpy())
    
                        val_targets_all.append(y.argmax(dim=1).cpu().numpy())
                        val_targets_all.append(y.argmax(dim=1).cpu().numpy())
    
                        loss = F.cross_entropy(logits, y)
                        val_loss += loss.item()
    
                if len(train_targets) >= 0:
                    train_acc = 0.
                    try:
                        train_acc = CompetitionMetric().calculate_hierarchical_f1(
                            pd.DataFrame({'gesture': le.classes_[np.concatenate(train_targets)]}),
                            pd.DataFrame({'gesture': le.classes_[np.concatenate(train_preds)]})
                        )
                    except:
                        pass
                    val_acc = CompetitionMetric().calculate_hierarchical_f1(
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_targets)]}),
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_preds)]})
                    )
                    val_acc_imuonly = CompetitionMetric().calculate_hierarchical_f1(
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_targets_imuonly)]}),
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_preds_imuonly)]})
                    )
                    val_acc_split = CompetitionMetric().calculate_hierarchical_f1(
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_targets_all)]}),
                        pd.DataFrame({'gesture': le.classes_[np.concatenate(val_preds_all)]})
                    )

                        
                    train_loss = np.mean(train_loss)
                    val_loss /= len(val_loader)
                    print('epoch', epoch, 'loss : ', round(train_loss, 4), '| TRAIN : ', round(train_acc, 4), '| IMUONLY : ', round(val_acc_imuonly, 4), '| ALL : ', round(val_acc, 4),  '| SPLIT : ', round(val_acc_split, 4), '| LR : ', optimizer.param_groups[0]['lr'])
    
                    metric = val_acc
                    
                    if metric > best_val_acc:
                        best_val_acc = metric
                        # Save model
                        torch.save({
                            'model_state_dict': model.state_dict(),
                            'imu_dim': len(imu_cols),
                            'tof_dim': len(tof_cols),
                            'n_classes': len(le.classes_),
                            'pad_len': pad_len
                        }, self.EXPORT_DIR / f"gesture_two_branch_fold{fold}.pth")
                        print(f"fold: {fold} val_all_acc: {metric:.4f}")
                        print("✔ Training done – artefacts saved in", self.EXPORT_DIR)
                    
                    ema.restore(model)

                targetfinalall.append(np.concatenate(val_targets))
                predimuonlyall.append(np.concatenate(val_preds_imuonly_logits, 0))
                predallfeatall.append(np.concatenate(val_preds_logits, 0))
                
                if self.VALIDATION:
                    break
                
            
            ema.apply_shadow(model)
            metrics.append(best_val_acc)

        print(metrics, sum(metrics)/len(metrics))

        import joblib
        name = str(self.PRETRAINED_DIR).replace('/', '_')
        joblib.dump(
            {
                'guestures': le.classes_, 
                'pred_all': np.concatenate(predallfeatall, 0), 
                'pred_imuonly': np.concatenate(predimuonlyall, 0), 
                'targets': np.concatenate(targetfinalall), 
                'idlist': np.concatenate(idlistall),
            }, 
                    f"oof_{name}.joblib", 
                   )

        return le.classes_, np.concatenate(predallfeatall, 0), np.concatenate(predimuonlyall, 0), np.concatenate(targetfinalall), np.concatenate(idlistall)
        
    def get_new_features(self, df):
        # 优化后的函数
        def remove_gravity_from_acc(acc_data, rot_data):
            if isinstance(acc_data, pd.DataFrame):
                acc_values = acc_data[['acc_x', 'acc_y', 'acc_z']].values
            else:
                acc_values = np.asarray(acc_data)
                
            if isinstance(rot_data, pd.DataFrame):
                quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
            else:
                quat_values = np.asarray(rot_data)
            
            num_samples = acc_values.shape[0]
            linear_accel = np.zeros_like(acc_values)
            gravity_world = np.array([0, 0, 9.81])
            
            quat_norms = np.linalg.norm(quat_values, axis=1)
            valid_mask = ~(np.isnan(quat_norms) | np.isclose(quat_norms, 0))
            
            valid_quats = quat_values[valid_mask]
            valid_quats_normalized = valid_quats / quat_norms[valid_mask, np.newaxis]
            try:
                rotations = R.from_quat(valid_quats_normalized)
                    
                gravity_sensor_frame = rotations.apply(gravity_world, inverse=True)
            
                linear_accel[valid_mask] = acc_values[valid_mask] - gravity_sensor_frame
            except:
                linear_accel[valid_mask] = acc_values[valid_mask]
            linear_accel[~valid_mask] = acc_values[~valid_mask]
            
            return linear_accel
        
        
        def calculate_angular_velocity_from_quat(rot_data, time_delta=1/200):
            if isinstance(rot_data, pd.DataFrame):
                quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
            else:
                quat_values = np.asarray(rot_data)
            
            num_samples = quat_values.shape[0]
            angular_vel = np.zeros((num_samples, 3))
            
            if num_samples < 2:
                return angular_vel
            
            quat_norms = np.linalg.norm(quat_values, axis=1)
            valid_mask = ~(np.isnan(quat_norms) | np.isclose(quat_norms, 0))
            valid_pairs_mask = valid_mask[:-1] & valid_mask[1:]
            
            if np.any(valid_pairs_mask):
                q_t = quat_values[:-1][valid_pairs_mask]
                q_t_plus_dt = quat_values[1:][valid_pairs_mask]
                
                q_t_norms = quat_norms[:-1][valid_pairs_mask]
                q_t_plus_dt_norms = quat_norms[1:][valid_pairs_mask]
                
                q_t_norm = q_t / q_t_norms[:, np.newaxis]
                q_t_plus_dt_norm = q_t_plus_dt / q_t_plus_dt_norms[:, np.newaxis]
                
                rot_t = R.from_quat(q_t_norm)
                rot_t_plus_dt = R.from_quat(q_t_plus_dt_norm)
                delta_rot = rot_t.inv() * rot_t_plus_dt
                
                angular_vel[:-1][valid_pairs_mask] = delta_rot.as_rotvec() / time_delta
            
            angular_vel[-1, :] = angular_vel[-2, :] if num_samples > 1 else 0
                    
            return angular_vel
        
        
        def calculate_angular_distance(rot_data, cumulative=False):
            if isinstance(rot_data, pd.DataFrame):
                quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
            else:
                quat_values = np.asarray(rot_data)
            
            num_samples = quat_values.shape[0]
            angular_dist = np.zeros(num_samples)
            
            if num_samples < 2:
                return angular_dist
            
            quat_norms = np.linalg.norm(quat_values, axis=1)
            valid_mask = ~(np.isnan(quat_norms) | np.isclose(quat_norms, 0))
            valid_pairs_mask = valid_mask[:-1] & valid_mask[1:]
            
            if np.any(valid_pairs_mask):
                q1 = quat_values[:-1][valid_pairs_mask]
                q2 = quat_values[1:][valid_pairs_mask]
                
                q1_norms = quat_norms[:-1][valid_pairs_mask]
                q2_norms = quat_norms[1:][valid_pairs_mask]
                
                q1_norm = q1 / q1_norms[:, np.newaxis]
                q2_norm = q2 / q2_norms[:, np.newaxis]
                
                r1 = R.from_quat(q1_norm)
                r2 = R.from_quat(q2_norm)
                relative_rotation = r1.inv() * r2
                
                angular_dist[:-1][valid_pairs_mask] = np.linalg.norm(relative_rotation.as_rotvec(), axis=1)
            
            if num_samples > 1:
                angular_dist[-1] = angular_dist[-2] if cumulative else 0
                
            if cumulative:
                angular_dist = np.cumsum(angular_dist)
                    
            return angular_dist
    
    
        
        linear_accel_list = []
        for _, group in tqdm(df.groupby('sequence_id')):
            acc_data_group = group[['acc_x', 'acc_y', 'acc_z']].values
            rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
            linear_accel_group = remove_gravity_from_acc(acc_data_group, rot_data_group)
            linear_accel_list.append(pd.DataFrame(linear_accel_group, columns=['f1', 'f2', 'f3'], index=group.index))
        df_linear_accel = pd.concat(linear_accel_list)
        df = pd.concat([df, df_linear_accel], axis=1)
    
        angular_vel_list = []
        for _, group in tqdm(df.groupby('sequence_id')):
            rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
            angular_vel_group = calculate_angular_velocity_from_quat(rot_data_group)
            angular_vel_list.append(pd.DataFrame(angular_vel_group, columns=['f4', 'f5', 'f6'], index=group.index))
        df_angular_vel = pd.concat(angular_vel_list)
        df = pd.concat([df, df_angular_vel], axis=1)
    
        angular_distance_list = []
        for _, group in tqdm(df.groupby('sequence_id')):
            rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
            angular_dist_group = calculate_angular_distance(rot_data_group)
            angular_distance_list.append(pd.DataFrame(angular_dist_group, columns=['f7'], index=group.index))
        df_angular_distance = pd.concat(angular_distance_list)
        df = pd.concat([df, df_angular_distance], axis=1)
    
        feature_names = ['f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7']
    
    
        thm_mean = df[['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']].mean(1)
    
        for col in ['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']:
            df[col] = (np.where(df[col].isna(), thm_mean, df[col]) ) / 5.                                                                                                                                             
        
        new_columns = {} 
        for i in range(1, 6):
            pixel_cols = [f"tof_{i}_v{p}" for p in range(64)]
            
            new_columns.update({
                f'tof_{i}_isna_mean1': (df[pixel_cols]==-1).mean(axis=1),
                f'tof_{i}_isna_mean2': (df[pixel_cols].isna()).mean(axis=1),
            })
    
            df[pixel_cols] = df[pixel_cols].replace(-1, 255.) / 50. 
            
            tof_data = df[pixel_cols]
    
            new_columns.update({
                f'tof_{i}_mean': tof_data.mean(axis=1),
                f'tof_{i}_std': tof_data.std(axis=1),
                f'tof_{i}_min': tof_data.min(axis=1),
                f'tof_{i}_max': tof_data.max(axis=1),
                
                f'tof_{i}_median_norm': tof_data.median(axis=1)/tof_data.mean(axis=1).clip(1),
                f'tof_{i}_max_norm': tof_data.min(axis=1)/tof_data.mean(axis=1).clip(1),
                f'tof_{i}_min_norm': tof_data.max(axis=1)/tof_data.mean(axis=1).clip(1),
            })
            if self.tof_mode > 1:
                region_size = 64 // self.tof_mode
                for r in tqdm(range(self.tof_mode)):
                    region_data = tof_data.iloc[:, r*region_size : (r+1)*region_size]
                    new_columns.update({
                        f'tof{self.tof_mode}_{i}_region_{r}_mean': region_data.mean(axis=1),
                        f'tof{self.tof_mode}_{i}_region_{r}_std': region_data.std(axis=1),
                        f'tof{self.tof_mode}_{i}_region_{r}_min': region_data.min(axis=1),
                        f'tof{self.tof_mode}_{i}_region_{r}_max': region_data.max(axis=1),
                    })
                    
            if self.tof_mode == -1:
                for mode in [2, 4, 8, 16, 32]:
                    region_size = 64 // mode
                    for r in tqdm(range(mode)):
                        region_data = tof_data.iloc[:, r*region_size : (r+1)*region_size]
                        new_columns.update({
                            f'tof{mode}_{i}_region_{r}_mean': region_data.mean(axis=1),
                            f'tof{mode}_{i}_region_{r}_std': region_data.std(axis=1),
                            f'tof{mode}_{i}_region_{r}_min': region_data.min(axis=1),
                            f'tof{mode}_{i}_region_{r}_max': region_data.max(axis=1)
                        })
        df_tof = pd.DataFrame(new_columns)
        df = pd.concat([df, df_tof], axis=1)
    
        
        return df, feature_names, list(df_tof.columns)
        

    def get_model(self, ):    
        print("▶ INFERENCE MODE – loading artefacts from", self.PRETRAINED_DIR)
        self.feature_cols = np.load(self.PRETRAINED_DIR / "feature_cols.npy", allow_pickle=True).tolist()
        self.pad_len = int(np.load(self.PRETRAINED_DIR / "sequence_maxlen.npy"))
        self.gesture_classes = np.load(self.PRETRAINED_DIR / "gesture_classes.npy", allow_pickle=True)
    
        self.imu_cols = [c for c in self.feature_cols if not (c.startswith('thm_') or c.startswith('tof_'))]
        self.tof_cols = [c for c in self.feature_cols if c.startswith('thm_') or c.startswith('tof_')]
        # Load model
        MODELS = [f'gesture_two_branch_fold{i}.pth' for i in range(5) if i in self.TRAIN_FOLDS]
        
        self.models = []
        for path in MODELS:
            checkpoint = torch.load(self.PRETRAINED_DIR / path, map_location=self.device)
            
            model = self.create_model(
                checkpoint['pad_len'], 
                checkpoint['imu_dim'], 
                checkpoint['tof_dim'], 
                checkpoint['n_classes']
                ).to(self.device)
    
            
            
            model.load_state_dict({k.replace('_orig_mod.', ''):v for k, v in checkpoint['model_state_dict'].items()})
            model.eval()
            self.models.append(model)
    
        print("  model, scaler, pads loaded – ready for evaluation")
    
        return 
    
    def predict(self, sequence: pl.DataFrame, demographics: pl.DataFrame, nanratio=0.):        
        df_seq = sequence.to_pandas()
        df_seq, _, _ = self.get_new_features(df_seq)
        
        with torch.no_grad():
            outputs = None
            for model in self.models:
                mat = self.preprocess_sequence(df_seq, self.feature_cols, None)
                pad = self.pad_sequences_torch([mat], maxlen=self.pad_len, padding='pre', truncating='pre')
                x = torch.FloatTensor(pad).to(self.device)
                
                model.eval()
                p = model(x)
                if outputs is None: outputs = p
                else: outputs += p
            outputs /= len(self.models)
                    
        return self.gesture_classes, outputs.cpu().numpy()

In [None]:
# for local training, prepare all the data under ./cmi-detect-behavior-with-sensor-data/ folder

model_zhou_inference_v1_2 = model_zhou_v1('/kaggle/input/cmi3v38')

model_zhou_inference_v1_2.TRAIN_FOLDS = [0, 1, 2, 3, 4]
model_zhou_inference_v1_2.VALIDATION = False
model_zhou_inference_v1_2.train()

In [None]:
# for seed in [
#              # 42,
#              6665252, 
#              # 88885252, 
#              #12345252, 
#              #325252,
#              #885252,
#              #99995252,
#              #115252,
#             ]:
#     model_zhou_inference_v0 = model_zhou_v15(f'./{seed}', seed=seed, save_path=f'./{seed}')
    
#     model_zhou_inference_v0.TRAIN_FOLDS = [0, 1, 2, 3, 4]

#     model_zhou_inference_v0.VALIDATION = False
    
#     model_zhou_inference_v0.train()
