In [2]:
import matplotlib.pyplot as plt
import numpy as np
import requests
import seaborn as sns
import torch
from datasets import load_dataset
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from wavelet_spectrogram import cwt_spectrogram
import itertools
import os




In [3]:
dataset = load_dataset("DavidVivancos/MindBigData2022_MNIST_IN")
list(dataset["train"][0].items())[:8]


  0%|          | 0/2 [00:00<?, ?it/s]

[('label', 6),
 ('AF3-0', 4284.102564),
 ('AF3-1', 4281.073718),
 ('AF3-2', 4296.923077),
 ('AF3-3', 4312.588141),
 ('AF3-4', 4318.141025),
 ('AF3-5', 4314.20673),
 ('AF3-6', 4296.65064)]

In [4]:
channels = sorted(set(map(lambda x: x.split("-")[0], dataset["train"][0].keys())) - {'label'})
channels


['AF3', 'AF4', 'PZ', 'T7', 'T8']

In [168]:
# B, C, H, W
C = len(channels)
H, W = 168, 256
L = 256


In [93]:
map_digit_to_token = {
    0: "zero",
    1: "one",
    2: "two",
    3: "three",
    4: "four",
    5: "five",
    6: "six",
    7: "seven",
    8: "eight",
    9: "nine",
}


In [169]:
def preprocess(inputs):
  outputs = {}
  outputs["label"] = [map_digit_to_token[label] for label in inputs["label"]]
  B = len(inputs["label"])
  samples = np.zeros((B, C, H, W))
  for b in range(B):
    for i, channel in enumerate(channels):
      sample = np.array([inputs[f"{channel}-{j}"][b] for j in range(L)])
      power, *_ = cwt_spectrogram(sample, 120, nNotes=24, detrend=True, normalize=True)
      samples[b, i, :, :] = power.squeeze()
  outputs["pixel_values"] = samples
  return outputs


In [166]:
item = preprocess(dataset["train"][:8])
item["label"], item['pixel_values'].shape


(['six', 'six', 'six', 'four'], (4, 5, 168, 256))

In [176]:

remove_columns = [f"{channel}-{i}" for channel, i in itertools.product(channels, range(L))]


In [177]:
preprocessed_dataset = dataset.map(preprocess, batched=True, remove_columns=remove_columns, num_proc=os.cpu_count())


                                 

#0:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#1:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#2:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#3:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#4:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#5:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#6:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#7:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#8:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#9:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#10:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#11:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#12:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#13:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#14:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#15:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#16:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#17:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#18:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#19:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#20:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#21:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#22:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#23:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#24:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#25:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#26:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#27:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#28:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#29:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#30:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#31:   0%|          | 0/1 [00:00<?, ?ba/s]

KeyboardInterrupt: 

In [59]:
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection


In [125]:
configuration = CLIPVisionConfig(
    image_size=max(H, W),
    num_channels=C,
)
vision_model = CLIPVisionModelWithProjection(configuration)


In [147]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [127]:
# model.vision_model = vision_model.vision_model


In [123]:
# for p in model.text_model.parameters():
#   p.requires_grad = False
# for p in model.text_projection.parameters():
#   p.requires_grad = False


In [84]:
model


CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [124]:
print(f"All Parameters: {sum(p.numel() for p in model.parameters()) / 1e6}M")
print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6}M")


All Parameters: 152.861697M
Trainable Parameters: 89.433601M


In [62]:
from torchvision.transforms import Pad


In [64]:
S = max(H, W)
pad = Pad(
    (0, S - H, 0, S - W),
    padding_mode="constant",
    fill=0,
)


In [151]:
urls = [
  "http://images.cocodataset.org/val2017/000000039769.jpg",
  "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg",
]
images = [Image.open(requests.get(url, stream=True).raw) for url in urls]


In [154]:
dataset["train"].batch()


AttributeError: 'Dataset' object has no attribute 'batch'

In [153]:
for item in dataset["train"]:
    item = preprocess(item)
    texts = item["label"]
    texts = ["six", "cat"]
    inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
    a = inputs["pixel_values"]
    inputs["pixel_values"] = pad(torch.tensor(item["pixel_values"]).unsqueeze(0))
    b = inputs["pixel_values"]
    clip_output = model(**inputs, return_loss=True)
    print(clip_output.loss)
    break


tensor(2.3250, grad_fn=<DivBackward0>)


In [143]:
b.shape


torch.Size([1, 3, 224, 224])

In [144]:
b.shape


torch.Size([1, 5, 256, 256])