In [None]:
# BIOSCAN

import torch
from torchvision.transforms import v2 as transforms
from bioscan_dataset import BIOSCAN5M
from bioscan_dataset.bioscan5m import RGB_MEAN, RGB_STDEV

image_transform = transforms.Compose(
    [
        transforms.CenterCrop(256),
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=RGB_MEAN, std=RGB_STDEV),
    ]
)
# Create a DNA transform, mapping from characters to integers and padding to a fixed length
charmap = {"P": 0, "A": 1, "C": 2, "G": 3, "T": 4, "N": 5}
dna_transform = lambda seq: torch.tensor(
    [charmap[char] for char in seq] + [0] * (660 - len(seq)), dtype=torch.long
)

ds_train = BIOSCAN5M(
    root=".",
    split="val",
    transform=image_transform,
    dna_transform=dna_transform,
    download=True,
)

In [None]:
# WIT

import itertools
import re
import numpy as np
import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_mutual_info_score, mean_absolute_error, mean_squared_error
from sklearn.neighbors import KNeighborsRegressor
from scipy.optimize import linear_sum_assignment
from nltk.translate.meteor_score import meteor_score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score, davies_bouldin_score
import matplotlib.pyplot as plt
from sacrebleu.metrics import CHRF
from k_means_constrained import KMeansConstrained

n_samples = 100000

# 1. Stream the WIT dataset (avoid full download)
ds_stream = load_dataset("wikimedia/wit_base", split="train", streaming=True)

# 2. Take first `n_samples` examples from the stream
subset = list(itertools.islice(ds_stream, n_samples))

def extract_english_caption(s: str) -> str:
    """
    Extracts the substring following 'English:' up to the next '<Lang>:' or end of string.
    Returns an empty string if no English part is found.
    """
    match = re.search(r'English:\s*(.*?)(?=\s*[A-Z][^:]*:|$)', s)
    return match.group(1).strip() if match else ""


# 3. Build DataFrame with raw embedding ("x"), caption ("y"), and raw features ("raw_f")
df = pd.DataFrame({
    "x": [item.get("embedding") for item in subset],
    "y": [item.get("caption_attribution_description") for item in subset],
    "raw_f": [item.get("wit_features") for item in subset]
})

# 4. Drop rows with missing x or y, and clean y
df = df.dropna(subset=["x", "y"]).reset_index(drop=True)
df["y"] = df["y"].apply(extract_english_caption)
df = df[df["y"].astype(bool)].reset_index(drop=True)

# 5. Extract only the English page_title into "z"
def extract_en_title(wf):
    langs  = wf.get("language", [])
    titles = wf.get("page_title", [])
    if "en" in langs:
        return titles[langs.index("en")]
    else:
        return None

df["z"] = df["raw_f"].apply(extract_en_title)
df = df.dropna(subset=["z"]).reset_index(drop=True)

# 6. Encode all captions with Sentence-BERT into "yv"
model = SentenceTransformer('all-MiniLM-L6-v2')  # 384-dim encoder
df["yv"] = model.encode(
    df["y"].tolist(),
    batch_size=32,
    show_progress_bar=True,
    convert_to_numpy=True
).tolist()

# 7. Encode English titles ("z") into "zw"
df["zv"] = model.encode(
    df["z"].tolist(),
    batch_size=32,
    show_progress_bar=True,
    convert_to_numpy=True
).tolist()

In [None]:
# Flickr30k

import io
import torch
from PIL import Image
from datasets import load_dataset
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from sentence_transformers import SentenceTransformer
from torchvision.models import (
    resnet50, ResNet50_Weights,
    vit_b_16, ViT_B_16_Weights,
    efficientnet_b0, EfficientNet_B0_Weights
)

# 1) Load only the split
ds = load_dataset("nlphuji/flickr30k", split="test")

# 2) Prepare models & transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#img_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device).eval()

img_model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1).to(device)

feature_extractor = torch.nn.Sequential(*list(img_model.children())[:-1])
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

sbert = SentenceTransformer("all-MiniLM-L6-v2", device=device)

# 3) Batch‐encode function, handling both PIL images and raw bytes
def encode_batch(batch):
    imgs = []
    for img in batch["image"]:
        if isinstance(img, bytes):
            img = Image.open(io.BytesIO(img))
        # if it’s already a PIL image, skip the BytesIO step
        imgs.append(img.convert("RGB"))
    tensor_batch = torch.stack([preprocess(img) for img in imgs]).to(device)
    with torch.no_grad():
        feats = feature_extractor(tensor_batch).view(len(imgs), -1)
    batch["image_emb"] = feats.cpu().numpy()

    # captions → list of 5×D arrays
    batch["caption_embs"] = [
        sbert.encode(caps, convert_to_numpy=True)
        for caps in batch["caption"]
    ]
    return batch

ds = ds.map(
    encode_batch,
    batched=True,
    batch_size=32,
    remove_columns=["image"],
)

ds.to_parquet("flickr30k_eff.parquet")

In [None]:
# Coco

import os
import random
import json
import pandas as pd
import torch
from PIL import Image
from pycocotools.coco import COCO
from sentence_transformers import SentenceTransformer
import torchvision.transforms as T
import torchvision.models as models

# Settings
dataType    = 'val2017'
imgDir      = 'val2017'  # adjust to your local image folder
capAnnFile  = f'annotations/captions_{dataType}.json'
instAnnFile = f'annotations/instances_{dataType}.json'
output_csv  = 'coco_2.csv'

# Load COCO APIs
coco_caps = COCO(capAnnFile)
coco_inst = COCO(instAnnFile)

# Collect categories with >=100 simple-image candidates
all_cats = coco_inst.loadCats(coco_inst.getCatIds())
random.seed(42)
random.shuffle(all_cats)

selected_cats = []
for cat in all_cats:
    if len(selected_cats) == 45:
        break
    cat_id   = cat['id']
    cat_name = cat['name']
    img_ids  = coco_inst.getImgIds(catIds=[cat_id])

    # filter simple-category images
    valid_img_ids = []
    for img_id in img_ids:
        inst_anns = coco_inst.loadAnns(coco_inst.getAnnIds(imgIds=[img_id]))
        cats_in_img = {ann['category_id'] for ann in inst_anns}
        if cats_in_img == {cat_id}:
            valid_img_ids.append(img_id)

    if len(valid_img_ids) >= 40:
        selected_cats.append({
            'id': cat_id,
            'name': cat_name,
            'img_ids': valid_img_ids
        })

# Prepare models & transforms
device    = 'cuda' if torch.cuda.is_available() else 'cpu'
resnet    = models.resnet50(pretrained=True)
resnet.fc = torch.nn.Identity()
resnet    = resnet.to(device).eval()
transform = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
sbert = SentenceTransformer('all-MiniLM-L6-v2', device=device)

# Build records
records = []
for cat in selected_cats:
    cat_name      = cat['name']
    valid_img_ids = cat['img_ids']

    # Collect all caption-rows for this category
    cat_rows = []
    for img_id in valid_img_ids:
        # encode image
        img_path = os.path.join(imgDir, f"{img_id:012d}.jpg")
        img = Image.open(img_path).convert('RGB')
        with torch.no_grad():
            x_t    = transform(img).unsqueeze(0).to(device)
            x_feat = resnet(x_t).squeeze(0).cpu().tolist()

        # explode by caption
        cap_anns = coco_caps.loadAnns(coco_caps.getAnnIds(imgIds=[img_id]))
        for ann in cap_anns:
            caption = ann['caption']
            yv = sbert.encode(caption).tolist()
            cat_rows.append({
                'cat': cat_name,
                'image_id': img_id,
                'x'  : json.dumps(x_feat),
                'y'  : caption,
                'yv' : json.dumps(yv),
            })

    # Prune/exactly 100 samples per category
    sampled_rows = random.sample(cat_rows, 200)
    records.extend(sampled_rows)

# Save to CSV
df = pd.DataFrame.from_records(records)
df.to_csv(output_csv, index=False)
print(f"Saved {len(df)} rows to {output_csv}")