In [None]:
import pandas as pd
import os
from os.path import join
import numpy as np
import mne
from mne_bids import (
    BIDSPath,
    read_raw_bids,
    print_dir_tree,
    make_report,
    find_matching_paths,
    get_entity_vals,
)

import h5py
from os.path import join as opj
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import tqdm
#from versatile_diffusion_dual_guided_fake_images import *

from torchsummary import summary

import pandas as pd
import os
from os.path import join as opj
from PIL import Image
import h5py
import numpy as np
import nibabel as nib
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

from sklearn.metrics import accuracy_score,classification_report,confusion_matrix

#import labelencoder
from sklearn.preprocessing import LabelEncoder
#import pipeline
from sklearn.pipeline import Pipeline
import tqdm
import torchvision

In [None]:

class NSDDataset(Dataset):
    

    
    def __init__(self, fmri_data,imgs_data,caption_data,transforms=None):
        self.fmri_data=np.load(fmri_data)
        self.imgs_data=np.load(imgs_data).astype(np.uint8)
        self.caption_data=np.load(caption_data,allow_pickle=True)
        self.transforms=transforms
        
    def __len__(self):
        return  len(self.fmri_data)
    
    def __getitem__(self,idx):
        fmri=torch.tensor(self.fmri_data[idx])
        img=Image.fromarray(self.imgs_data[idx])
        
        if self.transforms:
            img=self.transforms(img)
        
        caption=self.caption_data[idx][0] #cambiare se ne voglio altre
        
        return fmri,img,caption
                       

In [None]:
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image

pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda:3")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
pipeline.set_ip_adapter_scale(0.6)

In [None]:
def get_dataset(sub):
    base_path="/storage/fMRI_NSD/data"
    timeseries_path=opj(base_path,"nsddata_timeseries")
    betas_path=opj(base_path,"nsddata_betas")

    stimuli_path=opj(base_path,"nsddata_stimuli","stimuli","nsd")
    stim_file_path=opj(stimuli_path,"nsd_stimuli.hdf5")
    mod="func1pt8mm"
    subj_data_path=opj(timeseries_path,"ppdata",sub,mod,"timeseries")
    subj_betas_path=opj(betas_path,"ppdata",sub,mod,"betas_assumehrf")

    subj_betas_roi_extracted_path=opj(base_path,"processed_roi",sub,mod)

    stim_order_path=opj(base_path,"nsddata","experiments","nsd","nsd_expdesign.mat")
    stim_info_path=opj(base_path,"nsddata","experiments","nsd","nsd_stim_info_merged.csv")
    stim_captions_train_path=opj(base_path,"nsddata_stimuli","stimuli","nsd","annotations",f"captions_train2017.json")
    stim_captions_val_path=opj(base_path,"nsddata_stimuli","stimuli","nsd","annotations",f"captions_val2017.json")

    processed_data=opj(base_path,"processed_data",sub)

    sub_idx=int(sub.split("0")[-1])

    fmri_train_data=opj(processed_data,f"nsd_train_fmriavg_nsdgeneral_sub{sub_idx}.npy")
    imgs_train_data=opj(processed_data,f"nsd_train_stim_sub{sub_idx}.npy")
    captions_train_data=opj(processed_data, f"nsd_train_cap_sub{sub_idx}.npy")

    fmri_test_data=opj(processed_data,f"nsd_test_fmriavg_nsdgeneral_sub{sub_idx}.npy")
    imgs_test_data=opj(processed_data,f"nsd_test_stim_sub{sub_idx}.npy")
    captions_test_data=opj(processed_data, f"nsd_test_cap_sub{sub_idx}.npy")



    tr=torchvision.transforms.ToTensor()
    train_dataset=NSDDataset(fmri_train_data,imgs_train_data,captions_train_data,transforms=tr)
    test_dataset=NSDDataset(fmri_test_data,imgs_test_data,captions_test_data,transforms=tr)


    BS=32

    train_dataloader=DataLoader(train_dataset,BS,shuffle=False)
    # val_dataloader=DataLoader(val_dataset,BS,shuffle=True)
    test_dataloader=DataLoader(test_dataset,BS,shuffle=False)



    to_pil= torchvision.transforms.ToPILImage()

    train_clip_img_embeds=[]
    train_fmri_data=[]
    train_imgs=[]
    with torch.no_grad():
        for x, y, c in tqdm.tqdm(train_dataloader, position=0):
            # Save img data
            images = [to_pil(i) for i in y]
            # Process images one by one
            image_features_batch = []
            for image in images:
                image_features = pipeline.prepare_ip_adapter_image_embeds(
                    ip_adapter_image=image,
                    ip_adapter_image_embeds=None,
                    device="cuda:3",
                    num_images_per_prompt=1,
                    do_classifier_free_guidance=True,
                )
                image_features_batch.append(image_features[0])  # Append the tensor directly
            image_features_batch = torch.stack(image_features_batch, axis=0)  # Use stack instead of cat
            
            train_clip_img_embeds.append(image_features_batch)
            train_fmri_data.append(x)
            train_imgs.append(y)

        train_clip_img_embeds = torch.cat(train_clip_img_embeds, axis=0)
        train_fmri_data = torch.cat(train_fmri_data, 0)
        train_imgs = torch.cat(train_imgs, 0)

    test_clip_img_embeds = []
    test_fmri_data = []
    test_imgs = []

    with torch.no_grad():
        for x, y, c in tqdm.tqdm(test_dataloader, position=0):
            # Save img data
            images = [to_pil(i) for i in y]
            # Process images one by one
            image_features_batch = []
            for image in images:
                image_features = pipeline.prepare_ip_adapter_image_embeds(
                    ip_adapter_image=image,
                    ip_adapter_image_embeds=None,
                    device="cuda:3",
                    num_images_per_prompt=1,
                    do_classifier_free_guidance=True,
                )
                image_features_batch.append(image_features[0])  # Append the tensor directly
            image_features_batch = torch.stack(image_features_batch, axis=0)  # Use stack instead of cat
            
            test_clip_img_embeds.append(image_features_batch)
            test_fmri_data.append(x)
            test_imgs.append(y)

        test_clip_img_embeds = torch.cat(test_clip_img_embeds, axis=0)
        test_fmri_data = torch.cat(test_fmri_data, 0)
        test_imgs = torch.cat(test_imgs, 0)

    # Standardize data
    mean = train_fmri_data.mean(0)
    std = train_fmri_data.std(0)
    train_fmri_data = (train_fmri_data - mean) / std
    test_fmri_data = (test_fmri_data - mean) / std

    test_fmri_data = torch.nan_to_num(test_fmri_data)
    train_fmri_data = torch.nan_to_num(train_fmri_data)

    return train_fmri_data, test_fmri_data, train_imgs, test_imgs, train_clip_img_embeds, test_clip_img_embeds


In [None]:
train_datas=[]
test_datas=[]
img_train=[]
img_test=[]
train_clip_img_embeds=[]
test_clip_img_embeds=[]
subject_train_ids=[]
subject_test_ids=[]

for p in tqdm.tqdm(["subj01","subj02","subj05","subj07"]):
    train_data_,test_data_, img_train_, img_test_,train_clip_img_embeds_, test_clip_img_embeds_=get_dataset(p)
    img_train.append(img_train_)
    img_test.append(img_test_)
    train_datas.append(train_data_)
    test_datas.append(test_data_)
    train_clip_img_embeds.append(train_clip_img_embeds_)
    test_clip_img_embeds.append(test_clip_img_embeds_)

    subject_train_ids+=[p]*len(train_data_)
    subject_test_ids+=[p]*len(test_data_)

In [None]:
print("train data shape",train_datas[0].shape)
print("test data shape",test_datas[0].shape)
print("img shape",img_train[0].shape)
print("imgtest shape",img_test[0].shape)
print("train_emb shape",train_clip_img_embeds[0].shape)
print("test emb shape",test_clip_img_embeds[0].shape)
