In [1]:
import torch
import json
import torch.nn as nn
import numpy as np
from math import floor
from tqdm import tqdm

from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, DefaultDataCollator
from datasets import load_dataset, Dataset

In [2]:
device = "cuda" if torch.cuda.is_available else "cpu"
class GlowViT(ViTForImageClassification):
    def help():
        print(ViTForImageClassification.__doc__)

In [3]:
yalu_ds_list = [("SeaSponge/wildme10_classify", "glow-vit"),
                ("yin30lei/wildlife_very_dark", "glow-vit-dark"),
                ("yin30lei/wildlife_well_illuminated", "glow-vit-illuminate"),
                ("yin30lei/wildlife_mixed", "glow-vit-mix")]
yalu_ds = yalu_ds_list[1][0]
yalu_model_name = yalu_ds_list[1][1]
wildlife_test_ds = load_dataset(yalu_ds, cache_dir=Path.cwd() / "yalu_dataset", split="test" , num_proc=2)

In [4]:
# 11 labels for wildme10_classify
num_labels = 11
print(wildlife_test_ds[0])

{'file_name': '40000625.jpg', 'image_id': 1009, 'width': 640, 'height': 426, 'image': <PIL.Image.Image image mode=RGB size=640x426 at 0x1BCC2E17CA0>, 'labels': 'hare'}


In [5]:
label2id = {"lion": 0, "raccoon": 1, "tiger": 2, "wolf": 3,
            "bear": 4, "hare": 5, "fox": 6, "deer": 7, 
            "leopard" : 8, "hyena": 9, "antelope": 10}


id2label = {v:k for k, v in label2id.items()}

In [6]:
yalu_checkpoint = f"SeaSponge/{yalu_model_name}"
model = GlowViT.from_pretrained(yalu_checkpoint,
                cache_dir=yalu_model_name,
                label2id=label2id,
                id2label=id2label,
                num_labels=num_labels,
                attn_implementation="sdpa") # no flash attention yet for ViT model

image_processor = ViTImageProcessor.from_pretrained(yalu_checkpoint)

config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/325 [00:00<?, ?B/s]

In [7]:
def transform_ds(examples):
    images, labels = [], []
    for image, label in zip(examples["image"], examples["labels"]):
        pix_val = image_processor(images=image.convert("RGB"), return_tensors="pt")["pixel_values"].squeeze(0)
        pix_val.to(device)
        #! supposed to be a number here
        label = label2id[label]
        images.append(pix_val)
        labels.append(label)

    return {"pixel_values": images, "labels": labels}

In [8]:
test_ds = wildlife_test_ds.with_transform(transform_ds)

In [9]:
print(test_ds[0])
# data_collator = DefaultDataCollator()

{'pixel_values': tensor([[[-0.9608, -0.9608, -0.9608,  ..., -0.9608, -0.9529, -0.9608],
         [-0.9608, -0.9608, -0.9608,  ..., -0.9686, -0.9608, -0.9686],
         [-0.9608, -0.9608, -0.9608,  ..., -0.9686, -0.9608, -0.9686],
         ...,
         [-0.9608, -0.9529, -0.9451,  ..., -0.9216, -0.9216, -0.9373],
         [-0.9608, -0.9529, -0.9451,  ..., -0.9216, -0.9137, -0.9373],
         [-0.9529, -0.9529, -0.9529,  ..., -0.9137, -0.9059, -0.9216]],

        [[-0.9451, -0.9451, -0.9451,  ..., -0.9529, -0.9451, -0.9529],
         [-0.9451, -0.9451, -0.9451,  ..., -0.9608, -0.9529, -0.9608],
         [-0.9451, -0.9451, -0.9451,  ..., -0.9608, -0.9529, -0.9608],
         ...,
         [-0.9451, -0.9373, -0.9294,  ..., -0.9059, -0.9059, -0.9216],
         [-0.9451, -0.9373, -0.9294,  ..., -0.9059, -0.8980, -0.9216],
         [-0.9373, -0.9373, -0.9373,  ..., -0.8980, -0.8902, -0.9059]],

        [[-0.9529, -0.9529, -0.9529,  ..., -0.9451, -0.9294, -0.9373],
         [-0.9529, -0.9529, 

In [None]:
# Define the testing arguments

testing_args = TrainingArguments(
    output_dir=yalu_model_name,
    per_device_eval_batch_size=16,
    logging_steps=30,
    save_total_limit=1,
    remove_unused_columns=False,
    push_to_hub=True,
)

# Define the trainer

trainer = Trainer(
    model=model,
    args=testing_args,
    # data_collator=data_collator,
    eval_dataset=test_ds,
    tokenizer=image_processor,
)

  trainer = Trainer(


In [11]:
metrics = trainer.evaluate(test_ds)
trainer.log_metrics(f"{yalu_model_name} | test...", metrics)

  context_layer = torch.nn.functional.scaled_dot_product_attention(


  0%|          | 0/44 [00:00<?, ?it/s]

***** glow-vit-dark | test... metrics *****
  eval_loss                   =     0.5614
  eval_model_preparation_time =      0.002
  eval_runtime                = 0:00:09.52
  eval_samples_per_second     =     73.675
  eval_steps_per_second       =      4.618
