# fNIRS Data Preprocessing

In [None]:
import re
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from openpyxl import load_workbook

class fNIRSDataset(Dataset):
    """
    Custom dataset for fNIRS HRF data from an Excel workbook.

    Each sample corresponds to one subject–event combination (one column from the headers)
    and returns (fNIRS_data, event_label).

    fNIRS_data is a tensor of shape [2, num_channels, target_length]:
      - 2: two signal types (0: HbO; 1: HbR)
      - num_channels: expected 24
      - target_length: number of time points after sliding-window processing

    Event label mapping:
      - If header has {"S", "F", "H"} (tertiary): {"S": 0, "F": 1, "H": 2}
      - Otherwise (binary, only {"F","H"}): {"F": 0, "H": 1}

    Sliding-window processing:
      Tertiary (has S):
         - S: already 4549 points;
         - F: valid length 4801 → windows [0:4549] and [252:252+4549] averaged → 4549 points;
         - H: valid length 6801 → windows [0:4549] and [2252:2252+4549] averaged.
      Binary (only F and H):
         - F: valid length 4801 → take first 4801 points;
         - H: valid length 6801 → windows [0:4801] and [2000:2000+4801] averaged.
    """
    def __init__(self, excel_file, split="train"):
        self.excel_file = excel_file
        wb = load_workbook(excel_file, read_only=True)
        all_sheet_names = wb.sheetnames

        # Group sheets by channel id using regex to extract "number,number"
        channel_dict = {}
        for sheet_name in all_sheet_names:
            name_lower = sheet_name.lower()
            measurement = None
            if "hbo" in name_lower:
                measurement = "HbO"
            elif "hbr" in name_lower:
                measurement = "HbR"
            else:
                continue
            match = re.search(r'(\d+,\d+)', sheet_name)
            if not match:
                continue
            channel_id = match.group(1)
            channel_dict.setdefault(channel_id, {})[measurement] = sheet_name

        valid_channels = {cid: sheets for cid, sheets in channel_dict.items()
                          if "HbO" in sheets and "HbR" in sheets}
        self.channels = sorted(valid_channels.keys())
        self.num_channels = len(self.channels)
        if self.num_channels == 0:
            raise ValueError("No valid channels found. Ensure sheet names contain 'HbO' or 'HbR' and a channel id in the format 'number,number'.")
        self.hb_sheets = [(valid_channels[cid]["HbO"], valid_channels[cid]["HbR"]) for cid in self.channels]

        # Extract header information from one HbO sheet.
        sample_sheet = wb[self.hb_sheets[0][0]]
        rows = list(sample_sheet.iter_rows(values_only=True))
        header1 = list(rows[0])  # event labels
        header2 = list(rows[1])  # subject identifiers (not used for classification)
        expected_events = {"S", "F", "H"}
        # If first cell not an expected event, assume timestamp column.
        if header1[0] not in expected_events:
            header1 = header1[1:]
            header2 = header2[1:]
            self.drop_first = True
        else:
            self.drop_first = False
        self.sample_headers = list(zip(header1, header2))

        # Split subjects into train (16 subjects) and test (4 subjects).
        all_subjects = sorted(list(set([subj for (_, subj) in self.sample_headers])))
        if len(all_subjects) != 20:
            print(f"Warning: Expected 20 subjects but found {len(all_subjects)}.")
        if split == "train":
            selected_subjects = all_subjects[:16]
        elif split == "test":
            selected_subjects = all_subjects[16:]
        elif split in ["all", "validation"]:
            selected_subjects = all_subjects
        else:
            raise ValueError("Invalid split type. Must be 'train' or 'test'.")
        self.samples_meta = []
        for col_idx, (event, subj) in enumerate(self.sample_headers):
            if subj in selected_subjects:
                self.samples_meta.append((event, subj, col_idx))
        self.sampling_points = len(rows) - 2  # maximum available time points in sheet
        wb.close()

        # Set event mapping and target length based on header content.
        if expected_events == {"S", "F", "H"}:
            self.event_map = {"S": 0, "F": 1, "H": 2}
            self.target_length = 4549
            self.offset_F = 252      # 4801 - 4549
            self.offset_H = 2252     # 6801 - 4549
        else:  # binary: only F and H
            self.event_map = {"F": 0, "H": 1}
            self.target_length = 4801
            self.offset_H = 2000     # 6801 - 4801

    def __len__(self):
        return len(self.samples_meta)

    def __getitem__(self, idx):
        event, subj, col_idx = self.samples_meta[idx]
        actual_col_idx = col_idx + 1 if self.drop_first else col_idx

        wb = load_workbook(self.excel_file, read_only=True)
        hbO_list = []
        hbR_list = []
        for hbO_sheet_name, hbR_sheet_name in self.hb_sheets:
            # Extract HbO data.
            sheet_hbO = wb[hbO_sheet_name]
            col_vals_hbO = []
            for row in sheet_hbO.iter_rows(min_row=3, values_only=True):
                val = row[actual_col_idx] if actual_col_idx < len(row) else 0
                if val is None:
                    val = 0
                col_vals_hbO.append(val * 1_000_000)
            hbO_list.append(np.array(col_vals_hbO, dtype=np.float32))
            # Extract HbR data.
            sheet_hbR = wb[hbR_sheet_name]
            col_vals_hbR = []
            for row in sheet_hbR.iter_rows(min_row=3, values_only=True):
                val = row[actual_col_idx] if actual_col_idx < len(row) else 0
                if val is None:
                    val = 0
                col_vals_hbR.append(val * 1_000_000)
            hbR_list.append(np.array(col_vals_hbR, dtype=np.float32))
        wb.close()

        HbO_data = np.stack(hbO_list, axis=0)  # [num_channels, T]
        HbR_data = np.stack(hbR_list, axis=0)  # [num_channels, T]
        fNIRS_data = np.stack([HbO_data, HbR_data], axis=0)  # [2, num_channels, T]
        fNIRS_data = torch.tensor(fNIRS_data, dtype=torch.float32)

        # Sliding-window processing based on whether header has "S" (tertiary) or not (binary).
        if "S" in self.event_map:  # tertiary: target_length = 4549
            if event == "S":
                new_data = fNIRS_data[:, :, :self.target_length]
            elif event == "F":
                window1 = fNIRS_data[:, :, :self.target_length]
                window2 = fNIRS_data[:, :, self.offset_F:self.offset_F+self.target_length]
                new_data = (window1 + window2) / 2.0
            elif event == "H":
                window1 = fNIRS_data[:, :, :self.target_length]
                window2 = fNIRS_data[:, :, self.offset_H:self.offset_H+self.target_length]
                new_data = (window1 + window2) / 2.0
            else:
                new_data = fNIRS_data[:, :, :self.target_length]
        else:  # binary: target_length = 4801
            if event == "F":
                new_data = fNIRS_data[:, :, :self.target_length]
            elif event == "H":
                window1 = fNIRS_data[:, :, :self.target_length]
                window2 = fNIRS_data[:, :, self.offset_H:self.offset_H+self.target_length]
                new_data = (window1 + window2) / 2.0
            else:
                new_data = fNIRS_data[:, :, :self.target_length]

        event_label = self.event_map.get(event, -1)
        return new_data, event_label

excel_path = '/content/Subjectwise Conc (GLM+no MA) with S.xlsx'
validation_excel_path = '/content/Subjectwise Validation Conc (GLM+no MA) with S.xlsx'
train_dataset = fNIRSDataset(excel_path, split="train")
val_dataset   = fNIRSDataset(validation_excel_path, split="all")
test_dataset  = fNIRSDataset(excel_path, split="test")

print("Train samples:", len(train_dataset))
print("Validation samples:", len(val_dataset))
print("Test samples:", len(test_dataset))

sample_data, event_lbl = train_dataset[0]
print("fNIRS data shape:", sample_data.shape)
print("Labels -> Event: {}".format(event_lbl))


Train samples: 48
Validation samples: 12
Test samples: 12


ValueError: not enough values to unpack (expected 5, got 2)

In [None]:
sample_data, event_lbl = train_dataset[0]
print("fNIRS data shape:", sample_data.shape)
print("Labels -> Event: {}".format(event_lbl))

fNIRS data shape: torch.Size([2, 24, 4549])
Labels -> Event: 0


In [None]:
from torch.utils.data import DataLoader, Dataset
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=3,shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=3,shuffle=False)
print("Done!")

Done!


# Shared Modules

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat
from torch import einsum
from einops.layers.torch import Rearrange

# Shared Modules (Residual, PreNorm, FeedForward, Attention, Transformer)

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, mask=None):
        b, n, _ = x.shape  # [B, seq_length, dim]
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value=True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
            dots = dots.masked_fill(~mask, mask_value)

        attn = dots.softmax(dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))
            ]))
    def forward(self, x, mask=None):
        for attn, ff in self.layers:
            x = attn(x, mask=mask)
            x = ff(x)
        return x


# Shared Backbone (Version 1)

## Version 1: Shared Backbone with Separate Heads

Concept
A single fNIRS transformer backbone (modified version of your fNIRS_T) is used to extract a shared latent representation from the raw input.
This latent representation is then passed into four separate classifier heads (each implemented using a transformer block followed by pooling and a linear classifier) for the different tasks (event, group, music, eye).

Pros
1. Efficiency: The backbone is computed only once for all tasks.
2. Shared Features: If tasks are related, the shared features may improve performance through transfer learning.
3. Parameter Savings: Fewer overall parameters than training four completely separate models.

Cons
1. Task Conflict: If the tasks require very different features, sharing may hurt performance.

In [None]:
# fNIRS Backbone: Modified fNIRS_T that returns latent features

class fNIRS_T_Backbone(nn.Module):
    """
    fNIRS-T backbone that outputs latent features (before final classification).
    Input shape: [B, 2, fNIRS_channels, sampling_point]
    """
    def __init__(self, sampling_point, dim, depth, heads, mlp_dim, pool='cls', dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        num_patches = 100
        num_channels = 100

        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=8, kernel_size=(5, 30), stride=(1, 4)),
            Rearrange('b c h w -> b h (c w)'),
            nn.Linear((math.floor((sampling_point-30)/4)+1)*8, dim),
            nn.LayerNorm(dim)
        )
        self.to_channel_embedding = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=8, kernel_size=(1, 30), stride=(1, 4)),
            Rearrange('b c h w -> b h (c w)'),
            nn.Linear((math.floor((sampling_point-30)/4)+1)*8, dim),
            nn.LayerNorm(dim)
        )
        self.pos_embedding_patch = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token_patch = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout_patch = nn.Dropout(emb_dropout)
        self.transformer_patch = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pos_embedding_channel = nn.Parameter(torch.randn(1, num_channels + 1, dim))
        self.cls_token_channel = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout_channel = nn.Dropout(emb_dropout)
        self.transformer_channel = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = nn.Identity()

    def forward(self, img, mask=None):
        x = self.to_patch_embedding(img)
        x2 = self.to_channel_embedding(img.squeeze())
        b, n, _ = x.shape
        cls_tokens_patch = repeat(self.cls_token_patch, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens_patch, x), dim=1)
        x += self.pos_embedding_patch[:, :(n + 1)]
        x = self.dropout_patch(x)
        x = self.transformer_patch(x, mask)

        b, n, _ = x2.shape
        cls_tokens_channel = repeat(self.cls_token_channel, '() n d -> b n d', b=b)
        x2 = torch.cat((cls_tokens_channel, x2), dim=1)
        x2 += self.pos_embedding_channel[:, :(n + 1)]
        x2 = self.dropout_channel(x2)
        x2 = self.transformer_channel(x2, mask)

        # Pooling: choose either CLS token or mean pooling.
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        x2 = x2.mean(dim=1) if self.pool == 'mean' else x2[:, 0]

        x = self.to_latent(x)
        x2 = self.to_latent(x2)
        latent = torch.cat((x, x2), 1)  # Shared latent space
        return latent

# TransformerClassifierHead (for each task)

class TransformerClassifierHead(nn.Module):
    def __init__(self, input_dim, num_classes, depth=2, heads=8, dim_head=64, mlp_dim=64, dropout=0.0):
        super().__init__()
        self.layers = nn.Sequential(*[
            Residual(PreNorm(input_dim, Attention(input_dim, heads, dim_head, dropout)))
            for _ in range(depth)
        ])
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, num_classes)
        )

    def forward(self, x, mask=None):
        # x - shape [B, seq_length, input_dim].
        x = self.layers(x, mask=mask)
        x = x.transpose(1, 2)
        x = self.pool(x).squeeze(-1)
        x = self.classifier(x)
        return x

# MultiHead fNIRS Classifier using a Shared Backbone

class MultiHeadfNIRSClassifier_Shared(nn.Module):
    """
    Multi-head classifier with a shared backbone and separate transformer classifier heads.
    """
    def __init__(self, sampling_point, dim, depth, heads, mlp_dim, emb_dropout=0., dropout=0., pool='cls'):
        super().__init__()
        self.backbone = fNIRS_T_Backbone(sampling_point, dim, depth, heads, mlp_dim, pool, dim_head=64, dropout=dropout, emb_dropout=emb_dropout)
        # The latent feature dimension is 2*dim due to concatenation.
        input_dim_out = dim * 2
        self.event_classifier = TransformerClassifierHead(input_dim_out, num_classes=5)
        self.group_classifier = TransformerClassifierHead(input_dim_out, num_classes=2)
        self.music_classifier = TransformerClassifierHead(input_dim_out, num_classes=2)
        self.eye_classifier = TransformerClassifierHead(input_dim_out, num_classes=2)

    def forward(self, img, mask=None):
        # will obtain the shared latent representation from the backbone.
        latent = self.backbone(img, mask)  # shape: [B, input_dim_out]
        # a sequence dimension for the heads, add one (here we use a sequence length of 1).
        latent_seq = latent.unsqueeze(1)  # shape: [B, 1, input_dim_out]
        result1 = self.event_classifier(latent_seq, mask)
        result2 = self.group_classifier(latent_seq, mask)
        result3 = self.music_classifier(latent_seq, mask)
        result4 = self.eye_classifier(latent_seq, mask)
        return result1, result2, result3, result4

## Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

num_epochs = 10
criterion = nn.CrossEntropyLoss()

model_shared = MultiHeadfNIRSClassifier_Shared(sampling_point, dim, depth, heads, mlp_dim, emb_dropout=emb_dropout, dropout=dropout, pool=pool)
optimizer_shared = optim.Adam(model_shared.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    model_shared.train()
    for img, label_event, label_group, label_music, label_eye in train_loader:
        optimizer_shared.zero_grad()
        output_event, output_group, output_music, output_eye = model_shared(img)
        loss_event = criterion(output_event, label_event)
        loss_group = criterion(output_group, label_group)
        loss_music = criterion(output_music, label_music)
        loss_eye = criterion(output_eye, label_eye)
        loss = loss_event + loss_group + loss_music + loss_eye
        loss.backward()
        optimizer_shared.step()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

# Independent Model - Version 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat
from torch import einsum
from einops.layers.torch import Rearrange
import torch.optim as optim

# fNIRS_T Model (Original version with integrated classification head)

class fNIRS_T(nn.Module):
    """
    fNIRS-T model for classification.
    Input shape: [B, 2, fNIRS_channels, sampling_point]
    """
    def __init__(self, n_class, fNIRS_channels, sampling_point, dim, depth, heads, mlp_dim, pool='cls', dim_head=32, dropout=0., emb_dropout=0.):
        super().__init__()
        # Patch embedding branch with updated convolution parameters:
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=8, kernel_size=(6, 100), stride=(2, 30)), # 30/25 = 1.25 pattern 1, 2.5 pattern 2
            Rearrange('b c h w -> b h (c w)'), # [B,2,24,4549] -> [B,8,10,149] -> [B,10,8*149] -> [B,10,1194]
            nn.Linear((math.floor((sampling_point - 100) / 30) + 1) * 8, dim), # [1194,dim=128]
            nn.LayerNorm(dim)
        )
        # Channel embedding branch remains similar to before:
        self.to_channel_embedding = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=8, kernel_size=(1, 30), stride=(1, 30)),
            Rearrange('b c h w -> b h (c w)'),
            nn.Linear((math.floor((sampling_point - 30) / 30) + 1) * 8, dim),
            nn.LayerNorm(dim)
        )
        # Dynamically compute number of patches from fNIRS_channels:
        num_patches = math.floor((fNIRS_channels - 6) / 2) + 1
        self.pos_embedding_patch = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token_patch = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout_patch = nn.Dropout(emb_dropout)
        self.transformer_patch = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # For channel embedding, the number of channels is fNIRS_channels:
        num_channels = fNIRS_channels
        self.pos_embedding_channel = nn.Parameter(torch.randn(1, num_channels + 1, dim))
        self.cls_token_channel = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout_channel = nn.Dropout(emb_dropout)
        self.transformer_channel = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim * 2),
            nn.Linear(dim * 2, n_class)
        )

    def forward(self, img, mask=None):
        x = self.to_patch_embedding(img)
        x2 = self.to_channel_embedding(img.squeeze())

        b, n, _ = x.shape
        cls_tokens_patch = repeat(self.cls_token_patch, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens_patch, x), dim=1)
        x += self.pos_embedding_patch[:, :(n + 1)]
        x = self.dropout_patch(x)
        x = self.transformer_patch(x, mask)

        b, n, _ = x2.shape
        cls_tokens_channel = repeat(self.cls_token_channel, '() n d -> b n d', b=b)
        x2 = torch.cat((cls_tokens_channel, x2), dim=1)
        x2 += self.pos_embedding_channel[:, :(n + 1)]
        x2 = self.dropout_channel(x2)
        x2 = self.transformer_channel(x2, mask)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        x2 = x2.mean(dim=1) if self.pool == 'mean' else x2[:, 0]

        x = self.to_latent(x)
        x2 = self.to_latent(x2)
        x3 = torch.cat((x, x2), 1)
        return self.mlp_head(x3)


## Training

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Assume fNIRS_T and fNIRSClassifier are defined as in our previous code.
class fNIRSClassifier(nn.Module):
    def __init__(self, n_class, fNIRS_channels, sampling_point, dim, depth, heads, mlp_dim,
                 dim_head=32, dropout=0., emb_dropout=0., pool='cls'):
        super().__init__()
        self.model = fNIRS_T(n_class=n_class,
                             fNIRS_channels=fNIRS_channels,
                             sampling_point=sampling_point,
                             dim=dim,
                             depth=depth,
                             heads=heads,
                             mlp_dim=mlp_dim,
                             pool=pool,
                             dim_head=dim_head,
                             dropout=dropout,
                             emb_dropout=emb_dropout)
    def forward(self, x, mask=None):
        return self.model(x, mask)

# Hyperparameters
num_epochs = 10
learning_rate = 1e-6
criterion = nn.CrossEntropyLoss()
sampling_point = 4549
dim = 128
depth = 3
heads = 4
dim_head = 32
mlp_dim = 128
dropout = 0.0
emb_dropout = 0.0
pool = 'cls'
fNIRS_channels = 24
n_class = 3  # or 2 based on your dataset

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model and move to device.
model = fNIRSClassifier(n_class=n_class,
                        fNIRS_channels=fNIRS_channels,
                        sampling_point=sampling_point,
                        dim=dim,
                        depth=depth,
                        heads=heads,
                        mlp_dim=mlp_dim,
                        dim_head=dim_head,
                        dropout=dropout,
                        emb_dropout=emb_dropout,
                        pool=pool)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

checkpoint_dir = "/content/event_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Training loop with validation evaluation.
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for img, label in train_loader:
        img = img.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        outputs = model(img)
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_loss:.4f}")

    # Save checkpoint.
    checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, checkpoint_path)
    print(f"Saved checkpoint: {checkpoint_path}")

    # Evaluate on validation set.
    model.eval()
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for img, label in val_loader:
            img = img.to(device)
            label = label.to(device)
            outputs = model(img)
            _, predicted = torch.max(outputs, 1)
            total_val += label.size(0)
            correct_val += (predicted == label).sum().item()
    val_accuracy = 100.0 * correct_val / total_val
    print(f"Validation Accuracy: {val_accuracy:.2f}%")

# Final evaluation on test set.
model.eval()
correct_test = 0
total_test = 0
with torch.no_grad():
    for img, label in test_loader:
        img = img.to(device)
        label = label.to(device)
        outputs = model(img)
        _, predicted = torch.max(outputs, 1)
        total_test += label.size(0)
        correct_test += (predicted == label).sum().item()
test_accuracy = 100.0 * correct_test / total_test
print(f"Test Accuracy: {test_accuracy:.2f}%")

Epoch 1/10, Training Loss: 1.1521
Saved checkpoint: /content/event_checkpoints/model_epoch_1.pth
Validation Accuracy: 33.33%


# Validation

In [None]:
for batch in val_dataloader:
    inputs, labels = batch
    print(inputs[0][0][0][0].shape)
    print(inputs[0][0][0][0])
    print(inputs[0][0][0].shape) # the first index represents subject-event pair, 2nd index represents HbO or HbR
    print(inputs[0][0][0])
    print(inputs[0][0].shape)
    print(inputs[0][0])
    print(inputs[0].shape)
    print(inputs[0])
    print(inputs.shape)
    print(inputs)
    break

# 2*24*6801*60
# 19,586,880
'''
NOT USING ANYMORE
Each event is of 270 sec duration (after padding event S and event F)
F -> 270 (same padding technique as S), S->270 (179.92 secs original data other time segments consists of 0's) and H->270
And sampling freq is 25Hz i.e data is captured once every 0.04 secs
t1 (Onset), duration = 270, buffer_time=2 secs, [t1-2,t1+270)+1->data to extracted
# sampling points = 272*25+1 = 6800 + 1 = 6801
'''
# 3 in position, black & white 1 in 2nd position
# [32,3,64,64] -> looking at 32 images at once, 3 is the channels (Red, green or Blue)
# [3,64,64] -> only image, 4096 pixels from all 3 colors or channels (single image)
# [64,64] -> grid view, 4096 pixels of an image

# [4,2,24,6801] -> we are looking at 4 subject-event pair data at once
# [2,24,6801] -> we are looking at a subject-event pair data at once (HbO,HbR)
# [24,6801] -> inside HbO or HbR (all 24 channels ki time series)
# [6801] -> actual time series data of a single channel

Total samples (subject-event combinations): 12
Sampling points (from first channel sheet): 6801
Number of channels (HbO/HbR pairs): 24
fNIRS data shape: torch.Size([2, 24, 6801])
Event label (mapped): 0
torch.Size([6801])
tensor([0.0458, 0.0564, 0.0669,  ..., 0.0000, 0.0000, 0.0000])
torch.Size([24, 6801])
tensor([[ 0.0458,  0.0564,  0.0669,  ...,  0.0000,  0.0000,  0.0000],
        [-0.6289, -0.6088, -0.5889,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.2758,  0.2937,  0.3117,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.1112,  1.1138,  1.1169,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.0854,  1.1100,  1.1351,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.9804,  2.9705,  2.9626,  ...,  0.0000,  0.0000,  0.0000]])
torch.Size([2, 24, 6801])
tensor([[[ 0.0458,  0.0564,  0.0669,  ...,  0.0000,  0.0000,  0.0000],
         [-0.6289, -0.6088, -0.5889,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.2758,  0.2937,  0.3117,  ...,  0.0000,  0.0000,  0.0000],
         ...,
     

"\nEach event is of 270 sec duration (after padding event S and event F)\nF -> 270 (same padding technique as S), S->270 (179.92 secs original data other time segments consists of 0's) and H->270\nAnd sampling freq is 25Hz i.e data is captured once every 0.04 secs\nt1 (Onset), duration = 270, buffer_time=2 secs, [t1-2,t1+270)+1->data to extracted \n# sampling points = 270*25+1 = 6800 + 1 = 6801\n"

In [None]:
# -----------------------------
# Validation Code
# -----------------------------
# Assuming you have a separate validation dataloader called validation_dataloader
import os
checkpoint_dir = "/content/Checkpoints"
n_class = 3
sampling_point = 6801
dim = 64
depth = 2
heads = 8
dim_head = 64
mlp_dim = 64
dropout = 0.0
emb_dropout = 0.0
pool = 'cls'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# For example, load the checkpoint from the last epoch:
checkpoint_to_load = os.path.join(checkpoint_dir, f"model_epoch_{1}.pth")
model_val = fNIRSClassifier(n_class=n_class, sampling_point=sampling_point, dim=dim,
                            depth=depth, heads=heads, mlp_dim=mlp_dim,
                            dim_head=dim_head, dropout=dropout, emb_dropout=emb_dropout, pool=pool)
model_val = model_val.to(device)
model_val.load_state_dict(torch.load(checkpoint_to_load))
model_val.eval()

correct = 0
total = 0

with torch.no_grad():
    for img, label in validation_dataloader:
        img = img.to(device)
        label = label.to(device)
        outputs = model_val(img)
        _, predicted = torch.max(outputs, 1)
        print("Predicted labels:", predicted.cpu().tolist())
        print("True labels:", label.cpu().tolist())
        total += label.size(0)
        correct += (predicted == label).sum().item()

accuracy = 100 * correct / total
print(f"Validation Accuracy: {accuracy:.2f}%")

Predicted labels: [2, 2, 1, 2]
True labels: [0, 1, 1, 1]
Predicted labels: [2, 2, 2, 2]
True labels: [0, 0, 2, 2]
Predicted labels: [1, 2, 2, 2]
True labels: [2, 1, 2, 0]
Validation Accuracy: 33.33%
