In [2]:


import torch
from torchvision import transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import numpy as np
import os
import tqdm
from tqdm import tqdm
import timm
from huggingface_hub import login, hf_hub_download
from cebmf_torch import cEBMF
from cebmf_torch.torch_main import ModelParams, NoiseParams, CovariateParams
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
### Load data
# images = images
# locs = locations x,y
# ged = gene expression 

pwd = r"C:/Document/Serieux/Travail/Data_analysis_and_papers/chevrier/torch_data"
dir1 = os.path.join(pwd, 'gene_expression_dataset.pt')
'gene_expression_dataset.pt'
dir2 = os.path.join(pwd, 'subimg_coord_dataset.pt') 

ged = torch.load(dir1)
scd = torch.load(dir2)
images = []
locs = []
for img, loc in scd:
    images.append(img)
    locs.append(loc)

images = torch.stack(images)
locs = torch.stack(locs)

In [4]:
# Assuming 'processor', 'model', and 'device' are already defined.
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)


def get_clip_embedding(image_input, model, processor, device):
    """
    Takes a file path, PIL.Image, or torch.Tensor and returns a 512-dim CLIP embedding.
    - Handles grayscale [1,H,W] tensors by repeating channels → [3,H,W].
    - Works on GPU without unnecessary CPU/NumPy conversions.
    """
    # Case 1: file path
    if isinstance(image_input, str):
        image = Image.open(image_input).convert("RGB")

    # Case 2: PIL image
    elif isinstance(image_input, Image.Image):
        if image_input.mode != "RGB":
            image = image_input.convert("RGB")
        else:
            image = image_input

    # Case 3: torch tensor
    elif isinstance(image_input, torch.Tensor):
        image = image_input
        if image.ndim == 4:  # [B,C,H,W], reduce to batch size 1 if needed
            if image.shape[0] == 1:
                image = image.squeeze(0)
            else:
                raise ValueError("Batch tensors not supported here. Pass single image.")
        if image.ndim != 3:
            raise ValueError(f"Expected [C,H,W] tensor, got shape {image.shape}")

        # Grayscale → RGB by channel repeat
        if image.shape[0] == 1:
            image = image.repeat(3, 1, 1)

        # Convert tensor to PIL (processor expects PIL or numpy),
        # but keep data on CPU just before feeding
        image = transforms.ToPILImage()(image.cpu())

    else:
        raise TypeError(f"Unsupported input type: {type(image_input)}")

    # Process and get embedding
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_emb = model.get_image_features(**inputs)
        image_emb = image_emb / image_emb.norm(p=2, dim=-1, keepdim=True)  # normalize

    return image_emb.squeeze(0)  # shape: (512,)

In [None]:
# Load pretrained CLIP
device = "mps" if torch.mps.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

clip_embedded_images = []
for image in tqdm(images):
    tempimage = image.squeeze(0).to(device)
    feature_emb = get_clip_embedding(tempimage, clip_model, clip_processor, device) # Extracted features (torch.Tensor) with shape [1,1536]
    clip_embedded_images.append(feature_emb)

clip_embedded_images = torch.stack(clip_embedded_images)  # shape (N, 512)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
