In [43]:
%matplotlib inline
!pip install yacs timm einops opencv-python

Collecting yacs
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Installing collected packages: yacs
Successfully installed yacs-0.1.8


In [62]:
# ==========================================================
# üîß STEP 1 ‚Äì SETUP
# ==========================================================
from google.colab import drive
drive.mount('/content/drive')

!pip install -q yacs timm einops opencv-python Pillow ipywidgets transformers scikit-learn

# Clone repo if not yet
!git clone https://github.com/jlianglab/Ark.git || echo "Repo already exists"
%cd /content/Ark

# ==========================================================
# üîß STEP 2 ‚Äì LOAD MODEL (SimMIM / Swin backbone)
# ==========================================================
import sys, os, torch, numpy as np
sys.path.append('/content/Ark/Ark_Plus/Finetuning/simmim')
from config import get_config
from models.build import build_model

class Args:
    def __init__(self):
        self.cfg = "/content/Ark/Ark_Plus/Finetuning/simmim/configs/simmim_finetune__swin_base__img224_window7__800ep.yaml"
        self.opts = None

args = Args()
config = get_config(args)
print("‚úÖ Config loaded")

model = build_model(config, is_pretrain=False)
model.eval()
print("‚úÖ Ark‚Å∫ visual encoder (Swin backbone) built")

# ==========================================================
# üíæ STEP 3 ‚Äì LOAD CHECKPOINT
# ==========================================================
torch.serialization.add_safe_globals([np.core.multiarray.scalar])
ckpt_path = "/content/drive/MyDrive/Ark6_swinLarge768_ep50.pth.tar"  # adjust if needed

ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
if 'state_dict' in ckpt:
    model.load_state_dict(ckpt['state_dict'], strict=False)
else:
    model.load_state_dict(ckpt, strict=False)

print("‚úÖ Ark‚Å∫ Nature weights loaded successfully!")

# ==========================================================
# üß† STEP 4 ‚Äì FEATURE EXTRACTION (NO MASK)
# ==========================================================
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

@torch.no_grad()
def encode_image(img_path):
    """Extract Ark‚Å∫ visual embedding without needing mask."""
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0)

    feat = None
    if hasattr(model, "backbone") and hasattr(model.backbone, "forward_features"):
        feat = model.backbone.forward_features(x)
    elif hasattr(model, "encoder") and hasattr(model.encoder, "forward_features"):
        feat = model.encoder.forward_features(x)
    elif hasattr(model, "forward_features"):
        feat = model.forward_features(x)
    else:
        # fallback: zero mask if model(x, mask)
        patch = 4
        L = (224 // patch) * (224 // patch)
        mask = torch.zeros((1, L), dtype=torch.bool)
        feat = model(x, mask)

    if feat.ndim == 4:
        feat = feat.mean(dim=[2,3])
    return F.normalize(feat.flatten(), dim=0)

# ==========================================================
# ‚úçÔ∏è TEXT ENCODER (CLIP)
# ==========================================================
from transformers import CLIPTokenizer, CLIPTextModel
_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
_text_model.eval()

@torch.no_grad()
def encode_texts(texts):
    tokens = _tokenizer(texts, padding=True, return_tensors="pt")
    text_feats = _text_model(**tokens).pooler_output
    return F.normalize(text_feats, dim=1)

# ==========================================================
# üîç ZERO-SHOT COMPARISON FUNCTION
# ==========================================================
@torch.no_grad()
def zero_shot_inference(image_path, prompts):
    img_feat = encode_image(image_path)       # (D_img,)
    text_feats = encode_texts(prompts)        # (K, D_text)

    # üîß match dimensions (Ark‚Å∫=1024, CLIP=512)
    D = min(text_feats.shape[1], img_feat.shape[0])
    text_feats = text_feats[:, :D]
    img_feat = img_feat[:D]

    sims = torch.matmul(text_feats, img_feat.unsqueeze(1)).squeeze(1)
    return sims

# ==========================================================
# ü©ª STEP 5 ‚Äì INTERACTIVE CLASSROOM DEMO
# ==========================================================
from ipywidgets import Dropdown, Button, Output, VBox, HBox, Text
from IPython.display import display
import matplotlib.pyplot as plt

img_dir = "/content/drive/MyDrive/X_Rays"
images = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(('.png','.jpg','.jpeg'))])

dropdown = Dropdown(options=images, description="Select X-ray:", layout={'width':'400px'})
prompt1 = Text(value="pneumonia present", description="Prompt 1:", layout={'width':'400px'})
prompt2 = Text(value="normal chest X-ray", description="Prompt 2:", layout={'width':'400px'})
run_btn = Button(description="Run Zero-Shot", button_style='success')
out = Output()

def run_inference(b):
    out.clear_output()
    with out:
        img_path = os.path.join(img_dir, dropdown.value)
        display(Image.open(img_path).convert("RGB"))
        texts = [prompt1.value.strip(), prompt2.value.strip()]
        sims = zero_shot_inference(img_path, texts)

        print(f"üß† Ark‚Å∫ Zero-Shot inference on: {dropdown.value}")
        print("\nüîç Prompts:", texts)
        print("üìä Cosine similarities:", [float(s) for s in sims])

        pred_idx = int(torch.argmax(sims))
        pred_text = texts[pred_idx]
        conf = float(sims[pred_idx])
        print(f"\n‚úÖ Prediction: '{pred_text}' (confidence ‚âà {conf:.3f})")

        # bar chart
        plt.figure(figsize=(5,2.5))
        colors = ['tab:green' if i==pred_idx else 'tab:gray' for i in range(len(texts))]
        plt.barh([f"‚Äú{t}‚Äù" for t in texts], [float(s) for s in sims], color=colors)
        plt.xlabel("Cosine similarity (image ‚Üî text)")
        plt.title("Ark‚Å∫ Zero-Shot Text‚ÄìImage Alignment")
        plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.show()

run_btn.on_click(run_inference)
display(VBox([dropdown, HBox([prompt1, prompt2]), run_btn, out]))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
fatal: destination path 'Ark' already exists and is not an empty directory.
Repo already exists
/content/Ark
=> merge config from /content/Ark/Ark_Plus/Finetuning/simmim/configs/simmim_finetune__swin_base__img224_window7__800ep.yaml
‚úÖ Config loaded
‚úÖ Ark‚Å∫ visual encoder (Swin backbone) built
‚úÖ Ark‚Å∫ Nature weights loaded successfully!


VBox(children=(Dropdown(description='Select X-ray:', layout=Layout(width='400px'), options=('X_Ray_001.png', '‚Ä¶