In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/PRIM/code/src
%run setup_imports.py
import os
import pandas as pd

/content/drive/MyDrive/PRIM/code/src
Preprocess module imported successfully!
model_layers module imported successfully!
losses module imported successfully!
split module imported successfully!


# 1. PREPROCESS



In [None]:
# Preprocessing input data
input_folder = "/content/drive/MyDrive/PRIM/dataset/inputs"
output_folder = "/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed"
os.makedirs(output_folder, exist_ok=True)

for file_name in os.listdir(input_folder):
    if file_name.endswith(".xml"):
        input_xml_path = os.path.join(input_folder, file_name)
        print(f"Processing: {input_xml_path}")
        tree = preprocess.load_xml(input_xml_path)
        extracted_data = preprocess.extract_data_from_input(tree)
        df = preprocess.data_to_dataframe(extracted_data)
        df_interpolated = preprocess.interpolate_data_with_flag(df)
        df_normalized = preprocess.normalize(df_interpolated)
        output_csv_path = os.path.join(output_folder, file_name.replace(".xml", "_preprocessed.csv"))
        preprocess.save_data(df_normalized, output_csv_path)

        print(f"Saved preprocessed file to: {output_csv_path}")


In [None]:
# Preprocessing output data
import numpy as np
input_folder = "/content/drive/MyDrive/PRIM/dataset/outputs"
output_folder = "/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed"
os.makedirs(output_folder, exist_ok=True)
desired_columns = [
    'frame', 'stroke', 'player 1', 'player 2',
    'backhand', 'forehand', 'serve', 'ball pass',
    'point', 'mistake', 'void serve'
]

for file_name in os.listdir(input_folder):
    if file_name.endswith(".xml"):
        input_xml_path = os.path.join(input_folder, file_name)
        print(f"Processing: {input_xml_path}")
        tree = preprocess.load_xml(input_xml_path)
        extracted_data = preprocess.extract_data_from_output(tree)
        df = preprocess.data_to_dataframe(extracted_data)
        df['stroke'] = np.where(df['player 1'] + df['player 2'] <= 1, df['player 1'] + df['player 2'], 1)
        df['void serve'] = np.where(df['let serve'] + df['void serve'] <= 1, df['let serve'] + df['void serve'], 1)
        df = df[desired_columns]
        df['player'] = np.where(df['player 1'] == 1, 1, 0)
        df['type'] = np.where(df['backhand'] == 1, 1, 0)
        df['role'] = np.where(df['serve'] == 1, 1, np.where(df['ball pass'] == 1, 2, 0))
        df['impact'] = np.where(df['point'] == 1, 1, np.where(df['mistake'] == 1, 2, np.where(df['void serve'] == 1, 3, 0)))
        result = df[['frame', 'stroke', 'player', 'type', 'role', 'impact']]
        output_csv_path = os.path.join(output_folder, file_name.replace(".xml", "_preprocessed.csv"))
        preprocess.save_data(result, output_csv_path)
        print(f"Saved preprocessed file to: {output_csv_path}")


# 2. LOADING

In [None]:
feature_dir = "/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed"
target_dir = "/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed"

feature_files = [
    os.path.join(feature_dir, f) for f in os.listdir(feature_dir)
    if os.path.isfile(os.path.join(feature_dir, f))
]

target_files = [
    os.path.join(target_dir, f) for f in os.listdir(target_dir)
    if os.path.isfile(os.path.join(target_dir, f))
]



In [None]:
feature_files.sort()
target_files.sort()
train_features, train_targets, val_features, val_targets, test_features, test_targets = proportional_split(feature_files, target_files, calculate_file_lengths(feature_files), train_ratio=0.9, val_ratio=0.1, test_ratio=0)
print("Train Features:", sum(calculate_file_lengths(train_features)), len(train_features))
print("Validation Features:", sum(calculate_file_lengths(val_features)), len(val_features))
print("Test Features:", sum(calculate_file_lengths(test_features)), len(test_features))


Train Features: 565060 27
Validation Features: 85558 4
Test Features: 0 0


In [None]:
from dataset import TimeSeriesDataset
from torch.utils.data import DataLoader

chunk_length = 128
overlap = 112
train_dataset = TimeSeriesDataset(train_features, train_targets, chunk_length=chunk_length, overlap=0, augment= True, augment_prob=1)
val_dataset = TimeSeriesDataset(val_features, val_targets, chunk_length=chunk_length, overlap=overlap)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
for batch_idx, (features, targets) in enumerate(train_dataloader):
  print(features.shape)
  print(targets.shape)
  break


torch.Size([16, 128, 116])
torch.Size([16, 128, 5])


In [None]:
# tinsach tifsakh hadha
import wandb
wandb.login(key="83c00bb6bf15c1285e4617518f9a3c0b65a13872")


# 2. Experiment n°1: regular_loss + baseline model

In [None]:
#baseline model

class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, num_heads, hidden_dim, num_layers, dropout=0.1,max_len=1024):
        super(TransformerClassifier, self).__init__()

        # Transformer encoder
        self.encoder = TransformerEncoder(
            input_dim=input_dim,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            max_len=max_len
        )

        # Task-specific heads
        self.stroke_head = nn.Sequential(
            nn.Linear(hidden_dim, 1),
        )

        # Player classification head (binary)
        self.player_head = nn.Sequential(
            nn.Linear(hidden_dim, 2)           # Output layer
        )

        # Type classification head (binary)
        self.type_head = nn.Sequential(
            nn.Linear(hidden_dim, 2)           # Output layer
        )

        # Role classification head (3-class)
        self.role_head = nn.Sequential(
            nn.Linear(hidden_dim, 3)           # Output layer
        )

        # Impact classification head (4-class)
        self.impact_head = nn.Sequential(
            nn.Linear(hidden_dim, 4)
        )

    def forward(self, x):
        # Extract features using the transformer encoder
        features = self.encoder(x)

        # Task-specific predictions
        stroke_pred = self.stroke_head(features)
        player_pred = self.player_head(features)
        type_pred = self.type_head(features)
        role_pred = self.role_head(features)
        impact_pred = self.impact_head(features)

        return stroke_pred, player_pred, type_pred, role_pred, impact_pred


In [None]:
from train import train
model = TransformerClassifier(
    input_dim=124,
    num_heads=3,
    hidden_dim=192,
    num_layers=4,
    dropout=0.1,
    max_len=128
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
loss_fn = regular_loss

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train and validate
train(
    model=model,
    params=model.parameters(),
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=100,
    device=device,
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)


NameError: name 'TransformerClassifier' is not defined

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")


Total trainable parameters: 1805772


# 3. Experiment n°2: weighted loss + baseline model

In [None]:
from train import train

model = TransformerClassifier(
    input_dim=124,
    num_heads=3,
    hidden_dim=192,
    num_layers=4,
    dropout=0.1,
    max_len=128
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
loss_fn = weighted_loss

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train and validate
train(
    model=model,
    params=model.parameters(),
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=100,
    device=device,
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)


# 4. Experiment n°3: multivariate weighted loss + baseline model

In [None]:
from train import train

model = TransformerClassifier(
    input_dim=124,
    num_heads=3,
    hidden_dim=192,
    num_layers=4,
    dropout=0.1,
    max_len=128
)

log_vars = nn.Parameter(torch.zeros(5))
params = list(model.parameters()) + [log_vars]
optimizer = torch.optim.Adam(params, lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
loss_fn = multivariate_weighted_loss

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train(
    model=model,
    params=params,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=100,
    device=device,
    alpha=log_vars,
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)



# Experiment n° : Multivariate Weighted Loss + Undersampling Majority Class

In [None]:
# Undersampling Majority class

input_folder_train = '/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed_train'
output_folder_train = '/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed_train'
input_folder = '/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed'
output_folder = '/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed'

os.makedirs(input_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)

p = 0
for file_name in os.listdir(input_folder):
        input_csv_path = os.path.join(input_folder, file_name)
        output_csv_path = os.path.join(output_folder, file_name)

        df_input = pd.read_csv(input_csv_path)
        df_output = pd.read_csv(output_csv_path)

        n_before = 20
        n_after = 20

        stroke_indices = df_output[df_output["stroke"] == 1].index

        indices_to_keep = set()

        for idx in stroke_indices:
            indices_to_keep.add(idx)
            indices_to_keep.update(range(max(0, idx - n_before), idx))
            indices_to_keep.update(range(idx + 1, min(len(df_output), idx + 1 + n_after)))

        indices_to_keep = sorted(indices_to_keep)
        reduced_df_output = df_output.iloc[indices_to_keep].drop_duplicates(subset=['frame']).reset_index(drop=True)
        reduced_df_input = df_input.iloc[indices_to_keep].reset_index(drop=True)

        if reduced_df_input.isna().sum().sum() > 0:
            print(f"NaN values detected in reduced input file {file_name}. Skipping...")
            continue

        if reduced_df_output.isna().sum().sum() > 0:
            print(f"NaN values detected in reduced output file {file_name}. Skipping...")
            continue


        print(f"Reduced Output DataFrame shape: {reduced_df_output.shape}")
        print(f"Reduced Input DataFrame shape: {reduced_df_input.shape}")

        input_csv_path = os.path.join(input_folder_train, file_name)
        output_csv_path = os.path.join(output_folder_train, file_name)

        preprocess.save_data(reduced_df_input, input_csv_path)
        preprocess.save_data(reduced_df_output, output_csv_path)

        print(f"Saved preprocessed output file to: {output_csv_path}")
        print(f"Saved preprocessed input file to: {input_csv_path}")


In [None]:
feature_dir_train = "/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed_train"
target_dir_train = "/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed_train"

feature_dir_val = "/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed"
target_dir_val = "/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed"

feature_files_train = [
    os.path.join(feature_dir_train, f) for f in os.listdir(feature_dir_train)
    if os.path.isfile(os.path.join(feature_dir_train, f))
]

target_files_train = [
    os.path.join(target_dir_train, f) for f in os.listdir(target_dir_train)
    if os.path.isfile(os.path.join(target_dir_train, f))
]

feature_files_val = [
    os.path.join(feature_dir_val, f) for f in os.listdir(feature_dir_val)
    if os.path.isfile(os.path.join(feature_dir_val, f))
]

target_files_val = [
    os.path.join(target_dir_val, f) for f in os.listdir(target_dir_val)
    if os.path.isfile(os.path.join(target_dir_val, f))
]

In [None]:
feature_files_train.sort()
target_files_train.sort()
feature_files_val.sort()
target_files_val.sort()

train_features, train_targets, val_features, val_targets = feature_files_train[:24], target_files_train[:24], feature_files_val[24:25], target_files_val[24:25]
print("Train Features:", sum(calculate_file_lengths(train_features)), len(train_features))
print("Validation Features:", sum(calculate_file_lengths(val_features)), len(val_features))


Train Features: 128552 24
Validation Features: 43614 1


In [None]:
def multivariate_weighted_loss(y_pred, y_true, log_vars):
    """
    y_pred : tuple of predictions (logits)
    y_true : tuple of ground truth labels
    """
    y_pred_stroke, y_pred_player, y_pred_type, y_pred_role, y_pred_impact = y_pred
    y_true_stroke = y_true[:, :, 0].unsqueeze(-1)
    y_true_player = y_true[:, :, 1:3]
    y_true_type = y_true[:, :, 3:5]
    y_true_role = torch.stack([
        y_true[:, :, 5],
        y_true[:, :, 6],
        torch.ones_like(y_true[:, :, 5]) - y_true[:, :, 5] - y_true[:, :, 6]
    ], dim=-1)
    y_true_impact = torch.cat([
        y_true[:, :, 7:],
        1 - torch.sum(y_true[:, :, 7:], dim=-1, keepdim=True)
    ], dim=-1)

    # Task-specific loss computation
    mask_stroke = (y_true_stroke != -1).squeeze(-1)
    #Weights are calculated based on all data using the following formula w_i = nb_total_samples / (nb_classes * nb_samples_i)
    weight = torch.zeros_like(y_true_stroke)
    weight[y_true_stroke == 1] = 10.57
    weight[y_true_stroke == 0] = 0.52
    weight = weight[mask_stroke].squeeze(-1)



    bce_with_logits_loss = nn.BCEWithLogitsLoss(weight=weight)
    loss_stroke = bce_with_logits_loss(y_pred_stroke[mask_stroke].squeeze(-1), y_true_stroke[mask_stroke].squeeze(-1)) if mask_stroke.sum() > 0 else torch.tensor(0.0, device=y_pred_stroke.device)

    mask = (y_true_stroke == 1).squeeze(-1)

    # Player and Type tasks (No weighting applied here almost balanced)
    loss_player = F.cross_entropy(y_pred_player[mask], y_true_player[mask]) if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_player.device)
    loss_type = F.cross_entropy(y_pred_type[mask], y_true_type[mask]) if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_type.device)

    weights_role = torch.tensor([1.63, 3.41, 0.47], dtype=torch.float32, device=y_pred_role.device)

    loss_role = F.cross_entropy(y_pred_role[mask], y_true_role[mask], weight=weights_role) if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_role.device)

    weights_impact = torch.tensor([8.71, 1.42, 34.7, 0.31], dtype=torch.float32, device=y_pred_role.device)

    loss_impact = F.cross_entropy(y_pred_impact[mask], y_true_impact[mask], weight=weights_impact) if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_role.device)

    total_loss = (
        (1 / (2 * torch.exp(log_vars[0]))) * loss_stroke + log_vars[0] +
        (1 / (2 * torch.exp(log_vars[1]))) * loss_player + log_vars[1] +
        (1 / (2 * torch.exp(log_vars[2]))) * loss_type + log_vars[2] +
        (1 / (2 * torch.exp(log_vars[3]))) * loss_role + log_vars[3] +
        (1 / (2 * torch.exp(log_vars[4]))) * loss_impact + log_vars[4]
    )
    return total_loss


In [None]:
from dataset import TimeSeriesDataset
from torch.utils.data import DataLoader

chunk_length = 128
overlap = 112
train_dataset = TimeSeriesDataset(train_features, train_targets, chunk_length=chunk_length, overlap=0)
val_dataset = TimeSeriesDataset(val_features, val_targets, chunk_length=chunk_length, overlap=overlap)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
from train import train


model = TransformerClassifier(
    input_dim=124,
    num_heads=8,
    hidden_dim=256,
    num_layers=4,
    dropout=0.1,
    max_len=128
)

log_vars = nn.Parameter(torch.zeros(5))
params = list(model.parameters()) + [log_vars]
optimizer = torch.optim.Adam(params, lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
loss_fn = multivariate_weighted_loss

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train(
    model=model,
    params=params,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=100,
    device=device,
    alpha=log_vars,
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)



# Model Hyperparameter Tuning / Testing

In [None]:
feature_files.sort()
target_files.sort()
train_features, train_targets, val_features, val_targets, test_features, test_targets = proportional_split(feature_files, target_files, calculate_file_lengths(feature_files), train_ratio=0.8, val_ratio=0.1, test_ratio=0.1)
print("Train Features:", sum(calculate_file_lengths(train_features)), len(train_features))
print("Validation Features:", sum(calculate_file_lengths(val_features)), len(val_features))
print("Test Features:", sum(calculate_file_lengths(test_features)), len(test_features))


Train Features: 487571 25
Validation Features: 77489 2
Test Features: 85558 4


In [None]:
from dataset import TimeSeriesDataset
from torch.utils.data import DataLoader

chunk_length = 128
overlap = 112
train_dataset = TimeSeriesDataset(train_features, train_targets, chunk_length=chunk_length, overlap=0, augment= True, augment_prob=1)
val_dataset = TimeSeriesDataset(val_features, val_targets, chunk_length=chunk_length, overlap=overlap)
test_dataset = TimeSeriesDataset(test_features, test_targets, chunk_length=chunk_length, overlap=overlap)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    """
    Focal Loss for Binary and Multi-Class Classification.
    Args:
        gamma (float): Focusing parameter to down-weight easy examples. Default: 2.
        alpha (float, list, or None): Balancing factor for classes.
            - If float, applies the same alpha to the positive class (binary).
            - If list, applies per-class weights (multi-class).
            - If None, no weighting is applied. Default: None.
        reduction (str): Specifies the reduction to apply to the output:
            - 'none': No reduction.
            - 'mean': Average over all examples.
            - 'sum': Sum over all examples. Default: 'mean'.
    """
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction
        # Process alpha
        if isinstance(alpha, (float, int)):
            self.alpha = torch.tensor([alpha, 1 - alpha])
        elif isinstance(alpha, list):
            self.alpha = torch.tensor(alpha)
        else:
            self.alpha = None

    def forward(self, logits, targets):
        """
        Args:
            logits (Tensor): Predicted logits (before softmax) of shape (N, C, ...) for multi-class
                             or (N, ...) for binary classification.
            targets (Tensor): Ground truth labels of shape (N, ...) for binary or (N, ...) for multi-class.
        Returns:
            Tensor: Computed focal loss.
        """
        # For multi-class classification
        if logits.dim() > 2:
            # Apply softmax across the class dimension
            probs = F.softmax(logits, dim=1)

            # One-hot encode targets
            targets_one_hot = F.one_hot(targets, num_classes=logits.size(1)).float()
            ce_loss = -targets_one_hot * torch.log(probs.clamp(min=1e-6))
            p_t = probs * targets_one_hot + (1 - probs) * (1 - targets_one_hot)
        else:
            probs = torch.sigmoid(logits)
            ce_loss = -targets * torch.log(probs.clamp(min=1e-6)) - (1 - targets) * torch.log((1 - probs).clamp(min=1e-6))
            p_t = probs * targets + (1 - probs) * (1 - targets)

        focal_term = (1 - p_t) ** self.gamma
        # Apply alpha if provided
        if self.alpha is not None:
            targets = targets.to(self.alpha.device).long()
            # Index alpha

            alpha_t = self.alpha[targets]
            # Apply alpha to the loss
            ce_loss = ce_loss.to(self.alpha.device) * alpha_t.unsqueeze(-1)

        # Combine focal term and cross-entropy
        loss = focal_term.to(self.alpha.device) * ce_loss
        # Reduction
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:  # 'none'
            return loss


In [None]:
# dynamic class weighting
import torch
import torch.nn as nn
import torch.nn.functional as F

def loss(y_pred, y_true, log_vars):
    y_pred_stroke, y_pred_player, y_pred_type, y_pred_role, y_pred_impact = y_pred
    y_true_stroke = y_true[:, :, 0]
    y_true_player = y_true[:, :, 1]
    y_true_type = y_true[:, :, 2]
    y_true_role = y_true[:, :, 3]
    y_true_impact = y_true[:, :, 4]

    mask_stroke = (y_true_stroke != -1).squeeze(-1)

    num_strokes = (y_true_stroke[mask_stroke] == 1).sum()
    pos_weight_stroke = (mask_stroke.sum() / (num_strokes + 1e-6)).to(y_true.device) if num_strokes > 0 else torch.tensor(1.0, device=y_true.device)
    bce_with_logits_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight_stroke)
    loss_stroke = bce_with_logits_loss(y_pred_stroke[mask_stroke], y_true_stroke[mask_stroke]) if mask_stroke.sum() > 0 else torch.tensor(0.0, device=y_pred_stroke.device)

    mask = (y_true_stroke == 1).squeeze(-1)

    loss_player = F.cross_entropy(y_pred_player[mask], y_true_player[mask]) if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_player.device)
    loss_type = F.cross_entropy(y_pred_type[mask], y_true_type[mask]) if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_type.device)

    default_role_weight = 1.0
    default_impact_weight = 1.0
    if mask.sum() > 0:
        class_counts_role = torch.bincount(y_true_role[mask].flatten(), minlength=3)
        role_weights = (class_counts_role.sum() / (class_counts_role + 1e-6)).to(y_pred_role.device)
        role_weights[class_counts_role == 0] = default_role_weight
        loss_role = F.cross_entropy(y_pred_role[mask], y_true_role[mask], weight=role_weights)
    else:
        loss_role = torch.tensor(0.0, device=y_pred_role.device)

    if mask.sum() > 0:
        class_counts_impact = torch.bincount(y_true_impact[mask].flatten(), minlength=y_pred_impact.size(-1))
        impact_weights = (class_counts_impact.sum() / (class_counts_impact + 1e-6)).to(y_pred_impact.device)
        impact_weights[class_counts_impact == 0] = default_impact_weight

        loss_impact = F.cross_entropy(y_pred_impact[mask], y_true_impact[mask], weight=impact_weights)
    else:
        loss_impact = torch.tensor(0.0, device=y_pred_impact.device)


    total_loss = (
        (1 / (2 * torch.exp(log_vars[0]))) * loss_stroke + log_vars[0] +
        (1 / (2 * torch.exp(log_vars[1]))) * loss_player + log_vars[1] +
        (1 / (2 * torch.exp(log_vars[2]))) * loss_type + log_vars[2] +
        (1 / (2 * torch.exp(log_vars[3]))) * loss_role + log_vars[3] +
        (1 / (2 * torch.exp(log_vars[4]))) * loss_impact + log_vars[4]
    )

    return total_loss


In [None]:
# dynamic class weighting
import torch
import torch.nn as nn
import torch.nn.functional as F

def loss(y_pred, y_true, log_vars):
    y_pred_stroke, y_pred_player, y_pred_type, y_pred_role, y_pred_impact = y_pred
    y_true_stroke = y_true[:, :, 0]
    y_true_player = y_true[:, :, 1]
    y_true_type = y_true[:, :, 2]
    y_true_role = y_true[:, :, 3]
    y_true_impact = y_true[:, :, 4]
    mask_stroke = (y_true_stroke != -1).squeeze(-1)
    focal_loss = FocalLoss(gamma=2,alpha=0.01)
    loss_stroke = focal_loss(y_pred_stroke[mask_stroke], y_true_stroke[mask_stroke])
    mask = (y_true_stroke == 1).squeeze(-1)

    loss_player = F.cross_entropy(y_pred_player[mask], y_true_player[mask]) if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_player.device)
    loss_type = F.cross_entropy(y_pred_type[mask], y_true_type[mask]) if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_type.device)

    default_role_weight = 1.0
    default_impact_weight = 1.0
    if mask.sum() > 0:
        class_counts_role = torch.bincount(y_true_role[mask].flatten(), minlength=3)
        role_weights = (class_counts_role.sum() / (class_counts_role + 1e-6)).to(y_pred_role.device)
        role_weights[class_counts_role == 0] = default_role_weight
        loss_role = F.cross_entropy(y_pred_role[mask], y_true_role[mask], weight=role_weights)
    else:
        loss_role = torch.tensor(0.0, device=y_pred_role.device)

    if mask.sum() > 0:
        class_counts_impact = torch.bincount(y_true_impact[mask].flatten(), minlength=y_pred_impact.size(-1))
        impact_weights = (class_counts_impact.sum() / (class_counts_impact + 1e-6)).to(y_pred_impact.device)
        impact_weights[class_counts_impact == 0] = default_impact_weight

        loss_impact = F.cross_entropy(y_pred_impact[mask], y_true_impact[mask], weight=impact_weights)
    else:
        loss_impact = torch.tensor(0.0, device=y_pred_impact.device)


    total_loss = (
        (1 / (2 * torch.exp(log_vars[0]))) * loss_stroke + log_vars[0] +
        (1 / (2 * torch.exp(log_vars[1]))) * loss_player + log_vars[1] +
        (1 / (2 * torch.exp(log_vars[2]))) * loss_type + log_vars[2] +
        (1 / (2 * torch.exp(log_vars[3]))) * loss_role + log_vars[3] +
        (1 / (2 * torch.exp(log_vars[4]))) * loss_impact + log_vars[4]
    )

    return total_loss


In [None]:
from train import train
model = TransformerClassifier(
    input_dim=124,
    num_heads=8,
    hidden_dim=256,
    num_layers=4,
    dropout=0.2,
    max_len=128
)

log_vars = nn.Parameter(torch.zeros(5))
params = list(model.parameters()) + [log_vars]
optimizer = torch.optim.AdamW(params, lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
loss_fn = loss

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train(
    model=model,
    params=params,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=200,
    device=device,
    alpha=log_vars,
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")


Total trainable parameters: 3406158


# 5. Experiment n°4:  weighted loss + CNN

In [None]:
import torch
import torch.nn as nn

class ImprovedCNNFeatureExtractor(nn.Module):
    def __init__(self, input_channels, output_dim):
        super(ImprovedCNNFeatureExtractor, self).__init__()
        self.cnn = nn.Sequential(
            # Layer 1
            nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            # Layer 2
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            # Layer 3 (Project to output_dim)
            nn.Conv1d(in_channels=128, out_channels=output_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        # Input shape: (batch_size, seq_length, num_features)
        # Transpose to (batch_size, num_features, seq_length) for CNN
        x = x.permute(0, 2, 1)
        x = self.cnn(x)  # Output shape: (batch_size, output_dim, seq_length)
        # Transpose back to (batch_size, seq_length, output_dim) for Transformer
        x = x.permute(0, 2, 1)
        return x

    def forward(self, x):
        # Input shape: (batch_size, seq_length, num_features)
        # Transpose to (batch_size, num_features, seq_length) for CNN
        x = x.permute(0, 2, 1)
        x = self.cnn(x)
        return x


class StrongHybridModel(nn.Module):
    def __init__(self, cnn_input_dim, cnn_output_dim, transformer_params):
        super(StrongHybridModel, self).__init__()
        self.cnn = ImprovedCNNFeatureExtractor(input_channels=cnn_input_dim, output_dim=cnn_output_dim)
        self.transformer = TransformerClassifier(
            input_dim=cnn_output_dim,
            **transformer_params  # Pass your Transformer settings (e.g., num_heads, num_layers)
        )

    def forward(self, x):
        # Pass through CNN
        x = self.cnn(x)  # Shape: (batch_size, cnn_output_dim, reduced_seq_length)
        # Transpose back to (batch_size, reduced_seq_length, cnn_output_dim) for Transformer
        x = x.permute(0, 2, 1)
        # Pass through Transformer
        return self.transformer(x)




# Instantiate the improved hybrid model
model = StrongHybridModel(
    cnn_input_dim=124,       # Number of input features
    cnn_output_dim=128,      # Number of features output by the CNN
    transformer_params={     # Transformer-specific settings
        "num_heads": 4,
        "hidden_dim": 256,
        "num_layers": 6,
        "dropout": 0.2,      # Slightly higher dropout for regularization
        "max_len": 128
    }
)

# Define optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

# Use the weighted loss function for multi-task learning
loss_fn = weighted_loss

# Train the model
train(
    model=model,
    params=model.parameters(),
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=100,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)


In [None]:
class StrongHybridModel(nn.Module):
    def __init__(self, cnn_input_dim, cnn_output_dim, transformer_params):
        super(StrongHybridModel, self).__init__()
        self.cnn = ImprovedCNNFeatureExtractor(input_channels=cnn_input_dim, output_dim=cnn_output_dim)
        self.transformer = TransformerClassifier(
            input_dim=cnn_output_dim,
            **transformer_params  # Pass your Transformer settings (e.g., num_heads, num_layers)
        )

    def forward(self, x):
        # Pass through CNN
        x = self.cnn(x)  # Shape: (batch_size, cnn_output_dim, reduced_seq_length)
        # Transpose back to (batch_size, reduced_seq_length, cnn_output_dim) for Transformer
        x = x.permute(0, 2, 1)
        # Pass through Transformer
        return self.transformer(x)


In [None]:
# Instantiate the improved hybrid model
model = StrongHybridModel(
    cnn_input_dim=124,       # Number of input features
    cnn_output_dim=128,      # Number of features output by the CNN
    transformer_params={     # Transformer-specific settings
        "num_heads": 4,
        "hidden_dim": 256,
        "num_layers": 6,
        "dropout": 0.2,      # Slightly higher dropout for regularization
        "max_len": 128
    }
)

# Define optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

# Use the weighted loss function for multi-task learning
loss_fn = weighted_loss

# Train the model
train(
    model=model,
    params=model.parameters(),
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=100,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)


In [None]:
# Preprocessing output data
import numpy as np
input_folder = "/content/drive/MyDrive/PRIM/dataset/outputs"
output_folder = "/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed"
os.makedirs(output_folder, exist_ok=True)
desired_columns = [
    'frame', 'stroke', 'player 1', 'player 2',
    'backhand', 'forehand', 'serve', 'ball pass',
    'point', 'mistake', 'void serve'
]

for file_name in os.listdir(input_folder):
    if file_name.endswith(".xml"):
        input_xml_path = os.path.join(input_folder, file_name)
        print(f"Processing: {input_xml_path}")
        tree = preprocess.load_xml(input_xml_path)
        extracted_data = preprocess.extract_data_from_output(tree)
        df = preprocess.data_to_dataframe(extracted_data)
        df['stroke'] = np.where(df['player 1'] + df['player 2'] <= 1, df['player 1'] + df['player 2'], 1)
        df['void serve'] = df['let serve'] + df['void serve']
        df = df[desired_columns]
        df['player'] = np.where(df['player 1'] == 1, 1, np.where(df['player 2'] == 1, 2, 0))
        df['type'] = np.where(df['backhand'] == 1, 1, np.where(df['forehand'] == 1, 2, 0))
        df['role'] = np.where(df['serve'] == 1, 1, np.where(df['ball pass'] == 1, 2, 0))
        df['impact'] = np.where(df['point'] == 1, 1,
                                np.where(df['mistake'] == 1, 2,
                                         np.where(df['void serve'] == 1, 3, 0)))

        # Select columns for the result DataFrame
        result = df[['frame', 'stroke', 'player', 'type', 'role', 'impact']]

        output_csv_path = os.path.join(output_folder, file_name.replace(".xml", "_preprocessed.csv"))
        preprocess.save_data(result, output_csv_path)
        print(f"Saved preprocessed file to: {output_csv_path}")


In [None]:
# Undersampling Majority class

input_folder_train = '/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed_train'
output_folder_train = '/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed_train'
input_folder = '/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed'
output_folder = '/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed'

os.makedirs(input_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)

p = 0
for file_name in os.listdir(input_folder):
        input_csv_path = os.path.join(input_folder, file_name)
        output_csv_path = os.path.join(output_folder, file_name)

        df_input = pd.read_csv(input_csv_path)
        df_output = pd.read_csv(output_csv_path)

        n_before = 30
        n_after = 30

        stroke_indices = df_output[df_output["stroke"] == 1].index

        indices_to_keep = set()

        for idx in stroke_indices:
            indices_to_keep.add(idx)
            indices_to_keep.update(range(max(0, idx - n_before), idx))
            indices_to_keep.update(range(idx + 1, min(len(df_output), idx + 1 + n_after)))

        indices_to_keep = sorted(indices_to_keep)
        reduced_df_output = df_output.iloc[indices_to_keep].drop_duplicates(subset=['frame']).reset_index(drop=True)
        reduced_df_input = df_input.iloc[indices_to_keep].reset_index(drop=True)

        if reduced_df_input.isna().sum().sum() > 0:
            print(f"NaN values detected in reduced input file {file_name}. Skipping...")
            continue

        if reduced_df_output.isna().sum().sum() > 0:
            print(f"NaN values detected in reduced output file {file_name}. Skipping...")
            continue


        print(f"Reduced Output DataFrame shape: {reduced_df_output.shape}")
        print(f"Reduced Input DataFrame shape: {reduced_df_input.shape}")

        input_csv_path = os.path.join(input_folder_train, file_name)
        output_csv_path = os.path.join(output_folder_train, file_name)

        preprocess.save_data(reduced_df_input, input_csv_path)
        preprocess.save_data(reduced_df_output, output_csv_path)
        print()
        print(f"Saved preprocessed output file to: {output_csv_path}")
        print(f"Saved preprocessed input file to: {input_csv_path}")


Reduced Output DataFrame shape: (6585, 6)
Reduced Input DataFrame shape: (6585, 117)

Saved preprocessed output file to: /content/drive/MyDrive/PRIM/dataset/outputs_preprocessed_train/alatn_preprocessed.csv
Saved preprocessed input file to: /content/drive/MyDrive/PRIM/dataset/inputs_preprocessed_train/alatn_preprocessed.csv
Reduced Output DataFrame shape: (7476, 6)
Reduced Input DataFrame shape: (7476, 117)

Saved preprocessed output file to: /content/drive/MyDrive/PRIM/dataset/outputs_preprocessed_train/bnivh_preprocessed.csv
Saved preprocessed input file to: /content/drive/MyDrive/PRIM/dataset/inputs_preprocessed_train/bnivh_preprocessed.csv
Reduced Output DataFrame shape: (7395, 6)
Reduced Input DataFrame shape: (7395, 117)

Saved preprocessed output file to: /content/drive/MyDrive/PRIM/dataset/outputs_preprocessed_train/edysv_preprocessed.csv
Saved preprocessed input file to: /content/drive/MyDrive/PRIM/dataset/inputs_preprocessed_train/edysv_preprocessed.csv
Reduced Output DataFra

In [None]:
feature_dir = "/content/drive/MyDrive/PRIM/dataset/inputs_preprocessed_train"
target_dir = "/content/drive/MyDrive/PRIM/dataset/outputs_preprocessed_train"

feature_files = [
    os.path.join(feature_dir, f) for f in os.listdir(feature_dir)
    if os.path.isfile(os.path.join(feature_dir, f))
]

target_files = [
    os.path.join(target_dir, f) for f in os.listdir(target_dir)
    if os.path.isfile(os.path.join(target_dir, f))
]
feature_files.sort()
target_files.sort()

In [None]:
train_features, train_targets = feature_files[:28] , target_files[:28]

In [None]:
from dataset import TimeSeriesDataset
from torch.utils.data import DataLoader

chunk_length = 128
overlap = 112
train_dataset = TimeSeriesDataset(train_features, train_targets, chunk_length=chunk_length, overlap=0, augment= True, augment_prob=1)

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False)


In [None]:
totals = {
    'num_frames': 0,
    'num_strokes': 0,
    'num_player_1': 0,
    'num_player_2': 0,
    'num_type_1': 0,
    'num_type_2': 0,
    'num_impact_1': 0,
    'num_impact_2': 0,
    'num_impact_3': 0,
    'num_role_1': 0,
    'num_role_2': 0
}

# Loop through all feature and target files
for file_path in target_files[:25]:
    if file_path.endswith('.csv'):
        df = pd.read_csv(file_path)
        # Aggregate calculations
        totals['num_frames'] += df.shape[0]  # Total number of frames
        totals['num_strokes'] += df[df['stroke'] == 1].shape[0]
        totals['num_player_1'] += df[df['player'] == 1].shape[0]
        totals['num_player_2'] += df[df['player'] == 2].shape[0]
        totals['num_type_1'] += df[df['type'] == 1].shape[0]
        totals['num_type_2'] += df[df['type'] == 2].shape[0]
        totals['num_role_1'] += df[df['role'] == 1].shape[0]
        totals['num_role_2'] += df[df['role'] == 2].shape[0]

        totals['num_impact_1'] += df[df['impact'] == 1].shape[0]
        totals['num_impact_2'] += df[df['impact'] == 2].shape[0]
        totals['num_impact_3'] += df[df['impact'] == 3].shape[0]

# Display the total counts
print("Aggregate Results:")
for key, value in totals.items():
    print(f"{key}: {value}")


Aggregate Results:
num_frames: 487571
num_strokes: 5072
num_player_1: 2550
num_player_2: 0
num_type_1: 2792
num_type_2: 0
num_impact_1: 155
num_impact_2: 892
num_impact_3: 38
num_role_1: 1035
num_role_2: 498


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    """
    Focal Loss for Binary and Multi-Class Classification.
    Args:
        gamma (float): Focusing parameter to down-weight easy examples. Default: 2.
        alpha (float, list, or None): Balancing factor for classes.
            - If float, applies the same alpha to the positive class (binary).
            - If list, applies per-class weights (multi-class).
            - If None, no weighting is applied. Default: None.
        reduction (str): Specifies the reduction to apply to the output:
            - 'none': No reduction.
            - 'mean': Average over all examples.
            - 'sum': Sum over all examples. Default: 'mean'.
    """
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction
        # Process alpha
        if isinstance(alpha, (float, int)):
            self.alpha = torch.tensor([alpha, 1 - alpha])
        elif isinstance(alpha, list):
            self.alpha = torch.tensor(alpha)
        else:
            self.alpha = None

    def forward(self, logits, targets):
        """
        Args:
            logits (Tensor): Predicted logits (before softmax) of shape (N, C, ...) for multi-class
                             or (N, ...) for binary classification.
            targets (Tensor): Ground truth labels of shape (N, ...) for binary or (N, ...) for multi-class.
        Returns:
            Tensor: Computed focal loss.
        """
        # For multi-class classification
        if logits.dim() > 2:
            # Apply softmax across the class dimension
            probs = F.softmax(logits, dim=1)

            # One-hot encode targets
            targets_one_hot = F.one_hot(targets, num_classes=logits.size(1)).float()
            ce_loss = -targets_one_hot * torch.log(probs.clamp(min=1e-6))
            p_t = probs * targets_one_hot + (1 - probs) * (1 - targets_one_hot)
        else:
            probs = torch.sigmoid(logits)
            ce_loss = -targets * torch.log(probs.clamp(min=1e-6)) - (1 - targets) * torch.log((1 - probs).clamp(min=1e-6))
            p_t = probs * targets + (1 - probs) * (1 - targets)

        focal_term = (1 - p_t) ** self.gamma
        # Apply alpha if provided
        if self.alpha is not None:
            targets = targets.to(self.alpha.device).long()
            # Index alpha

            alpha_t = self.alpha[targets]
            # Apply alpha to the loss
            ce_loss = ce_loss.to(self.alpha.device) * alpha_t.unsqueeze(-1)

        # Combine focal term and cross-entropy
        loss = focal_term.to(self.alpha.device) * ce_loss
        # Reduction
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:  # 'none'
            return loss


In [None]:
def multivariate_weighted_loss(y_pred, y_true, log_vars):
    """
    y_pred : tuple of predictions (logits)
    y_true : tuple of ground truth labels
    """
    y_pred_stroke, y_pred_player, y_pred_type, y_pred_role, y_pred_impact = y_pred
    y_true_stroke = y_true[:, :, 0]
    y_true_player = y_true[:, :, 1]
    y_true_type =  y_true[:, :, 2]
    y_true_role = y_true[:, :, 3]
    y_true_impact = y_true[:, :, 4]
    mask = (y_true_stroke != -1).squeeze(-1)
    #Weights are calculated based on all data using the following formula w_i = nb_total_samples / (nb_classes * nb_samples_i)
    weight = torch.zeros_like(y_true_stroke)
    weight[y_true_stroke == 1] = totals['num_frames'] / (2*totals['num_strokes'])
    weight[y_true_stroke == 0] = totals['num_frames'] / (2*(totals['num_frames'] - totals['num_strokes']))
    weight = weight[mask].squeeze(-1)



    bce_with_logits_loss = nn.BCEWithLogitsLoss(weight=weight, reduction="sum")
    loss_stroke = bce_with_logits_loss(y_pred_stroke[mask].squeeze(-1), y_true_stroke[mask].squeeze(-1))
    #focal_loss = FocalLoss(alpha=0.01,gamma = 1.5 )
    #loss_stroke = focal_loss(y_pred_stroke[mask].squeeze(-1), y_true_stroke[mask].squeeze(-1))
    mask = (y_true_player != -1)
    weights_player = torch.tensor([ totals['num_frames']/(3*(totals['num_frames'] - totals['num_player_1'] - totals['num_player_2'])), totals['num_frames']/(3*totals['num_player_1']) , totals['num_frames']/(3*totals['num_player_2'])], dtype=torch.float32, device=y_pred_role.device)

    loss_player = F.cross_entropy(y_pred_player[mask].squeeze(), y_true_player[mask].squeeze().long(), weight=weights_player, reduction="sum")
    weights_type = torch.tensor([totals['num_frames']/ (3*(totals['num_frames'] - totals['num_type_1'] - totals['num_type_2'])), totals['num_frames']/ (3*totals['num_type_1']), totals['num_frames']/ (3*totals['num_type_2'])], dtype=torch.float32, device=y_pred_role.device)

    loss_type = F.cross_entropy(y_pred_type[mask].squeeze(), y_true_type[mask].squeeze().long(), weight=weights_type, reduction="sum")

    weights_role = torch.tensor([totals['num_frames']/ (3*(totals['num_frames'] - totals['num_role_1'] - totals['num_role_2'])), totals['num_frames']/ (3*totals['num_role_1']) , totals['num_frames']/ (3*totals['num_role_2']) ], dtype=torch.float32, device=y_pred_role.device)

    loss_role = F.cross_entropy(y_pred_role[mask].squeeze(), y_true_role[mask].squeeze().long(), weight=weights_role, reduction="sum")
    weights_impact = torch.tensor([totals['num_frames']/ (4*(totals['num_frames'] - totals['num_impact_1'] - totals['num_impact_2'] - totals['num_impact_3'])), totals['num_frames']/ (3*totals['num_impact_1']), totals['num_frames']/ (3*totals['num_impact_2']), totals['num_frames']/ (3*totals['num_impact_3'])], dtype=torch.float32, device=y_pred_role.device)

    loss_impact = F.cross_entropy(y_pred_impact[mask].squeeze(), y_true_impact[mask].squeeze().long(), weight=weights_impact, reduction="sum")
    total_loss = (
        (1 / (2 * torch.exp(log_vars[0]))) * loss_stroke + log_vars[0] +
        (1 / (2 * torch.exp(log_vars[1]))) * loss_player + log_vars[1] +
        (1 / (2 * torch.exp(log_vars[2]))) * loss_type + log_vars[2] +
        (1 / (2 * torch.exp(log_vars[3]))) * loss_role + log_vars[3] +
        (1 / (2 * torch.exp(log_vars[4]))) * loss_impact + log_vars[4]
    )
    return total_loss


In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss.mean()


In [None]:
def label_smoothing_loss(predictions, targets, smoothing=0.1):
    num_classes = predictions.size(-1)
    smooth_targets = (1 - smoothing) * F.one_hot(targets, num_classes) + smoothing / num_classes
    return -torch.sum(smooth_targets * F.log_softmax(predictions, dim=-1), dim=-1).mean()


In [None]:
def multivariate_weighted_loss(y_pred, y_true, log_vars):
    y_pred_stroke, y_pred_player, y_pred_type, y_pred_role, y_pred_impact = y_pred
    y_true_stroke = y_true[:, :, 0]
    y_true_player = y_true[:, :, 1]
    y_true_type =  y_true[:, :, 2]
    y_true_role = y_true[:, :, 3]
    y_true_impact = y_true[:, :, 4]

    mask = (y_true_stroke != -1).squeeze(-1)

    # Focal Loss for Stroke
    focal_loss = FocalLoss(alpha=0.25, gamma=2.0)
    loss_stroke = focal_loss(y_pred_stroke[mask].squeeze(-1), y_true_stroke[mask].squeeze(-1))

    print(f"stroke{loss_stroke}")

    # Cross-Entropy with Label Smoothing for other tasks
    label_smooth_loss = label_smoothing_loss
    mask = (y_true_player != -1)

    loss_player = label_smooth_loss(y_pred_player[mask].squeeze(), y_true_player[mask].squeeze().long())
    print(f"player{loss_player}")
    loss_type = label_smooth_loss(y_pred_type[mask].squeeze(), y_true_type[mask].squeeze().long())
    print(f"type{loss_type}")
    loss_role = label_smooth_loss(y_pred_role[mask].squeeze(), y_true_role[mask].squeeze().long())
    print(f"role{loss_role}")
    loss_impact = label_smooth_loss(y_pred_impact[mask].squeeze(), y_true_impact[mask].squeeze().long())
    print(f"impact{loss_impact}")
    # Dynamic Task Weights
    total_loss = (
        (1 / (2 * torch.exp(log_vars[0]))) * loss_stroke + log_vars[0] +
        (1 / (2 * torch.exp(log_vars[1]))) * loss_player + log_vars[1] +
        (1 / (2 * torch.exp(log_vars[2]))) * loss_type + log_vars[2] +
        (1 / (2 * torch.exp(log_vars[3]))) * loss_role + log_vars[3] +
        (1 / (2 * torch.exp(log_vars[4]))) * loss_impact + log_vars[4]
    )
    return total_loss


In [None]:
# dynamic class weighting
import torch
import torch.nn as nn
import torch.nn.functional as F

def loss(y_pred, y_true, log_vars):
    y_pred_stroke, y_pred_player, y_pred_type, y_pred_role, y_pred_impact = y_pred
    y_true_stroke = y_true[:, :, 0].unsqueeze(-1)
    y_true_player = y_true[:, :, 1].long()
    y_true_type =  y_true[:, :, 2].long()
    y_true_role = y_true[:, :, 3].long()
    y_true_impact = y_true[:, :, 4].long()

    mask_stroke = (y_true_stroke != -1).squeeze(-1)

    num_strokes = (y_true_stroke[mask_stroke] == 1).sum()
    pos_weight_stroke = (mask_stroke.sum() / (num_strokes + 1e-6)).to(y_true.device) if num_strokes > 0 else torch.tensor(1.0, device=y_true.device)
    bce_with_logits_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight_stroke, reduction="sum")
    loss_stroke = bce_with_logits_loss(y_pred_stroke[mask_stroke], y_true_stroke[mask_stroke]) if mask_stroke.sum() > 0 else torch.tensor(0.0, device=y_pred_stroke.device)

    mask = (y_true_stroke == 1).squeeze(-1)

    loss_player = F.cross_entropy(y_pred_player[mask], y_true_player[mask], reduction="sum") if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_player.device)
    loss_type = F.cross_entropy(y_pred_type[mask], y_true_type[mask], reduction="sum") if mask.sum() > 0 else torch.tensor(0.0, device=y_pred_type.device)

    default_role_weight = 1.0
    default_impact_weight = 1.0
    if mask.sum() > 0:
        class_counts_role = torch.bincount(y_true_role[mask].flatten(), minlength=3)
        role_weights = (class_counts_role.sum() / (class_counts_role + 1e-6)).to(y_pred_role.device)
        role_weights[class_counts_role == 0] = default_role_weight
        loss_role = F.cross_entropy(y_pred_role[mask], y_true_role[mask], weight=role_weights, reduction="sum")
    else:
        loss_role = torch.tensor(0.0, device=y_pred_role.device)

    if mask.sum() > 0:
        class_counts_impact = torch.bincount(y_true_impact[mask].flatten(), minlength=y_pred_impact.size(-1))
        impact_weights = (class_counts_impact.sum() / (class_counts_impact + 1e-6)).to(y_pred_impact.device)
        impact_weights[class_counts_impact == 0] = default_impact_weight

        loss_impact = F.cross_entropy(y_pred_impact[mask], y_true_impact[mask], weight=impact_weights, reduction="sum")
    else:
        loss_impact = torch.tensor(0.0, device=y_pred_impact.device)


    total_loss = (
        (1 / (2 * torch.exp(log_vars[0]))) * loss_stroke + log_vars[0] +
        (1 / (2 * torch.exp(log_vars[1]))) * loss_player + log_vars[1] +
        (1 / (2 * torch.exp(log_vars[2]))) * loss_type + log_vars[2] +
        (1 / (2 * torch.exp(log_vars[3]))) * loss_role + log_vars[3] +
        (1 / (2 * torch.exp(log_vars[4]))) * loss_impact + log_vars[4]
    )

    return total_loss


In [None]:
class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, num_heads, hidden_dim, num_layers, dropout=0.1,max_len=1024):
        super(TransformerClassifier, self).__init__()

        # Transformer encoder
        self.encoder = TransformerEncoder(
            input_dim=input_dim,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            max_len=max_len
        )

        # Task-specific heads
        self.stroke_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, 1)
            )

        # Player classification head (binary)
        self.player_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim) ,       # Output layer
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, 2)
        )

        # Type classification head (binary)
        self.type_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim) ,          # Output layer
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, 2),
        )

        # Role classification head (3-class)
        self.role_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim)  ,          # Output layer
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, 3)
        )

        # Impact classification head (4-class)
        self.impact_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, 4),
        )

    def forward(self, x):
        # Extract features using the transformer encoder
        features = self.encoder(x)

        # Task-specific predictions
        stroke_pred = self.stroke_head(features)
        player_pred = self.player_head(features)
        type_pred = self.type_head(features)
        role_pred = self.role_head(features)
        impact_pred = self.impact_head(features)

        return stroke_pred, player_pred, type_pred, role_pred, impact_pred


In [None]:
model = TransformerClassifier(
    input_dim=116,
    num_heads=32,
    hidden_dim=256,
    num_layers=3,
    dropout=0.2,
    max_len=256
)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")


Total trainable parameters: 2959756


In [None]:
log_vars = nn.Parameter(torch.zeros(5))
params = list(model.parameters()) + [log_vars]

optimizer = torch.optim.AdamW(params, lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
loss_fn = loss

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train(
    model=model,
    params=params,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=200,
    device=device,
    alpha= log_vars,
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)

In [None]:
wandb.finish()

In [None]:
model = TransformerClassifier(
    input_dim=116,
    num_heads=32,
    hidden_dim=256,
    num_layers=3,
    dropout=0.2,
    max_len=256
)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")


Total trainable parameters: 2959756


In [None]:
log_vars = nn.Parameter(torch.zeros(5))
params = list(model.parameters()) + [log_vars]

optimizer = torch.optim.AdamW(params, lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
loss_fn = multivariate_weighted_loss

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train(
    model=model,
    params=params,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=200,
    device=device,
    alpha= log_vars,
    scheduler=scheduler,
    overlap=overlap,
    chunk_length=chunk_length
)