In [10]:
pip install git+https://github.com/lilab-stanford/MUSK.git

Collecting git+https://github.com/lilab-stanford/MUSK.git
  Cloning https://github.com/lilab-stanford/MUSK.git to /tmp/pip-req-build-42era1aw
  Running command git clone --filter=blob:none --quiet https://github.com/lilab-stanford/MUSK.git /tmp/pip-req-build-42era1aw
  Resolved https://github.com/lilab-stanford/MUSK.git to commit fc9421aaebb2a3651fed5b69558c306f2836c228
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hBuilding wheels for collected packages: musk
  Building wheel for musk (pyproject.toml) ... [?25ldone
[?25h  Created wheel for musk: filename=musk-1.0.0-py3-none-any.whl size=51766 sha256=2a44c5481ac837b58e4bb3cd9b4efbf860f88ff2996d6d31310c558faaf92d99
  Stored in directory: /tmp/pip-ephem-wheel-cache-qldhnzd_/wheels/05/9d/56/a69e763dd2663e34d1a36ceb4feec33f79f036d7cd20fa7396
Successfully built musk
Installing collected packages: musk
Successfully in

In [26]:
!pip install fairscale

Collecting fairscale
  Using cached fairscale-0.4.13.tar.gz (266 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: fairscale
  Building wheel for fairscale (pyproject.toml) ... [?25ldone
[?25h  Created wheel for fairscale: filename=fairscale-0.4.13-py3-none-any.whl size=332208 sha256=662df6dc9399d1675a801cebc8b795c96289ae9c629bc3de9d26c31f35951fc5
  Stored in directory: /homes2/vmishra/.cache/pip/wheels/5a/88/aa/d84b2cf1bad6b273cbf661640141a82c7b9f496e024f80aac0
Successfully built fairscale
Installing collected packages: fairscale
Successfully installed fairscale-0.4.13


In [27]:
from PIL import Image
import os
import numpy as np
import torch
import pandas as pd
from math import ceil
from tqdm import tqdm
from huggingface_hub import login
import torchvision
from timm.models import create_model
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from musk import utils
from musk import modeling

# Constants
num_slides = 250
num_patches_per_slide = 250
patch_size = 224

In [28]:
preprocessed_patches_dir_brca = "/lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_BRCA"
preprocessed_patches_dir_luad = "/lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUAD"
preprocessed_patches_dir_lusc = "/lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUSC"
preprocessed_patches_dir_coad = "/lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_COAD"

login(token = "YOUR_HF_TOKEN")

In [35]:
def embed(
    patches,
    model,
    transform,
    device,
    batch_size=64,
    verbose=True,
):
    num_batches = ceil(len(patches) / batch_size)
    opt_embs = []

    for batch_idx in tqdm(range(num_batches), disable=not verbose):
        # Slice batch
        start = batch_idx * batch_size
        end = min(start + batch_size, len(patches))
        batch_np = patches[start:end]

        # Convert numpy arrays to PIL Images for transform
        batch_pil = [Image.fromarray(patch.astype('uint8')).convert("RGB") for patch in batch_np]
        
        # Apply transform to each image
        batch_transformed = [transform(img) for img in batch_pil]
        
        # Stack transformed images
        batch = torch.stack(batch_transformed).to(device, dtype=torch.float16)

        # Call MUSK model
        with torch.inference_mode():
            batch_emb = model(
                image=batch,
                with_head=False,
                out_norm=False,
                ms_aug=True,
                return_global=True  
            )[0]

        # Copy to host and append
        opt_embs.append(batch_emb.cpu())

    # Stack to contiguous array
    opt_embs = torch.cat(opt_embs, dim=0)

    return opt_embs

In [36]:
def load_patches_from_individual_files(patches_dir, normalized=False):
    patches_list = []
    
    if not os.path.exists(patches_dir):
        print(f"Directory not found: {patches_dir}")
        return np.array([])
    
    if normalized:
        pattern = "_patches-normalized.npy"
    else:
        pattern = "_patches.npy"
    
    filenames = [f for f in os.listdir(patches_dir) if f.endswith(pattern)]
    
    if not filenames:
        print(f"No files found matching pattern '{pattern}' in {patches_dir}")
        return np.array([])
    
    print(f"Found {len(filenames)} patch files in {patches_dir}")
    
    for filename in tqdm(filenames, desc=f"Loading {'normalized' if normalized else 'original'} patches"):
        try:
            patches = np.load(os.path.join(patches_dir, filename))
            patches_list.append(patches)
        except Exception as e:
            print(f"Error loading {filename}: {e}")
            continue
    
    if patches_list:
        all_patches = np.concatenate(patches_list, axis=0)
        print(f"Total patches loaded: {len(all_patches)}")
        return all_patches
    else:
        print("No patches could be loaded")
        return np.array([])

In [37]:
def embed_patches(patches, model, transform, device):
    if len(patches) == 0:
        return np.array([])
    
    return embed(patches, model, transform, device).numpy()

In [38]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = create_model("musk_large_patch16_384")
utils.load_model_and_may_interpolate("hf_hub:xiangjx/musk", model, 'model|module', '')
model.to(device=device, dtype=torch.float16)
model.eval()

preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(384, interpolation=3, antialias=True),
    torchvision.transforms.CenterCrop((384, 384)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
])

model.safetensors:   0%|          | 0.00/1.35G [00:00<?, ?B/s]

Load ckpt from hf_hub:xiangjx/musk


In [39]:
brca_patches = load_patches_from_individual_files(preprocessed_patches_dir_brca, normalized=False)
brca_patches_norm = load_patches_from_individual_files(preprocessed_patches_dir_brca, normalized=True)

luad_patches = load_patches_from_individual_files(preprocessed_patches_dir_luad, normalized=False)
luad_patches_norm = load_patches_from_individual_files(preprocessed_patches_dir_luad, normalized=True)

lusc_patches = load_patches_from_individual_files(preprocessed_patches_dir_lusc, normalized=False)
lusc_patches_norm = load_patches_from_individual_files(preprocessed_patches_dir_lusc, normalized=True)

coad_patches = load_patches_from_individual_files(preprocessed_patches_dir_coad, normalized=False)
coad_patches_norm = load_patches_from_individual_files(preprocessed_patches_dir_coad, normalized=True)

brca_embeddings = embed_patches(brca_patches, model, preprocess, device)
luad_embeddings = embed_patches(luad_patches, model, preprocess, device)
lusc_embeddings = embed_patches(lusc_patches, model, preprocess, device)
coad_embeddings = embed_patches(coad_patches, model, preprocess, device)

brca_embeddings_norm = embed_patches(brca_patches_norm, model, preprocess, device)
luad_embeddings_norm = embed_patches(luad_patches_norm, model, preprocess, device)
lusc_embeddings_norm = embed_patches(lusc_patches_norm, model, preprocess, device)
coad_embeddings_norm = embed_patches(coad_patches_norm, model, preprocess, device)

num_brca = len(brca_embeddings)
num_luad = len(luad_embeddings)
num_lusc = len(lusc_embeddings)
num_coad = len(coad_embeddings)

num_brca_norm = len(brca_embeddings_norm)
num_luad_norm = len(luad_embeddings_norm)
num_lusc_norm = len(lusc_embeddings_norm)
num_coad_norm = len(coad_embeddings_norm)

brca_labels = [f"BRCA_{i+1}" for i in range(num_brca)]
luad_labels = [f"LUAD_{i+1}" for i in range(num_luad)]
lusc_labels = [f"LUSC_{i+1}" for i in range(num_lusc)]
coad_labels = [f"COAD_{i+1}" for i in range(num_coad)]

brca_labels_norm = [f"BRCA_norm_{i+1}" for i in range(num_brca_norm)]
luad_labels_norm = [f"LUAD_norm_{i+1}" for i in range(num_luad_norm)]
lusc_labels_norm = [f"LUSC_norm_{i+1}" for i in range(num_lusc_norm)]
coad_labels_norm = [f"COAD_norm_{i+1}" for i in range(num_coad_norm)]

np.save("/lotterlab/users/vmishra/RSA_updated100/brca_embeddings_musk_updated.npy", brca_embeddings)
np.save("/lotterlab/users/vmishra/RSA_updated100/luad_embeddings_musk_updated.npy", luad_embeddings)
np.save("/lotterlab/users/vmishra/RSA_updated100/lusc_embeddings_musk_updated.npy", lusc_embeddings)
np.save("/lotterlab/users/vmishra/RSA_updated100/coad_embeddings_musk_updated.npy", coad_embeddings)

np.save("/lotterlab/users/vmishra/RSA_updated100/brca_embeddings_musk_normalized_updated.npy", brca_embeddings_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/luad_embeddings_musk_normalized_updated.npy", luad_embeddings_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/lusc_embeddings_musk_normalized_updated.npy", lusc_embeddings_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/coad_embeddings_musk_normalized_updated.npy", coad_embeddings_norm)

np.save("/lotterlab/users/vmishra/RSA_updated100/brca_labels_musk_updated.npy", brca_labels)
np.save("/lotterlab/users/vmishra/RSA_updated100/luad_labels_musk_updated.npy", luad_labels)
np.save("/lotterlab/users/vmishra/RSA_updated100/lusc_labels_musk_updated.npy", lusc_labels)
np.save("/lotterlab/users/vmishra/RSA_updated100/coad_labels_musk_updated.npy", coad_labels)

np.save("/lotterlab/users/vmishra/RSA_updated100/brca_labels_musk_norm_updated.npy", brca_labels_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/luad_labels_musk_norm_updated.npy", luad_labels_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/lusc_labels_musk_norm_updated.npy", lusc_labels_norm)
np.save("/lotterlab/users/vmishra/RSA_updated100/coad_labels_musk_norm_updated.npy", coad_labels_norm)

Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_BRCA


Loading original patches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 58.01it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_BRCA


Loading normalized patches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 71.31it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUAD


Loading original patches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 66.27it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUAD


Loading normalized patches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 66.81it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUSC


Loading original patches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 67.32it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_LUSC


Loading normalized patches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 66.80it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_COAD


Loading original patches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 70.50it/s]


Total patches loaded: 750
Found 3 patch files in /lotterlab/users/vmishra/RSA_updated100/preprocessed_patches_COAD


Loading normalized patches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 72.44it/s]


Total patches loaded: 750


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:20<00:00,  1.68s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:17<00:00,  1.49s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:17<00:00,  1.48s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:17<00:00,  1.48s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████