In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
from PIL import Image
from skimage import exposure

In [None]:
import sys
sys.path.append('./utils')
from build_model import MultiModalHead

In [None]:
model = MultiModalHead(in_dim = 128*2, out_dim = 4, ecg_drop = 0, projector=False)
model_dict = torch.load('./saved_models/best_multimodal.pth')

xray_dict = {}
for k,v in model_dict.items():
    if 'xray_model' in k:
        name = k[21:]
        xray_dict[name] = v

ecg_dict = {}
for k,v in model_dict.items():
    if 'ecg_model' in k:
        name = k[20:]
        ecg_dict[name] = v
        
        
proj_xray = {}
for k,v in model_dict.items():
    if 'projector_xray_text' in k:# and 'projector_xray' in k:
        name = k[20:]
        proj_xray[name] = v

proj_ecg = {}
for k,v in model_dict.items():
    if 'projector_ecg_text' in k:# and 'projector_ecg' in k:
        name = k[19:]
        proj_ecg[name] = v
        

model.xray_model.vit_model.load_state_dict(xray_dict)
model.projector_xray.load_state_dict(proj_xray)
model.ecg_model.vit_model.load_state_dict(ecg_dict)
model.projector_ecg.load_state_dict(proj_ecg)

In [None]:
chexpert = pd.read_csv('./data/chexpert_5x200.csv')
chexpert_imgs = torch.load('./data/zero_shot_chexpert_tensor.pth')

In [None]:
ground_truth = chexpert[chexpert.columns[[14, 8, 11, 16]]]
labels_indices = np.argmax(ground_truth, axis=1)

In [None]:
model.eval()
with torch.no_grad():
    features = model.projector_xray(model.xray_model(chexpert_imgs))

In [None]:
features_np = features.cpu().numpy()
tsne = TSNE(n_components=2, random_state=42)
features_reduced = tsne.fit_transform(features_np)

In [None]:
plt.figure(figsize=(10, 8))
scatter = plt.scatter(features_reduced[:, 0], features_reduced[:, 1], c=labels_indices, cmap='tab20b', alpha=0.6)
plt.legend(handles=scatter.legend_elements()[0], labels=set(labels_indices))
plt.title('t-SNE Visualization')
plt.xlabel('Component 1')
plt.ylabel('Component 2')
plt.show()
plt.savefig('tnse_ours.png')