In [6]:
import os

import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

os.environ["WANDB_API_KEY"] = "KEY"
os.environ["WANDB_MODE"] = 'offline'
from itertools import combinations

import clip
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import tqdm
from eegdatasets_leaveone import EEGDataset
from eegencoder import eeg_encoder
from einops.layers.torch import Rearrange, Reduce
from lavis.models.clip_models.loss import ClipLoss
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
import random
from utils import wandb_logger
from torch import Tensor
import math
from modules.fft import *
from modules.MambaIR import *
from torch.autograd import Variable

criterion_cls = torch.nn.CrossEntropyLoss().cuda()
LongTensor = torch.cuda.LongTensor 
temperature = 0.07

def get_1d_sincos_pos_embed(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb

def get_sph_sincos_pos_embed(embed_dim, sph_coordination, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """

    sph_coordination = sph_coordination.reshape(2,-1)

    sph_the = sph_coordination[0]
    sph_phi = sph_coordination[1]
    # use half of dimensions to encode sph_theta
    emb_h = get_1d_sincos_pos_embed(embed_dim // 2, sph_the)  # (channel_number, D/2)
    # use half of dimensions to encode sph_phi
    emb_w = get_1d_sincos_pos_embed(embed_dim // 2, sph_phi)  # (channel_number, D/2)

    pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (channel_number, D)

    if cls_token:
        pos_embed = np.concatenate([pos_embed,np.zeros([1, embed_dim])], axis=0)
    return pos_embed

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # Ensure d_model is even
        if d_model % 2 != 0:
            d_model += 1

        self.d_model = d_model
        self.max_len = max_len

        # Generate spherical positional embeddings
        pe = get_sph_sincos_pos_embed(d_model, np.arange(max_len), cls_token=False)
        self.pe = torch.tensor(pe, dtype=torch.float).unsqueeze(1)
        # self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [time_length, batch_size, channel]
        """
        batch_size = x.size(1)
        time_length = x.size(0)
        
        # Adjust positional embeddings for batch size
        pe = self.pe[:time_length, :, :].repeat(1, batch_size, 1)
        pe = pe.to(x.device)
        
        # Remove extra dimensions if d_model was increased
        if self.d_model > x.size(2):
            pe = pe[:, :, :x.size(2)]
        
        x = x + pe
        return x


class EEGAttention(nn.Module):
    def __init__(self, channel, d_model, nhead):
        super(EEGAttention, self).__init__()
        self.pos_encoder = PositionalEncoding(d_model)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
        self.channel = channel
        self.d_model = d_model

    def forward(self, src):
        src = src.permute(2, 0, 1)  # Change shape to [time_length, batch_size, channel]
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        return output.permute(1, 2, 0)  # Change shape back to [batch_size, channel, time_length]
class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        super().__init__()
        # revised from shallownet
        self.tsconv = nn.Sequential(
            nn.Conv2d(1, 40, (1, 5), (1, 1)),
            nn.AvgPool2d((1, 17), (1, 5)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.Conv2d(40, 40, (63, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

    def forward(self, x: Tensor) -> Tensor:
        # b, _, _, _ = x.shape
        x = x.unsqueeze(1)     
        # print("x", x.shape)   
        x = self.tsconv(x)
        # print("tsconv", x.shape)   
        x = self.projection(x)
        # print("projection", x.shape)  
        return x


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FlattenHead(nn.Sequential):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        return x


class Enc_eeg(nn.Sequential):
    def __init__(self, emb_size=40, **kwargs):
        super().__init__(
            PatchEmbedding(emb_size),
            FlattenHead()
        )

        
class Proj_eeg(nn.Sequential):
    def __init__(self, embedding_dim=1840, proj_dim=1024, drop_proj=0.5):
        super().__init__(
            nn.Linear(embedding_dim, proj_dim),
            ResidualAdd(nn.Sequential(
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.Dropout(drop_proj),
            )),
            nn.LayerNorm(proj_dim),
        )


class Proj_img(nn.Sequential):
    def __init__(self, embedding_dim=1024, proj_dim=1024, drop_proj=0.3):
        super().__init__(
            nn.Linear(embedding_dim, proj_dim),
            ResidualAdd(nn.Sequential(
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.Dropout(drop_proj),
            )),
            nn.LayerNorm(proj_dim),
        )
    def forward(self, x):
        return x 

class SCFD(nn.Module):    
    def __init__(self, num_channels=63, sequence_length=250, num_subjects=1, num_features=64, num_latents=1024, num_blocks=1):
        super(SCFD, self).__init__()
        self.attention_model = EEGAttention(num_channels, num_channels, nhead=1)   
        self.VSS = VSSBlock(hidden_dim=250, drop_path=0.1, attn_drop_rate=0.1, d_state=16, expand=2.0, is_light_sr=False)
        self.subject_wise_linear = nn.ModuleList([nn.Linear(sequence_length, sequence_length) for _ in range(num_subjects)])
        self.enc_eeg = Enc_eeg()
        self.proj_eeg = Proj_eeg()        
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.loss_func = ClipLoss()       
         
    def forward(self, x):
        x = self.attention_model(x)
        # print(f'After attention shape: {x.shape}')
        
        x = self.subject_wise_linear[0](x)
        x = self.VSS(x, (7, 9))
        # print(f'After subject-specific linear transformation shape: {x.shape}')
        eeg_embedding = self.enc_eeg(x)
        # print(f'After enc_eeg shape: {eeg_embedding.shape}')
        out = self.proj_eeg(eeg_embedding)
        return out  
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
caputo = CaputoEncoder(input_size=250, lstm_size=512, lstm_layers=1, output_size=1024).to(device)    

def get_eegfeatures(sub, eegmodel, dataloader, device, text_features_all, img_features_all, k, mode):
    eegmodel.eval()
    text_features_all = text_features_all.to(device).float()
    img_features_all = img_features_all.to(device).float()
    total_loss = 0
    correct = 0
    total = 0
    alpha =0.99
    top5_correct = 0
    top5_correct_count = 0

    all_labels = set(range(text_features_all.size(0)))
    top5_acc = 0
    mse_loss_fn = nn.MSELoss()
    ridge_lambda = 0.1
    save_features = True
    features_list = []  # List to store features    
    with torch.no_grad():
        for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
            eeg_data = eeg_data.to(device)
            eeg_data = eeg_data[:, :, :250]
            # print("eeg_data", eeg_data.shape)
            text_features = text_features.to(device).float()
            labels = labels.to(device)
            img_features = img_features.to(device).float()
            eeg_features = eegmodel(eeg_data).float()
            img_features = img_model(img_features).float()
            caputo_features = caputo(eeg_data)
            # 逐元素相加
            eeg_features = 0.5*caputo_features + 0.5*eeg_features
            features_list.append(eeg_features)
            logit_scale = eegmodel.logit_scale 
                   
            regress_loss =  mse_loss_fn(eeg_features, img_features)      
            img_loss = eegmodel.loss_func(eeg_features, img_features, logit_scale)
            text_loss = eegmodel.loss_func(eeg_features, text_features, logit_scale)
            contrastive_loss = img_loss
            # loss = img_loss + text_loss

            regress_loss =  mse_loss_fn(eeg_features, img_features)
            # print("text_loss", text_loss)
            # print("img_loss", img_loss)
            # print("regress_loss", regress_loss)            
            # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            # loss = (regress_loss + ridge_lambda * l2_norm)       
            # loss = alpha * regress_loss *10 + (1 - alpha) * contrastive_loss*10
            loss = alpha * img_loss + (1 - alpha) * text_loss
            logits_eeg = logit_scale * eeg_features @ img_features.T
            logits_img = logits_eeg.t()
            labels1 = torch.arange(eeg_data.shape[0])
            labels1 = Variable(labels1.cuda().type(LongTensor))
            loss_eeg = criterion_cls(logits_eeg, labels1)
            loss_img = criterion_cls(logits_img, labels1)
            loss_infoNCE = (loss_eeg + loss_img) / 2
            loss = loss + loss_infoNCE
            # print("loss", loss)
            total_loss += loss.item()
            
            for idx, label in enumerate(labels):

                possible_classes = list(all_labels - {label.item()})
                selected_classes = random.sample(possible_classes, k-1) + [label.item()]
                selected_img_features = img_features_all[selected_classes]
                

                logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
                # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T
                # logits_single = (logits_text + logits_img) / 2.0
                logits_single = logits_img
                # print("logits_single", logits_single.shape)

                # predicted_label = selected_classes[torch.argmax(logits_single).item()]
                predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \in {0, 1, ..., n_cls-1}
                if predicted_label == label.item():
                    correct += 1        
                total += 1

        if save_features:
            features_tensor = torch.cat(features_list, dim=0)
            print("features_tensor", features_tensor.shape)
            torch.save(features_tensor.cpu(), f"ATM_S_eeg_features_{sub}_{mode}.pt")  # Save features as .pt file
    average_loss = total_loss / (batch_idx+1)
    accuracy = correct / total
    return average_loss, accuracy, labels, features_tensor.cpu()

from IPython.display import Image, display
config = {
"data_path": "/root/autodl-tmp/EEG/EEG_Image_decode/datasets/THINGS/Preprocessed_data_250Hz",
"project": "train_pos_img_text_rep",
"entity": "sustech_rethinkingbci",
"name": "lr=3e-4_img_pos_pro_eeg",
"lr": 3e-4,
"epochs": 50,
"batch_size": 512,
"logger": True,
"encoder_type":'SCFD',
"img_encoder": 'Proj_img'
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_path = config['data_path']
emb_img_test = torch.load('variables/ViT-H-14_features_test.pt')
emb_img_train = torch.load('variables/ViT-H-14_features_train.pt')

eeg_model = ATM_S_reconstruction_scale_0_1000(63, 250)
print('number of parameters:', sum([p.numel() for p in eeg_model.parameters()]))

#####################################################################################
eeg_model.load_state_dict(torch.load("/root/autodl-tmp/EEG/EEG_Image_decode/Generation/models/contrast/ATM_S_reconstruction_scale_0_1000/08-30_00-58/sub-08/39.pth"))
eeg_model = eeg_model.to(device)
img_model = globals()[config['img_encoder']]().to(device)
sub = 'sub-08'

#####################################################################################

test_dataset = EEGDataset(data_path, subjects= [sub], train=False)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0)
text_features_test_all = test_dataset.text_features
img_features_test_all = test_dataset.img_features
test_loss, test_accuracy,labels, eeg_features_test = get_eegfeatures(sub, eeg_model, test_loader, device, text_features_test_all, img_features_test_all,k=200, mode="test")
print(f" - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

number of parameters: 4580132
self.subjects ['sub-08']
exclude_subject None
Data tensor shape: torch.Size([200, 63, 250]), label tensor shape: torch.Size([200]), text length: 200, image length: 200
features_tensor torch.Size([200, 1024])
 - Test Loss: 4.7581, Test Accuracy: 0.4250


In [7]:
#####################################################################################
train_dataset = EEGDataset(data_path, subjects= [sub], train=True)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0)
text_features_test_all = train_dataset.text_features
img_features_test_all = train_dataset.img_features

train_loss, train_accuracy, labels, eeg_features_train = get_eegfeatures(sub, eeg_model, train_loader, device, text_features_test_all, img_features_test_all,k=200, mode="train")
print(f" - Test Loss: {train_loss:.4f}, Test Accuracy: {train_accuracy:.4f}")
#####################################################################################

self.subjects ['sub-08']
exclude_subject None
data_tensor torch.Size([66160, 63, 250])
Data tensor shape: torch.Size([66160, 63, 250]), label tensor shape: torch.Size([66160]), text length: 1654, image length: 16540
features_tensor torch.Size([66160, 1024])
 - Test Loss: 4.4049, Test Accuracy: 0.0050


In [8]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import open_clip
from matplotlib.font_manager import FontProperties

import sys
from diffusion_prior import *
from custom_pipeline import *
# os.environ["CUDA_VISIBLE_DEVICES"] = "5" 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [9]:
emb_img_train_4 = emb_img_train.view(1654,10,1,1024).repeat(1,1,4,1).view(-1,1024)
emb_eeg = torch.load(f'/root/autodl-tmp/EEG/EEG_Image_decode/Generation/ATM_S_eeg_features_{sub}_train.pt')
emb_eeg_test = torch.load(f'/root/autodl-tmp/EEG/EEG_Image_decode/Generation/ATM_S_eeg_features_{sub}_test.pt')

In [10]:
emb_eeg.shape, emb_eeg_test.shape

(torch.Size([66160, 1024]), torch.Size([200, 1024]))

In [11]:
eeg_features_train

tensor([[-0.5160,  1.4970,  1.2727,  ...,  0.3626, -0.6813,  0.4593],
        [-1.0497,  0.0807, -0.0973,  ...,  1.3810, -0.0080, -0.5037],
        [ 0.2541,  0.4853,  0.5700,  ..., -0.3420, -0.6477, -0.6254],
        ...,
        [-0.4777, -0.2025,  0.6289,  ...,  0.6740, -0.3684,  0.9980],
        [ 0.5863,  0.2192,  1.1327,  ...,  0.7494, -0.1537,  1.3203],
        [ 0.2057,  1.1828, -1.2171,  ..., -0.4977,  0.5306,  0.5098]])

In [12]:
dataset = EmbeddingDataset(
    c_embeddings=eeg_features_train, h_embeddings=emb_img_train_4, 
    # h_embeds_uncond=h_embeds_imgnet
)
dl = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)
diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)
# number of parameters
print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))
pipe = Pipe(diffusion_prior, device=device)


9675648


In [15]:
dataset = EmbeddingDataset(
    c_embeddings=eeg_features_train, h_embeddings=emb_img_train_4, 
    # h_embeds_uncond=h_embeds_imgnet
)
dl = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)
diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)
# number of parameters
print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))
pipe = Pipe(diffusion_prior, device=device)

# load pretrained model
model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'
pipe.train(dl, num_epochs=300, learning_rate=1e-3) # to 0.142 ，0.147

9675648
epoch: 0, loss: 1.1356700970576359
epoch: 1, loss: 0.9270742874879103
epoch: 2, loss: 0.7335815182098976
epoch: 3, loss: 0.5874950317236093
epoch: 4, loss: 0.4823428438260005
epoch: 5, loss: 0.40065046411294203
epoch: 6, loss: 0.3466468091194446
epoch: 7, loss: 0.314552659713305
epoch: 8, loss: 0.2961499993617718
epoch: 9, loss: 0.2809753693067111
epoch: 10, loss: 0.26901656480935904
epoch: 11, loss: 0.2579496131493495
epoch: 12, loss: 0.24792072429106785
epoch: 13, loss: 0.2399262079825768
epoch: 14, loss: 0.2308388590812683
epoch: 15, loss: 0.22408424134437854
epoch: 16, loss: 0.2179516063286708
epoch: 17, loss: 0.211402075107281
epoch: 18, loss: 0.20465697669065916
epoch: 19, loss: 0.2039179047712913
epoch: 20, loss: 0.20301525432329912
epoch: 21, loss: 0.20329864873335912
epoch: 22, loss: 0.20044077451412495
epoch: 23, loss: 0.19819101576621717
epoch: 24, loss: 0.19529663989177118
epoch: 25, loss: 0.1955018609762192
epoch: 26, loss: 0.19413241904515485
epoch: 27, loss: 0.19

In [16]:

# pipe.diffusion_prior.load_state_dict(torch.load(f'./fintune_ckpts/{config['data_path']}/{sub}/{model_name}.pt', map_location=device))
save_path = f'./fintune_ckpts/{config["encoder_type"]}/{sub}/{model_name}.pt'

directory = os.path.dirname(save_path)

# Create the directory if it doesn't exist
os.makedirs(directory, exist_ok=True)
torch.save(pipe.diffusion_prior.state_dict(), save_path)
