In [None]:
import math
import torch
import torch.nn as nn

def get_positional_encoding(seq_len: int, d_model: int, device=None) -> torch.Tensor:
    """
    Sinusoidal positional encoding (Vaswani et al.).
    Returns a tensor shape [seq_len, d_model].
    """
    pe = torch.zeros(seq_len, d_model, device=device)
    position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1)  # [L,1]
    div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float, device=device) * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    if d_model % 2 == 1:
        # odd-d_model: handle the last column safely
        pe[:, 1::2] = torch.cos(position * div_term[:-1])
    else:
        pe[:, 1::2] = torch.cos(position * div_term)
    return pe  # [L, d_model]


class MultiHeadAttention(nn.Module):
    """Lightweight attention mechanism for feature fusion"""
    def __init__(self, d_model, n_heads=4, dropout=0.1):
        super().__init__()
        self.orig_d = d_model
        self.n_heads = n_heads

        # compute padded dimension (ceiling to nearest multiple of n_heads)
        pad_to = math.ceil(d_model / n_heads) * n_heads
        self.pad_extra = pad_to - d_model

        # project up if padding needed, else identity
        self.pre_proj = nn.Linear(d_model, pad_to) if pad_to != d_model else nn.Identity()
        self.d_model = pad_to
        self.d_k = pad_to // n_heads

        # attention projections
        self.w_q = nn.Linear(pad_to, pad_to, bias=False)
        self.w_k = nn.Linear(pad_to, pad_to, bias=False)
        self.w_v = nn.Linear(pad_to, pad_to, bias=False)
        self.w_o = nn.Linear(pad_to, pad_to)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        x: [B, L, orig_d]
        returns: [B, L, orig_d]
        """
        x = self.pre_proj(x)  # [B, L, pad_to]
        batch_size, seq_len, _ = x.size()

        # project and reshape for multi-head
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, V)  # [B, heads, L, d_k]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)  # [B, L, pad_to]

        out = self.w_o(context)  # [B, L, pad_to]

        # slice off padding if added
        if self.pad_extra:
            out = out[..., : self.orig_d]

        return out  # [B, L, orig_d]


class ResidualConvBlock(nn.Module):
    """Conv1D with residual + squeeze-excitation + batchnorm """
    def __init__(self, in_ch, out_ch, kernel, dropout, se_ratio=0.25):
        super().__init__()
        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel, padding=kernel // 2)
        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel, padding=kernel // 2)
        self.norm1 = nn.BatchNorm1d(out_ch)
        self.norm2 = nn.BatchNorm1d(out_ch)
        self.act = nn.GELU()
        self.drop = nn.Dropout(dropout)

        # Squeeze-and-Excitation
        se_ch = max(1, int(out_ch * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(out_ch, se_ch, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(se_ch, out_ch, kernel_size=1),
            nn.Sigmoid()
        )

        # Residual connection (project if channels differ)
        self.res = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.res_norm = nn.BatchNorm1d(out_ch) if in_ch != out_ch else nn.Identity()

    def forward(self, x):
        # x: [B, C, L]
        residual = self.res(x)
        if hasattr(self.res_norm, "weight"):
            residual = self.res_norm(residual)

        out = self.conv1(x)
        out = self.norm1(out)
        out = self.act(out)
        out = self.drop(out)

        out = self.conv2(out)
        out = self.norm2(out)

        se_weight = self.se(out)  # [B, C_out, 1]
        out = out * se_weight

        out = self.act(out + residual)
        return out


class ModalitySpecificEncoder(nn.Module):
    """Encoder for the EEG channels per timestep"""
    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: [B, L, input_dim]
        return self.encoder(x)  # [B, L, hidden_dim]


class EEGRegressor(nn.Module):
    """
    Single-task EEG regressor that accepts EEG with 129 channels and sequence length 200.
    Returns a single scalar per sample (regression).

    Default hyperparameters are provided but can be overridden.
    """
    def __init__(
        self,
        eeg_channels: int = 129,
        enc_hidden_dim: int = 64,
        dropout_enc: float = 0.1,
        num_features: int = 128,
        n_heads: int = 4,
        dropout_attn: float = 0.1,
        lstm_hidden_dim: int = 64,
        lstm_layers: int = 2,
        bidirectional: bool = True,
        dropout_lstm: float = 0.2,
        cnn_filter_sizes: list = (64, 128, 256),
        cnn_kernel_sizes: list = (3, 3, 3),
        dropout_cnn: float = 0.2,
        se_ratio: float = 0.25,
        dropout_classifier: float = 0.3,
        use_positional: bool = True,
    ):
        super().__init__()

        # Positional Encoding flag
        self.use_positional = use_positional

        # --- Encoder for EEG (per-timestep channel encoding) ---
        self.encoder = ModalitySpecificEncoder(eeg_channels, enc_hidden_dim, dropout_enc)

        # Fusion (only one modality => enc_hidden_dim)
        fusion_dim = enc_hidden_dim
        self.fusion_proj = nn.Linear(fusion_dim, num_features)
        self.fusion_norm = nn.LayerNorm(num_features)

        # Attention for temporal modelling
        self.attention = MultiHeadAttention(num_features, n_heads=n_heads, dropout=dropout_attn)
        # Transformer-style post-attention normalization + dropout
        self.attn_dropout = nn.Dropout(dropout_attn)
        self.attn_norm = nn.LayerNorm(num_features)

        # LSTM branch
        lstm_input_size = num_features
        self.lstm = nn.LSTM(
            lstm_input_size,
            lstm_hidden_dim,
            num_layers=lstm_layers,
            dropout=dropout_lstm if lstm_layers > 1 else 0.0,
            bidirectional=bidirectional,
            batch_first=True,
        )
        lstm_out_dim = lstm_hidden_dim * (2 if bidirectional else 1)

        # CNN branch (Conv1d expects channels-first: [B, C, L])
        assert len(cnn_filter_sizes) == 3 and len(cnn_kernel_sizes) == 3, "Need 3 filter sizes and 3 kernel sizes"
        fs = list(cnn_filter_sizes)
        ks = list(cnn_kernel_sizes)
        self.conv_blocks = nn.ModuleList([
            ResidualConvBlock(num_features, fs[0], ks[0], dropout_cnn, se_ratio),
            ResidualConvBlock(fs[0], fs[1], ks[1], dropout_cnn, se_ratio),
            ResidualConvBlock(fs[1], fs[2], ks[2], dropout_cnn, se_ratio),
        ])

        # Pooling
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)

        # Classifier (regression head)
        conv_feat_dim = fs[-1] * 2
        final_dim = lstm_out_dim + conv_feat_dim
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_classifier),
            nn.Linear(final_dim, max(16, final_dim // 2)),
            nn.LayerNorm(max(16, final_dim // 2)),
            nn.GELU(),
            nn.Dropout(dropout_classifier * 0.5),
            nn.Linear(max(16, final_dim // 2), 1),  # scalar regressor
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Conv1d):
            torch.nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm)):
            if hasattr(module, "weight"):
                torch.nn.init.ones_(module.weight)
            if hasattr(module, "bias"):
                torch.nn.init.zeros_(module.bias)

    def forward(self, x):
        """
        x: [B, L, 129]  (channels last per timestep)
        returns: [B]    (scalar regression per sample)
        """
        B, L, C = x.shape
        # encode per-timestep channels
        enc = self.encoder(x)  # [B, L, enc_hidden_dim]

        # fusion proj (single modality)
        fused = self.fusion_proj(enc)  # [B, L, num_features]
        fused = self.fusion_norm(fused)

        # positional encoding (sinusoidal)
        if self.use_positional:
            pos = get_positional_encoding(L, fused.size(-1), device=fused.device)  # [L, D]
            fused = fused + pos.unsqueeze(0)  # broadcast over batch

        # self-attention
        attn_out = self.attention(fused)
        attn_out = self.attn_dropout(attn_out)
        fused = self.attn_norm(fused + attn_out)  # residual

        # LSTM branch
        lstm_out, (h_n, _) = self.lstm(fused)  # lstm_out: [B, L, hidden*dirs]
        if self.lstm.bidirectional:
            h_forward = h_n[-2]
            h_backward = h_n[-1]
            h_final = torch.cat([h_forward, h_backward], dim=1)  # [B, lstm_hidden*2]
        else:
            h_final = h_n[-1]  # [B, lstm_hidden]

        # CNN branch
        x_conv = fused.permute(0, 2, 1)  # [B, num_features, L]
        for block in self.conv_blocks:
            x_conv = block(x_conv)  # [B, ch, L]

        conv_avg = self.global_avg_pool(x_conv).squeeze(-1)  # [B, fs[-1]]
        conv_max = self.global_max_pool(x_conv).squeeze(-1)  # [B, fs[-1]]
        conv_feat = torch.cat([conv_avg, conv_max], dim=1)  # [B, fs[-1]*2]

        # final concatenation + regression head
        final_feat = torch.cat([h_final, conv_feat], dim=1)  # [B, final_dim]
        out = self.classifier(final_feat).squeeze(-1)  # [B]

        return out


In [None]:
# Initialize model
L = 200   # time length
C = 129   # channels
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EEGRegressor(
        eeg_channels=C,
        enc_hidden_dim=64,
        num_features=128,
        n_heads=4,
        lstm_hidden_dim=64,
        lstm_layers=2,
        bidirectional=True,
        cnn_filter_sizes=(64, 128, 256),
        cnn_kernel_sizes=(3, 3, 3),
    ).to(device)

# Specify optimizer and criterion
optimizer = optim.Adamax(params=model.parameters(), lr=1e-3)
criterion = nn.L1Loss()

print("Number of total parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))


In [None]:
# Create PyTorch Dataloader
batch_size = 32
num_workers = 1 if device.type == "cpu" else 2
pin_memory = True if device.type == "cuda" else False

import matplotlib.pyplot as plt

# optional: for notebook live update
try:
    from IPython.display import clear_output
    _have_ipy = True
except Exception:
    _have_ipy = False

dataloader = DataLoader(
    windows_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

dataloader_val = DataLoader(
    windows_ds_val,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

def ensure_channels_last(X: torch.Tensor, model) -> torch.Tensor:
    """
    Ensure X has shape [B, L, C] where C == encoder input dim (usually 129).
    If input is [B, C, L], it permutes to [B, L, C].
    """
    if X.dim() != 3:
        raise ValueError(f"Expected X dim==3 (B,L,C). Got {X.dim()} dims: {list(X.shape)}")
    expected_in = model.encoder.encoder[0].in_features  # 129 di default
    # case 1: already [B, L, C]
    if X.size(-1) == expected_in:
        return X
    # case 2: [B, C, L] -> permute
    if X.size(1) == expected_in:
        return X.permute(0, 2, 1).contiguous()
    # otherwise print helpful info and raise
    raise ValueError(f"Can't reconcile input shape {list(X.shape)} with model expected channels {expected_in}."
                     " If your dataset returns (C,L) per sample, permute to (L,C).")


# Training loop parameters
n_epochs = 100
grad_clip_max_norm = 5.0  # utile per stabilizzare l'addestramento

train_epoch_losses = []
val_epoch_losses = []

epoch_losses = []   # store avg loss per epoch
all_batch_losses = []  # optional: store all batch losses if you want more granular plot

# --- Training loop ---
for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    batch_count = 0

    for idx, batch in enumerate(dataloader):
        # Reset gradients
        optimizer.zero_grad()

        # puoi adattare se la struttura del tuo batch è diversa
        X, y, *rest = batch   # rest contiene crop_inds, infos o simili

        # Move to device & correct dtype
        X = X.to(dtype=torch.float32, device=device)
        y = y.to(dtype=torch.float32, device=device)

        # ensure X shape is [B, L, C]
        try:
            X = ensure_channels_last(X, model)
        except Exception as e:
            # utili informazioni di debug: stampa forme e alza l'eccezione
            print("=== SHAPE DEBUG ===")
            print("X.shape:", list(X.shape))
            print("model expected channels (encoder input):", model.encoder.encoder[0].in_features)
            raise e

        # Normalize target shape -> [B]
        if y.dim() > 1 and y.size(-1) == 1:
            y = y.view(-1)
        else:
            # se ha shape [B,1,1,...] o simili
            y = y.squeeze()
        # ATTENZIONE: se y è (B,), squeeze non cambia nulla

        # Forward pass
        y_pred = model(X)   # modello restituisce [B]

        # ensure prediction shape is [B]
        if y_pred.dim() > 1 and y_pred.size(-1) == 1:
            y_pred = y_pred.view(-1)
        else:
            y_pred = y_pred.squeeze()

        # Compute loss
        loss = criterion(y_pred, y)
        loss.backward()
        
        # gradient clipping (opzionale ma consigliato)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_max_norm)

        optimizer.step()

        running_loss += loss.item()
        batch_count += 1

        if idx % 50 == 0:
            print(f"Epoch {epoch} - step {idx}, loss: {loss.item():.6f}")

    avg_train_loss = running_loss / max(1, batch_count)
    train_epoch_losses.append(avg_train_loss)

    # ---- validation ----
    model.eval()
    val_running = 0.0
    val_batches = 0

    with torch.no_grad():
        for idx, batch in enumerate(dataloader_val):
            X, y, *rest = batch
            X = X.to(device=device, dtype=torch.float32)
            y = y.to(device=device, dtype=torch.float32)

            # ensure correct input layout
            X = ensure_channels_last(X, model)

            # normalize y shape -> [B]
            if y.dim() > 1 and y.size(-1) == 1:
                y = y.view(-1)
            else:
                y = y.squeeze()

            y_pred = model(X)
            if y_pred.dim() > 1 and y_pred.size(-1) == 1:
                y_pred = y_pred.view(-1)
            else:
                y_pred = y_pred.squeeze()

            loss_val = criterion(y_pred, y)
            val_running += loss_val.item()
            val_batches += 1

            if idx % 50 == 0:
                print(f"[Val]   Epoch {epoch} - step {idx}, batch loss: {loss_val.item():.6f}")
    
    avg_val_loss = val_running / max(1, val_batches)
    val_epoch_losses.append(avg_val_loss)


    print(f"Epoch {epoch} finished. Train avg loss: {avg_train_loss:.6f} | Val avg loss: {avg_val_loss:.6f}")

    # ---- plot train & val losses so far ----
    if _have_ipy:
        clear_output(wait=True)

    plt.figure(figsize=(7, 4))
    epochs = range(1, len(train_epoch_losses) + 1)
    plt.plot(epochs, train_epoch_losses, marker='o', linestyle='-', label='Train')   # default colors
    plt.plot(epochs, val_epoch_losses, marker='o', linestyle='-', label='Validation')
    plt.xlabel("Epoch")
    plt.ylabel("Avg L1 loss")
    plt.title("Training and Validation loss per epoch")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


# Finally, we can save the model for later use
torch.save(model.state_dict(), "model_weights_challenge_2.pt")

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
import csv
import os

# --- dataloader_test (use your provided definition) ---
dataloader_test = DataLoader(
    windows_ds_test,
    batch_size=batch_size,
    shuffle=True,          # recommended: set to False for deterministic evaluation
    num_workers=num_workers,
    pin_memory=pin_memory,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.L1Loss()

# Ensure model is on device
model.to(device)
model.eval()

batch_losses = []
batch_indices = []

with torch.no_grad():
    for idx, batch in enumerate(dataloader_test):
        # Unpack batch
        X, y, *rest = batch

        # Move to device & float
        X = X.to(device=device, dtype=torch.float32)
        y = y.to(device=device, dtype=torch.float32)

        # Ensure shape [B, L, C]
        X = ensure_channels_last(X, model)  # reuse from training script

        # Normalize target shape [B]
        if y.dim() > 1 and y.size(-1) == 1:
            y = y.view(-1)
        else:
            y = y.squeeze()

        # Model forward
        y_pred = model(X)
        if y_pred.dim() > 1 and y_pred.size(-1) == 1:
            y_pred = y_pred.view(-1)
        else:
            y_pred = y_pred.squeeze()

        # Loss
        loss = criterion(y_pred, y)
        loss_val = loss.item()

        batch_losses.append(loss_val)
        batch_indices.append(idx)

        # print progress (every 10 batches)
        if idx % 10 == 0:
            print(f"Test step {idx}, batch loss: {loss_val:.6f}")

# Summary
num_batches = len(batch_losses)
avg_test_loss = float(sum(batch_losses) / max(1, num_batches))
print(f"Test finished. Num batches: {num_batches}. Avg L1 loss: {avg_test_loss:.6f}")

# Plot batch losses
plt.figure(figsize=(8, 4))
plt.plot(batch_indices, batch_losses, marker='o', linestyle='-')  # single plot, default colors
plt.axhline(avg_test_loss, color='k', linewidth=1.0, linestyle='--', label=f"Avg loss = {avg_test_loss:.4f}")
plt.xlabel("Batch index")
plt.ylabel("L1 loss")
plt.title("Test loss per batch")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Save figure
out_fig = "test_loss.png"
plt.savefig(out_fig)
print(f"Saved plot to {out_fig}")

# Optionally save per-batch losses to CSV for further analysis
out_csv = "test_batch_losses.csv"
with open(out_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["batch_idx", "loss"])
    for bi, loss_val in zip(batch_indices, batch_losses):
        writer.writerow([bi, loss_val])
print(f"Saved batch losses to {out_csv}")
