## 脳波分類

In [2]:
import os, sys
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
import hydra
from omegaconf import DictConfig # Operate configs as a dict
import wandb
from termcolor import cprint
from tqdm import tqdm
from torcheeg import transforms

from src.datasets import ThingsMEGDataset
from src.models import BasicConvClassifier
from src.utils import set_seed
from src.preprocess import CAR, ToNDarray, extract_timepoint

### 設定の読み込み

In [3]:
# Configs
import yaml

class AttrDict:
    """
    辞書を受け取り、属性アクセス可能なオブジェクトに変換するクラスです。
    """
    def __init__(self, dictionary: dict):
        for key, value in dictionary.items():
            setattr(self, key, value)

with open('configs\config.yaml') as file:
    args = yaml.safe_load(file.read())

args = AttrDict(args)

### データの読み込み

#### 自作のデータセットの作成（dataset.py)

In [7]:
import os
import numpy as np
import torch
from typing import Tuple
from termcolor import cprint


class ThingsMEGDataset(torch.utils.data.Dataset):
    # Train = ThingsMEGDataset("train",data_dir)
    # 
    # Methods:
    # Train.split: data type
    # Train.num_classes: number of classes
    # Train.X: data [n, ch, seq]
    # Train.subject_idxs: subject index for each sample
    # Train.y: true labels
    
    def __init__(self, split: str, data_dir: str = "data", transform = None) -> None:
        super().__init__()
        
        assert split in ["train", "val", "test"], f"Invalid split: {split}"
        self.split = split
        self.num_classes = 1854
        
        self.X = torch.load(os.path.join(data_dir, f"{split}_X.pt"))
        self.subject_idxs = torch.load(os.path.join(data_dir, f"{split}_subject_idxs.pt"))
        self.n_subject = len(self.subject_idxs.unique())
        
        self.transform = transform
        
        if split in ["train", "val"]:
            self.y = torch.load(os.path.join(data_dir, f"{split}_y.pt"))
            assert len(torch.unique(self.y)) == self.num_classes, "Number of classes do not match."

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, i):
        x = self.X[i]
        if self.transform:
            x = self.transform(eeg=x)['eeg']
            
        if hasattr(self, "y"):
            return x, self.y[i], self.subject_idxs[i]
        else:
            return self.X[i], self.subject_idxs[i]
        
    @property
    def num_channels(self) -> int:
        return self.X.shape[1]
    
    @property
    def seq_len(self) -> int:
        return self.X.shape[2]

#### Transformクラスの作成

In [5]:
class ToNDarray(object):
    def __init__(self):
        pass

    def __call__(self, eeg):
        x = eeg
        x = x.detach().clone().cpu()   #x=(C,N,T)
        x = x.numpy()   #x=(C,N,T)
        
        return {'eeg': x}

class CAR(object): # Global contrast normalization
    '''
    Class to process common median reference for eeg signals. Input data should be a tensor with a shape [C, T].
    '''
    def __init__(self):
        pass

    def __call__(self, eeg):
        x = eeg
        # noises inside channels
        inner_med = x.median(dim=-1).values
        x = x.t()
        x -= inner_med
        x = x.t()
        
        # noises shared among channels
        inter_med = x.median(dim=-2).values 
        x -= inter_med

        return {'eeg': x}

class extract_timepoint(object):
    def __init__(self, start, end):
        # time window to extract
        self.start = start
        self.end = end

    def __call__(self, eeg):
        x = eeg
        x = x[:,self.start:self.end]
        return {'eeg': x}
        

#### データローダー

In [116]:
c = transforms.BandSignal(sampling_rate=200)
X = train_set[0][0]
c(eeg=X)['eeg'].reshape((-1,281)).shape

(1084, 281)

In [17]:
set_seed(args.seed)

# ------------------
#    Dataloader
# ------------------
# loader_args = {"batch_size": args.batch_size, "num_workers": args.num_workers}
loader_args = {"batch_size": args.batch_size, "num_workers": 1}
train_set = ThingsMEGDataset("train", args.data_dir) # [n, ch, seq]
train_set.transform = transforms.Compose([
    # CAR(),
    ToNDarray(),
    transforms.MeanStdNormalize(axis=-1),
    transforms.ToTensor()
    # RandomMask()
])
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)

# Load valid data
val_set = ThingsMEGDataset("val", args.data_dir) # [n, ch, seq]
val_set.transform = transforms.Compose([
    # CAR(),
    ToNDarray(),
    transforms.MeanStdNormalize(axis=-1),
    transforms.ToTensor()
    # RandomMask()
])
val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args)

# Load test data
test_set = ThingsMEGDataset("test", args.data_dir) # [n, ch, seq]
test_set.transform = transforms.Compose([
    # CAR(),
    ToNDarray(),
    transforms.MeanStdNormalize(axis=-1),
    transforms.ToTensor()
])
test_loader = torch.utils.data.DataLoader(
    test_set, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers
)

### モデル

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

# Original models
class BasicConvClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int,
        seq_len: int,
        in_channels: int,
        hid_dim: int = 128
    ) -> None:
        super().__init__()

        self.blocks = nn.Sequential(
            ConvBlock(in_channels, hid_dim),
            ConvBlock(hid_dim, hid_dim),
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            Rearrange("b d 1 -> b d"),
            nn.Linear(hid_dim, num_classes),
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        """_summary_
        Args:
            X ( b, c, t ): _description_
        Returns:
            X ( b, num_classes ): _description_
        """
        X = self.blocks(X)

        return self.head(X)


class ConvBlock(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        kernel_size: int = 3,
        p_drop: float = 0.1,
    ) -> None:
        super().__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.conv0 = nn.Conv1d(in_dim, out_dim, kernel_size, padding="same")
        self.conv1 = nn.Conv1d(out_dim, out_dim, kernel_size, padding="same")
        # self.conv2 = nn.Conv1d(out_dim, out_dim, kernel_size) # , padding="same")
        
        self.batchnorm0 = nn.BatchNorm1d(num_features=out_dim)
        self.batchnorm1 = nn.BatchNorm1d(num_features=out_dim)

        self.dropout = nn.Dropout(p_drop)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        if self.in_dim == self.out_dim:
            X = self.conv0(X) + X  # skip connection
        else:
            X = self.conv0(X)

        X = F.gelu(self.batchnorm0(X))

        X = self.conv1(X) + X  # skip connection
        X = F.gelu(self.batchnorm1(X))

        # X = self.conv2(X)
        # X = F.glu(X, dim=-2)

        return self.dropout(X)

In [93]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

# Model based on Defossez et al., 2023
# 1. Convolution layer without activation
# 2. Subject dependent convlution layer without activation
# 3. 
class DefossezClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int,
        seq_len: int,
        in_channels: int,
        hid_dim: int = 320,
        n_subject: int = 4,
    ) -> None:
        super().__init__()

        self.blocks = nn.Sequential(
            SimpleConvBlock(in_dim=in_channels, out_dim=in_channels, kernel_size=1),
            SubjectBlock(in_channels,in_channels,n_subject),
            ConvBlock(in_dim=in_channels,hid_dim=hid_dim,out_dim=hid_dim*2,kernel_size=3,k=0),
            ConvBlock(in_dim=hid_dim,hid_dim=hid_dim,out_dim=hid_dim*2,kernel_size=3,k=1),
            ConvBlock(in_dim=hid_dim,hid_dim=hid_dim,out_dim=hid_dim*2,kernel_size=3,k=2),
            ConvBlock(in_dim=hid_dim,hid_dim=hid_dim,out_dim=hid_dim*2,kernel_size=3,k=3),
            ConvBlock(in_dim=hid_dim,hid_dim=hid_dim,out_dim=hid_dim*2,kernel_size=3,k=4),
            SimpleConvBlock(in_dim = hid_dim, out_dim = hid_dim*2, kernel_size=1, activate = True),
            SimpleConvBlock(in_dim = hid_dim*2, out_dim = hid_dim*2, kernel_size=1, activate = True),
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            Rearrange("b d 1 -> b d"),
            nn.Linear(hid_dim*2, num_classes),
        )

    def forward(self, X: torch.Tensor, subject_idxs: torch.Tensor) -> torch.Tensor:
        """_summary_
        Args:
            X ( b, c, t ): _description_
        Returns:
            X ( b, num_classes ): _description_
        """
        X = self.blocks(X, subject_idxs)

        return self.head(X)

In [94]:
class SimpleConvBlock(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        kernel_size: int = 1,
        p_drop: float = 0.1,
        activate: bool = False,
    ) -> None:
        super().__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.activate = activate

        self.conv0 = nn.Conv1d(in_dim, out_dim, kernel_size, padding="same")
        self.dropout = nn.Dropout(p_drop)

    def forward(self, X: torch.Tensor, subject_idxs: torch.Tensor) -> torch.Tensor:
        X = self.conv0(X)
        if self.activate:
            X = F.gelu(X)
        return self.dropout(X)

In [95]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        in_dim = 271,
        hid_dim = 320,
        out_dim = 640,
        kernel_size: int = 3,
        k: int = 0,
        p_drop: float = 0.1,
    ) -> None:
        super().__init__()
        
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim

        self.conv0 = nn.Conv1d(in_dim, hid_dim, kernel_size, padding="same",dilation=2**(2*k)%5)
        self.conv1 = nn.Conv1d(hid_dim, hid_dim, kernel_size, padding="same",dilation=2**(2*k+1)%5)
        self.conv2 = nn.Conv1d(hid_dim, out_dim, kernel_size, padding="same",dilation=2)
        
        self.batchnorm0 = nn.BatchNorm1d(num_features=hid_dim)
        self.batchnorm1 = nn.BatchNorm1d(num_features=hid_dim)

        self.dropout = nn.Dropout(p_drop)

    def forward(self, X: torch.Tensor, subject_idxs: torch.Tensor) -> torch.Tensor:
        if self.in_dim == self.hid_dim:
            X = self.conv0(X) + X  # skip connection
        else:
            X = self.conv0(X)
        X = F.gelu(self.batchnorm0(X))

        X = self.conv1(X) + X  # skip connection
        X = F.gelu(self.batchnorm1(X))
        
        X = self.conv2(X)
        X = F.glu(X, dim=-2) # No normalization for 2nd convolution

        return self.dropout(X)

In [98]:
class SubjectBlock(nn.Module):
    """Subject linear layer"""
    def __init__(self, in_channels: int, out_channels: int, n_subjects: int = 4):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels)) # [S,C_in,C_out]
        self.weights.data *= 1 / in_channels**0.5 # Xavier initialization

    def forward(self, X:torch.Tensor, subject_idxs: torch.Tensor):
        _, C_in, C_out = self.weights.shape
        weights = self.weights.gather(0, subject_idxs.view(-1, 1, 1).expand(-1, C_in, C_out)) # Assign subject-specific weights: [B,C_in,C_out]
        return torch.einsum("bct,bcd->bdt", X, weights)

In [90]:
X = train_set[0:128][0]
subject_idxs = train_set[0:128][2]
n_subject = train_set.n_subject
print(f"Original shape: {X.shape}") # [B,C,T]
print(f"Subject index: {subject_idxs.shape}") # [B]

c = SimpleConvBlock(271, 271, kernel_size=1)
X = c(X) # [B,C_in,T] -> [B,C_out,T]
print(f"1st conv block: {X.shape}")

c = SubjectLayers(271,271,n_subject)
X = c(X,subject_idxs) # [B,C_in,T] -> [B,C_out,T]
print(f"Subject block: {X.shape}") # [B,C,T]

c = ConvBlock(in_dim=271,hid_dim=320,out_dim=640,kernel_size=3,k=0)
X = c(X)
print(f"Conv block (k=0): {X.shape}") # [B,C,T]

c = ConvBlock(in_dim=320,hid_dim=320,out_dim=640,kernel_size=3,k=1)
X = c(X)
print(f"Conv block (k=1): {X.shape}") # [B,C,T]

c = ConvBlock(in_dim=320,hid_dim=320,out_dim=640,kernel_size=3,k=2)
X = c(X)
print(f"Conv block (k=2): {X.shape}") # [B,C,T]

c = ConvBlock(in_dim=320,hid_dim=320,out_dim=640,kernel_size=3,k=3)
X = c(X)
print(f"Conv block (k=3): {X.shape}") # [B,C,T]

c = ConvBlock(in_dim=320,hid_dim=320,out_dim=640,kernel_size=3,k=4)
X = c(X)
print(f"Conv block (k=4): {X.shape}") # [B,C,T]

c = SimpleConvBlock(in_dim = 320, out_dim = 640, kernel_size=1, activate = True)
X = c(X)
print(f"2bd conv block: {X.shape}")

c = SimpleConvBlock(in_dim = 640, out_dim = 640, kernel_size=1, activate = True)
X = c(X)
print(f"Final conv block: {X.shape}")

c = nn.AdaptiveAvgPool1d(1)
X = c(X)
print(f"After adaptive pooling: {X.shape}")
c = Rearrange("b d 1 -> b d")
X = c(X)
print(f"After rearrangement: {X.shape}")
c = nn.Linear(640, 1854)
X = c(X)
print(f"Final output: {X.shape}")

Original shape: torch.Size([128, 271, 281])
Subject index: torch.Size([128])
1st conv block: torch.Size([128, 271, 281])
Subject block: torch.Size([128, 271, 281])
Conv block (k=0): torch.Size([128, 320, 281])
Conv block (k=1): torch.Size([128, 320, 281])
Conv block (k=2): torch.Size([128, 320, 281])
Conv block (k=3): torch.Size([128, 320, 281])
Conv block (k=4): torch.Size([128, 320, 281])
2bd conv block: torch.Size([128, 640, 281])
Final conv block: torch.Size([128, 640, 281])
After adaptive pooling: torch.Size([128, 640, 1])
After rearrangement: torch.Size([128, 640])
Final output: torch.Size([128, 1854])


In [33]:
X = train_set[0:128][0]
model = BasicConvClassifier(
    train_set.num_classes, train_set.seq_len, train_set.num_channels
)
model.forward(X).shape

torch.Size([128, 1854])

In [None]:
        self.blocks = nn.Sequential(
            ConvBlock(in_channels, hid_dim),
            ConvBlock(hid_dim, hid_dim),
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            Rearrange("b d 1 -> b d"),
            nn.Linear(hid_dim, num_classes),
        )

### モデルの訓練

In [178]:
# ------------------
#       Model
# ------------------
model = BasicConvClassifier(
    train_set.num_classes, train_set.seq_len, train_set.num_channels
).to(args.device)

# ------------------
#     Optimizer
# ------------------
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

In [None]:


# ------------------
#   Start training
# ------------------  
max_val_acc = 0
accuracy = Accuracy(
    task="multiclass", num_classes=train_set.num_classes, top_k=10
).to(args.device)
  
for epoch in range(args.epochs):
    print(f"Epoch {epoch+1}/{args.epochs}")
    
    train_loss, train_acc, val_loss, val_acc = [], [], [], []
    
    model.train()
    for X, y, subject_idxs in tqdm(train_loader, desc="Train"):
        X, y = X.to(args.device), y.to(args.device)

        y_pred = model(X)
        
        loss = F.cross_entropy(y_pred, y)
        train_loss.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        acc = accuracy(y_pred, y)
        train_acc.append(acc.item())

    model.eval()
    for X, y, subject_idxs in tqdm(val_loader, desc="Validation"):
        X, y = X.to(args.device), y.to(args.device)
        
        with torch.no_grad():
            y_pred = model(X)
        
        val_loss.append(F.cross_entropy(y_pred, y).item())
        val_acc.append(accuracy(y_pred, y).item())

    print(f"Epoch {epoch+1}/{args.epochs} | train loss: {np.mean(train_loss):.3f} | train acc: {np.mean(train_acc):.3f} | val loss: {np.mean(val_loss):.3f} | val acc: {np.mean(val_acc):.3f}")
    torch.save(model.state_dict(), os.path.join(logdir, "model_last.pt"))
    if args.use_wandb:
        wandb.log({"train_loss": np.mean(train_loss), "train_acc": np.mean(train_acc), "val_loss": np.mean(val_loss), "val_acc": np.mean(val_acc)})
    
    if np.mean(val_acc) > max_val_acc:
        cprint("New best.", "cyan")
        torch.save(model.state_dict(), os.path.join(logdir, "model_best.pt"))
        max_val_acc = np.mean(val_acc)

### モデルの検証

In [102]:
# ----------------------------------
#  Start evaluation with best model
# ----------------------------------
model.load_state_dict(torch.load("outputs\\2024-06-19\\19-07-08\\model_best.pt", map_location=args.device))

preds = [] 
model.eval()
for X, subject_idxs in tqdm(test_loader, desc="Validation"):        
    preds.append(model(X.to(args.device)).detach().cpu())
    
preds = torch.cat(preds, dim=0).numpy()
np.save(os.path.join(logdir, "submission"), preds)
cprint(f"Submission {preds.shape} saved at {logdir}", "cyan")

RuntimeError: Error(s) in loading state_dict for BasicConvClassifier:
	Missing key(s) in state_dict: "blocks.0.conv0.weight", "blocks.0.conv0.bias", "blocks.0.conv1.weight", "blocks.0.conv1.bias", "blocks.0.batchnorm0.weight", "blocks.0.batchnorm0.bias", "blocks.0.batchnorm0.running_mean", "blocks.0.batchnorm0.running_var", "blocks.0.batchnorm1.weight", "blocks.0.batchnorm1.bias", "blocks.0.batchnorm1.running_mean", "blocks.0.batchnorm1.running_var", "blocks.1.conv0.weight", "blocks.1.conv0.bias", "blocks.1.conv1.weight", "blocks.1.conv1.bias", "blocks.1.batchnorm0.weight", "blocks.1.batchnorm0.bias", "blocks.1.batchnorm0.running_mean", "blocks.1.batchnorm0.running_var", "blocks.1.batchnorm1.weight", "blocks.1.batchnorm1.bias", "blocks.1.batchnorm1.running_mean", "blocks.1.batchnorm1.running_var". 
	Unexpected key(s) in state_dict: "pre_conv_block.0.conv0.weight", "pre_conv_block.0.conv0.bias", "subject_block.weights", "post_conv_block.0.conv0.weight", "post_conv_block.0.conv0.bias", "post_conv_block.0.conv1.weight", "post_conv_block.0.conv1.bias", "post_conv_block.0.conv2.weight", "post_conv_block.0.conv2.bias", "post_conv_block.0.batchnorm0.weight", "post_conv_block.0.batchnorm0.bias", "post_conv_block.0.batchnorm0.running_mean", "post_conv_block.0.batchnorm0.running_var", "post_conv_block.0.batchnorm0.num_batches_tracked", "post_conv_block.0.batchnorm1.weight", "post_conv_block.0.batchnorm1.bias", "post_conv_block.0.batchnorm1.running_mean", "post_conv_block.0.batchnorm1.running_var", "post_conv_block.0.batchnorm1.num_batches_tracked", "post_conv_block.1.conv0.weight", "post_conv_block.1.conv0.bias", "post_conv_block.1.conv1.weight", "post_conv_block.1.conv1.bias", "post_conv_block.1.conv2.weight", "post_conv_block.1.conv2.bias", "post_conv_block.1.batchnorm0.weight", "post_conv_block.1.batchnorm0.bias", "post_conv_block.1.batchnorm0.running_mean", "post_conv_block.1.batchnorm0.running_var", "post_conv_block.1.batchnorm0.num_batches_tracked", "post_conv_block.1.batchnorm1.weight", "post_conv_block.1.batchnorm1.bias", "post_conv_block.1.batchnorm1.running_mean", "post_conv_block.1.batchnorm1.running_var", "post_conv_block.1.batchnorm1.num_batches_tracked", "post_conv_block.2.conv0.weight", "post_conv_block.2.conv0.bias", "post_conv_block.2.conv1.weight", "post_conv_block.2.conv1.bias", "post_conv_block.2.conv2.weight", "post_conv_block.2.conv2.bias", "post_conv_block.2.batchnorm0.weight", "post_conv_block.2.batchnorm0.bias", "post_conv_block.2.batchnorm0.running_mean", "post_conv_block.2.batchnorm0.running_var", "post_conv_block.2.batchnorm0.num_batches_tracked", "post_conv_block.2.batchnorm1.weight", "post_conv_block.2.batchnorm1.bias", "post_conv_block.2.batchnorm1.running_mean", "post_conv_block.2.batchnorm1.running_var", "post_conv_block.2.batchnorm1.num_batches_tracked", "post_conv_block.3.conv0.weight", "post_conv_block.3.conv0.bias", "post_conv_block.3.conv1.weight", "post_conv_block.3.conv1.bias", "post_conv_block.3.conv2.weight", "post_conv_block.3.conv2.bias", "post_conv_block.3.batchnorm0.weight", "post_conv_block.3.batchnorm0.bias", "post_conv_block.3.batchnorm0.running_mean", "post_conv_block.3.batchnorm0.running_var", "post_conv_block.3.batchnorm0.num_batches_tracked", "post_conv_block.3.batchnorm1.weight", "post_conv_block.3.batchnorm1.bias", "post_conv_block.3.batchnorm1.running_mean", "post_conv_block.3.batchnorm1.running_var", "post_conv_block.3.batchnorm1.num_batches_tracked", "post_conv_block.4.conv0.weight", "post_conv_block.4.conv0.bias", "post_conv_block.4.conv1.weight", "post_conv_block.4.conv1.bias", "post_conv_block.4.conv2.weight", "post_conv_block.4.conv2.bias", "post_conv_block.4.batchnorm0.weight", "post_conv_block.4.batchnorm0.bias", "post_conv_block.4.batchnorm0.running_mean", "post_conv_block.4.batchnorm0.running_var", "post_conv_block.4.batchnorm0.num_batches_tracked", "post_conv_block.4.batchnorm1.weight", "post_conv_block.4.batchnorm1.bias", "post_conv_block.4.batchnorm1.running_mean", "post_conv_block.4.batchnorm1.running_var", "post_conv_block.4.batchnorm1.num_batches_tracked", "post_conv_block.5.conv0.weight", "post_conv_block.5.conv0.bias", "post_conv_block.6.conv0.weight", "post_conv_block.6.conv0.bias". 
	size mismatch for head.2.weight: copying a param with shape torch.Size([1854, 640]) from checkpoint, the shape in current model is torch.Size([1854, 128]).