In [None]:
import os
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
from transformers import SwinModel, AutoImageProcessor
import cv2
import matplotlib.pyplot as plt


In [None]:
def generate_label_maps(image_folder):
    pathology_map = {}
    calc_type_map = {}
    patient_id_map = {}
    binary_pathology_map = {}

    calc_type_set = set()
    pathology_set = set()
    callback_labels = {"MALIGNANT", "BENIGN"} 

    filenames = sorted(os.listdir(image_folder))
    for fname in filenames:
        name = fname.replace(".png", "").replace(".jpg", "").replace(".jpeg", "").strip().upper()
        parts = name.split("_")
        if len(parts) < 4:
            continue

        patient_id = f"{parts[0]}_{parts[1]}"
        pathology = parts[-1].strip().replace("-", "_").replace(" ", "_").upper()
        calc_type = "_".join(parts[3:-1]).strip().replace("-", "_").replace(" ", "_").upper()

        calc_type = "_".join([p for p in calc_type.split("_") if p])

        pathology_map[fname] = pathology
        calc_type_map[fname] = calc_type
        patient_id_map[fname] = patient_id

        binary_pathology_map[fname] = 1 if pathology in callback_labels else 0

        calc_type_set.add(calc_type)
        pathology_set.add(pathology)

    # Encode
    pathology_list = sorted(list(pathology_set))

    pathology_encoder = {name: i for i, name in enumerate(pathology_list)}

    patient_ids = [patient_id_map[fname] for fname in filenames if fname in patient_id_map]
    pathology_labels = [binary_pathology_map[fname] for fname in filenames if fname in binary_pathology_map]

    return (
        np.array(pathology_labels),
        np.array(patient_ids),
        pathology_encoder
    )

In [None]:
def extract_features_swin(folder_path, output_path, label_map_patho=None, label_map_calc=None):
    # Load pretrained Swin-Tiny
    model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
    extractor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

    features, patho_labels, calc_labels, filenames = [], [], [], []

    for fname in tqdm(sorted(os.listdir(folder_path))):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
            continue

        img_path = os.path.join(folder_path, fname)
        img = Image.open(img_path).convert("RGB") 

        inputs = extractor(images=img, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            feat = outputs.pooler_output.squeeze().numpy()

        features.append(feat)
        filenames.append(fname)

        if label_map_patho:
            patho_labels.append(label_map_patho.get(fname, -1))
        if label_map_calc:
            calc_labels.append(label_map_calc.get(fname, -1))

    features = np.array(features)
    np.save(output_path, features)
    print(f"Saved features to {output_path}")

    if label_map_patho and label_map_calc:
        return features, np.array(patho_labels), np.array(calc_labels), filenames
    else:
        return features, filenames

In [None]:
folder_path = "raw/cropped_images_all"
labels_callback_out = "features/labels_callback_binary.npy"
labels_patient_out = "features/labels_patient_id.npy"
features_out = "features/features_swin.npy"


y_binary_patho, y_calc, patient_ids, patho_encoder = generate_label_maps(folder_path)

label_map_callback = {fname: label for fname, label in zip(sorted(os.listdir(folder_path)), y_binary_patho)}

# ==== Feature Extraction With Swin ====
features, callback_labels, calc_labels, filenames = extract_features_swin(
    folder_path=folder_path,
    output_path=features_out,
    label_map_patho=label_map_callback,
)

# ==== Save ====
np.save(labels_callback_out, callback_labels)
np.save(labels_patient_out, patient_ids)
