In [1]:
import numpy as np


In [2]:
dataset_eeg = np.load('dataset_eeg.npy')
dataset_image = np.load('dataset_image.npy')
dataset_label = np.load('dataset_label.npy')


In [3]:
dataset_eeg.shape


(505, 15, 400)

In [4]:
dataset_label.shape


(505,)

In [5]:
dataset_eeg = dataset_eeg[:, 1:5]


In [6]:
dataset_eeg.shape


(505, 4, 400)

In [7]:
from datasets import Dataset, DatasetDict

dataset = Dataset.from_dict({'eeg': dataset_eeg, 'label': dataset_label})
dataset = dataset.train_test_split(test_size=0.1)


  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import numpy as np
from datasets import load_dataset
from transformers import CLIPTokenizerFast
from transformers.utils.logging import set_verbosity_warning

from trainable_clip_model import TrainableCLIPModel
from wavelet_spectrogram import cwt_spectrogram



# %%
# B, C, H, W
C = 4
H, W = 168, 400
L = 400
S = max(H, W)


# %%
map_digit_to_token = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}


# %%
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")

map_digit_to_input_ids = {
    digit: tokenizer(f"{map_digit_to_token[digit]}", padding=True).input_ids
    for digit in range(10)
}

def preprocess(inputs):
  B = len(inputs["label"])
  inputs["labels"] = inputs["label"]
  inputs["input_ids"] = [map_digit_to_input_ids[label] for label in inputs["label"]]
  inputs["attention_mask"] = [[1] * 3] * B
  samples = np.zeros((B, C, S, S))
  eeg = np.array(inputs["eeg"])
  for b in range(B):
    for i in range(C):
      sample = eeg[b, i]
      power, *_ = cwt_spectrogram(sample, 200, nNotes=24, detrend=True, normalize=True)
      samples[b, i, :min(H, S), :min(W, S)] = power.squeeze()
  inputs["pixel_values"] = samples
  return inputs


# %%
item = preprocess(dataset["train"][:8])
item["labels"], item["input_ids"], item['pixel_values'].shape


([7, 4, 3, 9, 0, 5, 8, 9],
 [[49406, 4558, 49407],
  [49406, 8700, 49407],
  [49406, 2368, 49407],
  [49406, 4629, 49407],
  [49406, 16451, 49407],
  [49406, 1929, 49407],
  [49406, 1158, 49407],
  [49406, 4629, 49407]],
 (8, 4, 400, 400))

In [11]:
# %%
remove_columns = ["eeg", "label"]


# %%
preprocessed_dataset = dataset.map(preprocess, batched=True, remove_columns=remove_columns)
print("done preprocessing!")
# %%
preprocessed_dataset.save_to_disk("our_data", num_proc=24)


Map: 100%|██████████| 454/454 [00:25<00:00, 17.88 examples/s]
Map: 100%|██████████| 51/51 [00:02<00:00, 17.32 examples/s]


done preprocessing!


Saving the dataset (24/24 shards): 100%|██████████| 454/454 [01:01<00:00,  7.42 examples/s]
Saving the dataset (24/24 shards): 100%|██████████| 51/51 [00:08<00:00,  6.03 examples/s]
