In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import clip
from torch.nn import functional as F
import torch.nn as nn
from torchvision import transforms
from PIL import Image
train = False
classes = None
pictures= None

def load_data():
    data_list = []
    label_list = []
    texts = []
    images = []

    if train:
        text_directory = "/media/hanakawalab/Transcend/THINGS-Data/THINGS-EEG_images_set/training_images"
    else:
        text_directory = "/media/hanakawalab/Transcend/THINGS-Data/THINGS-EEG_images_set/test_images"

    dirnames = [d for d in os.listdir(text_directory) if os.path.isdir(os.path.join(text_directory, d))]
    dirnames.sort()

    if classes is not None:
        dirnames = [dirnames[i] for i in classes]

    for dir in dirnames:

        try:
            idx = dir.index('_')
            description = dir[idx+1:]
        except ValueError:
            print(f"Skipped: {dir} due to no '_' found.")
            continue

        new_description = f"{description}"
        texts.append(new_description)

    if train:
        img_directory = "/media/hanakawalab/Transcend/THINGS-Data/THINGS-EEG_images_set/training_images"
    else:
        img_directory ="/media/hanakawalab/Transcend/THINGS-Data/THINGS-EEG_images_set/test_images"

    all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))]
    all_folders.sort()

    if classes is not None and pictures is not None:
        images = []
        for i in range(len(classes)):
            class_idx = classes[i]
            pic_idx = pictures[i]
            if class_idx < len(all_folders):
                folder = all_folders[class_idx]
                folder_path = os.path.join(img_directory, folder)
                all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
                all_images.sort()
                if pic_idx < len(all_images):
                    images.append(os.path.join(folder_path, all_images[pic_idx]))
    elif classes is not None and pictures is None:
        images = []
        for i in range(len(classes)):
            class_idx = classes[i]
            if class_idx < len(all_folders):
                folder = all_folders[class_idx]
                folder_path = os.path.join(img_directory, folder)
                all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
                all_images.sort()
                images.extend(os.path.join(folder_path, img) for img in all_images)
    elif classes is None:
        images = []
        for folder in all_folders:
            folder_path = os.path.join(img_directory, folder)
            all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
            all_images.sort()
            images.extend(os.path.join(folder_path, img) for img in all_images)
    else:

        print("Error")
    return texts, images
texts, images = load_data()
# images

In [3]:
import os

import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

os.environ["WANDB_API_KEY"] = "KEY"
os.environ["WANDB_MODE"] = 'offline'
from itertools import combinations

import clip
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import tqdm
from eegdatasets_leaveone import EEGDataset

from einops.layers.torch import Rearrange, Reduce

from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
import random
from util import wandb_logger
from braindecode.models import EEGNetv4, ATCNet, EEGConformer, EEGITNet, ShallowFBCSPNet
import csv
from torch import Tensor
import itertools
import math
import re
from subject_layers.Transformer_EncDec import Encoder, EncoderLayer
from subject_layers.SelfAttention_Family import FullAttention, AttentionLayer
from subject_layers.Embed import DataEmbedding
import numpy as np
from loss import ClipLoss, FeatureContrastLoss
import argparse
from torch import nn
from torch.optim import AdamW
from pytorch_wavelets import DWT1DForward, DWT1DInverse

class WaveBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super(WaveBlock, self).__init__()
        self.dwt = DWT1DForward(J=1, wave='db1', mode='zero')
        self.dwt.to('cuda')
        self.conv = nn.Conv1d(c_in, c_out, 3, 1, padding=1)
    def forward(self, x):
        xl, xh = self.dwt(x)

        x2 = torch.cat((xl, xh[0]), dim=-1)
        x3 = self.conv(x2)
        return x3

class iWaveBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super(iWaveBlock, self).__init__()

        self.idwt = DWT1DInverse(wave='db1', mode='zero')
        self.idwt.to('cuda')

    def forward(self, x ):
        xl, xh = torch.split(x, [128, 128], dim=-1)
        x4 = self.idwt([xl, [xh]])
        return x4


class Config:
    def __init__(self):
        self.task_name = 'classification'  # Example task name
        self.seq_len = 250  # Sequence length
        self.pred_len = 256  # Prediction length
        self.output_attention = False  # Whether to output attention weights
        self.d_model = 256  # Model dimension
        self.embed = 'timeF'  # Time encoding method
        self.freq = 'h'  # Time frequency
        self.dropout = 0.25  # Dropout rate
        self.factor = 1  # Attention scaling factor
        self.n_heads = 4  # Number of attention heads
        self.e_layers = 1  # Number of encoder layers
        self.d_ff = 256  # Feedforward network dimension
        self.activation = 'gelu'  # Activation function
        self.enc_in = 63  # Encoder input dimension (example value)


class iTransformer(nn.Module):
    def __init__(self, configs, joint_train=False, num_subjects=10):
        super(iTransformer, self).__init__()
        self.task_name = configs.task_name
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.output_attention = configs.output_attention
        # Embedding
        self.enc_embedding = DataEmbedding(configs.seq_len, configs.d_model, configs.embed, configs.freq,
                                           configs.dropout, joint_train=False, num_subjects=num_subjects)

        self.dwt = WaveBlock(64, 64)
        # Encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout,
                                      output_attention=configs.output_attention),
                        configs.d_model, configs.n_heads
                    ),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model)
        )
        self.local_conv_time = nn.Conv1d(63, 63, kernel_size=3, padding=3 // 2,
                                         groups=63)
        # Cross-electrode convolution
        self.local_conv_electrode = nn.Conv2d(1, 1, kernel_size=(3, 1), padding=(3 // 2, 0),
                                              groups=1)
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(63, 63, 1),
            nn.Sigmoid()
        )

        self.idwt = iWaveBlock(63, 63)
        self.conv = nn.Conv1d(63, 63, 4, 2, padding=1)
    def forward(self, x_enc, x_mark_enc, subject_ids=None):
        # Embedding
        enc_out = self.enc_embedding(x_enc, x_mark_enc, subject_ids)
        a = self.dwt(enc_out)
        a = a[:, :63, :]
        local_time = self.local_conv_time(a)
        local_time_expanded = local_time.unsqueeze(1)  # (Batch, 1, Channels, Signal_Length)
        local_electrode = self.local_conv_electrode(local_time_expanded).squeeze(1)  # (Batch, Channels, Signal_Length)
        out, attns = self.encoder(a, attn_mask=None)
        out = out[:, :63, :]
        weight = self.gate(a)  # (Batch, Channels, 1)
        out = local_electrode * (1 - weight) + out * weight
        b = self.idwt(out)
        c = self.conv(torch.cat([out, b], dim=-1))
        #print("c", c.shape)
        return c


class PrintShape(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name

    def forward(self, x):
        print(f"{self.name} shape: {x.shape}")
        return x


class MultiBranchNet(nn.Module):
    def __init__(self,
                 in_channels=63,
                 n_filters1=40,
                 n_filters2=40,
                 dropout_rate=0.5):
        super().__init__()

        # 时间特征分支（基于ShallowNet的设计）
        self.temporal_branch = nn.Sequential(
            # 时间卷积
            nn.Conv2d(in_channels, n_filters1, kernel_size=(1, 25), padding='same'),
            nn.BatchNorm2d(n_filters1),
            nn.ELU(),
            # 平均池化
            nn.AvgPool2d((1, 4), stride=(1, 2)),
            nn.Dropout(dropout_rate),
            # 点卷积
            nn.Conv2d(n_filters1, n_filters1, kernel_size=(1, 1)),
            nn.BatchNorm2d(n_filters1),
            nn.ELU(),
            # 最终池化
            nn.AdaptiveAvgPool2d((1, 37))
        )

        # 空间特征分支（关注通道间关系）
        self.spatial_branch = nn.Sequential(
            # 空间卷积（跨通道）
            nn.Conv2d(in_channels, n_filters1, kernel_size=(3, 1), padding=(1, 0)),
            nn.BatchNorm2d(n_filters1),
            nn.ELU(),
            # 深度可分离卷积
            nn.Conv2d(n_filters1, n_filters1 * 2, kernel_size=(3, 3),
                      padding='same', groups=n_filters1),
            nn.BatchNorm2d(n_filters1 * 2),
            nn.ELU(),
            nn.Conv2d(n_filters1 * 2, n_filters1, kernel_size=1),
            nn.BatchNorm2d(n_filters1),
            nn.ELU(),
            # 最终池化
            nn.AdaptiveAvgPool2d((1, 37))
        )

        # 特征聚合分支
        self.fusion_branch = nn.Sequential(
            # 合并特征
            nn.Conv2d(n_filters1 * 2, n_filters2, kernel_size=1),
            nn.BatchNorm2d(n_filters2),
            nn.ELU(),
            nn.Dropout(dropout_rate),
            # 确保输出维度
            nn.AdaptiveAvgPool2d((1, 37))
        )

        # 注意力机制
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(n_filters1 * 2, n_filters1 // 2, kernel_size=1),
            nn.ELU(),
            nn.Conv2d(n_filters1 // 2, n_filters1 * 2, kernel_size=1),
            nn.Sigmoid()
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, 40, (1, 1), stride=(1, 1)),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

    def forward(self, x):
        # 输入: [batch_size, channels, 1, time]
        x = x.unsqueeze(2)
        # 1. 时间特征提取
        temporal_features = self.temporal_branch(x)

        # 2. 空间特征提取
        spatial_features = self.spatial_branch(x)

        # 3. 特征拼接
        combined_features = torch.cat([temporal_features, spatial_features], dim=1)

        # 4. 应用注意力机制
        attention_weights = self.attention(combined_features)
        weighted_features = combined_features * attention_weights

        # 5. 特征融合
        output = self.fusion_branch(weighted_features)
        output = self.projection(output)

        return output


# 辅助模块：自定义的深度可分离卷积块
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size=kernel_size,
            padding='same', groups=in_channels
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        super().__init__()
        # Revised from ShallowNet
        self.tsconv = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), stride=(1, 1)),
            nn.AvgPool2d((1, 51), (1, 5)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.Conv2d(40, 40, (63, 1), stride=(1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

    def forward(self, x: Tensor) -> Tensor:
        # b, _, _, _ = x.shape
        x = x.unsqueeze(1)
        #print("x", x.shape)
        x = self.tsconv(x)
        #print("tsconv", x.shape)
        x = self.projection(x)
        #print("projection", x.shape)
        return x


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FlattenHead(nn.Sequential):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        return x


class Enc_eeg(nn.Sequential):
    def __init__(self, emb_size=40, **kwargs):
        super().__init__(
            MultiBranchNet(),
            FlattenHead()
        )


class Proj_eeg(nn.Sequential):
    def __init__(self, embedding_dim=1480, proj_dim=1024, drop_proj=0.5):
        super().__init__(
            nn.Linear(embedding_dim, proj_dim),
            ResidualAdd(nn.Sequential(
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.Dropout(drop_proj),
            )),
            nn.LayerNorm(proj_dim),
        )


class ATMS(nn.Module):
    def __init__(self, num_channels=63, sequence_length=250, num_subjects=1, num_features=64, num_latents=2048,
                 num_blocks=1):
        super(ATMS, self).__init__()
        default_config = Config()
        self.encoder = iTransformer(default_config)
        self.subject_wise_linear = nn.ModuleList(
            [nn.Linear(default_config.d_model, sequence_length) for _ in range(num_subjects)])
        self.enc_eeg = Enc_eeg()
        self.proj_eeg = Proj_eeg()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.loss_func = ClipLoss()
        self.loss2 = FeatureContrastLoss()

    def forward(self, x, subject_ids):
        x = self.encoder(x, None, subject_ids)
        # print(f'After attention shape: {x.shape}')
        # print("x", x.shape)
        # x = self.subject_wise_linear[0](x)
        # print(f'After subject-specific linear transformation shape: {x.shape}')
        eeg_embedding = self.enc_eeg(x)

        out = self.proj_eeg(eeg_embedding)
        return out

def extract_id_from_string(s):
    match = re.search(r'\d+$', s)
    if match:
        return int(match.group())
    return None


def get_eegfeatures_test(sub, eegmodel, dataloader, device, text_features_all, img_features_all, k):
    eegmodel.eval()
    text_features_all = text_features_all.to(device).float()
    img_features_all = img_features_all.to(device).float()
    total_loss = 0
    correct = 0
    total = 0
    alpha =0.9
    beta = 0.9
    top5_correct = 0
    top5_correct_count = 0

    all_labels = set(range(text_features_all.size(0)))
    top5_acc = 0
    mse_loss_fn = nn.MSELoss()
    loss_CF = FeatureContrastLoss()
    ridge_lambda = 0.1
    save_features = True
    features_list = []  # List to store features
    with torch.no_grad():
        for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
            print(eeg_data.shape)
            eeg_data = eeg_data.to(device)
            text_features = text_features.to(device).float()
            labels = labels.to(device)
            img_features = img_features.to(device).float()

            batch_size = eeg_data.size(0)  # Assume the first element is the data tensor
            subject_id = extract_id_from_string(sub)
            # eeg_data = eeg_data.permute(0, 2, 1)
            subject_ids = torch.full((batch_size,), subject_id, dtype=torch.long).to(device)
            # if not config.insubject:
            #     subject_ids = torch.full((batch_size,), -1, dtype=torch.long).to(device)
            eeg_features = eeg_model(eeg_data, subject_ids)
            features_list.append(eeg_features.cpu())

            logit_scale = eeg_model.logit_scale

            regress_loss =  mse_loss_fn(eeg_features, img_features)
            # print("eeg_features", eeg_features.shape)
            # print(torch.std(eeg_features, dim=-1))
            # print(torch.std(img_features, dim=-1))
            # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            # loss = (regress_loss + ridge_lambda * l2_norm)
            img_loss = eegmodel.loss_func(eeg_features, img_features, logit_scale)
            text_loss = eegmodel.loss_func(eeg_features, text_features, logit_scale)
            contrastive_loss = img_loss
            # loss = img_loss + text_loss

            regress_loss =  mse_loss_fn(eeg_features, img_features)
            # print("text_loss", text_loss)
            # print("img_loss", img_loss)
            # print("regress_loss", regress_loss)
            # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            # loss = (regress_loss + ridge_lambda * l2_norm)
            loss = alpha * regress_loss *10 + (1 - alpha) * contrastive_loss*10
            loss = beta * loss + (1 - beta) * loss_CF(eeg_features, labels)
            # print("loss", loss)
            total_loss += loss.item()

            for idx, label in enumerate(labels):

                possible_classes = list(all_labels - {label.item()})
                selected_classes = random.sample(possible_classes, k-1) + [label.item()]
                selected_img_features = img_features_all[selected_classes]


                logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
                # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T
                # logits_single = (logits_text + logits_img) / 2.0
                logits_single = logits_img
                # print("logits_single", logits_single.shape)

                # predicted_label = selected_classes[torch.argmax(logits_single).item()]
                predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \in {0, 1, ..., n_cls-1}
                if predicted_label == label.item():
                    correct += 1
                total += 1

        if save_features:
            features_tensor = torch.cat(features_list, dim=0)
            print("features_tensor", features_tensor.shape)
            torch.save(features_tensor.cpu(), f"20250103_ATM_S_eeg_features_{sub}-test.pt")  # Save features as .pt file
    average_loss = total_loss / (batch_idx+1)
    accuracy = correct / total
    return average_loss, accuracy, labels, features_tensor.cpu()

from IPython.display import Image, display
config = {
"data_path": "/media/hanakawalab/Transcend/THINGS-Data/EEG/Preprocessed_data_250Hz/",
"project": "train_pos_img_text_rep",
"entity": "sustech_rethinkingbci",
"name": "lr=3e-4_img_pos_pro_eeg",
"lr": 3e-4,
"epochs": 50,
"batch_size": 1024,
"logger": True,
"encoder_type":'ATMS',
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_path = config['data_path']
emb_img_test = torch.load('variables/ViT-H-14_features_test.pt')
emb_img_train = torch.load('variables/ViT-H-14_features_train.pt')

eeg_model = ATMS(63, 250)
print('number of parameters:', sum([p.numel() for p in eeg_model.parameters()]))

#####################################################################################

# eeg_model.load_state_dict(torch.load("/home/ldy/Workspace/Reconstruction/models/contrast/sub-08/01-30_00-44/40.pth"))
eeg_model.load_state_dict(torch.load("models/contrast/ATMS/sub-08/01-03_10-18/15.pth"))
eeg_model = eeg_model.to(device)
sub = 'sub-08'
#####################################################################################

test_dataset = EEGDataset(data_path, subjects= [sub], train=False)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0)
text_features_test_all = test_dataset.text_features
img_features_test_all = test_dataset.img_features
test_loss, test_accuracy,labels, eeg_features_test = get_eegfeatures_test(sub, eeg_model, test_loader, device, text_features_test_all, img_features_test_all,k=200)
print(f" - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

number of parameters: 3214978
self.subjects ['sub-08']
exclude_subject None
Data tensor shape: torch.Size([200, 63, 250]), label tensor shape: torch.Size([200]), text length: 200, image length: 200
torch.Size([200, 63, 250])
features_tensor torch.Size([200, 1024])
 - Test Loss: 1.8868, Test Accuracy: 0.4300


In [4]:

def get_eegfeatures_train(sub, eegmodel, dataloader, device, text_features_all, img_features_all, k):
    eegmodel.eval()
    text_features_all = text_features_all.to(device).float()
    img_features_all = img_features_all.to(device).float()
    total_loss = 0
    correct = 0
    total = 0
    alpha =0.9
    beta = 0.9
    top5_correct = 0
    top5_correct_count = 0

    all_labels = set(range(text_features_all.size(0)))
    top5_acc = 0
    mse_loss_fn = nn.MSELoss()
    loss_CF = FeatureContrastLoss()
    ridge_lambda = 0.1
    save_features = True
    features_list = []  # List to store features
    with torch.no_grad():
        for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
            print(eeg_data.shape)
            eeg_data = eeg_data.to(device)
            text_features = text_features.to(device).float()
            labels = labels.to(device)
            img_features = img_features.to(device).float()

            batch_size = eeg_data.size(0)  # Assume the first element is the data tensor
            subject_id = extract_id_from_string(sub)
            # eeg_data = eeg_data.permute(0, 2, 1)
            subject_ids = torch.full((batch_size,), subject_id, dtype=torch.long).to(device)
            # if not config.insubject:
            #     subject_ids = torch.full((batch_size,), -1, dtype=torch.long).to(device)
            eeg_features = eeg_model(eeg_data, subject_ids)
            features_list.append(eeg_features.cpu())

            logit_scale = eeg_model.logit_scale

            regress_loss =  mse_loss_fn(eeg_features, img_features)
            # print("eeg_features", eeg_features.shape)
            # print(torch.std(eeg_features, dim=-1))
            # print(torch.std(img_features, dim=-1))
            # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            # loss = (regress_loss + ridge_lambda * l2_norm)
            img_loss = eegmodel.loss_func(eeg_features, img_features, logit_scale)
            text_loss = eegmodel.loss_func(eeg_features, text_features, logit_scale)
            contrastive_loss = img_loss
            # loss = img_loss + text_loss

            regress_loss =  mse_loss_fn(eeg_features, img_features)
            # print("text_loss", text_loss)
            # print("img_loss", img_loss)
            # print("regress_loss", regress_loss)
            # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            # loss = (regress_loss + ridge_lambda * l2_norm)
            loss = alpha * regress_loss *10 + (1 - alpha) * contrastive_loss*10
            loss = beta * loss + (1 - beta) * loss_CF(eeg_features, labels)
            # print("loss", loss)
            total_loss += loss.item()

            for idx, label in enumerate(labels):

                possible_classes = list(all_labels - {label.item()})
                selected_classes = random.sample(possible_classes, k-1) + [label.item()]
                selected_img_features = img_features_all[selected_classes]


                logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
                # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T
                # logits_single = (logits_text + logits_img) / 2.0
                logits_single = logits_img
                # print("logits_single", logits_single.shape)

                # predicted_label = selected_classes[torch.argmax(logits_single).item()]
                predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \in {0, 1, ..., n_cls-1}
                if predicted_label == label.item():
                    correct += 1
                total += 1

        if save_features:
            features_tensor = torch.cat(features_list, dim=0)
            print("features_tensor", features_tensor.shape)
            torch.save(features_tensor.cpu(), f"20250103_ATM_S_eeg_features_{sub}.pt")  # Save features as .pt file
    average_loss = total_loss / (batch_idx+1)
    accuracy = correct / total
    return average_loss, accuracy, labels, features_tensor.cpu()

#####################################################################################
train_dataset = EEGDataset(data_path, subjects= [sub], train=True)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0)
text_features_test_all = train_dataset.text_features
img_features_test_all = train_dataset.img_features

train_loss, train_accuracy, labels, eeg_features_train = get_eegfeatures_train(sub, eeg_model, train_loader, device, text_features_test_all, img_features_test_all,k=200)
print(f" - Test Loss: {train_loss:.4f}, Test Accuracy: {train_accuracy:.4f}")
#####################################################################################

self.subjects ['sub-08']
exclude_subject None
data_tensor torch.Size([66160, 63, 250])
Data tensor shape: torch.Size([66160, 63, 250]), label tensor shape: torch.Size([66160]), text length: 1654, image length: 16540
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])
torch.Size([1024, 63, 250])


In [6]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import open_clip
from matplotlib.font_manager import FontProperties

import sys
from diffusion_prior import *
# os.environ["CUDA_VISIBLE_DEVICES"] = "5"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
emb_img_train_4 = emb_img_train.view(1654,10,1,1024).repeat(1,1,4,1).view(-1,1024)
emb_eeg = torch.load('/media/hanakawalab/Transcend/EEG_Image_decode-main/Generation/20250103_ATM_S_eeg_features_sub-08.pt')
emb_eeg_test = torch.load('/media/hanakawalab/Transcend/EEG_Image_decode-main/Generation/20250103_ATM_S_eeg_features_sub-08-test.pt')

In [8]:
dataset = EmbeddingDataset(
    c_embeddings=eeg_features_train, h_embeddings=emb_img_train_4,
    # h_embeds_uncond=h_embeds_imgnet
)
dl = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)
diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)
# number of parameters
print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))
pipe = Pipe(diffusion_prior, device=device)

# load pretrained model
model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'
#pipe.train(dl, num_epochs=150, learning_rate=1e-3) # to 0.142
string = f'/media/hanakawalab/Transcend/EEG_Image_decode/fintune_ckpts/{sub}/{model_name}.pt'
pipe.diffusion_prior.load_state_dict(torch.load(string, map_location=device))

9675648




<All keys matched successfully>

In [9]:
from custom_pipeline import *
# Create the directory if it doesn't exist
from PIL import Image
import os

# Assuming generator.generate returns a PIL Image
generator = Generator4Embeds(num_inference_steps=4, device=device)

directory = f"generated_imgs/{sub}"
os.makedirs(directory, exist_ok=True)
for k in range(200):
    eeg_embeds = emb_eeg_test[k:k+1]
    h = pipe.generate(c_embeds=eeg_embeds, num_inference_steps=50, guidance_scale=5.0)
    for j in range(10):
        image = generator.generate(h.to(dtype=torch.float16))
        # Construct the save path for each image
        path = f'{directory}/{texts[k]}/{j}.png'
        # Ensure the directory exists
        os.makedirs(os.path.dirname(path), exist_ok=True)
        # Save the PIL Image
        image.save(path)
        print(f'Image saved to {path}')



Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

50it [00:00, 637.13it/s]


  0%|          | 0/4 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 15.69 GiB total capacity; 14.52 GiB already allocated; 25.50 MiB free; 14.97 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF