# 3D Neural Network for fMRI Data Classification

## Data

### Dataset

In [11]:
import pandas as pd
import nibabel as nib
import numpy as np
from typing import Tuple

import torch
from torch.utils.data import Dataset
from torch.nn.functional import one_hot

class fMRIDataset(Dataset):
    def __init__(
        self, 
        data_path: str, 
        selected_subjects: list = None,
        selected_runs: list = None,
        selected_classes: list = None,
        encode_labels: bool = True,
        exclude_end: bool = True,
        exclude_blank: bool = False,
        exclude_scrambled: bool = False,
        normalize_data: bool = True,
        time_first: bool = True
    ):
        self.selected_subjects = selected_subjects
        self.selected_runs = selected_runs
        self.selected_classes = selected_classes
        self.encode_labels = encode_labels
        self.exclude_end = exclude_end
        self.exclude_blank = exclude_blank
        self.exclude_scrambled = exclude_scrambled
        self.normalize_data = normalize_data
        self.time_first = time_first
        self.data_info, self.class_ids, self.subject_ids, self.run_ids = self.__load_data(data_path)
    
    def __len__(self) -> int:
        return len(self.data_info.index)

    def __getitem__(self, index: int) -> int:
        try:
            data_record = self.data_info.iloc[index]
        except:
            print(index)
            raise

        fmri_data = nib.load(
            data_record['ext_frmi_pths']
        )
        fmri_data = torch.from_numpy(
            fmri_data.get_fdata()
        ).float()
        if not self.time_first:
            fmri_data = fmri_data.permute(2, 3, 0, 1).contiguous()
        else:
            fmri_data = fmri_data.permute(3, 0, 1, 2).contiguous()

        fmri_label = np.array(data_record['trial_ids'])
        fmri_label = torch.from_numpy(fmri_label).float()
        if self.encode_labels:
            fmri_label = one_hot(fmri_label.long(), num_classes=len(self.class_ids)).float()

        if self.normalize_data:
            fmri_data = (fmri_data - fmri_data.mean()) / fmri_data.std()

        fmri_label = fmri_label.repeat(fmri_data.shape[0], 1)

        return fmri_data, fmri_label

    def __load_data(self, data_path: str):
        data = pd.read_csv(data_path, sep=';')

        if self.exclude_end:
            data = data.loc[data['trial_types'] != 'end']
            data = data.reset_index(drop=True)
        if self.exclude_blank:
            data = data.loc[data['trial_types'] != 'blank']
            data = data.reset_index(drop=True)
        if self.exclude_scrambled:
            data = data.loc[data['trial_types'] != 'scrambled']
            data = data.reset_index(drop=True)

        class_ids = np.unique(data['trial_ids']).astype(np.longlong)

        if self.selected_subjects is not None:
            data = data.loc[data['subjects'].isin(self.selected_subjects)]
            data = data.reset_index(drop=True)
        if self.selected_runs is not None:
            data = data.loc[data['runs'].isin(self.selected_runs)]
            data = data.reset_index(drop=True)
        if self.selected_classes is not None:
            if isinstance(self.selected_classes[0], str):
                data = data.loc[data['trial_types'].isin(self.selected_classes)]
                data = data.reset_index(drop=True)
            elif isinstance(self.selected_classes[0], int):
                data = data.loc[data['trial_ids'].isin(self.selected_classes)]
                data = data.reset_index(drop=True)
            else:
                raise AttributeError('Not supported class types! Only a list of str or int values can be given!')
        
        subject_ids = np.unique(data['subjects'])
        run_ids = np.unique(data['runs'])

        return data, class_ids, subject_ids, run_ids

### Dataset Splitter

In [12]:
import numpy as np
from torch.utils.data import Subset

def split_data_set(
    dataset: fMRIDataset, 
    train_test_ratio: float = 0.8, 
    val_ratio: float = 0.1, 
    shuffle: bool = False, 
    seed: int = 42,
):
    train_indices = []
    test_indices = []
    val_indices = []

    if shuffle:
        rnd_state = np.random.RandomState(seed=seed)

    didf = dataset.data_info
    for class_id in dataset.class_ids:
        for subject_id in dataset.subject_ids:
            selection = didf.query(f'trial_ids == {class_id} & subjects == {subject_id}').index.to_numpy()
            if shuffle:  
                rnd_state.shuffle(selection)
            
            train_indices.append(
                selection[:int(len(selection) * (train_test_ratio - val_ratio))]
            )
            test_indices.append(
                selection[int(len(selection) * (train_test_ratio)):]
            )
            if val_ratio > 0:
                val_indices.append(
                    selection[int(len(selection) * (train_test_ratio - val_ratio)) : int(len(selection) * (train_test_ratio))]
                )

    train_indices = np.concatenate(train_indices, axis=-1)
    test_indices = np.concatenate(test_indices, axis=-1)
    if val_ratio > 0:
        val_indices = np.concatenate(val_indices, axis=-1)
    
    train_set = Subset(dataset, train_indices)
    test_set = Subset(dataset, test_indices)
    if val_ratio > 0:
        val_set = Subset(dataset, val_indices)
        return train_set, test_set, val_set
    else:
        return train_set, test_set

## Model

### Architecture

#### EfficientNet

In [13]:
# INPUT_CHANS = 79
INPUT_CHANS = 10

In [14]:
import copy
import math
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch import nn, Tensor
from torchvision.ops import StochasticDepth

from torchvision.ops.misc import Conv3dNormActivation
from torchvision.transforms._presets import ImageClassification, InterpolationMode


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class SqueezeExcitation3D(torch.nn.Module):
    """
    This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
    Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.

    Args:
        input_channels (int): Number of channels in the input image
        squeeze_channels (int): Number of squeeze channels
        activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
        scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
    """

    def __init__(
        self,
        input_channels: int,
        squeeze_channels: int,
        activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
        scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
    ) -> None:
        super().__init__()
        self.avgpool = torch.nn.AdaptiveAvgPool3d(1)
        self.fc1 = torch.nn.Conv3d(input_channels, squeeze_channels, 1)
        self.fc2 = torch.nn.Conv3d(squeeze_channels, input_channels, 1)
        self.activation = activation()
        self.scale_activation = scale_activation()

    def _scale(self, input: Tensor) -> Tensor:
        scale = self.avgpool(input)
        scale = self.fc1(scale)
        scale = self.activation(scale)
        scale = self.fc2(scale)
        return self.scale_activation(scale)

    def forward(self, input: Tensor) -> Tensor:
        scale = self._scale(input)
        return scale * input


@dataclass
class _MBConvConfig:
    expand_ratio: float
    kernel: int
    stride: int
    input_channels: int
    out_channels: int
    num_layers: int
    block: Callable[..., nn.Module]

    @staticmethod
    def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
        return _make_divisible(channels * width_mult, 8, min_value)


class MBConvConfig(_MBConvConfig):
    # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
    def __init__(
        self,
        expand_ratio: float,
        kernel: int,
        stride: int,
        input_channels: int,
        out_channels: int,
        num_layers: int,
        width_mult: float = 1.0,
        depth_mult: float = 1.0,
        block: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        input_channels = self.adjust_channels(input_channels, width_mult)
        out_channels = self.adjust_channels(out_channels, width_mult)
        num_layers = self.adjust_depth(num_layers, depth_mult)
        if block is None:
            block = MBConv
        super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)

    @staticmethod
    def adjust_depth(num_layers: int, depth_mult: float):
        return int(math.ceil(num_layers * depth_mult))


class FusedMBConvConfig(_MBConvConfig):
    # Stores information listed at Table 4 of the EfficientNetV2 paper
    def __init__(
        self,
        expand_ratio: float,
        kernel: int,
        stride: int,
        input_channels: int,
        out_channels: int,
        num_layers: int,
        block: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        if block is None:
            block = FusedMBConv
        super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)


class MBConv(nn.Module):
    def __init__(
        self,
        cnf: MBConvConfig,
        stochastic_depth_prob: float,
        norm_layer: Callable[..., nn.Module],
        se_layer: Callable[..., nn.Module] = SqueezeExcitation3D,
    ) -> None:
        super().__init__()

        if not (1 <= cnf.stride <= 2):
            raise ValueError("illegal stride value")

        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

        layers: List[nn.Module] = []
        activation_layer = nn.SiLU

        # expand
        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
        if expanded_channels != cnf.input_channels:
            layers.append(
                Conv3dNormActivation(
                    cnf.input_channels,
                    expanded_channels,
                    kernel_size=1,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )

        # depthwise
        layers.append(
            Conv3dNormActivation(
                expanded_channels,
                expanded_channels,
                kernel_size=cnf.kernel,
                stride=cnf.stride,
                groups=expanded_channels,
                norm_layer=norm_layer,
                activation_layer=activation_layer,
            )
        )

        # squeeze and excitation
        squeeze_channels = max(1, cnf.input_channels // 4)
        layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))

        # project
        layers.append(
            Conv3dNormActivation(
                expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
            )
        )

        self.block = nn.Sequential(*layers)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
        self.out_channels = cnf.out_channels

    def forward(self, input: Tensor) -> Tensor:
        result = self.block(input)
        if self.use_res_connect:
            result = self.stochastic_depth(result)
            result += input
        return result


class FusedMBConv(nn.Module):
    def __init__(
        self,
        cnf: FusedMBConvConfig,
        stochastic_depth_prob: float,
        norm_layer: Callable[..., nn.Module],
    ) -> None:
        super().__init__()

        if not (1 <= cnf.stride <= 2):
            raise ValueError("illegal stride value")

        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

        layers: List[nn.Module] = []
        activation_layer = nn.SiLU

        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
        if expanded_channels != cnf.input_channels:
            # fused expand
            layers.append(
                Conv3dNormActivation(
                    cnf.input_channels,
                    expanded_channels,
                    kernel_size=cnf.kernel,
                    stride=cnf.stride,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )

            # project
            layers.append(
                Conv3dNormActivation(
                    expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
                )
            )
        else:
            layers.append(
                Conv3dNormActivation(
                    cnf.input_channels,
                    cnf.out_channels,
                    kernel_size=cnf.kernel,
                    stride=cnf.stride,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )

        self.block = nn.Sequential(*layers)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
        self.out_channels = cnf.out_channels

    def forward(self, input: Tensor) -> Tensor:
        result = self.block(input)
        if self.use_res_connect:
            result = self.stochastic_depth(result)
            result += input
        return result


class EfficientNet(nn.Module):
    def __init__(
        self,
        inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
        dropout: float,
        stochastic_depth_prob: float = 0.2,
        num_classes: int = 1000,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        last_channel: Optional[int] = None,
        **kwargs: Any,
    ) -> None:
        """
        EfficientNet V1 and V2 main class

        Args:
            inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
            dropout (float): The droupout probability
            stochastic_depth_prob (float): The stochastic depth probability
            num_classes (int): Number of classes
            norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
            last_channel (int): The number of channels on the penultimate layer
        """
        super().__init__()

        if not inverted_residual_setting:
            raise ValueError("The inverted_residual_setting should not be empty")
        elif not (
            isinstance(inverted_residual_setting, Sequence)
            and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
        ):
            raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")

        if "block" in kwargs:
            warnings.warn(
                "The parameter 'block' is deprecated since 0.13 and will be removed 0.15. "
                "Please pass this information on 'MBConvConfig.block' instead."
            )
            if kwargs["block"] is not None:
                for s in inverted_residual_setting:
                    if isinstance(s, MBConvConfig):
                        s.block = kwargs["block"]

        if norm_layer is None:
            norm_layer = nn.BatchNorm3d

        layers: List[nn.Module] = []

        # building first layer
        firstconv_output_channels = inverted_residual_setting[0].input_channels
        layers.append(
            Conv3dNormActivation(
                INPUT_CHANS, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
            )
        )

        # building inverted residual blocks
        total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
        stage_block_id = 0
        for cnf in inverted_residual_setting:
            stage: List[nn.Module] = []
            for _ in range(cnf.num_layers):
                # copy to avoid modifications. shallow copy is enough
                block_cnf = copy.copy(cnf)

                # overwrite info if not the first conv in the stage
                if stage:
                    block_cnf.input_channels = block_cnf.out_channels
                    block_cnf.stride = 1

                # adjust stochastic depth probability based on the depth of the stage block
                sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks

                stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
                stage_block_id += 1

            layers.append(nn.Sequential(*stage))

        # building last several layers
        lastconv_input_channels = inverted_residual_setting[-1].out_channels
        lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
        layers.append(
            Conv3dNormActivation(
                lastconv_input_channels,
                lastconv_output_channels,
                kernel_size=1,
                norm_layer=norm_layer,
                activation_layer=nn.SiLU,
            )
        )

        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool3d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout, inplace=True),
            nn.Linear(lastconv_output_channels, num_classes),
        )

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                init_range = 1.0 / math.sqrt(m.out_features)
                nn.init.uniform_(m.weight, -init_range, init_range)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        x = self.classifier(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def _efficientnet_conf(
    arch: str,
    **kwargs: Any,
) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
    if arch.startswith("efficientnet_b"):
        bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
        inverted_residual_setting = [
            bneck_conf(1, 3, 1, 32, 16, 1),
            bneck_conf(6, 3, 2, 16, 24, 2),
            bneck_conf(6, 5, 2, 24, 40, 2),
            bneck_conf(6, 3, 2, 40, 80, 3),
            bneck_conf(6, 5, 1, 80, 112, 3),
            bneck_conf(6, 5, 2, 112, 192, 4),
            bneck_conf(6, 3, 1, 192, 320, 1),
        ]
        last_channel = None
    elif arch.startswith("efficientnet_v2_s"):
        inverted_residual_setting = [
            FusedMBConvConfig(1, 3, 1, 24, 24, 2),
            FusedMBConvConfig(4, 3, 2, 24, 48, 4),
            FusedMBConvConfig(4, 3, 2, 48, 64, 4),
            MBConvConfig(4, 3, 2, 64, 128, 6),
            MBConvConfig(6, 3, 1, 128, 160, 9),
            MBConvConfig(6, 3, 2, 160, 256, 15),
        ]
        last_channel = 1280
    elif arch.startswith("efficientnet_v2_m"):
        inverted_residual_setting = [
            FusedMBConvConfig(1, 3, 1, 24, 24, 3),
            FusedMBConvConfig(4, 3, 2, 24, 48, 5),
            FusedMBConvConfig(4, 3, 2, 48, 80, 5),
            MBConvConfig(4, 3, 2, 80, 160, 7),
            MBConvConfig(6, 3, 1, 160, 176, 14),
            MBConvConfig(6, 3, 2, 176, 304, 18),
            MBConvConfig(6, 3, 1, 304, 512, 5),
        ]
        last_channel = 1280
    elif arch.startswith("efficientnet_v2_l"):
        inverted_residual_setting = [
            FusedMBConvConfig(1, 3, 1, 32, 32, 4),
            FusedMBConvConfig(4, 3, 2, 32, 64, 7),
            FusedMBConvConfig(4, 3, 2, 64, 96, 7),
            MBConvConfig(4, 3, 2, 96, 192, 10),
            MBConvConfig(6, 3, 1, 192, 224, 19),
            MBConvConfig(6, 3, 2, 224, 384, 25),
            MBConvConfig(6, 3, 1, 384, 640, 7),
        ]
        last_channel = 1280
    else:
        raise ValueError(f"Unsupported model type {arch}")

    return inverted_residual_setting, last_channel


def _efficientnet(
    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
    dropout: float,
    last_channel: Optional[int],
    weights: object,
    progress: bool,
    **kwargs: Any,
) -> EfficientNet:

    model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

    return model


def efficientnet_b0(*, weights: object = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
    """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.

    Args:
        weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B0_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B0_Weights
        :members:
    """
    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
    return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)


def efficientnet_v2_s(*, weights: object = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
    """
    Constructs an EfficientNetV2-S architecture from
    `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.

    Args:
        weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
        :members:
    """

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
    return _efficientnet(
        inverted_residual_setting,
        0.2,
        last_channel,
        weights,
        progress,
        norm_layer=partial(nn.BatchNorm3d, eps=1e-03),
        **kwargs,
    )

#### Own model

In [15]:
import torch
from torch import nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(44*64*64, 2)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view((x.shape[0] * x.shape[1], -1)).contiguous()
        x = (x - (x.mean(-1, keepdims=True))) / x.std(-1, keepdims=True)
        x = self.classifier(x)
        return x


## Framework

### Configurations

In [16]:
import os
from datetime import datetime

DATASET_CFG = {
    'data_path': r"...\data\ds002338\PreProcessed\Trials\XP1\labels.csv",  # need to be set
    'selected_subjects': None,
    'selected_runs': None,
    'selected_classes': None,
    'encode_labels': True,
    'exclude_end': False,
    'exclude_blank': False,
    'exclude_scrambled': False,
    'normalize_data': False,
    'time_first': True
}

DATA_SPLIT_CFG = {
    'train_test_ratio': 0.8,
    'val_ratio': 0.2,
    'shuffle': True,
    'seed': 42,
}

DATA_LOAD_CFG = {
    'batch_size': 16,
    'shuffle': True,
    'num_workers': 0,
    'drop_last': False,
}

RUN_CFG = {
    'epochs': 1000,
    'learning_rate': 1e-4,
    'save_every_n_epoch': 10,
    'log_every_n_step': 1,
    'on_gpu': torch.cuda.is_available(),
}

### Setup dataset

In [17]:
from torch.utils.data import DataLoader

# Initialize dataset
dataset = fMRIDataset(**DATASET_CFG)
if DATA_SPLIT_CFG['val_ratio'] > 0:
    train_set, test_set, val_set = split_data_set(dataset, **DATA_SPLIT_CFG)
else:
    train_set, test_set = split_data_set(dataset, **DATA_SPLIT_CFG)

# Initialize data loaders
train_loader = DataLoader(dataset=train_set, **DATA_LOAD_CFG)
test_loader = DataLoader(dataset=test_set, **DATA_LOAD_CFG)
val_loader = DataLoader(dataset=val_set, **DATA_LOAD_CFG) if DATA_SPLIT_CFG['val_ratio'] > 0 else None

### Setup model and utils

In [18]:
from torch.utils.tensorboard import SummaryWriter

# Initialize model
# model = efficientnet_v2_s(weights=None, pretrained=False, num_classes=len(dataset.class_ids))
model = SimpleNet()
if RUN_CFG['on_gpu']:
    model.cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=RUN_CFG['learning_rate'])
criterion = torch.nn.BCELoss()

### Training loop

#### Evaluation utils

In [19]:
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

def get_conf_mat(predictions, labels):
    conf_matrix = confusion_matrix(labels.argmax(-1).numpy(), predictions.argmax(-1).numpy(),
                                   labels=[class_id for class_id in range(labels.shape[-1])],
                                   normalize=None)

    diag = np.eye(*conf_matrix.shape, dtype=bool)
    recall = np.sum(conf_matrix, axis=-1, keepdims=True)
    recall = np.where(recall > 0, (conf_matrix[diag] / recall.T).T, np.zeros_like(recall))
    precision = np.sum(conf_matrix, axis=0, keepdims=True)
    precision = np.where(precision > 0, (conf_matrix[diag] / precision), np.zeros_like(precision))
    accuracy = np.sum(conf_matrix[diag]) / np.sum(conf_matrix)

    conf_matrix = np.concatenate((conf_matrix, recall), axis=-1)
    conf_matrix = np.concatenate((conf_matrix, np.concatenate((precision, np.array([[accuracy]])), axis=-1)), axis=0)

    plt.clf()
    plt.cla()
    fig = plt.figure()
    labels = list(range(labels.shape[-1]))
    xlabels = labels + ['Recall']
    ylabels = labels + ['Precision']
    fig.ax = sns.heatmap(conf_matrix, annot=True, cmap="YlGnBu", linewidths=.2,
                         xticklabels=xlabels, yticklabels=ylabels)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted label')
    plt.ylabel('True label')

    return fig

#### Main loop

In [20]:
# from copy import deepcopy

# train_balances = {'0': 0,'1': 0,}
# for (_, label) in deepcopy(train_loader):
#     train_balances[str(label.argmax(-1).item())] += 1

# val_balances = {'0': 0,'1': 0,}
# for (_, label) in deepcopy(val_loader):
#     val_balances[str(label.argmax(-1).item())] += 1

# print(train_balances)
# print(val_balances)

In [None]:
from tqdm.notebook import tqdm

logger = SummaryWriter(os.path.join(os.getcwd(), 'logs', datetime.now().strftime('%y%m%d_%H%M%S')))
epochs = RUN_CFG['epochs']
use_cuda = RUN_CFG['on_gpu']
epoch_train_length = len(train_loader)
best_val_loss = 999.

for epoch in range(epochs):
    # Training part
    model.train()
    epoch_loss_train = 0.
    with tqdm(total=epoch_train_length, leave=True) as pbar:
        for bidx, (fmri, label) in enumerate(train_loader):
            if use_cuda:
                fmri, label = fmri.cuda(), label.cuda()
            label = label.view(-1, label.shape[-1]).contiguous()

            optimizer.zero_grad()
            prediction = \
            torch.softmax(
                model(fmri), dim=-1
            )
            loss = criterion(prediction, label)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
            optimizer.step()

            if (bidx + 1) % RUN_CFG['log_every_n_step'] == 0:
                step = epoch * epoch_train_length + bidx + 1
                logger.add_scalar('Loss/Train', loss.item(), step)
                logger.add_scalar('Accuracy/Train', (prediction.argmax(-1) == label.argmax(-1)).sum() / len(label), step)

            epoch_loss_train += loss.item()
            pbar.set_postfix({
                'Mode': 'Train',
                'Epoch': f'{epoch + 1}/{epochs}',
                'AvgLoss': epoch_loss_train / (bidx + 1)
            })
            pbar.update(1)

    # Validation part
    with torch.no_grad():
        model.eval()
        epoch_loss_eval = 0.
        avg_acc = 0.
        predictions = []
        labels = []

        with tqdm(total=len(val_loader), leave=True) as pbar:
            for bidx, (fmri, label) in enumerate(val_loader):
                if use_cuda:
                    fmri, label = fmri.cuda(), label.cuda()
                label = label.view(-1, label.shape[-1]).contiguous()
                
                prediction = \
                torch.softmax(
                    model(fmri), dim=-1
                )
                loss = criterion(prediction, label)

                avg_acc += ( ( (prediction.argmax(-1) == label.argmax(-1)).sum() / len(label) ) / len(val_loader) )
                predictions.append(prediction.cpu())
                labels.append(label.cpu())
                
                loss /= len(label)
                epoch_loss_eval += loss.item()
                pbar.set_postfix({
                    'Mode': 'Validation',
                    'Epoch': f'{epoch + 1}/{epochs}',
                    'AvgLoss': epoch_loss_eval / (bidx + 1)
                })
                pbar.update(1)
    
    predictions = torch.concatenate(predictions, dim=0)
    labels = torch.concatenate(labels, dim=0)
    logger.add_figure('Confusion Matrix/Validation', get_conf_mat(predictions, labels), epoch + 1)
    logger.add_scalar('Loss/Validation', epoch_loss_eval / len(val_loader), epoch + 1)
    logger.add_scalar('Accuracy/Validation', avg_acc, epoch + 1)

    if (epoch + 1) % RUN_CFG['save_every_n_epoch'] == 0 or epoch_loss_eval < best_val_loss:
        if epoch_loss_eval < best_val_loss:
            best_val_loss = epoch_loss_eval
            save_name = os.path.join(logger.log_dir, f'model_best_val.pth') 
        else:
            save_name = os.path.join(logger.log_dir, f'model_epoch{epoch}.pth')
        
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': epoch_loss_train,
            'epoch': epoch},
            save_name
        )