In [2]:
import torch
import numpy as np
import torch.nn.functional as F
import warnings
import torchvision.transforms as transforms
warnings.filterwarnings('ignore')
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

## Loading PNG's and converting into tensor

In [None]:
import os
import torch
import torchvision.transforms as transforms
from PIL import Image

def load_images_and_save_tensor(image_folder, output_file):
    # Get all PNG files and sort them numerically
    image_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.png')], key=lambda x: int(os.path.splitext(x)[0]))
    
    images = []
    transform = transforms.ToTensor()
    
    for img_file in image_files:
        img_path = os.path.join(image_folder, img_file)
        img = Image.open(img_path).convert('RGB')  # Convert to RGB to maintain consistency
        img_tensor = transform(img)  # Convert image to tensor
        images.append(img_tensor)
    
    images_tensor = torch.stack(images)  # Stack into a single tensor
    torch.save(images_tensor, output_file)  # Save tensor to .pt file
    print(f"Saved {len(images)} images as a tensor in {output_file}")

load_images_and_save_tensor("/New_results/cross_model_final", "/New_results/cross_model_final.pt")

Saved 4623 images as a tensor in /srv/nfs-data/sisko/kashif/New_results/cross_model_final.pt


In [None]:

BOLD_test_stimulus = torch.load("/BOLD5000_V2/test_stimulus.pt")

In [None]:
NSD_test_stimulus = torch.load("/decoding_NSD/data_fmri_nsd/img_test.pt")

In [None]:
GOD_test_stimulus = torch.load("/decoding/GOD/test_stimulus.pt")

In [None]:
cross_model_reconstructed = torch.load("/New_results/cross_model_final.pt")

In [7]:
NSD_recontructed_images = cross_model_reconstructed[:3928]
BOLD_recontructed_images = cross_model_reconstructed[3928:4373]
GOD_recontructed_images = cross_model_reconstructed[4373:]

NSD_recontructed_images.shape, BOLD_recontructed_images.shape, GOD_recontructed_images.shape

(torch.Size([3928, 3, 1024, 1024]),
 torch.Size([445, 3, 1024, 1024]),
 torch.Size([250, 3, 1024, 1024]))

In [1]:
# import torch
# import torchvision.transforms as transforms
# import matplotlib.pyplot as plt
# import numpy as np

# # Load tensors (assuming they are already loaded as 'original_tensors' and 'reconstructed_tensors')
# # original_tensors: torch.Size([3928, 3, 1024, 1024])
# # reconstructed_tensors: torch.Size([3928, 3, 425, 425])

# # Define transformation to convert tensors to PIL images
# to_pil = transforms.ToPILImage()

# # Number of images to display
# num_images = 10
# images_per_subject = 982
# subject_ids = [1, 2, 5, 7]  # NSD subject IDs
# num_subjects = len(subject_ids)

# # Initialize lists for images
# original_samples = NSD_test_stimulus[:num_images]
# reconstructed_samples = {subject: NSD_recontructed_images[i * images_per_subject: (i * images_per_subject) + num_images] for i, subject in enumerate(subject_ids)}

# # Plot the original and reconstructed images side by side
# fig, axes = plt.subplots(num_images, num_subjects + 1, figsize=(15, num_images * 5))

# for i in range(num_images):
#     # Convert original tensor to PIL image
#     ori_img = to_pil(original_samples[i])
    
#     # Plot original image in the first column
#     axes[i, 0].imshow(ori_img)
#     axes[i, 0].axis('off')
#     axes[i, 0].set_title('Original Image')
    
#     # Plot reconstructed images for each subject
#     for j, subject in enumerate(subject_ids):
#         rec_img = to_pil(reconstructed_samples[subject][i])
#         axes[i, j + 1].imshow(rec_img)
#         axes[i, j + 1].axis('off')
#         axes[i, j + 1].set_title(f'Subject {subject}')

# plt.tight_layout()
# plt.show()

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
import random


# Subject index ranges
subject_slices = {
    'Subject1': (0, 113),
    'Subject2': (113, 226),
    'Subject3': (226, 339),
    'Subject4': (339, 445),
}

# Setup
to_pil = ToPILImage()
n_rows = 100
n_cols = 8  # 4 subjects × (original + recon)


# Sample indices for each subject
subject_indices = {}
for subj, (start, end) in subject_slices.items():
    subject_indices[subj] = list(range(start, start + n_rows))

# Create one big figure
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))

for row in range(n_rows):
    for col, subj in enumerate(['Subject1', 'Subject2', 'Subject3', 'Subject4']):
        idx = subject_indices[subj][row]

        orig_img = to_pil(BOLD_test_stimulus[idx].cpu())
        recon_img = to_pil(BOLD_recontructed_images[idx].cpu())

        axs[row, col * 2].imshow(orig_img)
        axs[row, col * 2].axis('off')
        axs[row, col * 2].set_title(f'{subj} Original', fontsize=8)

        axs[row, col * 2 + 1].imshow(recon_img)
        axs[row, col * 2 + 1].axis('off')
        axs[row, col * 2 + 1].set_title(f'{subj} Reconstructed', fontsize=8)

plt.tight_layout()

plt.show()


In [None]:
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np


# Define transformation to convert tensors to PIL images
to_pil = transforms.ToPILImage()

# Number of images to display
num_images = 50
images_per_subject = 50
subject_ids = [1, 2, 3, 4, 5]
num_subjects = len(subject_ids)

# Initialize lists for images
original_samples = GOD_test_stimulus[:num_images]
reconstructed_samples = {subject: GOD_recontructed_images[i * images_per_subject: (i * images_per_subject) + num_images] for i, subject in enumerate(subject_ids)}

# Plot the original and reconstructed images side by side
fig, axes = plt.subplots(num_images, num_subjects + 1, figsize=(15, num_images * 5))

for i in range(num_images):
    # Convert original tensor to PIL image
    ori_img = to_pil(original_samples[i])
    
    # Plot original image in the first column
    axes[i, 0].imshow(ori_img)
    axes[i, 0].axis('off')
    axes[i, 0].set_title('Original Image')
    
    # Plot reconstructed images for each subject
    for j, subject in enumerate(subject_ids):
        rec_img = to_pil(reconstructed_samples[subject][i])
        axes[i, j + 1].imshow(rec_img)
        axes[i, j + 1].axis('off')
        axes[i, j + 1].set_title(f'Subject {subject}')

plt.tight_layout()
plt.show()

In [8]:
bold_resized = F.interpolate(BOLD_test_stimulus, size=(425, 425), mode='bilinear', align_corners=False)
GOD_resized = F.interpolate(GOD_test_stimulus, size=(425, 425), mode='bilinear', align_corners=False)
NSD_BOLD_GOD_Combine = torch.cat((NSD_test_stimulus, bold_resized, GOD_resized), dim=0)
NSD_BOLD_GOD_Combine.shape

torch.Size([4623, 3, 425, 425])

In [9]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2gray
from tqdm import tqdm
import pandas as pd
import scipy.spatial as sp
import clip
from PIL import Image
from scipy.spatial.distance import correlation

@torch.no_grad()
def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True, device= device):
    preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))
    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))
    if feature_layer is None:
        preds = preds.float().flatten(1).cpu().numpy()
        reals = reals.float().flatten(1).cpu().numpy()
    else:
        preds = preds[feature_layer].float().flatten(1).cpu().numpy()
        reals = reals[feature_layer].float().flatten(1).cpu().numpy()

    r = np.corrcoef(reals, preds)
    r = r[:len(all_images), len(all_images):]
    congruents = np.diag(r)

    success = r < congruents
    success_cnt = np.sum(success, 0)

    if return_avg:
        perf = np.mean(success_cnt) / (len(all_images)-1)
        return perf
    else:
        return success_cnt, len(all_images)-1

def cal_metrics(all_images, all_brain_recons, device):
    all_images = all_images[:].to(device)
    all_brain_recons = torch.stack([img for img in all_brain_recons[:]]).to(device).to(all_images.dtype).clamp(0,1).squeeze()

    print("Images shape:", all_images.shape)
    print("Recons shape:", all_brain_recons.shape)

    # Ensure both tensors are the same size for MSE
    resize = transforms.Resize((all_images.size(2), all_images.size(3)), interpolation=transforms.InterpolationMode.BILINEAR)
    all_brain_recons = resize(all_brain_recons)

    print("Images shape after resize:", all_images.shape)
    print("Recons shape after resize:", all_brain_recons.shape)

    # Preprocess
    preprocess = transforms.Compose([
        transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    # Flatten images while keeping the batch dimension
    all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu()
    all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu()

    print(all_images_flattened.shape)
    print(all_brain_recons_flattened.shape)

    # PixCorr
    print("\n------calculating pixcorr------")
    corrsum = 0
    for i in tqdm(range(len(all_images))):
        corrsum += np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1]
    pixcorr = corrsum / len(all_images)
    print("PixCorr:", pixcorr)

    # SSIM
    preprocess = transforms.Compose([
        transforms.Resize(625, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu().numpy())
    recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu().numpy())
    print("converted, now calculating ssim...")

    ssim_score=[]
    for im, rec in tqdm(zip(img_gray, recon_gray), total=len(all_images)):
        ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))

    ssim_mean = np.mean(ssim_score)
    print("SSIM:", ssim_mean)

    # MSE
    mse = torch.nn.functional.mse_loss(all_brain_recons, all_images).item()
    print("MSE:", mse)

    # Cosine Similarity
    cosine_sim = torch.nn.functional.cosine_similarity(all_brain_recons_flattened, all_images_flattened).mean().item()
    print("Cosine Similarity:", cosine_sim)

    # Feature-based evaluations using different models
    def evaluate_model(model, preprocess, feature_layers, layer_names):
        results = {}
        for feature_layer, layer_name in zip(feature_layers, layer_names):
            print(f"\n---{layer_name}---")
            all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 
                                                     model, preprocess, feature_layer, device=device)
            results[layer_name] = np.mean(all_per_correct)
            print(f"2-way Percent Correct: {results[layer_name]:.4f}")
        return results

    # AlexNet
    from torchvision.models import alexnet, AlexNet_Weights
    alex_weights = AlexNet_Weights.IMAGENET1K_V1
    alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4', 'features.11']).to(device)
    alex_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    alexnet_results = evaluate_model(alex_model, preprocess, ['features.4', 'features.11'], ['AlexNet(2)', 'AlexNet(5)'])
    del alex_model
    torch.cuda.empty_cache()

    # InceptionV3
    from torchvision.models import inception_v3, Inception_V3_Weights
    inception_weights = Inception_V3_Weights.DEFAULT
    inception_model = create_feature_extractor(inception_v3(weights=inception_weights), return_nodes=['avgpool']).to(device)
    inception_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    inception_results = evaluate_model(inception_model, preprocess, ['avgpool'], ['InceptionV3'])
    del inception_model
    torch.cuda.empty_cache()

    #CLIP
    clip_model, preprocess = clip.load("ViT-L/14", device=device)

    preprocess = transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
    ])

    all_per_correct = two_way_identification(all_brain_recons, all_images,
                                            clip_model.encode_image, preprocess, None) # final layer
    clip_results = np.mean(all_per_correct)
    print("CLIP:", clip_results)

    # EfficientNet
    from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
    eff_weights = EfficientNet_B1_Weights.DEFAULT
    eff_model = create_feature_extractor(efficientnet_b1(weights=eff_weights), return_nodes=['avgpool']).to(device)
    eff_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    gt = eff_model(preprocess(all_images))['avgpool']
    gt = gt.reshape(len(gt), -1).cpu().numpy()
    fake = eff_model(preprocess(all_brain_recons))['avgpool']
    fake = fake.reshape(len(fake), -1).cpu().numpy()
    effnet_distance = np.array([sp.distance.correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()
    print("EffNet Distance:", effnet_distance)
    del eff_model
    torch.cuda.empty_cache()

    # SwAV
    swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
    swav_model = create_feature_extractor(swav_model, return_nodes=['avgpool']).to(device)
    swav_model.eval().requires_grad_(False)

    preprocess = transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    gt = swav_model(preprocess(all_images))['avgpool']
    gt = gt.reshape(len(gt), -1).cpu().numpy()
    fake = swav_model(preprocess(all_brain_recons))['avgpool']
    fake = fake.reshape(len(fake), -1).cpu().numpy()
    swav_distance = np.array([correlation(gt[i], fake[i]) for i in range(len(gt))]).mean()
    print("SwAV Distance:", swav_distance)
    del swav_model
    torch.cuda.empty_cache()

    # Save the results
    metrics = {
        'PixCorr': [pixcorr],
        'SSIM': [ssim_mean],
        'MSE': [mse],
        'Cosine Similarity': [cosine_sim],
        'AlexNet(2)': [alexnet_results["AlexNet(2)"]],
        'AlexNet(5)': [alexnet_results["AlexNet(5)"]],
        'InceptionV3': [inception_results["InceptionV3"]],
        'CLIP': [clip_results],  # corrected line
        'EffNet Distance': [effnet_distance],
        'SwAV Distance': [swav_distance]
    }
    return metrics



In [10]:

subject_indices = {
    "subject1": slice(0, 982),
    "subject2": slice(982, 1964),
    "subject5": slice(1964, 2946),
    "subject7": slice(2946, 3928),
    "subject8": slice(3928, 4041),
    "subject9": slice(4041, 4154),
    "subject10": slice(4154, 4267),
    "subject11": slice(4267, 4373),
    "subject12": slice(4373, 4423),
    "subject13": slice(4423, 4473),
    "subject14": slice(4473, 4523),
    "subject15": slice(4523, 4573),
    "subject16": slice(4573, 4623),
}

def calculate_subject_wise_metrics(all_images, all_brain_recons, device):
    subject_results = {}
    
    for subject, indices in subject_indices.items():
        print(f"\nProcessing {subject}...")
        subject_images = all_images[indices]
        subject_recons = all_brain_recons[indices]
        
        metrics = cal_metrics(subject_images, subject_recons, device)
        subject_results[subject] = metrics
    

    avg_metrics = {}
    for key in subject_results["subject1"].keys(): 
        avg_metrics[key] = np.mean([subject_results[subj][key] for subj in subject_results])
    
    print("\nAverage Metrics Across Subjects:")
    for key, value in avg_metrics.items():
        print(f"{key}: {value:.4f}")
    
    # Save results
    df = pd.DataFrame.from_dict(subject_results, orient='index')
    df.loc['Average'] = avg_metrics
    df.to_csv('New_results/cross_model_final.csv', sep='\t')

# Example usage
calculate_subject_wise_metrics(NSD_BOLD_GOD_Combine, cross_model_reconstructed, device)


Processing subject1...
Images shape: torch.Size([982, 3, 425, 425])
Recons shape: torch.Size([982, 3, 1024, 1024])
Images shape after resize: torch.Size([982, 3, 425, 425])
Recons shape after resize: torch.Size([982, 3, 425, 425])
torch.Size([982, 541875])
torch.Size([982, 541875])

------calculating pixcorr------


100%|██████████| 982/982 [00:01<00:00, 640.65it/s]


PixCorr: 0.06303688303255169
converted, now calculating ssim...


100%|██████████| 982/982 [00:17<00:00, 57.01it/s]


SSIM: 0.3376463062255907
MSE: 0.1145549938082695
Cosine Similarity: 0.7825461626052856

---AlexNet(2)---
2-way Percent Correct: 0.8063

---AlexNet(5)---
2-way Percent Correct: 0.9190

---InceptionV3---
2-way Percent Correct: 0.9265
CLIP: 0.9414205962160894
EffNet Distance: 0.6895740615917768


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.3866674695156197

Processing subject2...
Images shape: torch.Size([982, 3, 425, 425])
Recons shape: torch.Size([982, 3, 1024, 1024])
Images shape after resize: torch.Size([982, 3, 425, 425])
Recons shape after resize: torch.Size([982, 3, 425, 425])
torch.Size([982, 541875])
torch.Size([982, 541875])

------calculating pixcorr------


100%|██████████| 982/982 [00:01<00:00, 750.38it/s]


PixCorr: 0.06404633235039306
converted, now calculating ssim...


100%|██████████| 982/982 [00:16<00:00, 57.91it/s]


SSIM: 0.3379220873522155
MSE: 0.11432468146085739
Cosine Similarity: 0.781416118144989

---AlexNet(2)---
2-way Percent Correct: 0.8100

---AlexNet(5)---
2-way Percent Correct: 0.9232

---InceptionV3---
2-way Percent Correct: 0.9266
CLIP: 0.9319514772531459
EffNet Distance: 0.7050631283804736


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.396642854508883

Processing subject5...
Images shape: torch.Size([982, 3, 425, 425])
Recons shape: torch.Size([982, 3, 1024, 1024])
Images shape after resize: torch.Size([982, 3, 425, 425])
Recons shape after resize: torch.Size([982, 3, 425, 425])
torch.Size([982, 541875])
torch.Size([982, 541875])

------calculating pixcorr------


100%|██████████| 982/982 [00:01<00:00, 802.70it/s]


PixCorr: 0.06496410980078911
converted, now calculating ssim...


100%|██████████| 982/982 [00:16<00:00, 59.13it/s]


SSIM: 0.33996850282505714
MSE: 0.11399119347333908
Cosine Similarity: 0.7817409634590149

---AlexNet(2)---
2-way Percent Correct: 0.8102

---AlexNet(5)---
2-way Percent Correct: 0.9190

---InceptionV3---
2-way Percent Correct: 0.9262
CLIP: 0.9440406418488968
EffNet Distance: 0.685955386859804


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.383147000366125

Processing subject7...
Images shape: torch.Size([982, 3, 425, 425])
Recons shape: torch.Size([982, 3, 1024, 1024])
Images shape after resize: torch.Size([982, 3, 425, 425])
Recons shape after resize: torch.Size([982, 3, 425, 425])
torch.Size([982, 541875])
torch.Size([982, 541875])

------calculating pixcorr------


100%|██████████| 982/982 [00:01<00:00, 799.96it/s]


PixCorr: 0.0676461679428338
converted, now calculating ssim...


100%|██████████| 982/982 [00:15<00:00, 61.69it/s]


SSIM: 0.33514182883399135
MSE: 0.11407474428415298
Cosine Similarity: 0.7857064008712769

---AlexNet(2)---
2-way Percent Correct: 0.7868

---AlexNet(5)---
2-way Percent Correct: 0.9001

---InceptionV3---
2-way Percent Correct: 0.9092
CLIP: 0.9249570765107303
EffNet Distance: 0.7136681880021835


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.4049867496376099

Processing subject8...
Images shape: torch.Size([113, 3, 425, 425])
Recons shape: torch.Size([113, 3, 1024, 1024])
Images shape after resize: torch.Size([113, 3, 425, 425])
Recons shape after resize: torch.Size([113, 3, 425, 425])
torch.Size([113, 541875])
torch.Size([113, 541875])

------calculating pixcorr------


100%|██████████| 113/113 [00:00<00:00, 736.86it/s]

PixCorr: 0.033629428999447895





converted, now calculating ssim...


100%|██████████| 113/113 [00:01<00:00, 61.38it/s]


SSIM: 0.3787125321777612
MSE: 0.11293075233697891
Cosine Similarity: 0.777799665927887

---AlexNet(2)---
2-way Percent Correct: 0.7069

---AlexNet(5)---
2-way Percent Correct: 0.8237

---InceptionV3---
2-way Percent Correct: 0.7062
CLIP: 0.7664348925410873
EffNet Distance: 0.8946479359553924


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.5420178037631832

Processing subject9...
Images shape: torch.Size([113, 3, 425, 425])
Recons shape: torch.Size([113, 3, 1024, 1024])
Images shape after resize: torch.Size([113, 3, 425, 425])
Recons shape after resize: torch.Size([113, 3, 425, 425])
torch.Size([113, 541875])
torch.Size([113, 541875])

------calculating pixcorr------


100%|██████████| 113/113 [00:00<00:00, 771.96it/s]

PixCorr: 0.019490338119794646





converted, now calculating ssim...


100%|██████████| 113/113 [00:01<00:00, 60.93it/s]


SSIM: 0.3730833337385661
MSE: 0.11912981420755386
Cosine Similarity: 0.7724331021308899

---AlexNet(2)---
2-way Percent Correct: 0.6757

---AlexNet(5)---
2-way Percent Correct: 0.7485

---InceptionV3---
2-way Percent Correct: 0.6609
CLIP: 0.714680783817952
EffNet Distance: 0.9245341491973111


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.5778855604514083

Processing subject10...
Images shape: torch.Size([113, 3, 425, 425])
Recons shape: torch.Size([113, 3, 1024, 1024])
Images shape after resize: torch.Size([113, 3, 425, 425])
Recons shape after resize: torch.Size([113, 3, 425, 425])
torch.Size([113, 541875])
torch.Size([113, 541875])

------calculating pixcorr------


100%|██████████| 113/113 [00:00<00:00, 740.17it/s]

PixCorr: 0.042105902268694095





converted, now calculating ssim...


100%|██████████| 113/113 [00:01<00:00, 61.10it/s]


SSIM: 0.37561151445801216
MSE: 0.11604425311088562
Cosine Similarity: 0.7753045558929443

---AlexNet(2)---
2-way Percent Correct: 0.6448

---AlexNet(5)---
2-way Percent Correct: 0.7714

---InceptionV3---
2-way Percent Correct: 0.7173
CLIP: 0.7568742098609356
EffNet Distance: 0.9150810840843947


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.5629533629704593

Processing subject11...
Images shape: torch.Size([106, 3, 425, 425])
Recons shape: torch.Size([106, 3, 1024, 1024])
Images shape after resize: torch.Size([106, 3, 425, 425])
Recons shape after resize: torch.Size([106, 3, 425, 425])
torch.Size([106, 541875])
torch.Size([106, 541875])

------calculating pixcorr------


100%|██████████| 106/106 [00:00<00:00, 731.23it/s]

PixCorr: 0.04384213116365503





converted, now calculating ssim...


100%|██████████| 106/106 [00:01<00:00, 60.33it/s]


SSIM: 0.37497458333140726
MSE: 0.1149488165974617
Cosine Similarity: 0.7808213829994202

---AlexNet(2)---
2-way Percent Correct: 0.7143

---AlexNet(5)---
2-way Percent Correct: 0.8188

---InceptionV3---
2-way Percent Correct: 0.7130
CLIP: 0.780952380952381
EffNet Distance: 0.9114474792701386


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.5452974103243928

Processing subject12...
Images shape: torch.Size([50, 3, 425, 425])
Recons shape: torch.Size([50, 3, 1024, 1024])
Images shape after resize: torch.Size([50, 3, 425, 425])
Recons shape after resize: torch.Size([50, 3, 425, 425])
torch.Size([50, 541875])
torch.Size([50, 541875])

------calculating pixcorr------


100%|██████████| 50/50 [00:00<00:00, 810.51it/s]

PixCorr: 0.05199317191856506





converted, now calculating ssim...


100%|██████████| 50/50 [00:00<00:00, 60.30it/s]


SSIM: 0.3411259384904603
MSE: 0.12853121757507324
Cosine Similarity: 0.7547575235366821

---AlexNet(2)---
2-way Percent Correct: 0.7110

---AlexNet(5)---
2-way Percent Correct: 0.8412

---InceptionV3---
2-way Percent Correct: 0.7151
CLIP: 0.7848979591836734
EffNet Distance: 0.9248186806504047


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.57542197040215

Processing subject13...
Images shape: torch.Size([50, 3, 425, 425])
Recons shape: torch.Size([50, 3, 1024, 1024])
Images shape after resize: torch.Size([50, 3, 425, 425])
Recons shape after resize: torch.Size([50, 3, 425, 425])
torch.Size([50, 541875])
torch.Size([50, 541875])

------calculating pixcorr------


100%|██████████| 50/50 [00:00<00:00, 729.16it/s]


PixCorr: 0.006128569235353803
converted, now calculating ssim...


100%|██████████| 50/50 [00:00<00:00, 59.08it/s]


SSIM: 0.32120155601480893
MSE: 0.1338493525981903
Cosine Similarity: 0.7377384305000305

---AlexNet(2)---
2-way Percent Correct: 0.6722

---AlexNet(5)---
2-way Percent Correct: 0.8139

---InceptionV3---
2-way Percent Correct: 0.6731
CLIP: 0.8453061224489796
EffNet Distance: 0.9141452738170132


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.5634931046845765

Processing subject14...
Images shape: torch.Size([50, 3, 425, 425])
Recons shape: torch.Size([50, 3, 1024, 1024])
Images shape after resize: torch.Size([50, 3, 425, 425])
Recons shape after resize: torch.Size([50, 3, 425, 425])
torch.Size([50, 541875])
torch.Size([50, 541875])

------calculating pixcorr------


100%|██████████| 50/50 [00:00<00:00, 742.07it/s]

PixCorr: 0.03363652577853462





converted, now calculating ssim...


100%|██████████| 50/50 [00:00<00:00, 59.82it/s]


SSIM: 0.3318005115071707
MSE: 0.12873952090740204
Cosine Similarity: 0.7382766008377075

---AlexNet(2)---
2-way Percent Correct: 0.7437

---AlexNet(5)---
2-way Percent Correct: 0.8690

---InceptionV3---
2-way Percent Correct: 0.6106
CLIP: 0.8155102040816327
EffNet Distance: 0.9223298053519089


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.553029740301233

Processing subject15...
Images shape: torch.Size([50, 3, 425, 425])
Recons shape: torch.Size([50, 3, 1024, 1024])
Images shape after resize: torch.Size([50, 3, 425, 425])
Recons shape after resize: torch.Size([50, 3, 425, 425])
torch.Size([50, 541875])
torch.Size([50, 541875])

------calculating pixcorr------


100%|██████████| 50/50 [00:00<00:00, 784.99it/s]

PixCorr: 0.034170926154237384





converted, now calculating ssim...


100%|██████████| 50/50 [00:00<00:00, 60.18it/s]


SSIM: 0.33428428201347704
MSE: 0.13001148402690887
Cosine Similarity: 0.7602849006652832

---AlexNet(2)---
2-way Percent Correct: 0.7298

---AlexNet(5)---
2-way Percent Correct: 0.8580

---InceptionV3---
2-way Percent Correct: 0.7143
CLIP: 0.8420408163265306
EffNet Distance: 0.8962686529077939


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.5555986318587284

Processing subject16...
Images shape: torch.Size([50, 3, 425, 425])
Recons shape: torch.Size([50, 3, 1024, 1024])
Images shape after resize: torch.Size([50, 3, 425, 425])
Recons shape after resize: torch.Size([50, 3, 425, 425])
torch.Size([50, 541875])
torch.Size([50, 541875])

------calculating pixcorr------


100%|██████████| 50/50 [00:00<00:00, 812.87it/s]

PixCorr: 0.017489945389707946





converted, now calculating ssim...


100%|██████████| 50/50 [00:00<00:00, 60.36it/s]


SSIM: 0.3566275423336576
MSE: 0.12199704349040985
Cosine Similarity: 0.7541359066963196

---AlexNet(2)---
2-way Percent Correct: 0.7629

---AlexNet(5)---
2-way Percent Correct: 0.8661

---InceptionV3---
2-way Percent Correct: 0.7065
CLIP: 0.7644897959183674
EffNet Distance: 0.9077663490629392


Using cache found in /home/kashif/.cache/torch/hub/facebookresearch_swav_main


SwAV Distance: 0.5591891438549023

Average Metrics Across Subjects:
PixCorr: 0.0417
SSIM: 0.3491
MSE: 0.1202
Cosine Similarity: 0.7679
AlexNet(2): 0.7365
AlexNet(5): 0.8517
InceptionV3: 0.7620
CLIP: 0.8318
EffNet Distance: 0.8466
SwAV Distance: 0.5082


In [19]:
preprocess = transforms.Compose([
    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
])

BOLD_test_stimulus = preprocess(NSD_test_stimulus)
BOLD_recontructed_images = preprocess(NSD_recontructed_images)

In [20]:
from torchmetrics.functional import accuracy

def n_way_top_k_acc(pred, class_id, n_way, num_trials=40, top_k=1):
    pick_range =[i for i in np.arange(len(pred)) if i != class_id]
    acc_list = []
    for t in range(num_trials):
        idxs_picked = np.random.choice(pick_range, n_way-1, replace=False)
        pred_picked = torch.cat([pred[class_id].unsqueeze(0), pred[idxs_picked]])
        acc = accuracy(pred_picked.unsqueeze(0), torch.tensor([0], device=pred.device), 
                    task='multiclass',num_classes=n_way,top_k=top_k)
        acc_list.append(acc.item())
    return np.mean(acc_list), np.std(acc_list)


In [None]:
from torchvision.models import ViT_H_14_Weights, vit_h_14
from tqdm import tqdm

weights = ViT_H_14_Weights.DEFAULT
vit_model = vit_h_14(weights=weights)
vit_model.to(device)
vit_model.eval()
preprocess = weights.transforms()
acc_list = []
std_list = []

for i in tqdm(range(len(BOLD_recontructed_images)), desc="Processing images"):
    image, recon_image = preprocess(BOLD_test_stimulus[i].unsqueeze(0)).to(device), preprocess(BOLD_recontructed_images[i].unsqueeze(0)).to(device)
    recon_image_out = vit_model(recon_image).squeeze(0).softmax(0).detach()
    gt_class_id = vit_model(image).squeeze(0).softmax(0).argmax().item()
    acc, std = n_way_top_k_acc(recon_image_out, gt_class_id, 50, 1000, 1)
    acc_list.append(acc)
    std_list.append(std)

print("mean acc: {}, std acc: {}".format(np.mean(acc_list), np.std(acc_list)))
