In [None]:
# MSFAET vs CCT Performance Comparison
## Multi-Scale Frequency-Aware EEG Transformer の性能評価


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import time
import pickle
from torch.utils.data import DataLoader, TensorDataset
from torchinfo import summary
import warnings
warnings.filterwarnings('ignore')

# Import models
from model.cct import CCT
from model.msfaet import MSFAET


In [None]:
## Model Architecture Comparison


In [None]:
# Initialize models for comparison
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# CCT model (original)
cct_model = CCT(
    kernel_sizes=[(22, 1), (1, 24)], stride=(1, 1), padding=(0, 0),
    pooling_kernel_size=(3, 3), pooling_stride=(1, 1), pooling_padding=(0, 0),
    n_conv_layers=2, n_input_channels=1, in_planes=64, activation=None,
    max_pool=False, conv_bias=False, dim=64, num_layers=4, num_heads=8, num_classes=2,
    attn_dropout=0.1, dropout=0.1, mlp_size=64, positional_emb="learnable"
).to(device)

# MSFAET model (new)
msfaet_model = MSFAET(
    n_channels=22, n_classes=2, dim=64, depth=4, num_heads=8,
    mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1
).to(device)

print("\n=== Model Summary ===")
print("\nCCT Model:")
summary(cct_model, input_size=(32, 1, 22, 1000), verbose=0)

print("\nMSFAET Model:")
summary(msfaet_model, input_size=(32, 1, 22, 1000), verbose=0)
