In [1]:
import sys
import os

# Add MedCLIP folder to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', 'MedCLIP'))
if project_root not in sys.path:
    sys.path.append(project_root)

In [3]:
import medclip
print(medclip.__file__)

ModuleNotFoundError: No module named 'medclip'

In [2]:
import os
import torch
import pydicom
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from sklearn.manifold import TSNE
import pandas as pd

from medclip.modeling_medclip import MedCLIPVisionModelViT

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'medclip'

In [None]:
# -----------------------------
# Configuration
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "../MedCLIP/checkpoints/rsna_binary_classification_SU/final_model.bin"
csv_path = "rsna_samples.csv"  # CSV must contain columns: path, label
max_samples = 300  # limit number to speed up embedding + plotting

In [None]:
# -----------------------------
# Load MedCLIP Vision Model
# -----------------------------
model = MedCLIPVisionModelViT()
state_dict = torch.load(model_path, map_location="cpu")
vision_state_dict = {
    k.replace("vision_model.", ""): v
    for k, v in state_dict.items()
    if k.startswith("vision_model.")
}
model.load_state_dict(vision_state_dict)
model.to(device).eval()

In [None]:
# -----------------------------
# Image preprocessing
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304]),
])

In [None]:
# -----------------------------
# Load samples
# -----------------------------
df = pd.read_csv(csv_path).head(max_samples)

embeddings = []
labels = []

for idx, row in df.iterrows():
    dicom_path = row["path"]
    label = int(row["label"])

    ds = pydicom.dcmread(dicom_path)
    image_np = ds.pixel_array.astype(np.float32)
    image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
    image = Image.fromarray((image_np * 255).astype(np.uint8)).convert("L").convert("RGB")

    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        emb = model(input_tensor, project=False)
        emb = emb.cpu().squeeze().numpy()
        embeddings.append(emb)
        labels.append(label)

In [None]:
# -----------------------------
# Dimensionality reduction
# -----------------------------
print("Running t-SNE...")
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
emb_2d = tsne.fit_transform(np.array(embeddings))

In [None]:
# -----------------------------
# Plot
# -----------------------------
plt.figure(figsize=(8, 6))
for lbl in [0, 1]:
    idxs = np.where(np.array(labels) == lbl)[0]
    plt.scatter(
        emb_2d[idxs, 0], emb_2d[idxs, 1],
        label="Pneumonia" if lbl else "Healthy",
        alpha=0.7, s=30
    )

plt.title("2D Embedding Visualization via t-SNE")
plt.xlabel("Dim 1")
plt.ylabel("Dim 2")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("embedding_plot.png")
plt.show()