In [None]:
"""
Google Colab Script: Full Dataset Multimodal Encoder
Processes ALL samples from the dataset (train/val/test)
Using Google Drive folder: MyDrive/processed data
"""


# SETUP AND INSTALLATION

print("="*60)
print("GOOGLE COLAB - MULTIMODAL ENCODER (FULL DATASET)")
print("="*60)
print()

print("Installing required packages...")
import sys
!{sys.executable} -m pip install -q transformers torch pillow pandas tqdm
print("Packages installed\n")

# Imports
import torch
from transformers import ViTModel, ViTImageProcessor, AutoTokenizer, AutoModel
from PIL import Image
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import os
import warnings
warnings.filterwarnings('ignore')

print("Libraries imported\n")


# LOAD DATA FROM GOOGLE DRIVE


print("="*60)
print("DATA LOADING FROM GOOGLE DRIVE")
print("="*60)
print()

USE_GOOGLE_DRIVE = True

if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    DATA_PATH = Path('/content/drive/MyDrive/processed data')
    print(f"Using data from: {DATA_PATH}")
else:
    from google.colab import files
    print("Upload your 'processed_data.zip' file...")
    uploaded = files.upload()

    import zipfile
    for filename in uploaded.keys():
        print(f"Extracting {filename}...")
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('/content/')

    DATA_PATH = Path('/content/processed data')
    print(f"Data extracted to: {DATA_PATH}")

print()

# Verify data structure
IMAGES_DIR = DATA_PATH / 'images'
CSV_PATH = DATA_PATH / 'processed_data.csv'

assert CSV_PATH.exists(), f"CSV not found at {CSV_PATH}"
assert IMAGES_DIR.exists(), f"Images folder not found at {IMAGES_DIR}"

print("Data structure verified")
print(f"CSV: {CSV_PATH}")
print(f"Images: {IMAGES_DIR}")
print()


# VISION ENCODER (ViT-Base)


print("="*60)
print("VISION ENCODER - PROCESSING FULL DATASET")
print("="*60)
print()

VIT_MODEL_NAME = 'google/vit-base-patch16-224'
BATCH_SIZE = 32
OUTPUT_VIT = 'vit_embeddings_full.pt'

def load_vit_model(device):
    print(f"Loading ViT-Base: {VIT_MODEL_NAME}")
    model = ViTModel.from_pretrained(VIT_MODEL_NAME)
    processor = ViTImageProcessor.from_pretrained(VIT_MODEL_NAME)
    model = model.to(device)
    model.eval()
    print(f"Model loaded on {device}\n")
    return model, processor

def select_all_samples(csv_path, images_dir):
    print("Loading samples from CSV...")
    df = pd.read_csv(csv_path)
    print(f"Total rows in CSV: {len(df)}\n")

    selected = {}
    total_selected = 0

    for split in ['train', 'val', 'test']:
        split_df = df[df['split'] == split].copy()
        print(f"{split.upper()} split")
        print(f"Available in CSV: {len(split_df)}")

        valid_rows = []
        for idx, row in tqdm(
            split_df.iterrows(),
            desc=f"Verifying {split}",
            total=len(split_df),
            leave=False
        ):
            image_path = images_dir / split / row['filename']
            if image_path.exists():
                valid_rows.append(row)

        valid_df = pd.DataFrame(valid_rows)
        print(f"Valid images: {len(valid_df)}\n")

        selected[split] = valid_df
        total_selected += len(valid_df)

    print(f"Total valid samples: {total_selected}")
    print("="*60)
    print()
    return selected

def extract_vit_embeddings(model, processor, samples_df, split, images_dir, device, batch_size):
    filenames = samples_df['filename'].tolist()
    all_embeddings = []

    for i in tqdm(range(0, len(filenames), batch_size), desc=f"{split}"):
        batch_files = filenames[i:i + batch_size]
        batch_images = []

        for fname in batch_files:
            img_path = images_dir / split / fname
            img = Image.open(img_path).convert('RGB')
            batch_images.append(img)

        inputs = processor(images=batch_images, return_tensors="pt")
        pixel_values = inputs["pixel_values"].to(device)

        with torch.no_grad():
            outputs = model(pixel_values=pixel_values)
            embeddings = outputs.last_hidden_state

        all_embeddings.append(embeddings.cpu())

    all_embeddings = torch.cat(all_embeddings, dim=0)
    return {"embeddings": all_embeddings, "filenames": filenames}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

vit_model, vit_processor = load_vit_model(device)
selected_samples = select_all_samples(CSV_PATH, IMAGES_DIR)

vit_results = {}
vit_metadata = {
    "model": VIT_MODEL_NAME,
    "embedding_shape": [197, 768],
    "batch_size": BATCH_SIZE,
    "device": str(device),
    "splits": {}
}

for split in ['train', 'val', 'test']:
    print(f"Processing {split.upper()}...")
    samples_df = selected_samples[split]
    result = extract_vit_embeddings(
        vit_model,
        vit_processor,
        samples_df,
        split,
        IMAGES_DIR,
        device,
        BATCH_SIZE
    )
    vit_results[split] = result["embeddings"]
    vit_metadata["splits"][split] = {
        "num_samples": len(samples_df),
        "embedding_shape": list(result["embeddings"].shape),
        "filenames": result["filenames"]
    }
    print(f"{split}: {result['embeddings'].shape}")

print(f"\nSaving vision embeddings to {OUTPUT_VIT}...")
torch.save({
    "train": vit_results["train"],
    "val": vit_results["val"],
    "test": vit_results["test"],
    "metadata": vit_metadata
}, OUTPUT_VIT)

print("Vision embeddings saved\n")

del vit_model, vit_processor
torch.cuda.empty_cache()


# TEXT ENCODER (Bio_ClinicalBERT)


print("="*60)
print("TEXT ENCODER - PROCESSING FULL DATASET")
print("="*60)
print()

TEXT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
MAX_LENGTH = 512
OUTPUT_TEXT = "text_embeddings_full.pt"

def load_text_model(device):
    print(f"Loading Bio_ClinicalBERT: {TEXT_MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
    model = AutoModel.from_pretrained(TEXT_MODEL_NAME)
    model = model.to(device)
    model.eval()
    print(f"Model loaded on {device}\n")
    return model, tokenizer

def extract_text_embeddings(model, tokenizer, samples_df, device, batch_size, max_length):
    texts = samples_df["full_report"].tolist()
    filenames = samples_df["filename"].tolist()
    all_embeddings = []

    for i in tqdm(range(0, len(texts), batch_size), desc="Processing"):
        batch_texts = texts[i:i + batch_size]

        inputs = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )

        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            cls_emb = outputs.last_hidden_state[:, 0, :]

        all_embeddings.append(cls_emb.cpu())

    all_embeddings = torch.cat(all_embeddings, dim=0)
    return {"embeddings": all_embeddings, "filenames": filenames, "texts": texts}

text_model, text_tokenizer = load_text_model(device)

text_results = {}
text_metadata = {
    "model": TEXT_MODEL_NAME,
    "embedding_shape": [768],
    "max_length": MAX_LENGTH,
    "batch_size": BATCH_SIZE,
    "device": str(device),
    "splits": {}
}

for split in ["train", "val", "test"]:
    print(f"Processing {split.upper()}...")
    samples_df = selected_samples[split]

    result = extract_text_embeddings(
        text_model,
        text_tokenizer,
        samples_df,
        device,
        BATCH_SIZE,
        MAX_LENGTH
    )

    split_data = []
    for i in range(len(result["filenames"])):
        split_data.append({
            "filename": result["filenames"][i],
            "cls_emb": result["embeddings"][i],
            "text": result["texts"][i]
        })

    text_results[split] = split_data
    text_metadata["splits"][split] = {
        "num_samples": len(samples_df),
        "embedding_shape": [768]
    }

print(f"\nSaving text embeddings to {OUTPUT_TEXT}...")
torch.save({
    "train": text_results["train"],
    "val": text_results["val"],
    "test": text_results["test"],
    "metadata": text_metadata
}, OUTPUT_TEXT)

print("Text embeddings saved\n")

del text_model, text_tokenizer
torch.cuda.empty_cache()


# BUILD MULTIMODAL DATASET


print("="*60)
print("BUILDING MULTIMODAL DATASET")
print("="*60)
print()

OUTPUT_MULTIMODAL = "multimodal_dataset_full.pt"

def build_multimodal_full(vit_path, text_path, csv_df):
    print("Loading embeddings...")
    vit_data = torch.load(vit_path)
    text_data = torch.load(text_path)
    print("Embeddings loaded\n")

    multimodal_data = {}
    metadata = {
        "creation_date": pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S"),
        "vision_model": vit_data["metadata"]["model"],
        "text_model": text_data["metadata"]["model"],
        "image_embedding_shape": [197, 768],
        "text_embedding_shape": [768],
        "splits": {}
    }

    for split in ["train", "val", "test"]:
        print(f"Processing {split.upper()}...")
        vit_emb = vit_data[split]
        vit_files = vit_data["metadata"]["splits"][split]["filenames"]
        text_samples = text_data[split]

        text_dict = {sample["filename"]: sample for sample in text_samples}

        combined = []
        for i, fname in enumerate(tqdm(vit_files, desc=f"{split}", leave=False)):
            if fname in text_dict:
                t = text_dict[fname]
                row = csv_df[csv_df["filename"] == fname].iloc[0]

                combined.append({
                    "filename": fname,
                    "image_emb": vit_emb[i],
                    "text_emb": t["cls_emb"],
                    "impression": row["impression_final"],
                    "full_report": t["text"],
                    "mesh": row["MeSH"],
                    "problems": row["Problems"],
                    "projection": row["projection"]
                })

        multimodal_data[split] = combined
        metadata["splits"][split] = {
            "num_samples": len(combined),
            "image_embedding_shape": [197, 768],
            "text_embedding_shape": [768]
        }

    return multimodal_data, metadata

csv_df = pd.read_csv(CSV_PATH)

multimodal_data, multimodal_metadata = build_multimodal_full(
    OUTPUT_VIT, OUTPUT_TEXT, csv_df
)

print(f"\nSaving multimodal dataset to {OUTPUT_MULTIMODAL}...")
torch.save({
    "train": multimodal_data["train"],
    "val": multimodal_data["val"],
    "test": multimodal_data["test"],
    "metadata": multimodal_metadata
}, OUTPUT_MULTIMODAL)

print("Multimodal dataset saved\n")

# SUMMARY AND SAVE TO DRIVE


print("="*60)
print("SUMMARY")
print("="*60)

total = 0
for split in ["train", "val", "test"]:
    n = len(multimodal_data[split])
    total += n
    print(f"{split.upper()}: {n} samples")

print(f"\nTotal samples: {total}\n")

print("Saving output files into Google Drive...")

SAVE_DIR = DATA_PATH / "embeddings_full"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

import shutil
shutil.copy(OUTPUT_VIT, SAVE_DIR / OUTPUT_VIT)
shutil.copy(OUTPUT_TEXT, SAVE_DIR / OUTPUT_TEXT)
shutil.copy(OUTPUT_MULTIMODAL, SAVE_DIR / OUTPUT_MULTIMODAL)

print(f"Files saved to: {SAVE_DIR}")

print("="*60)
print("PROCESS COMPLETE")
print("="*60)


GOOGLE COLAB - MULTIMODAL ENCODER (FULL DATASET)

Installing required packages...
Packages installed

Libraries imported

DATA LOADING FROM GOOGLE DRIVE

Mounted at /content/drive
Using data from: /content/drive/MyDrive/processed data

Data structure verified
CSV: /content/drive/MyDrive/processed data/processed_data.csv
Images: /content/drive/MyDrive/processed data/images

VISION ENCODER - PROCESSING FULL DATASET

Using device: cuda

Loading ViT-Base: google/vit-base-patch16-224


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Model loaded on cuda

Loading samples from CSV...
Total rows in CSV: 7466

TRAIN split
Available in CSV: 5223




Valid images: 5223

VAL split
Available in CSV: 1122




Valid images: 1122

TEST split
Available in CSV: 1121




Valid images: 1121

Total valid samples: 7466

Processing TRAIN...


train: 100%|██████████| 164/164 [33:46<00:00, 12.35s/it]


train: torch.Size([5223, 197, 768])
Processing VAL...


val: 100%|██████████| 36/36 [07:19<00:00, 12.21s/it]


val: torch.Size([1122, 197, 768])
Processing TEST...


test: 100%|██████████| 36/36 [06:49<00:00, 11.39s/it]


test: torch.Size([1121, 197, 768])

Saving vision embeddings to vit_embeddings_full.pt...
Vision embeddings saved

TEXT ENCODER - PROCESSING FULL DATASET

Loading Bio_ClinicalBERT: emilyalsentzer/Bio_ClinicalBERT


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]