## Resources

* [OpenMind Wiki](https://github.mit.edu/MGHPCC/OpenMind/wiki)
* [OpenMind Website](openmind.mit.edu)
* [OpenMind Jupyter Notebook Tutorial](https://github.mit.edu/MGHPCC/OpenMind/wiki/How-to-use-Jupyter-Notebook-on-OpenMind%3F)
* [Accusleep Dataset](https://osf.io/py5eb/)

## Package Installation

In [None]:
%pip install -r requirements.txt

In [None]:
# Check correct hostname
!hostname

## Imports

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import accelerate
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from sklearn.metrics import classification_report, balanced_accuracy_score, accuracy_score, confusion_matrix
from sklearn.preprocessing import RobustScaler
from scipy.io import loadmat
from scipy.signal import resample
from mne.filter import filter_data
import matplotlib.pyplot as plt
import wandb
import os
from pathlib import Path
from tqdm import tqdm
from time import time
from types import SimpleNamespace

## Helper Functions/Classes

In [None]:
def downsample(x, sf, new_sf):
    num = int(new_sf / sf * x.shape[-1])
    return resample(x, num, axis=-1)


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self


class SleepDataset(Dataset):
    """Dataset class for EEG/EMG sleep data from the Accusleep Dataset (https://osf.io/py5eb/)
        Supports "homemade" data of the same file/directory format
    """
    def __init__(
        self,
        data_dir,
        raw_sf,
        model_sf,
        raw_epoch_len,
        model_epoch_len,
        homemade,
        *,
        run_preproc=False,
        bandpass_freqs=None,
        context_window=1,
        eeg_idx=0,
        emg_idx=5,
        random_shift=False,
        eeg_transform=None,
        emg_transform=None,
        target_transform=None,
    ):
        # Build the paths
        all_paths = list(Path(data_dir).glob("**/*"))
        all_dirs = sorted([folder for folder in all_paths if folder.is_dir()])
        if not all_dirs:
            all_dirs = [Path(data_dir)]
        assert (
            float(raw_sf * raw_epoch_len).is_integer()
            and float(model_sf * model_epoch_len).is_integer()
        )

        # Save attributes
        self.window_size = int(model_sf * model_epoch_len)
        self.raw_sf = raw_sf
        self.model_sf = model_sf
        self.raw_epoch_len = raw_epoch_len
        self.model_epoch_len = model_epoch_len
        assert context_window >= 1
        self.context_window = context_window
        self.random_shift = random_shift
        self.eeg_transform = eeg_transform
        self.emg_transform = emg_transform
        self.target_transform = target_transform

        self.eegs = []
        self.emgs = []
        self.labels = []
        index_map = []
        dir_map = []
        weight = torch.tensor([0,0,0])
        i_map = 0

        # Load data
        for dir in all_dirs:
            try:
                if run_preproc:
                    self.preprocess(dir, homemade, raw_sf, model_sf, bandpass_freqs)
                eeg, emg, label = self.get_preproc_files(
                    dir, homemade, eeg_idx, emg_idx
                )
            except FileNotFoundError:
                continue
            self.eegs.append(eeg)
            self.emgs.append(emg)
            self.labels.append(label)

            # Count labels for class weighting
            _, counts = label.unique(sorted=True, return_counts=True)
            weight += counts

            # Mapping from "global index" to list and tensor indices
            dir_map.append(np.repeat(i_map, len(label) - self.context_window + 1))
            index_map.append(np.arange(self.context_window - 1, len(label)))
            i_map += 1
        if len(dir_map) == 0:
            raise FileNotFoundError

        # Build index mapping
        self.weight = 1 - weight / weight.sum()
        dir_map = np.concatenate(dir_map)
        index_map = np.concatenate(index_map)
        self.index_map = np.column_stack((dir_map, index_map))

    def get_preproc_files(self, data_dir, homemade, eeg_idx, emg_idx):
        """Retrieves preprocessed EEG, EMG, and label data

        Args:
            data_dir (str): directory in which data is stored
            homemade (bool): homemade or Accusleep data
            eeg_idx (int): desired eeg probe index in raw homemade data array
            emg_idx (int): desired emg probe index in raw homemade data array
            a

        Raises:
            FileNotFoundError: raised if EEG.mat, EMG.mat, or labels.mat do not exist within target_folder

        Returns:
            eeg_array, emg_array, label_array (torch.Tensor, torch.Tensor, torch.Tensor): EEG, EMG, and label data
            as numpy arrays
        """
        if homemade:
            eeg_emg_file = os.path.join(data_dir, "EEG-EMG_preproc.npy")
            label_file = os.path.join(data_dir, "labels_preproc.npy")

            paths_are_files = list(
                map(lambda x: os.path.isfile(x), [eeg_emg_file, label_file])
            )
            if not all(paths_are_files):
                raise FileNotFoundError
            # load files
            eeg_emg = np.load(eeg_emg_file)
            label = np.load(label_file)

            eeg, emg = eeg_emg[eeg_idx], eeg_emg[emg_idx]
        else:
            eeg_file = os.path.join(data_dir, "EEG_preproc.npy")
            emg_file = os.path.join(data_dir, "EMG_preproc.npy")
            label_file = os.path.join(data_dir, "labels_preproc.npy")

            paths_are_files = list(
                map(lambda x: os.path.isfile(x), [eeg_file, emg_file, label_file])
            )
            if not all(paths_are_files):
                raise FileNotFoundError
            # load files
            eeg = np.load(eeg_file)
            emg = np.load(emg_file)
            label = np.load(label_file)

        # Check lengths
        assert eeg.shape[0] == emg.shape[0]
        num_eeg_samples = eeg.shape[0] / self.window_size
        assert num_eeg_samples >= len(label)

        eeg = eeg[: len(label) * self.window_size]
        emg = emg[: len(label) * self.window_size]
        return (
            torch.tensor(eeg, dtype=torch.float),
            torch.tensor(emg, dtype=torch.float),
            torch.tensor(label, dtype=torch.long),
        )

    def preprocess(
        self,
        data_dir,
        homemade,
        raw_sf,
        model_sf,
        bandpass_freqs,
    ):
        if homemade:
            eeg_emg_file = os.path.join(data_dir, "EEG-EMG.npy")
            label_file = os.path.join(data_dir, "labels.npy")

            paths_are_files = list(
                map(lambda x: os.path.isfile(x), [eeg_emg_file, label_file])
            )
            if not all(paths_are_files):
                raise FileNotFoundError
            # load files
            eeg_emg = np.load(eeg_emg_file)
            label = np.load(label_file)

            if bandpass_freqs:
                eeg_emg = filter_data(
                    eeg_emg.astype("float64"),
                    raw_sf,
                    bandpass_freqs[0],
                    bandpass_freqs[1],
                    verbose=0,
                )
            if raw_sf != model_sf:
                eeg_emg = downsample(eeg_emg, raw_sf, model_sf)

            # Scale EEG and EMG
            eeg_emg = RobustScaler().fit_transform(eeg_emg.T).T  # type: ignore

            num_samples = int(eeg_emg.shape[1] // (model_sf * self.raw_epoch_len))
            assert num_samples == len(label)

            if self.raw_epoch_len != self.model_epoch_len:
                label_new_len = int(
                    len(label) * self.raw_epoch_len // self.model_epoch_len
                )
                eeg_emg = eeg_emg[
                    : int(label_new_len * model_sf * self.model_epoch_len)
                ]

            np.save(os.path.join(data_dir, "EEG-EMG_preproc.npy"), eeg_emg)
        else:
            eeg_file = os.path.join(data_dir, "EEG.mat")
            emg_file = os.path.join(data_dir, "EMG.mat")
            label_file = os.path.join(data_dir, "labels.mat")

            paths_are_files = list(
                map(lambda x: os.path.isfile(x), [eeg_file, emg_file, label_file])
            )
            if not all(paths_are_files):
                raise FileNotFoundError
            # load files
            eeg = loadmat(eeg_file)
            emg = loadmat(emg_file)
            label = loadmat(label_file)

            eeg = np.squeeze(eeg["EEG"])
            emg = np.squeeze(emg["EMG"])
            label = np.squeeze(label["labels"])

            assert eeg.shape[0] == emg.shape[0]

            if bandpass_freqs:
                eeg = filter_data(
                    eeg.astype("float64"),
                    raw_sf,
                    bandpass_freqs[0],
                    bandpass_freqs[1],
                    verbose=0,
                )
                emg = filter_data(
                    emg.astype("float64"),
                    raw_sf,
                    bandpass_freqs[0],
                    bandpass_freqs[1],
                    verbose=0,
                )
            if raw_sf != model_sf:
                eeg = downsample(eeg, raw_sf, model_sf)
                emg = downsample(emg, raw_sf, model_sf)

            # Scale EEG and EMG
            eeg = RobustScaler().fit_transform(eeg[:, np.newaxis]).squeeze()  # type: ignore
            emg = RobustScaler().fit_transform(emg[:, np.newaxis]).squeeze()  # type: ignore

            num_samples = int(eeg.shape[0] // (model_sf * self.raw_epoch_len))
            assert num_samples == len(label)

            if self.raw_epoch_len != self.model_epoch_len:
                label_new_len = int(
                    len(label) * self.raw_epoch_len // self.model_epoch_len
                )
                eeg = eeg[: int(label_new_len * model_sf * self.model_epoch_len)]
                emg = emg[: int(label_new_len * model_sf * self.model_epoch_len)]

            np.save(os.path.join(data_dir, "EEG_preproc.npy"), eeg)
            np.save(os.path.join(data_dir, "EMG_preproc.npy"), emg)

        accusleep_dict = {
            1: 2,  # "R",
            2: 0,  # "W",
            3: 1,  # "N"
        }
        # re-map the label values
        label_df = pd.DataFrame({"label": label})
        label_df["label"] = label_df["label"].map(accusleep_dict)
        label = label_df["label"].values

        if self.raw_epoch_len != self.model_epoch_len:
            label_new = np.zeros(label_new_len, dtype=int)  # type: ignore
            for i in range(len(label_new)):
                label_new[i] = label[
                    int(round(i * self.model_epoch_len / self.raw_epoch_len))
                ]
            label = label_new

        np.save(os.path.join(data_dir, "labels_preproc.npy"), label)  # type: ignore

    def __len__(self):
        return self.index_map.shape[0]

    def __getitem__(self, idx):
        dir_idx, adjusted_idx = self.index_map[idx]
        label = self.labels[dir_idx][adjusted_idx]
        if self.random_shift and adjusted_idx < len(self.labels[dir_idx]) - 1:
            offset = np.random.randint(self.window_size)
        else:
            offset = 0
        low = (adjusted_idx - self.context_window + 1) * self.window_size + offset
        high = (adjusted_idx + 1) * self.window_size + offset
        eeg = self.eegs[dir_idx][low:high]
        emg = self.emgs[dir_idx][low:high]
        if self.eeg_transform:
            eeg = self.eeg_transform(eeg)
        if self.emg_transform:
            emg = self.emg_transform(emg)
        if self.target_transform:
            label = self.target_transform(label)
        x = torch.stack((eeg, emg), dim=0)
        return x, label

## Parameters

In [None]:
params = AttrDict(
    # System
    threads_per_gpu = 1,
    gpus = 4,

    # Data
    train_data_dir = 'data/24-hour_recordings',
    test_data_dir = 'data/4-hour_recordings',
    sample_data_dir = 'data/4-hour_recordings/Mouse01/Day1',
    homemade_data_dir = 'data/homemade_recordings',
    pretrain_dir = 'pretrain',
    finetune_dir = 'finetune',
    homemade_eeg_idx = 4,
    homemade_emg_idx = 5,
    model_sf = 128,
    accusleep_sf = 512, # Hz
    homemade_sf = 1000, # Hz
    bandpass_freqs = [1, 64],
    run_preproc = False,
    homemade_epoch_len = 2.0, # sec
    accusleep_epoch_len = 2.5, # sec
    model_epoch_len = 2.5, # sec
    context_window = 3,
    random_shift = False,
    eeg_transform = None,
    emg_transform = None,
    target_transform = None,

    # Model
    in_dim = 2,
    out_dim = 3,
    embed_dim = 32,
    feedforward_dim = 128,
    do_pos_embed = True,
    pos_embed_dim = 16,
    dim_blocks = [16, 32, 32, 32],
    res_blocks = [True, True, True, True],
    attn_blocks = [False, False, True, True],
    num_heads = 8,
    kernel_size = 3,
    activation = nn.GELU,
    dropout = 0.0,
    log_wandb = True,

    # Pretraining
    criterion = nn.CrossEntropyLoss,
    optimizer = optim.Adam,
    lr = 0.001,
    num_epochs = 30,
    batch_size = 16,
    log_freq = 50,
    val_batches = 50,
    use_weight = True,

    # Finetuning
    num_epochs_ft = 100,
    ft_last_block = False,
    log_freq_ft = 10
)
try:
    os.mkdir(params.pretrain_dir)
except FileExistsError:
    pass
try:
    os.mkdir(params.finetune_dir)
except FileExistsError:
    pass

## Data

### Accusleep Data

In [None]:
eeg = np.load("data/4-hour_recordings/Mouse01/Day1/EEG_preproc.npy")
emg = np.load("data/4-hour_recordings/Mouse01/Day1/EMG_preproc.npy")
i = 100
window_size = 5
window = slice(128*i, int(128*(i + window_size)))
plt.plot(eeg[window])
plt.title("EEG")
plt.show()
plt.plot(emg[window])
plt.title("EMG")
plt.show()

### Homemade Data

In [None]:
eeg_emg = np.load("data/homemade_recordings/EEG-EMG_preproc.npy")
i = 20
window_size = 5
window = slice(128*i, int(128*(i + window_size)))
plt.plot(eeg_emg[params.homemade_eeg_idx, window])
plt.title("EEG")
plt.show()
plt.plot(eeg_emg[params.homemade_emg_idx, window])
plt.title("EMG")
plt.show()

In [None]:
sample_data = SleepDataset(
    params.sample_data_dir,
    params.accusleep_sf,
    params.model_sf,
    params.accusleep_epoch_len,
    params.model_epoch_len,
    False,
    run_preproc=params.run_preproc,
    bandpass_freqs=params.bandpass_freqs,
    context_window=params.context_window,
    random_shift=params.random_shift,
    eeg_transform=params.eeg_transform,
    emg_transform=params.emg_transform,
    target_transform=params.target_transform
)
sample_loader = DataLoader(
    sample_data,
    params.batch_size,
    shuffle=False,
    pin_memory=False
)
params.input_shape = next(iter(sample_loader))[0].shape

In [None]:
def get_pretrain_dataloaders(params):
    # Datasets
    train_data = SleepDataset(
        params.train_data_dir,
        params.accusleep_sf,
        params.model_sf,
        params.accusleep_epoch_len,
        params.model_epoch_len,
        False,
        run_preproc=params.run_preproc,
        bandpass_freqs=params.bandpass_freqs,
        context_window=params.context_window,
        random_shift=params.random_shift,
        eeg_transform=params.eeg_transform,
        emg_transform=params.emg_transform,
        target_transform=params.target_transform
    )
    test_data = SleepDataset(
        params.sample_data_dir,
        params.accusleep_sf,
        params.model_sf,
        params.accusleep_epoch_len,
        params.model_epoch_len,
        False,
        run_preproc=params.run_preproc,
        bandpass_freqs=params.bandpass_freqs,
        context_window=params.context_window,
        random_shift=params.random_shift,
        eeg_transform=params.eeg_transform,
        emg_transform=params.emg_transform,
        target_transform=params.target_transform
    )
    homemade_test_data = SleepDataset(
        params.homemade_data_dir,
        params.homemade_sf,
        params.model_sf,
        params.homemade_epoch_len,
        params.model_epoch_len,
        True,
        run_preproc=params.run_preproc,
        bandpass_freqs=params.bandpass_freqs,
        context_window=params.context_window,
        eeg_idx=9,
        emg_idx=10,
        random_shift=params.random_shift,
        eeg_transform=params.eeg_transform,
        emg_transform=params.emg_transform,
        target_transform=params.target_transform
    )
    # DataLoaders
    train_loader = DataLoader(
        train_data,
        params.batch_size,
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
        num_workers=params.threads_per_gpu
    )
    indices = np.random.choice(len(test_data), params.val_batches*params.batch_size, replace=False)
    val_sampler = SubsetRandomSampler(indices)
    val_loader = DataLoader(
        test_data,
        params.batch_size,
        sampler=val_sampler,
        pin_memory=torch.cuda.is_available(),
        num_workers=params.threads_per_gpu
    )
    homemade_test_loader = DataLoader(
        homemade_test_data,
        params.batch_size,
        shuffle=False,
        pin_memory=torch.cuda.is_available(),
        num_workers=params.threads_per_gpu
    )

    return train_loader, val_loader, homemade_test_loader, train_data.weight

In [None]:
def get_finetune_dataloaders(params):
    # Datasets
    homemade_train_data = SleepDataset(
        params.homemade_data_dir,
        params.homemade_sf,
        params.model_sf,
        params.homemade_epoch_len,
        params.model_epoch_len,
        True,
        run_preproc=params.run_preproc,
        bandpass_freqs=params.bandpass_freqs,
        context_window=params.context_window,
        eeg_idx=params.homemade_eeg_idx,
        emg_idx=params.homemade_emg_idx,
        random_shift=params.random_shift,
        eeg_transform=params.eeg_transform,
        emg_transform=params.emg_transform,
        target_transform=params.target_transform
    )

    homemade_test_data = SleepDataset(
        params.homemade_data_dir,
        params.homemade_sf,
        params.model_sf,
        params.homemade_epoch_len,
        params.model_epoch_len,
        True,
        run_preproc=params.run_preproc,
        bandpass_freqs=params.bandpass_freqs,
        context_window=params.context_window,
        eeg_idx=9,
        emg_idx=10,
        random_shift=params.random_shift,
        eeg_transform=params.eeg_transform,
        emg_transform=params.emg_transform,
        target_transform=params.target_transform
    )

    # DataLoaders
    homemade_train_loader = DataLoader(
        homemade_train_data,
        params.batch_size,
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
        num_workers=params.threads_per_gpu
    )
    homemade_test_loader = DataLoader(
        homemade_test_data,
        params.batch_size,
        shuffle=False,
        pin_memory=torch.cuda.is_available(),
        num_workers=params.threads_per_gpu
    )

    return homemade_train_loader, homemade_test_loader, homemade_train_data.weight

## 1D Convolutional Attention Model

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding="same")
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding="same")
        self.activation = activation()
        self.downsample = (
            nn.Conv1d(in_channels, out_channels, kernel_size=1)
            if in_channels != out_channels
            else nn.Identity()
        )
        self.norm1 = nn.LayerNorm(in_channels)
        self.norm2 = nn.LayerNorm(out_channels)

    def forward(self, x):
        res = self.downsample(x)
        x = self.norm1(x.transpose(1, 2)).transpose(1, 2)
        x = self.activation(x)
        x = self.conv1(x)
        x = self.norm2(x.transpose(1, 2)).transpose(1, 2)
        x = self.activation(x)
        x = self.conv2(x)
        return x + res


class FeedForward(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, do_ln, activation):
        super().__init__()
        self.norm = nn.LayerNorm(in_dim) if do_ln else nn.Identity()
        self.linear1 = nn.Linear(in_dim, hidden_dim)
        self.activation = activation()
        self.linear2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.norm(x)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x


def positional_encoding(length, depth):
    original_depth = depth
    if depth % 2 != 0:
        depth += 1
    depth /= 2

    positions = np.arange(length)[:, np.newaxis]  # (length, 1)
    depths = np.arange(depth)[np.newaxis, :] / depth  # (1, depth)

    angle_rads = positions / (10000**depths)

    pos_encoding = np.concatenate([np.sin(angle_rads), np.cos(angle_rads)], axis=-1)[
        :, :original_depth
    ]

    return torch.Tensor(pos_encoding)


class PositionalEmbedding(nn.Module):
    def __init__(self, seq_length, out_dim, ff_dim, activation):
        super().__init__()
        pos_encoding = positional_encoding(seq_length, out_dim)[None, :, :]
        self.pos_encoding = nn.Parameter(pos_encoding, requires_grad=False)
        self.ffn = FeedForward(out_dim, ff_dim, out_dim, False, activation)

    def forward(self, x):
        x = x.transpose(1,2)
        x += self.ffn(self.pos_encoding)
        return x.transpose(1,2)


class AttentionBlock(nn.Module):
    def __init__(
        self, in_dim, embed_dim, num_heads, feedforward_dim, activation, dropout
    ):
        super().__init__()
        self.qkv = nn.Linear(in_dim, embed_dim * 3)
        self.attention = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.feed_forward = FeedForward(
            in_dim, feedforward_dim, in_dim, True, activation
        )
        self.proj = (
            nn.Linear(embed_dim, in_dim) if in_dim != embed_dim else nn.Identity()
        )
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(in_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x.transpose(1, 2)
        res = x
        x = self.norm1(x)
        q, k, v = self.qkv(x).chunk(3, -1)
        x, _ = self.attention(q, k, v, need_weights=False)
        x = self.proj(x)
        x += res
        res = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        x += res
        return x.transpose(1, 2)


class Model(torch.nn.Module):
    def __init__(
        self,
        input_shape,
        in_dim,
        out_dim,
        embed_dim,
        feedforward_dim,
        do_pos_embed,
        pos_embed_dim,
        dim_blocks,
        res_blocks,
        attn_blocks,
        num_heads,
        kernel_size,
        activation,
        dropout,
        **params
    ):
        super().__init__()
        assert len(dim_blocks) == len(attn_blocks) == len(res_blocks)
        seq_length = input_shape[-1]
        self.num_blocks = len(dim_blocks)
        res_blocks = (
            res_blocks
            if isinstance(res_blocks, list)
            else [res_blocks] * self.num_blocks
        )
        attn_blocks = (
            attn_blocks
            if isinstance(attn_blocks, list)
            else [attn_blocks] * self.num_blocks
        )

        self.pos_embed = (
            PositionalEmbedding(seq_length, in_dim, pos_embed_dim, activation)
            if do_pos_embed
            else nn.Identity()
        )

        self.blocks = nn.ModuleList()
        for i, (dim_block, do_res, do_attn) in enumerate(
            zip(dim_blocks, res_blocks, attn_blocks)
        ):
            assert do_res or do_attn
            block = nn.Module()
            block.res = (
                ResidualBlock(
                    in_dim if i == 0 else dim_blocks[i - 1],
                    dim_block,
                    kernel_size,
                    activation,
                )
                if do_res
                else nn.Identity()
            )
            block.attn = (
                AttentionBlock(
                    dim_block,
                    embed_dim,
                    num_heads,
                    feedforward_dim,
                    activation,
                    dropout,
                )
                if do_attn
                else nn.Identity()
            )
            self.blocks.append(block)

        self.out_block = nn.Sequential(
            nn.Linear(seq_length * dim_blocks[-1], 64),
            activation(),
            nn.Linear(64, out_dim),
        )

    def forward(self, x):
        x = self.pos_embed(x)
        for block in self.blocks:
            x = block.res(x)
            x = block.attn(x)
        return self.out_block(x.flatten(1))

## Training

In [None]:
def pretrain(params):
    torch.set_num_threads(1)
    project_config = ProjectConfiguration(project_dir=os.path.join(os.getcwd(), params.pretrain_dir), automatic_checkpoint_naming=True)
    accelerator = Accelerator(log_with="wandb", project_config=project_config)

    if params.log_wandb:
        accelerator.init_trackers(
            project_name="Sleep Staging",
            config=params
        )

    model = Model(**params)
    optimizer = params.optimizer(model.parameters(), params.lr)
    train_loader, val_loader, homemade_test_loader, weight = get_pretrain_dataloaders(params)
    criterion = params.criterion(weight=weight.to(accelerator.device) if params.use_weight else None)

    train_loader, val_loader, homemade_test_loader, model, optimizer = accelerator.prepare(
        train_loader, val_loader, homemade_test_loader, model, optimizer)

    for epoch in range(params.num_epochs):
        model.train()
        total_loss = 0
        total_acc = 0
        avg_loss = 0
        avg_acc = 0
        if params.log_wandb:
            accelerator.log(dict(epoch = epoch))

        bar = tqdm(train_loader, desc=(
            f"Training | Epoch: {epoch} | "
            f"Acc: {avg_acc:.2%} | "
            f"Loss: {avg_loss:.4f}"),
            disable=not accelerator.is_main_process
        )
        for i, (x, y) in enumerate(bar):
            # Forward pass
            pred = model(x)
            loss = criterion(pred, y)

            # Backward pass and optimization
            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
            total_loss += loss
            avg_loss += loss

            # Calculate Accuracy
            acc = (pred.argmax(-1) == y).float().mean()
            total_acc += acc
            avg_acc += acc

            # Log
            if (i + 1) % params.log_freq == 0:
                avg_loss /= params.log_freq
                avg_acc /= params.log_freq
                avg_loss = accelerator.gather(avg_loss).mean()
                avg_acc = accelerator.gather(avg_acc).mean()
                if params.log_wandb:
                    accelerator.log(dict(
                        train_loss = avg_loss,
                        train_acc = avg_acc)
                    )
                bar.set_description(f"Training | Epoch: {epoch} | "
                                f"Acc: {avg_acc:.2%} | "
                                f"Loss: {avg_loss:.4f}")
                avg_loss = 0
                avg_acc = 0

        if params.log_wandb:
            train_epoch_loss = total_loss/len(train_loader)
            train_epoch_acc = total_acc/len(train_loader)
            train_epoch_loss = accelerator.gather(train_epoch_loss).mean()
            train_epoch_acc = accelerator.gather(train_epoch_acc).mean()
            accelerator.log(dict(
                train_epoch_loss = train_epoch_loss,
                train_epoch_acc = train_epoch_loss)
            )

        # Test Phase
        model.eval()
        with torch.no_grad():
            total_loss = 0
            total_acc = 0
            for i, (x, y) in enumerate(val_loader):
                # Forward pass
                pred = model(x)
                loss = criterion(pred, y)
                total_loss += loss

                # Calculate Accuracy
                acc = (pred.argmax(-1) == y).float().mean()
                total_acc += acc
            if params.log_wandb:
                val_epoch_loss = total_loss/len(val_loader)
                val_epoch_acc = total_acc/len(val_loader)
                val_epoch_loss = accelerator.gather(val_epoch_loss).mean()
                val_epoch_acc = accelerator.gather(val_epoch_acc).mean()
                accelerator.log(dict(
                    val_epoch_loss = val_epoch_loss,
                    val_epoch_acc = val_epoch_acc)
                )

            total_loss = 0
            total_acc = 0
            for i, (x, y) in enumerate(homemade_test_loader):
                # Forward pass
                pred = model(x)
                loss = criterion(pred, y)
                total_loss += loss

                # Calculate Accuracy
                acc = (pred.argmax(-1) == y).float().mean()
                total_acc += acc
            if params.log_wandb:
                homemade_epoch_loss = total_loss/len(homemade_test_loader)
                homemade_epoch_acc = total_acc/len(homemade_test_loader)
                homemade_epoch_loss = accelerator.gather(homemade_epoch_loss).mean()
                homemade_epoch_acc = accelerator.gather(homemade_epoch_acc).mean()
                accelerator.log(dict(
                    homemade_epoch_loss = total_loss,
                    homemade_epoch_acc = total_acc)
                )

        # Save Model
        accelerator.save_state()

    accelerator.end_training()

In [None]:
accelerate.notebook_launcher(pretrain, (params,), num_processes=params.gpus)

## Fine-Tuning

In [None]:
def finetune(params, checkpoint_idx):
    torch.set_num_threads(1)
    project_config = ProjectConfiguration(project_dir=os.path.join(os.getcwd(), params.finetune_dir), automatic_checkpoint_naming=True)
    accelerator = Accelerator(log_with="wandb", project_config=project_config)

    if params.log_wandb:
        accelerator.init_trackers(
            project_name="Sleep Staging",
            config=params
        )

    model = Model(**params)
    state_dict = torch.load(os.path.join(os.getcwd(), params.pretrain_dir, f"checkpoints/checkpoint_{checkpoint_idx}/pytorch_model.bin"))
    model.load_state_dict(state_dict)
    if params.ft_last_block:
        for name, param in model.named_parameters():
            param.requires_grad = True if "out_block" in name else False

    optimizer = params.optimizer(model.parameters(), params.lr)

    homemade_train_loader, homemade_test_loader, weight = get_finetune_dataloaders(params)
    criterion = params.criterion(weight=weight.to(accelerator.device) if params.use_weight else None)
    
    homemade_train_loader, homemade_test_loader, model, optimizer = accelerator.prepare(
        homemade_train_loader, homemade_test_loader, model, optimizer)

    for epoch in range(params.num_epochs_ft):
        model.train()
        total_loss = 0
        total_acc = 0
        avg_loss = 0
        avg_acc = 0
        if params.log_wandb:
            accelerator.log(dict(epoch_ft = epoch))
        bar = tqdm(homemade_train_loader, desc=(
            f"Training | Epoch: {epoch} | "
            f"Acc: {avg_acc:.2%} | "
            f"Loss: {avg_loss:.4f}"),
            disable=not accelerator.is_main_process
        )
        for i, (x, y) in enumerate(bar):
            # Forward pass
            pred = model(x)
            loss = criterion(pred, y)

            # Backward pass and optimization
            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
            total_loss += loss
            avg_loss += loss

            # Calculate Accuracy
            acc = (pred.argmax(-1) == y).float().mean()
            total_acc += acc
            avg_acc += acc

            # Log
            if (i + 1) % params.log_freq_ft == 0:
                avg_loss /= params.log_freq
                avg_acc /= params.log_freq
                avg_loss = accelerator.gather(avg_loss).mean()
                avg_acc = accelerator.gather(avg_acc).mean()
                if params.log_wandb:
                    accelerator.log(dict(
                        homemade_train_loss = avg_loss,
                        homemade_train_acc = avg_acc)
                    )
                bar.set_description(f"Training | Epoch: {epoch} | "
                                f"Acc: {avg_acc:.2%} | "
                                f"Loss: {avg_loss:.4f}")
                avg_loss = 0
                avg_acc = 0

        if params.log_wandb:
            train_epoch_loss = total_loss/len(homemade_train_loader)
            train_epoch_acc = total_acc/len(homemade_train_loader)
            train_epoch_loss = accelerator.gather(train_epoch_loss).mean()
            train_epoch_acc = accelerator.gather(train_epoch_acc).mean()
            accelerator.log(dict(
                train_epoch_loss = train_epoch_loss,
                train_epoch_acc = train_epoch_loss)
            )

        # Test Phase
        model.eval()
        with torch.no_grad():
            total_loss = 0
            total_acc = 0
            for i, (x, y) in enumerate(homemade_test_loader):
                # Forward pass
                pred = model(x)
                loss = criterion(pred, y)
                total_loss += loss

                # Calculate Accuracy
                acc = (pred.argmax(-1) == y).float().mean()
                total_acc += acc
            if params.log_wandb:
                homemade_epoch_loss = total_loss/len(homemade_test_loader)
                homemade_epoch_acc = total_acc/len(homemade_test_loader)
                homemade_epoch_loss = accelerator.gather(homemade_epoch_loss).mean()
                homemade_epoch_acc = accelerator.gather(homemade_epoch_acc).mean()
                accelerator.log(dict(
                    homemade_epoch_loss = total_loss,
                    homemade_epoch_acc = total_acc)
                )

        # Save Model
        accelerator.save_state()

In [None]:
checkpoint_idx = 7
accelerate.notebook_launcher(finetune, (params, checkpoint_idx), num_processes=params.gpus)

## Testing

In [None]:
checkpoint_idx = 2
device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")
model = Model(**params).to(device)
state_dict = torch.load(os.path.join(os.getcwd(), params.finetune_dir, f"checkpoints/checkpoint_{checkpoint_idx}/pytorch_model.bin"))
model.load_state_dict(state_dict)

In [None]:
# Homemade Test Phase
model.eval()
_, loader, weight = get_finetune_dataloaders(params)
total_loss = 0
total_acc = 0
criterion = params.criterion(weight=weight.to(device))
ys = []
preds = []
with torch.no_grad():
    t0 = time()
    for i, (x, y) in enumerate(tqdm(loader, desc="Testing")):
        # Save ground truth for plotting
        ys.append(y)

        # Send to GPU (if available)
        x = x.to(device)
        y = y.to(device)

        # Forward pass
        pred = model(x)

        loss = criterion(pred, y)
        total_loss += loss.item()

        # Calculate Accuracy
        pred = pred.argmax(-1)
        acc = (pred == y).float().mean()
        total_acc += acc

        # Save prediction for plotting
        preds.append(pred.cpu())
    t1 = time()
    preds = torch.cat(preds)
    ys = torch.cat(ys)

print(f"Loss: {total_loss/len(loader):4f}\n"
      f"Acc: {total_acc/len(loader):4f}\n"
      f"Time per pred: {(t1-t0)/len(loader)}")

In [None]:
print(classification_report(ys, preds))

In [None]:
print(confusion_matrix(ys, preds))

In [None]:
window = slice(0, 500)
_, (ax1, ax2) = plt.subplots(2, sharex=True)
ax1.plot(ys[window])
ax1.set_title("Ground Truth")
ax2.plot(preds[window])
ax2.set_title("Predicted")