In [3]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import pytorch_lightning as pl
import open_clip
from torchvision import transforms
from pytorch_lightning.callbacks import ModelCheckpoint
from tqdm import tqdm


In [4]:
class BioCLIPDataset(Dataset):
    def __init__(self, csv_path, tokenizer, transform):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path = self.data.iloc[idx]['image_path']
        text = self.data.iloc[idx]['text']

        image = self.transform(Image.open(image_path).convert("RGB"))
        token = self.tokenizer([text])[0]

        return image, token


In [5]:
class BioCLIPLightningModule(pl.LightningModule):
    def __init__(self, model_name='hf-hub:imageomics/bioclip', lr=5e-5):
        super().__init__()
        self.model, self.preprocess_train, _ = open_clip.create_model_and_transforms(model_name)
        self.tokenizer = open_clip.get_tokenizer(model_name)
        self.lr = lr

    def forward(self, images, texts):
        image_features = self.model.encode_image(images)
        text_features = self.model.encode_text(texts)
        return image_features, text_features

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, stage="train")

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, stage="val")

    def _shared_step(self, batch, stage):
        images, texts = batch
        image_features, text_features = self(images, texts)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logits_per_image = image_features @ text_features.T
        logits_per_text = text_features @ image_features.T
        labels = torch.arange(len(images), device=self.device)

        loss_i = nn.CrossEntropyLoss()(logits_per_image, labels)
        loss_t = nn.CrossEntropyLoss()(logits_per_text, labels)
        loss = (loss_i + loss_t) / 2

        self.log(f"{stage}_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)


In [5]:
# Paths
train_csv = "/home/dzimmerman2021/Documents/fathomnet/train/bioclip_train.csv"
val_csv = "/home/dzimmerman2021/Documents/fathomnet/train/bioclip_val.csv"

# Model
model_module = BioCLIPLightningModule()

# Datasets and Dataloaders
train_dataset = BioCLIPDataset(train_csv, model_module.tokenizer, model_module.preprocess_train)
val_dataset = BioCLIPDataset(val_csv, model_module.tokenizer, model_module.preprocess_train)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Checkpointing
checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val_loss",
    mode="min",
    filename="bioclip-{epoch:02d}-{val_loss:.2f}",
    save_weights_only=True
)

# Trainer
trainer = pl.Trainer(
    max_epochs=10,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    log_every_n_steps=10,
    callbacks=[checkpoint_callback]
)

# Train
trainer.fit(model_module, train_loader, val_loader)


Trainer will use only 1 of 3 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=3)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/home/dzimmerman2021/miniconda3/envs/fathmonet_bioclip/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/dzimmerman2021/miniconda3/envs/fathmonet_biocl ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more d

Epoch 9: 100%|██████████| 593/593 [03:03<00:00,  3.23it/s, v_num=0]        

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 593/593 [03:03<00:00,  3.23it/s, v_num=0]


In [9]:
test_csv = "/home/dzimmerman2021/Documents/fathomnet/train/bioclip_test_internal.csv"
test_dataset = BioCLIPDataset(test_csv, model_module.tokenizer, model_module.preprocess_train)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Simple accuracy evaluation
def evaluate_accuracy(model_module, dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_module.model.to(device).eval()

    correct = total = 0

    with torch.no_grad():
        for imgs, texts in tqdm(dataloader, desc="Evaluating"):
            imgs = imgs.to(device)
            texts = texts.to(device)

            img_feat = model.encode_image(imgs)
            txt_feat = model.encode_text(texts)

            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)

            sims = img_feat @ txt_feat.T
            preds = sims.argmax(dim=1)
            labels = torch.arange(len(imgs), device=device)

            correct += (preds == labels).sum().item()
            total += len(imgs)

    acc = correct / total
    print(f"✅ Internal Test Accuracy: {acc:.4f}")


evaluate_accuracy(model_module, test_loader)


Evaluating:   0%|          | 0/75 [00:00<?, ?it/s]

Evaluating: 100%|██████████| 75/75 [00:08<00:00,  8.90it/s]

✅ Internal Test Accuracy: 0.5937





In [6]:
import os
import torch
from PIL import Image
import pandas as pd
from tqdm import tqdm
import open_clip

# Recreate the module and load weights
model_module = BioCLIPLightningModule.load_from_checkpoint(
    "/home/dzimmerman2021/Documents/fathomnet/lightning_logs/version_0/checkpoints/bioclip-epoch=08-val_loss=2.66.ckpt",
    model_name='hf-hub:imageomics/bioclip'
)
model_module.eval()

# Load BioCLIP model
model = model_module.model
_, preprocess_val, _ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')
model = model.to("cuda" if torch.cuda.is_available() else "cpu").eval()

# Load species prompts
species_names = [
        "Abyssocucumis abyssorum", "Acanthascinae", "Acanthoptilum", "Actinernus", "Actiniaria", "Actinopterygii", "Amphipoda", "Apostichopus leukothele", "Asbestopluma", "Asbestopluma monticola", "Asteroidea", "Benthocodon pedunculata", "Brisingida", "Caridea", "Ceriantharia", "Chionoecetes tanneri", "Chorilia longipes", "Corallimorphus pilatus", "Crinoidea", "Delectopecten", "Elpidia", "Farrea", "Florometra serratissima", "Funiculina", "Gastropoda", "Gersemia juliepackardae", "Heterocarpus", "Heterochone calyx", "Heteropolypus ritteri", "Hexactinellida", "Hippasteria", "Holothuroidea", "Hormathiidae", "Isidella tentaculum", "Isididae", "Isosicyonis", "Keratoisis", "Liponema brevicorne", "Lithodidae", "Mediaster aequalis", "Merluccius productus", "Metridium farcimen", "Microstomus pacificus", "Munidopsis", "Munnopsidae", "Mycale", "Octopus rubescens", "Ophiacanthidae", "Ophiuroidea", "Paelopatides confundens", "Pandalus amplus", "Pandalus platyceros", "Pannychia moseleyi", "Paragorgia", "Paragorgia arborea", "Paralomis multispina", "Parastenella", "Peniagone", "Pennatula phosphorea", "Porifera", "Psathyrometra fragilis", "Psolus squamatus", "Ptychogastria polaris", "Pyrosoma atlanticum", "Rathbunaster californicus", "Scleractinia", "Scotoplanes", "Scotoplanes globosa", "Sebastes", "Sebastes diploproa", "Sebastolobus", "Serpulidae", "Staurocalyptus", "Strongylocentrotus fragilis", "Terebellidae", "Tunicata", "Umbellula", "Vesicomyidae", "Zoantharia"
    ] 
species_prompts = [f"a photo of {s}" for s in species_names]
device = next(model.parameters()).device
text_tokens = tokenizer(species_prompts).to(device)
with torch.no_grad():
    text_features = model.encode_text(text_tokens)
    text_features /= text_features.norm(dim=-1, keepdim=True)

# Load test annotations
ann_df = pd.read_csv("test/annotations.csv")  # update path
image_dir = "test/rois"  # update if needed

# Run inference
results = []
for idx, row in tqdm(ann_df.iterrows(), total=len(ann_df), desc="BioCLIP Test Inference"):
    image_path = os.path.join(image_dir, row["path"])
    image = preprocess_val(Image.open(image_path).convert("RGB")).unsqueeze(0).to(next(model.parameters()).device)


    with torch.no_grad():
        image_feature = model.encode_image(image)
        image_feature /= image_feature.norm(dim=-1, keepdim=True)
        similarity = image_feature @ text_features.T
        pred_idx = similarity.argmax().item()
        pred_name = species_names[pred_idx]

    results.append((idx + 1, pred_name))  # annotation_id starts at 1

# Save results
submission = pd.DataFrame(results, columns=["annotation_id", "concept_name"])
submission.to_csv("bioclip_submission.csv", index=False)
print("Submission saved to bioclip_submission.csv")


  checkpoint = torch.load(checkpoint_path, map_location=map_location)
BioCLIP Test Inference: 100%|██████████| 788/788 [00:05<00:00, 141.93it/s]

Submission saved to bioclip_submission.csv



