# 1.0 Environment:

The model is trained with PyTorch and thus can be used in Python environments. The following packages meet the minimal requirements to load the model. 

*The Python version used in the pre-training phase is Python 3.11.5.

|   Package   |   Version   |   Note                                    |
|-------------|:-----------:|------------------------------------------:|
|  timm       |  1.0.10     |Basic implementations of transformer models|
|  torch      |2.4.1        |Deep learning framework                    |


Codes written by Liuyin Yang (liuyin.yang@kuleuven.be)

All rights reserved.

--------------------------------------------------------

# 1.1 Initialize a ViT model instance:

An MAE model contains an encoder and a decoder, which are used in the pre-training phase to perform a reconstruction task. Both encoder and decoder are stored in the checkpoint file but for the downstreaming class, we will only use the encoder part.

The ViT model shares the same backbone of the MAE encoder and decoder model. It is used for the downstreaming classification task. There are a few differences for example:

1. num_classes defines number of classification classes (the number of last fc layer output neurons) 
2. drop_path_rate defines the droppath rate
3. global_pool if True, using the global average pooling on all tokens as the final representation for the last layer input. If False, one can alternatively use the class token as the final representation. 

* Notice that in the MAE pretraining phase, there is no concept of num_classes. The num_classes are used to initialize a new fully connected layer that is concatenated at the end of the model. Typically, it needs to be trained to adapt to the downstream classification task.

* Notice this model works with EEG data @128Hz, bandpassed between 0.5-64Hz, standardized per channel (z-standardization)

All rights reserved.

--------------------------------------------------------

In [1]:
import models_vit_eeg
vit_model_variant = "vit_small_patch16"
encoder_model = models_vit_eeg.__dict__[vit_model_variant](
                        num_classes=4,
                        drop_path_rate=0.1,
                        global_pool=True)

  from .autonotebook import tqdm as notebook_tqdm


gloabl pool: True


# 1.2 Load the vit model weights from a checkpoint: 
To load the encoder, which is part of the mae model, you can run the following codes. It tries to find the matching weights from the checkpoint and disregard the others. In the end, you can also try to initialize the weights of the last fc layer for your training: using the trunc_normal_().

In [2]:
import torch
from timm.models.layers import trunc_normal_
check_point_dir = "/lustre1/project/stg_00160/new_eeg_mae/small/checkpoint-200.pth"
global_pool = True
checkpoint = torch.load(check_point_dir, map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = encoder_model.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        #print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]

# load pre-trained model
msg = encoder_model.load_state_dict(checkpoint_model, strict=False)
# Some sanity checks
if global_pool:
    assert set(msg.missing_keys) == {'pos_embed','head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
else:
    assert set(msg.missing_keys) == {'pos_embed','head.weight', 'head.bias'}

# manually initialize fc layer
#trunc_normal_(encoder_model.head.weight, std=2e-5)

# 2.1 Passing data to the ViT model 
All Pytorch models work with torch.tensor, which can be converted from numpy.array. For the ViT model, it needs two inputs: eeg data in a shape of (batch, channel, time_points), sensloc in a shape of (batch, channel). The output is the last fc layer output that relates to the probability of your number of classes in a shape of (batch, num_classes).

In [4]:
# construct fake EEG input data from numpy
import numpy as np
    eeg_data = np.random.rand(32,128*3) # construct a 3-second 32-channel EEG data sampled at 128 Hz
    eeg_senloc = np.random.randint(0,145,32) # construct its corresponding channel index 

    # convert the data to torch.tensor
    eeg_data = torch.from_numpy(eeg_data).float()
    eeg_senloc = torch.from_numpy(eeg_senloc).int()
    print(eeg_data.shape, eeg_senloc.shape)

    # create the batch dimension to be 1
    eeg_data = torch.unsqueeze(eeg_data,0)
    eeg_senloc = torch.unsqueeze(eeg_senloc,0)
print("after unsqueezing:", eeg_data.shape, eeg_senloc.shape)

# pass into the vit model
out = encoder_model(eeg_data, eeg_senloc)
print(out.shape)

torch.Size([32, 384]) torch.Size([32])
after unsqueezing: torch.Size([1, 32, 384]) torch.Size([1, 32])
torch.Size([1, 4])


# 2.2 Channel index 
Each channel has a unique index for the channel positional embeddings. This can be retrived from the senloc file

In [9]:
import pickle

with open("/vsc-hard-mounts/leuven-data/343/vsc34340/new_eeg_mae/senloc_file/sen_chan_idx.pkl", "rb") as f:
    data = pickle.load(f)
# you can get the channel index for the high gamma dataset as
hgd_chan_idx = data['hgd']
# you can also check per channel index:
chan_idx_map = data['channels_mapping']

print(hgd_chan_idx)
print(chan_idx_map)

[  6  60 112 133  71  25  11 118  49  18  34  19  31  65 130 110   2  68
  55  51  85  46  52  41  74   1 116  21  63  95   5 103  37  86  13  39
  72  42  54   3  70  14  75  20   0  10 127  84 119  62  90  99 117  23
  96 102  58 101 122  22 111  16 114 129  45  87  89  27  44  38  24  29
  26 100  28  53 121 109  77   7 120 106  32 126  94  67  15 104 115  80
   9 128  35 123  47  91  36  43  57 125 107 105 108  78  82  33  48  76
   4 113  64  92  79 124 132  40  30   8  97  98  83  69  56  59 131  17
  88  12]
{'C1': 0, 'Pz': 1, 'C4': 2, 'F6': 3, 'FTT8h': 4, 'Oz': 5, 'Fp1': 6, 'FCC5h': 7, 'TPP8h': 8, 'CPP6h': 9, 'C2': 10, 'F4': 11, 'OI2h': 12, 'AF4': 13, 'FCz': 14, 'CCP6h': 15, 'TP8': 16, 'POO10h': 17, 'FC1': 18, 'FC6': 19, 'C5': 20, 'P8': 21, 'FT8': 22, 'P6': 23, 'P9': 24, 'Fz': 25, 'AFF1': 26, 'TPP10h': 27, 'AFF2': 28, 'P10': 29, 'CPP2h': 30, 'M1': 31, 'FCC6h': 32, 'FTT7h': 33, 'FC2': 34, 'PPO2': 35, 'AFp3h': 36, 'AF7': 37, 'PO10': 38, 'AF8': 39, 'CPP1h': 40, 'P7': 41, 'F1': 42,

# 3.1 Get attention scores in a forward pass
By default timm's ViT implementations do not have the attention score outputs. Therefore, we need some triks in the model forward call, in order to save the model per-layer attention scores. This requires some manipulations of the timm ViT implementations, which can be done by replacing the original vision_transformer.py file by the one in the folder: /adaptation.  The following example shows a way to get the attention score after you replace the vision_transformer.py file.

In [17]:
encoder_model.eval()
avg_class0_attention_score = np.zeros((8,8,769,769)) # 8 layers, 8 heads, 769x769
avg_class1_attention_score = np.zeros((8,8,769,769)) # 8 layers, 8 heads, 769x769

# assume we have many trials that are of class 0 and 1
# we wnat to save the final averaged attention to the avg_class0_attention_score array and the avg_class1_attention_score array
# keeping all attention scores is memory-demanding, we sum them up and divide by the number of trials in th end

# random 5 trials
labels = [0,0,1,1,0] 

class_0_count = 0
class_1_count = 0

# first call this to enable tracking attention
encoder_model.register_hook()

for trial in range(5):
    # random trial data
    eeg_data = np.random.rand(32,128*3) # construct a 3-second 32-channel EEG data sampled at 128 Hz
    eeg_senloc = np.random.randint(0,145,32) # construct its corresponding channel index 

    # convert the data to torch.tensor
    eeg_data = torch.from_numpy(eeg_data).float()
    eeg_senloc = torch.from_numpy(eeg_senloc).int()

    # create the batch dimension to be 1
    eeg_data = torch.unsqueeze(eeg_data,0)
    eeg_senloc = torch.unsqueeze(eeg_senloc,0)

    out = encoder_model(eeg_data, eeg_senloc)
    
    for layer in range(8):
        if labels[trial] == 0:
            class_0_count+=1
            avg_class0_attention_score[layer,:,:,:] = avg_class0_attention_score[layer,:,:,:]+torch.squeeze(encoder_model.attn_scores[layer]).detach().cpu().numpy()
        elif labels[trial] == 1:
            class_1_count+=1
            avg_class1_attention_score[layer,:,:,:] = avg_class1_attention_score[layer,:,:,:]+torch.squeeze(encoder_model.attn_scores[layer]).detach().cpu().numpy()
        else:
            print("wrong trial", flush=True)
    # remember to clear this internal list which stores the attention score for one forward call before calling the next (otherwise it will just concatenate new attention scores and you will soon run out of memory)
    encoder_model.attn_scores=[]
    
    
# calculate the avg attention
avg_class0_attention_score = avg_class0_attention_score/class_0_count
avg_class1_attention_score = avg_class1_attention_score/class_1_count
    
# print the attention score
print(avg_class0_attention_score.shape)

(8, 8, 769, 769)


# 3.1.1 More details about the dimension of the attention score

This attention is calculated on the basis of tokens. Therefore it's important to know how tokenization is done inside the model.
Given an input EEG data in a shape of (num_channels, num_data_points), the model patches per channel and every #patch_size along the time dimension. The patch size for the vit_small_patch16 model is 16. In other words, every 16 data points (under a sampling rate of 128 Hz~=0.125s) will be converted into a token. So you can calculate the attention score matrix shape given an input data shape.

Example: In the previous example, the input is (32,384), for each channel, 384 data points will be converted to 24 tokens, in the end we have 32x24+1 = 769 tokens (with an extra class token). After the self-attention (you can simply regard it as calculating pairwise similarity between all token pairs), you will have an attention map of 769x769 

Another question you may ask is from the attention score matrix, how do I know each element belongs to which channel at which time. This is linked to the tokenization steps:

1. Patchify

EEG comes in as a tensor of shape (B,C,L) (batch, channels, time-samples). The model breaks each channel into non-overlapping patches of length ùëÉ (your patch_size), giving Seq = ùêø/ùëÉ patches per channel.

2. Project & reshape

After the linear projection you have shape (B,Seq,C,D). The model then flattens the middle two dims into a single ‚Äútoken‚Äù dimension (B,Seq√óC,D). Concretely, token index t‚àà[0,‚Ä¶,Seq√óC‚àí1] corresponds to time_patch=t/C,channel_idx=tmodC.

3. Add a CLS token

We concat one extra ‚ÄúCLS‚Äù token at index 0, so the total sequence length is N=1+Seq√óC.

4. Attention scores

After a forward pass with hooks, each block‚Äôs .attn_scores has shape (B,heads,N,N). Element attn[b,h,i,j] is ‚Äúhow much head h in example b attends from token i to token j.‚Äù

# 3.1.2 Interpreting the per-layer attention score
Gaining insights from computer vision, we can analyze the attention scores in the following way: attention rollout (check https://github.com/jacobgil/vit-explain). This is to mimic the attention flow inside the model.

Using the attention scores per-layer we obtained previously, you can pass the following function to obtain the attention rollout that hopefully tells you where the model attend to. 

In [19]:
# Attention rollout for estimating the flow of attention in ViT
# Formula: Ar = (A(l)+I)*Ar(l-1), Ar(1) = A1+I
def rollout(attentions, discard_ratio, head_fusion):
    result = torch.eye(attentions.shape[-1])
    with torch.no_grad():
        for attention in attentions:
            #print(attention.shape)
            if head_fusion == "mean":
                attention_heads_fused = attention.mean(axis=0)
            elif head_fusion == "max":
                attention_heads_fused = attention.max(axis=0)[0]
            elif head_fusion == "min":
                attention_heads_fused = attention.min(axis=0)[0]
            elif head_fusion <12:
                attention_heads_fused = attention[head_fusion]
            else:
                raise "Attention head fusion type Not supported"

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(1, -1)

            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            indices = indices[indices != 0]
            flat[0, indices] = 1e-10
            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2
            a = a / a.sum(dim=-1)
            
            result = torch.matmul(a, result)
    
    return result  


roll_out_attention = rollout(attentions=torch.from_numpy(avg_class0_attention_score).type(torch.FloatTensor), discard_ratio=0.90, head_fusion="mean")

print("attention rollout:", roll_out_attention.shape)

attention rollout: torch.Size([769, 769])


# 3.1.3 Check the attention matrix structure


In [37]:
#1 input data tokenization:
import torch
from models.models_vit_eeg import PatchEmbedEEG
test_data = torch.randn(1,3,128) # 3 channel, 128 time points
for channel in range(3):
    for time in range(128):
        test_data[0,channel,time] = 1+channel+time/1000
        
print(test_data.shape)
print(test_data)

torch.Size([1, 3, 128])
tensor([[[1.0000, 1.0010, 1.0020, 1.0030, 1.0040, 1.0050, 1.0060, 1.0070,
          1.0080, 1.0090, 1.0100, 1.0110, 1.0120, 1.0130, 1.0140, 1.0150,
          1.0160, 1.0170, 1.0180, 1.0190, 1.0200, 1.0210, 1.0220, 1.0230,
          1.0240, 1.0250, 1.0260, 1.0270, 1.0280, 1.0290, 1.0300, 1.0310,
          1.0320, 1.0330, 1.0340, 1.0350, 1.0360, 1.0370, 1.0380, 1.0390,
          1.0400, 1.0410, 1.0420, 1.0430, 1.0440, 1.0450, 1.0460, 1.0470,
          1.0480, 1.0490, 1.0500, 1.0510, 1.0520, 1.0530, 1.0540, 1.0550,
          1.0560, 1.0570, 1.0580, 1.0590, 1.0600, 1.0610, 1.0620, 1.0630,
          1.0640, 1.0650, 1.0660, 1.0670, 1.0680, 1.0690, 1.0700, 1.0710,
          1.0720, 1.0730, 1.0740, 1.0750, 1.0760, 1.0770, 1.0780, 1.0790,
          1.0800, 1.0810, 1.0820, 1.0830, 1.0840, 1.0850, 1.0860, 1.0870,
          1.0880, 1.0890, 1.0900, 1.0910, 1.0920, 1.0930, 1.0940, 1.0950,
          1.0960, 1.0970, 1.0980, 1.0990, 1.1000, 1.1010, 1.1020, 1.1030,
          1.10

In [38]:
test_patch = PatchEmbedEEG(patch_size=16, embed_dim=256)

patched_out = test_patch.patchify_eeg(test_data)

In [39]:
print(patched_out.shape, patched_out) #Batch, Seq, Ch, L 128/16=8

torch.Size([1, 8, 3, 16]) tensor([[[[1.0000, 1.0010, 1.0020, 1.0030, 1.0040, 1.0050, 1.0060, 1.0070,
           1.0080, 1.0090, 1.0100, 1.0110, 1.0120, 1.0130, 1.0140, 1.0150],
          [2.0000, 2.0010, 2.0020, 2.0030, 2.0040, 2.0050, 2.0060, 2.0070,
           2.0080, 2.0090, 2.0100, 2.0110, 2.0120, 2.0130, 2.0140, 2.0150],
          [3.0000, 3.0010, 3.0020, 3.0030, 3.0040, 3.0050, 3.0060, 3.0070,
           3.0080, 3.0090, 3.0100, 3.0110, 3.0120, 3.0130, 3.0140, 3.0150]],

         [[1.0160, 1.0170, 1.0180, 1.0190, 1.0200, 1.0210, 1.0220, 1.0230,
           1.0240, 1.0250, 1.0260, 1.0270, 1.0280, 1.0290, 1.0300, 1.0310],
          [2.0160, 2.0170, 2.0180, 2.0190, 2.0200, 2.0210, 2.0220, 2.0230,
           2.0240, 2.0250, 2.0260, 2.0270, 2.0280, 2.0290, 2.0300, 2.0310],
          [3.0160, 3.0170, 3.0180, 3.0190, 3.0200, 3.0210, 3.0220, 3.0230,
           3.0240, 3.0250, 3.0260, 3.0270, 3.0280, 3.0290, 3.0300, 3.0310]],

         [[1.0320, 1.0330, 1.0340, 1.0350, 1.0360, 1.0370, 1.038

In [27]:
patched_out[0,1,0,:]

tensor([1.0160, 1.0170, 1.0180, 1.0190, 1.0200, 1.0210, 1.0220, 1.0230, 1.0240,
        1.0250, 1.0260, 1.0270, 1.0280, 1.0290, 1.0300, 1.0310])

In [40]:
# then the model flatten this to a 1-d token vector:
B, Seq, Ch_all, Dmodel = patched_out.shape
Seq_total = Seq*Ch_all
patched_out2 = patched_out.reshape(B,Seq_total,Dmodel)
print(patched_out2.shape)

torch.Size([1, 24, 16])


In [44]:
print(patched_out2[0,3,:])

tensor([1.0160, 1.0170, 1.0180, 1.0190, 1.0200, 1.0210, 1.0220, 1.0230, 1.0240,
        1.0250, 1.0260, 1.0270, 1.0280, 1.0290, 1.0300, 1.0310])
