In [1]:
from PIL import Image
import os
import numpy as np
import torch
import pandas as pd
from math import ceil
from tqdm import tqdm
from huggingface_hub import login

# Constants
num_slides = 250
num_patches_per_slide = 250
patch_size = 224

In [2]:
preprocessed_patches_dir_brca = "/lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_BRCA"
preprocessed_patches_dir_luad = "/lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUAD"
preprocessed_patches_dir_lusc = "/lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUSC"
preprocessed_patches_dir_coad = "/lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_COAD"

In [3]:
def embed(patches, model, preprocess, device, batch_size=64, verbose=True):
    num_batches = ceil(len(patches) / batch_size)
    opt_embs = []

    for batch_idx in tqdm(range(num_batches), disable=not verbose):
        start = batch_idx * batch_size
        end = min(start + batch_size, len(patches))
        batch_np = patches[start:end]

        try:
            # Convert numpy arrays to PIL Images
            batch_pil = [Image.fromarray(patch.astype('uint8')) for patch in batch_np]
            
            # Apply CONCH preprocessing to each image individually
            batch_transformed = torch.stack([preprocess(img) for img in batch_pil])
            
            # Move batch to device
            batch = batch_transformed.to(device)

            # Get embeddings
            with torch.no_grad():
                batch_emb = model.encode_image(batch)

            # Move to CPU and append
            opt_embs.append(batch_emb.cpu())

        except Exception as e:
            print(f"Error processing batch {batch_idx}: {str(e)}")
            continue

    if not opt_embs:
        return np.array([])

    # Stack all embeddings
    opt_embs = torch.cat(opt_embs, dim=0)
    
    return opt_embs.numpy()

In [4]:
def load_patches_from_individual_files(patches_dir, normalized=False):
    patches_list = []
    
    if not os.path.exists(patches_dir):
        print(f"Directory not found: {patches_dir}")
        return np.array([])
    
    if normalized:
        pattern = "_patches-normalized.npy"
    else:
        pattern = "_patches.npy"
    
    filenames = [f for f in os.listdir(patches_dir) if f.endswith(pattern)]
    
    if not filenames:
        print(f"No files found matching pattern '{pattern}' in {patches_dir}")
        return np.array([])
    
    print(f"Found {len(filenames)} patch files in {patches_dir}")
    
    for filename in tqdm(filenames, desc=f"Loading {'normalized' if normalized else 'original'} patches"):
        try:
            patches = np.load(os.path.join(patches_dir, filename))
            patches_list.append(patches)
        except Exception as e:
            print(f"Error loading {filename}: {e}")
            continue
    
    if patches_list:
        all_patches = np.concatenate(patches_list, axis=0)
        print(f"Total patches loaded: {len(all_patches)}")
        return all_patches
    else:
        print("No patches could be loaded")
        return np.array([])

In [5]:
def embed_patches(patches, model, preprocess, device):
    if len(patches) == 0:
        return np.array([])
    
    model = model.to(device)
    model.eval()
    return embed(patches, model, preprocess, device)

In [None]:
pip install git+https://github.com/Mahmoodlab/CONCH.git

In [None]:
# Run this: !rm conch.py if not working

In [6]:
from conch.open_clip_custom import create_model_from_pretrained



In [8]:
model, preprocess = create_model_from_pretrained(
    'conch_ViT-B-16',
    "hf_hub:MahmoodLab/conch",
    hf_auth_token="YOUR_HF_TOKEN"
)
model.eval()
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)

CoCa(
  (text): TextTransformer(
    (token_embedding): Embedding(32007, 768)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (visual): VisualModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kerne

In [9]:
brca_patches = load_patches_from_individual_files(preprocessed_patches_dir_brca, normalized=False)
brca_patches_norm = load_patches_from_individual_files(preprocessed_patches_dir_brca, normalized=True)

luad_patches = load_patches_from_individual_files(preprocessed_patches_dir_luad, normalized=False)
luad_patches_norm = load_patches_from_individual_files(preprocessed_patches_dir_luad, normalized=True)

lusc_patches = load_patches_from_individual_files(preprocessed_patches_dir_lusc, normalized=False)
lusc_patches_norm = load_patches_from_individual_files(preprocessed_patches_dir_lusc, normalized=True)

coad_patches = load_patches_from_individual_files(preprocessed_patches_dir_coad, normalized=False)
coad_patches_norm = load_patches_from_individual_files(preprocessed_patches_dir_coad, normalized=True)

brca_embeddings = embed_patches(brca_patches, model, preprocess, device)
luad_embeddings = embed_patches(luad_patches, model, preprocess, device)
lusc_embeddings = embed_patches(lusc_patches, model, preprocess, device)
coad_embeddings = embed_patches(coad_patches, model, preprocess, device)

brca_embeddings_norm = embed_patches(brca_patches_norm, model, preprocess, device)
luad_embeddings_norm = embed_patches(luad_patches_norm, model, preprocess, device)
lusc_embeddings_norm = embed_patches(lusc_patches_norm, model, preprocess, device)
coad_embeddings_norm = embed_patches(coad_patches_norm, model, preprocess, device)

num_brca = len(brca_embeddings)
num_luad = len(luad_embeddings)
num_lusc = len(lusc_embeddings)
num_coad = len(coad_embeddings)

num_brca_norm = len(brca_embeddings_norm)
num_luad_norm = len(luad_embeddings_norm)
num_lusc_norm = len(lusc_embeddings_norm)
num_coad_norm = len(coad_embeddings_norm)

brca_labels = [f"BRCA_{i+1}" for i in range(num_brca)]
luad_labels = [f"LUAD_{i+1}" for i in range(num_luad)]
lusc_labels = [f"LUSC_{i+1}" for i in range(num_lusc)]
coad_labels = [f"COAD_{i+1}" for i in range(num_coad)]

brca_labels_norm = [f"BRCA_norm_{i+1}" for i in range(num_brca_norm)]
luad_labels_norm = [f"LUAD_norm_{i+1}" for i in range(num_luad_norm)]
lusc_labels_norm = [f"LUSC_norm_{i+1}" for i in range(num_lusc_norm)]
coad_labels_norm = [f"COAD_norm_{i+1}" for i in range(num_coad_norm)]

np.save("/lotterlab/users/vmishra/RSA_updated100/brca_embeddings_conch_updated.npy", brca_embeddings)
np.save("/lotterlab/users/vmishra/RSA_updated100/luad_embeddings_conch_updated.npy", luad_embeddings)
np.save("/lotterlab/users/vmishra/RSA_updated100/lusc_embeddings_conch_updated.npy", lusc_embeddings)
np.save("/lotterlab/users/vmishra/RSA_updated100/coad_embeddings_conch_updated.npy", coad_embeddings)

np.save("/lotterlab/users/vmishra/RSA_updated100/brca_embeddings_conch_normalized_updated.npy", brca_embeddings_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/luad_embeddings_conch_normalized_updated.npy", luad_embeddings_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/lusc_embeddings_conch_normalized_updated.npy", lusc_embeddings_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/coad_embeddings_conch_normalized_updated.npy", coad_embeddings_norm)

np.save("/lotterlab/users/vmishra/RSA_updated100/brca_labels_conch_updated.npy", brca_labels)
np.save("/lotterlab/users/vmishra/RSA_updated100/luad_labels_conch_updated.npy", luad_labels)
np.save("/lotterlab/users/vmishra/RSA_updated100/lusc_labels_conch_updated.npy", lusc_labels)
np.save("/lotterlab/users/vmishra/RSA_updated100/coad_labels_conch_updated.npy", coad_labels)

np.save("/lotterlab/users/vmishra/RSA_updated100/brca_labels_conch_norm_updated.npy", brca_labels_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/luad_labels_conch_norm_updated.npy", luad_labels_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/lusc_labels_conch_norm_updated.npy", lusc_labels_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/coad_labels_conch_norm_updated.npy", coad_labels_norm)

Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_BRCA


Loading original patches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 145.95it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_BRCA


Loading normalized patches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 137.99it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUAD


Loading original patches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 146.59it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUAD


Loading normalized patches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 149.90it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUSC


Loading original patches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 149.45it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUSC


Loading normalized patches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 152.11it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_COAD


Loading original patches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 151.00it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_COAD


Loading normalized patches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 152.08it/s]


Total patches loaded: 750


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:06<00:00,  1.80it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:06<00:00,  1.87it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:06<00:00,  1.86it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:06<00:00,  1.74it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████