In [None]:
from __future__ import annotations

import os 
import torch 
import numpy as np 
import pandas as pd

from math import ceil
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, root_mean_squared_error

from mmpfn.models.mmpfn_v2 import MMPFNClassifier
from mmpfn.models.dino_v2.models.vision_transformer import vit_base
from mmpfn.models.mmpfn_v2.constants import ModelInterfaceConfig
from mmpfn.models.mmpfn_v2.preprocessing import PreprocessorConfig
from mmpfn.scripts_finetune_mm.finetune_tabpfn_main import fine_tune_tabpfn



In [None]:
col_features = ["Age","Breed1","Breed2","Color1","Color2","Color3","Dewormed","Fee","FurLength","Gender","Health","MaturitySize","PhotoAmt","State","Sterilized","Type","Vaccinated","VideoAmt","Quantity",]
col_exclude = ["PetID", "RescureID", "Description", "Name"]
col_target = "AdoptionSpeed"
cat_features = ["Breed1","Breed2","Color1","Color2","Color3","Dewormed","FurLength","Gender","Health","MaturitySize","State","Sterilized","Type","Vaccinated",]
cat_features_index = [col_features.index(feature) for feature in cat_features]
train = pd.read_csv("datasets/petfinder-adoption-prediction/train/train.csv")
datasets_dir = "datasets/petfinder-adoption-prediction"

train["PetID"] = train["PetID"].astype(str)
train_images = [f for f in os.listdir(os.path.join(datasets_dir, "train_images")) if f.endswith(".jpg")]
train_images = [f for f in train_images if f.split("-")[0] in train["PetID"].values]
train_images_df = pd.DataFrame(
    {
        "PetID": [f.split("-")[0] for f in train_images],
        "ImageNumber": [f.split("-")[1].split(".")[0] for f in train_images],
    }
)
train_images_df = train_images_df[train_images_df["ImageNumber"] == "1"]
train = train.merge(train_images_df, on="PetID", how="left")
train = train[train["ImageNumber"].notna()]
train["ImagePath"] = train["PetID"] + "-1.jpg"

In [None]:
path_patch = 'adoption_patch.pt'
path_cls = 'adoption_cls.pt'

In [None]:
if os.path.exists(path_cls):
    # adoption_patch = torch.load(path_patch)
    # adoption_cls = torch.load(path_cls)
    pass
else:
    img_size = 14*24
    X_image = []
    i = 0
    for path in train['ImagePath']:
        full_path = os.path.join(datasets_dir, "train_images", path)
        if not os.path.exists(full_path):
            print(f"Image {full_path} does not exist, skipping.")
            break
        with Image.open(full_path) as img:
            if img.mode == "L":
                img = img.convert("RGB")
            image = np.array(img.resize((img_size, img_size), Image.BILINEAR))
        X_image.append(image)
    X_image = np.array(X_image)
    
    image_encoder = vit_base(
        patch_size=14, img_size=518, init_values=1.0, num_register_tokens=0, block_chunks=0
    )

    image_model_path = f"{Path().absolute()}/parameters/dinov2_vitb14_pretrain.pth"
    image_state_dict = torch.load(image_model_path)
    image_encoder.load_state_dict(image_state_dict)
    _ = image_encoder.cuda().eval()

    batch_size = 16
    adoption_patch, adoption_cls = [], []

    X_image_torch = torch.from_numpy(
    np.transpose(X_image, (0,3,1,2))
    ).float()

    X_image_torch /= 255.0

    with torch.no_grad():
        for i in range(0, X_image_torch.shape[0], batch_size):
            batch = X_image_torch[i:i+batch_size].to("cuda", non_blocking=True)
            feats = image_encoder.forward_features(batch)
            adoption_patch.append(feats['x_norm_patchtokens'].detach().cpu())
            adoption_cls.append(feats['x_norm_clstoken'].detach().cpu())

    # adoption_patch = [x.detach().cpu() for x in adoption_patch]
    # adoption_patch = torch.cat(adoption_patch, dim=0)
    # torch.save(adoption_patch.cpu(), path_patch)
    torch.save(adoption_cls, path_cls)
    
    torch.cuda.empty_cache()
    torch.save(adoption_patch, path_patch)
    
    # adoption_cls = [x.detach().cpu() for x in adoption_cls]
    # adoption_cls = torch.cat(adoption_cls, dim=0)
    # torch.save(adoption_cls.cpu(), path_cls)
    
    
    # print(adoption_patch.shape, adoption_cls.shape)
    # torch.save(adoption_cls.detach().cpu(), path_cls)
    # torch.cuda.empty_cache()

    # for i, part in enumerate(adoption_patch.split(1000)):  # split along dim=0
    #     torch.save(part.cpu(), f"adoption_patch_part{i}.pt")

    # # for i, part in enumerate(adoption_cls.split(1000)):  # split along dim=0
    # #     torch.save(part.cpu(), f"adoption_cls_part{i}.pt")
