In [28]:
from dataset import identity, Dataset
from scipy.interpolate import interp1d
from transformers import AutoProcessor
import os
from PIL import Image

class EEGDataset(Dataset):
    
    # Constructor
    def __init__(self, eeg_signals_path, image_transform=identity, subject = 0):
        # Load EEG signals
        loaded = torch.load(eeg_signals_path)
        if subject!=0:
            self.data = [loaded['dataset'][i] for i in range(len(loaded['dataset']) ) if loaded['dataset'][i]['subject']==subject]
        else:
            self.data = loaded['dataset']        
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        self.imagenet = 'datasets/imageNet_images'
        self.image_transform = image_transform
        self.num_voxels = 440
        self.data_len = 512
        # Compute size
        self.size = len(self.data)
        self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Process EEG
        # print(self.data[i])
        eeg = self.data[i]["eeg"].float().t()

        eeg = eeg[20:460,:]
        ##### 2023 2 13 add preprocess and transpose
        eeg = np.array(eeg.transpose(0,1))
        x = np.linspace(0, 1, eeg.shape[-1])
        x2 = np.linspace(0, 1, self.data_len)
        f = interp1d(x, eeg)
        eeg = f(x2)
        eeg = torch.from_numpy(eeg).float()
        ##### 2023 2 13 add preprocess
        label = torch.tensor(self.data[i]["label"]).long()

        # Get label
        image_name = self.images[self.data[i]["image"]]
        image_path = os.path.join(self.imagenet, image_name.split('_')[0], image_name+'.JPEG')
        # print(image_path)
        image_raw = Image.open(image_path).convert('RGB') 
        
        image = np.array(image_raw) / 255.0
        image_raw = self.processor(images=image_raw, return_tensors="pt")
        image_raw['pixel_values'] = image_raw['pixel_values'].squeeze(0)


        return {'eeg': eeg, 'label': label, 'image': self.image_transform(image), 'image_raw': image_raw}

In [29]:
from dataset import EEGDataset, create_EEG_dataset
from config import Config_MBM_EEG
eeg_signals_path = "/Data/summer24/DreamDiffusion/datasets/eeg_5_95_std.pth"
splits_path = "/Data/summer24/DreamDiffusion/datasets/block_splits_by_image_all.pth"
dataset_train, dataset_val = create_EEG_dataset(eeg_signals_path, splits_path, subject = 0)

In [58]:
split_file_all

{'splits': [{'train': [0,
    3,
    4,
    5,
    6,
    8,
    9,
    10,
    11,
    12,
    13,
    15,
    16,
    17,
    18,
    19,
    21,
    22,
    23,
    24,
    25,
    27,
    29,
    33,
    35,
    36,
    37,
    38,
    39,
    42,
    43,
    44,
    45,
    46,
    47,
    48,
    51,
    52,
    54,
    55,
    57,
    58,
    60,
    61,
    62,
    63,
    66,
    68,
    69,
    70,
    71,
    72,
    73,
    74,
    77,
    78,
    79,
    83,
    84,
    87,
    88,
    89,
    90,
    91,
    93,
    94,
    95,
    96,
    97,
    98,
    99,
    100,
    101,
    103,
    104,
    105,
    106,
    107,
    108,
    110,
    112,
    113,
    115,
    116,
    118,
    119,
    120,
    121,
    122,
    123,
    124,
    125,
    127,
    128,
    129,
    130,
    131,
    135,
    137,
    138,
    140,
    141,
    142,
    143,
    145,
    146,
    147,
    149,
    150,
    151,
    152,
    153,
    156,
    157,
    159,
    161,
    163,
    16

In [2]:
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import BatchSampler
import torchvision.transforms as transforms
import torch
import numpy as np
from PIL import Image
import os

def load_data(eeg_path, img_path, splits_path, batch_size=16):
    """
    Args:
        is_inception: True | False
    """
    loaded_eeg = torch.load(eeg_path)
    loaded_splits = torch.load(splits_path)['splits']
    
    train_dataset = EEGDataset(img_path, loaded_eeg, loaded_splits, mode="train")
    val_dataset = EEGDataset(img_path, loaded_eeg, loaded_splits, mode="val")
    test_dataset = EEGDataset(img_path, loaded_eeg, loaded_splits, mode="test")


    # kwargs = {'num_workers': 4, 'pin_memory': True} if device else {}
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, )
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True )
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    return train_dataloader, val_dataloader, test_dataloader

class EEGDataset(Dataset):
    """
    Train: For each sample (anchor) randomly chooses a positive and negative samples
    Test: Creates fixed triplets for testing
    """
    def __init__(self, img_dir_path, loaded_eeg, loaded_splits, mode="train", transform=None):
        """
        Args:
            img_dir_path: directory path of imagenet images,
            loaded_eeg: eeg dataset loaded from torch.load(),
            loaded_splits: cross-validation splits loaded from torch.load(),
        All arrays and data are returned as torch Tensors
        """
        self.mode = mode
        self.transform = transform
        self.img_dir_path = img_dir_path
        self.splits = loaded_splits
        dataset, classes, img_filenames = [loaded_eeg[k] for k in ['dataset', 'labels', 'images']]
        self.classes = classes
        self.img_filenames = img_filenames

        self.eeg_dataset = dataset
        """We use only split 0, no cross-validation"""
        self.split_chosen = loaded_splits[0]
        self.split_train = self.split_chosen['train']
        self.split_val = self.split_chosen['val']
        self.split_test = self.split_chosen['test']

        if self.mode == "train":
            self.labels = [self.eeg_dataset[sample_idx]['label'] for sample_idx in self.split_train]
        elif self.mode == "val":
            self.labels = [self.eeg_dataset[sample_idx]['label'] for sample_idx in self.split_val]
        elif self.mode == "test":
            self.labels = [self.eeg_dataset[sample_idx]['label'] for sample_idx in self.split_test]
        else:
            raise ValueError()

        self.labels = torch.tensor(self.labels)
    def __getitem__(self, index):
        """
        Return: (eeg, img), []
            - eeg: Tensor()
            - image: Tensor()
        """
        if self.mode == "train":
            dataset_idx = self.split_train[index]
        elif self.mode == "val":
            dataset_idx = self.split_val[index]
        elif self.mode == "test":
            dataset_idx = self.split_test[index]
        else:
            raise ValueError()
        eeg,_, label = [self.eeg_dataset[dataset_idx][key] for key in ['eeg', 'image', 'label']]

        return eeg, label

    def __len__(self):
        if self.mode == "train":
            return len(self.split_train)
        elif self.mode == "val":
            return len(self.split_val)
        elif self.mode == "test":
            return len(self.split_test)
        else:
            raise ValueError()
        


In [3]:
eeg_path = "/Data/summer24/DreamDiffusion/datasets/eeg_5_95_std.pth"
img_path = "/Data/summer24/DreamDiffusion/datasets/imageNet_images"
splits_path = "/Data/summer24/DreamDiffusion/datasets/block_splits_by_image_all.pth"
train_dataloader, val_dataloader, test_dataloader = load_data(eeg_path,img_path,splits_path)

In [107]:
loaded_eeg = torch.load(eeg_path)
loaded_splits = torch.load(splits_path)['splits']

train_dataset = EEGDataset(img_path, loaded_eeg, loaded_splits, mode="train")

In [3]:
len(train_dataset)

NameError: name 'train_dataset' is not defined

In [104]:
lst = []
for idx in range (len(train_dataset)):
    print(train_dataset[idx][1])

10
10
30
25
18
3
8
11
18
28
38
20
3
28
23
0
34
20
23
39
0
34
21
39
6
26
20
1
27
37
19
9
12
18
25
27
34
35
8
29
12
17
37
27
25
32
35
5
8
20
31
25
31
37
6
37
32
20
30
23
18
2
3
19
26
20
27
27
18
39
33
36
21
5
8
39
24
34
13
36
8
39
30
1
19
37
11
24
36
12
7
29
37
6
30
33
5
27
2
4
30
7
8
12
33
24
30
32
5
17
23
33
28
34
7
14
18
0
24
39
23
27
28
30
26
17
15
30
38
22
23
11
28
12
20
34
26
4
36
0
5
34
25
12
9
23
10
35
5
0
18
29
34
29
28
17
8
17
22
15
21
30
16
23
3
13
22
1
15
33
30
7
39
8
7
30
38
24
3
6
26
1
28
36
21
32
7
9
30
36
35
8
28
38
36
4
32
39
30
4
26
3
9
39
30
29
22
22
22
35
38
3
22
8
6
30
13
1
1
11
13
11
27
8
1
14
8
24
23
8
17
34
19
6
33
6
14
9
27
34
1
34
3
3
1
16
3
1
12
27
36
33
1
38
33
0
9
3
3
1
20
39
11
24
20
14
7
32
14
35
33
9
17
8
39
29
7
39
17
35
28
34
8
18
2
16
20
31
39
8
36
5
9
29
28
25
13
20
2
16
21
27
34
5
35
0
4
22
23
15
23
14
35
31
8
6
15
12
2
4
2
28
16
20
0
35
13
10
19
8
18
8
15
27
4
14
34
23
26
29
22
18
37
20
33
36
12
31
36
32
35
20
15
37
19
11
10
16
16
1
18
8
20
18
19
23


In [94]:
train_dataset[1]

(tensor([[ 0.0076,  0.0605,  0.1214,  ...,  0.1027,  0.0541,  0.0113],
         [ 0.0049,  0.0806,  0.1632,  ...,  0.0397,  0.0348,  0.0237],
         [-0.0148,  0.0531,  0.1162,  ...,  0.1000,  0.0386, -0.0223],
         ...,
         [-0.0229,  0.0678,  0.1807,  ..., -0.5030, -0.2601,  0.0196],
         [-0.0062,  0.0032,  0.0129,  ..., -0.0317, -0.0186, -0.0039],
         [-0.0179,  0.1185,  0.2597,  ..., -0.4428, -0.2308,  0.0150]]),
 10)

In [74]:
train_dataset

NameError: name 'train_dataset' is not defined

In [70]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7fa9b153c670>

In [5]:
from dataset import  create_EEG_dataset
from config import Config_Generative_Model
from einops import rearrange
import torchvision.transforms as transforms


class random_crop:
    def __init__(self, size, p):
        self.size = size
        self.p = p
    def __call__(self, img):
        if torch.rand(1) < self.p:
            return transforms.RandomCrop(size=(self.size, self.size))(img)
        return img

def channel_last(img):
        if img.shape[-1] == 3:
            return img
        return rearrange(img, 'c h w -> h w c')

def normalize(img):
    if img.shape[-1] == 3:
        img = rearrange(img, 'h w c -> c h w')
    img = torch.tensor(img)
    img = img * 2.0 - 1.0 # to -1 ~ 1
    return img


config = Config_Generative_Model()

crop_pix = int(config.crop_ratio*config.img_size)

img_transform_train = transforms.Compose([
    normalize,

    transforms.Resize((512, 512)),
    random_crop(config.img_size-crop_pix, p=0.5),

    transforms.Resize((512, 512)),
    channel_last
])
img_transform_test = transforms.Compose([
    normalize, 

    transforms.Resize((512, 512)),
    channel_last
    ])


In [27]:
eeg_signals_path = "/Data/summer24/DreamDiffusion/datasets/eeg_5_95_std.pth"
splits_path = "/Data/summer24/DreamDiffusion/datasets/block_splits_by_image_single.pth"


eeg_latents_dataset_train, eeg_latents_dataset_test = create_EEG_dataset(eeg_signals_path, splits_path, 
                image_transform=[img_transform_train, img_transform_test], subject = config.subject)

In [26]:
dataset

(<dataset.Splitter at 0x7fc1c8575880>, <dataset.Splitter at 0x7fbff456b5e0>)

In [30]:
dir(eeg_latents_dataset_train)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'data_len',
 'dataset',
 'num_voxels',
 'size',
 'split_idx']

In [37]:
dir(eeg_latents_dataset_train.dataset)

['__add__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_is_protocol',
 'data',
 'data_len',
 'image_transform',
 'imagenet',
 'images',
 'labels',
 'num_voxels',
 'processor',
 'size']

In [41]:
len(eeg_latents_dataset_train.dataset.labels)

40

In [13]:
dir(eeg_latents_dataset_train.dataset)

['__add__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_is_protocol',
 'data',
 'data_len',
 'image_transform',
 'imagenet',
 'images',
 'labels',
 'num_voxels',
 'processor',
 'size']

In [17]:
eeg_latents_dataset_train.dataset.data

[{'eeg': tensor([[-8.6142e-02, -2.9463e-01, -6.2266e-01,  ..., -2.0556e-01,
           -2.0201e-01, -2.2041e-01],
          [-4.6275e-02, -1.5711e-01, -3.3941e-01,  ..., -4.8833e-02,
           -7.8179e-02, -1.2174e-01],
          [ 2.5091e-02,  1.3313e-01,  2.6716e-01,  ...,  8.9221e-02,
            6.9278e-02,  5.2495e-02],
          ...,
          [ 8.0494e-02,  3.2728e-01,  4.4147e-01,  ...,  1.4659e+00,
            7.8212e-01,  2.5284e-03],
          [-1.2392e-03,  1.0254e-02,  1.4863e-02,  ...,  7.1831e-02,
            3.5992e-02, -5.0048e-03],
          [ 5.3350e-02,  2.9741e-01,  4.4446e-01,  ...,  8.7492e-01,
            4.6427e-01, -4.8223e-03]]),
  'image': 0,
  'label': 10,
  'subject': 6},
 {'eeg': tensor([[-0.0595, -0.2462, -0.5524,  ...,  2.3339,  1.2344,  0.1786],
          [-0.0297, -0.1574, -0.3423,  ...,  1.4544,  0.7648,  0.0985],
          [ 0.0175,  0.0805,  0.1608,  ..., -0.4176, -0.2216, -0.0385],
          ...,
          [-0.0880,  0.4591,  0.6931,  ...,  0.848

In [20]:
eeg_latents_dataset_test.dataset.data

[{'eeg': tensor([[-8.6142e-02, -2.9463e-01, -6.2266e-01,  ..., -2.0556e-01,
           -2.0201e-01, -2.2041e-01],
          [-4.6275e-02, -1.5711e-01, -3.3941e-01,  ..., -4.8833e-02,
           -7.8179e-02, -1.2174e-01],
          [ 2.5091e-02,  1.3313e-01,  2.6716e-01,  ...,  8.9221e-02,
            6.9278e-02,  5.2495e-02],
          ...,
          [ 8.0494e-02,  3.2728e-01,  4.4147e-01,  ...,  1.4659e+00,
            7.8212e-01,  2.5284e-03],
          [-1.2392e-03,  1.0254e-02,  1.4863e-02,  ...,  7.1831e-02,
            3.5992e-02, -5.0048e-03],
          [ 5.3350e-02,  2.9741e-01,  4.4446e-01,  ...,  8.7492e-01,
            4.6427e-01, -4.8223e-03]]),
  'image': 0,
  'label': 10,
  'subject': 6},
 {'eeg': tensor([[-0.0595, -0.2462, -0.5524,  ...,  2.3339,  1.2344,  0.1786],
          [-0.0297, -0.1574, -0.3423,  ...,  1.4544,  0.7648,  0.0985],
          [ 0.0175,  0.0805,  0.1608,  ..., -0.4176, -0.2216, -0.0385],
          ...,
          [-0.0880,  0.4591,  0.6931,  ...,  0.848

In [42]:
import torch
eeg_signals_path = "/Data/summer24/DreamDiffusion/datasets/eeg_5_95_std.pth"
splits_path = "/Data/summer24/DreamDiffusion/datasets/block_splits_by_image_all.pth"

eeg = torch.load(eeg_signals_path)
splits = torch.load(splits_path)

In [78]:
eeg['dataset']

[{'eeg': tensor([[-0.0098,  0.0195,  0.0620,  ...,  0.0638,  0.0120, -0.0118],
          [-0.0045,  0.1303,  0.2673,  ...,  0.0894,  0.0342, -0.0082],
          [ 0.0215, -0.2017, -0.4305,  ..., -0.2022, -0.0940,  0.0188],
          ...,
          [ 0.0160,  0.0707,  0.1005,  ...,  0.2066,  0.1156,  0.0036],
          [-0.0046, -0.0084, -0.0119,  ...,  0.0007, -0.0026, -0.0053],
          [ 0.0040,  0.0419,  0.0665,  ...,  0.0765,  0.0309, -0.0063]]),
  'image': 0,
  'label': 10,
  'subject': 4},
 {'eeg': tensor([[-0.0120,  0.0473,  0.1264,  ...,  0.0109,  0.0188,  0.0211],
          [-0.0061,  0.0061,  0.0379,  ...,  0.0466,  0.0355,  0.0135],
          [ 0.0016,  0.0690,  0.1212,  ...,  0.0077, -0.0025,  0.0047],
          ...,
          [ 0.0189,  0.0461,  0.0376,  ..., -0.0657, -0.0639, -0.0245],
          [-0.0043,  0.0026,  0.0072,  ..., -0.0171, -0.0132, -0.0064],
          [ 0.0073,  0.1099,  0.1803,  ..., -0.1500, -0.1019, -0.0184]]),
  'image': 1,
  'label': 30,
  'subject': 

In [51]:
from collections import Counter
subject_counts = {}

# 100까지 반복하면서 subject 값들을 수집
for i in range(len(eeg['dataset'])):
    # 데이터에서 'subject' 필드를 가져옵니다.
    subjects = eeg['dataset'][i]['subject']
    subjects = str(subjects)

    # subjects의 빈도를 계산합니다.
    counts = Counter(subjects)

    # subject_counts 딕셔너리에 빈도를 추가합니다.
    for subject, count in counts.items():
        if subject not in subject_counts:
            subject_counts[subject] = 0
        subject_counts[subject] += count

In [52]:
subject_counts

{'4': 1996, '1': 1985, '6': 1996, '3': 1996, '2': 1996, '5': 1996}

In [75]:
splits['splits'][0]['train']

[0,
 3,
 4,
 5,
 6,
 8,
 9,
 10,
 11,
 12,
 13,
 15,
 16,
 17,
 18,
 19,
 21,
 22,
 23,
 24,
 25,
 27,
 29,
 33,
 35,
 36,
 37,
 38,
 39,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 51,
 52,
 54,
 55,
 57,
 58,
 60,
 61,
 62,
 63,
 66,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 77,
 78,
 79,
 83,
 84,
 87,
 88,
 89,
 90,
 91,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 103,
 104,
 105,
 106,
 107,
 108,
 110,
 112,
 113,
 115,
 116,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 127,
 128,
 129,
 130,
 131,
 135,
 137,
 138,
 140,
 141,
 142,
 143,
 145,
 146,
 147,
 149,
 150,
 151,
 152,
 153,
 156,
 157,
 159,
 161,
 163,
 164,
 166,
 168,
 169,
 173,
 174,
 175,
 176,
 177,
 179,
 180,
 182,
 185,
 186,
 187,
 188,
 189,
 190,
 191,
 194,
 195,
 199,
 200,
 202,
 208,
 209,
 211,
 212,
 213,
 215,
 217,
 220,
 222,
 223,
 224,
 226,
 227,
 228,
 230,
 231,
 232,
 233,
 234,
 235,
 236,
 237,
 238,
 239,
 241,
 242,
 243,
 244,
 247,
 248,
 250,
 252,
 253,
 255,
 256,
 257,
 258,
 259,
 260

In [105]:
from collections import Counter
subject_counts = {}

# 100까지 반복하면서 subject 값들을 수집
for i in range(len(val_dataset['val_data'])):
    # 데이터에서 'subject' 필드를 가져옵니다.
    subjects = val_dataset['val_data'][i]['subject']
    subjects = str(subjects)

    # subjects의 빈도를 계산합니다.
    counts = Counter(subjects)

    # subject_counts 딕셔너리에 빈도를 추가합니다.
    for subject, count in counts.items():
        if subject not in subject_counts:
            subject_counts[subject] = 0
        subject_counts[subject] += count

subject_counts

{'4': 333, '1': 333, '6': 333, '3': 333, '2': 333, '5': 333}

In [1]:
import torch
eeg_signals_path = "/Data/summer24/DreamDiffusion/datasets/eeg_5_95_std.pth"
splits_path = "/Data/summer24/DreamDiffusion/datasets/block_splits_by_image_all.pth"

eeg = torch.load(eeg_signals_path)
splits = torch.load(splits_path)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_indices = splits['splits'][0]['train']
train_data = [eeg['dataset'][idx] for idx in train_indices]
train_dataset = {'train_data': train_data}

In [3]:
val_dataloader_indices = splits['splits'][0]['val']
val_data = [eeg['dataset'][idx] for idx in val_dataloader_indices]
val_dataset = {'val_data': val_data}

In [7]:
print(len(train_dataset['train_data']))
print(len(val_dataset['val_data']))

7970
1998


In [9]:
train_dataset['train_data']

[{'eeg': tensor([[-0.0098,  0.0195,  0.0620,  ...,  0.0638,  0.0120, -0.0118],
          [-0.0045,  0.1303,  0.2673,  ...,  0.0894,  0.0342, -0.0082],
          [ 0.0215, -0.2017, -0.4305,  ..., -0.2022, -0.0940,  0.0188],
          ...,
          [ 0.0160,  0.0707,  0.1005,  ...,  0.2066,  0.1156,  0.0036],
          [-0.0046, -0.0084, -0.0119,  ...,  0.0007, -0.0026, -0.0053],
          [ 0.0040,  0.0419,  0.0665,  ...,  0.0765,  0.0309, -0.0063]]),
  'image': 0,
  'label': 10,
  'subject': 4},
 {'eeg': tensor([[ 0.0076,  0.0605,  0.1214,  ...,  0.1027,  0.0541,  0.0113],
          [ 0.0049,  0.0806,  0.1632,  ...,  0.0397,  0.0348,  0.0237],
          [-0.0148,  0.0531,  0.1162,  ...,  0.1000,  0.0386, -0.0223],
          ...,
          [-0.0229,  0.0678,  0.1807,  ..., -0.5030, -0.2601,  0.0196],
          [-0.0062,  0.0032,  0.0129,  ..., -0.0317, -0.0186, -0.0039],
          [-0.0179,  0.1185,  0.2597,  ..., -0.4428, -0.2308,  0.0150]]),
  'image': 3,
  'label': 10,
  'subject': 

In [14]:
import torch
import os
import numpy as np

dataset = val_dataset['val_data']
base_save_path = "/Data/summer24/DreamDiffusion/datasets/eegdata/val/"


for idx, tensor_item in enumerate(dataset):
    for key, value in tensor_item.items():
        subfolder_path = os.path.join(base_save_path, key)
        
        # 폴더 생성
        if not os.path.exists(subfolder_path):
            os.makedirs(subfolder_path)
        
        # 파일 저장
        try:
            file_path = os.path.join(subfolder_path, f"{idx}.npy")
            if isinstance(value, torch.Tensor):
                ndarray = value.numpy()
            else:
                ndarray = np.array(value)
            
            np.save(file_path, ndarray)
        
        except Exception as e:
            print(f"Error saving file at index {idx}, key {key}: {e}")

In [18]:
from dataset import eeg_pretrain_dataset
from config import Config_MBM_EEG
config = Config_MBM_EEG()

# data augmentation 기법
def fmri_transform(x, sparse_rate=0.2):
    # x: 1, num_voxels
    x_aug = copy.deepcopy(x)
    idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False)
    x_aug[idx] = 0
    return torch.FloatTensor(x_aug)

dataset_pretrain = eeg_pretrain_dataset(path='/Data/summer24/DreamDiffusion/data/processed/eegData_npy/', roi=config.roi, patch_size=config.patch_size,
            transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, 
            include_kam=config.include_kam, include_hcp=config.include_hcp)

print(f'Dataset size: {len(dataset_pretrain)}\n Time len: {dataset_pretrain.data_len}')

# sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=local_rank)

# dataloader_eeg = DataLoader(dataset_pretrain, batch_size=config.batch_size, sampler=sampler, 
#             shuffle=False, pin_memory=True)

# # create model
# config.time_len=dataset_pretrain.data_len
# model = MAEforEEG(time_len=dataset_pretrain.data_len, patch_size=config.patch_size, embed_dim=config.embed_dim,
#                 decoder_embed_dim=config.decoder_embed_dim, depth=config.depth, 
#                 num_heads=config.num_heads, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio,
#                 focus_range=config.focus_range, focus_rate=config.focus_rate, 
#                 img_recon_weight=config.img_recon_weight, use_nature_img_loss=config.use_nature_img_loss)   

In [21]:
dataset_pretrain = eeg_pretrain_dataset(path='/Data/summer24/DreamDiffusion/data/processed/eegData_npy/', roi=config.roi, patch_size=config.patch_size,
            transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, 
            include_kam=config.include_kam, include_hcp=config.include_hcp)

print(f'Dataset size: {len(dataset_pretrain)}\n Time len: {dataset_pretrain.data_len}')

# sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=local_rank)

# dataloader_eeg = DataLoader(dataset_pretrain, batch_size=config.batch_size, sampler=sampler, 
#             shuffle=False, pin_memory=True)

# # create model
# config.time_len=dataset_pretrain.data_len
# model = MAEforEEG(time_len=dataset_pretrain.data_len, patch_size=config.patch_size, embed_dim=config.embed_dim,
#                 decoder_embed_dim=config.decoder_embed_dim, depth=config.depth, 
#                 num_heads=config.num_heads, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio,
#                 focus_range=config.focus_range, focus_rate=config.focus_rate, 
#                 img_recon_weight=config.img_recon_weight, use_nature_img_loss=config.use_nature_img_loss)   

ValueError: Path must contain 'train' or 'val' to determine dataset type