In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image
import models_vit as models
from huggingface_hub import hf_hub_download
np.set_printoptions(threshold=np.inf)
np.random.seed(1)
torch.manual_seed(1)

In [19]:
def prepare_model(chkpt_dir, arch='vit_large_patch16'):
    
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    
    # build model
    if arch=='vit_large_patch16':
        model = models.__dict__[arch](
            img_size=224,
            num_classes=5,
            drop_path_rate=0,
            global_pool=True,
        )
        msg = model.load_state_dict(checkpoint['model'], strict=False)
    else:
        model = models.__dict__[arch](
            num_classes=5,
            drop_path_rate=0,
            args=None,
        )
        msg = model.load_state_dict(checkpoint['teacher'], strict=False)
    return model

def run_one_image(img, model, arch):
    
    x = torch.tensor(img)
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)
    
    x = x.to(device, non_blocking=True)
    latent = model.forward_features(x.float())
    
    if arch=='dinov2_large':
        latent = latent[:, 1:, :].mean(dim=1,keepdim=True)
        latent = nn.LayerNorm(latent.shape[-1], eps=1e-6).to(device)(latent)
    
    latent = torch.squeeze(latent)

    return latent


In [20]:
def get_feature(data_path,
                chkpt_dir,
                device,
                arch='vit_large_patch16'):
     #loading model
    model_ = prepare_model(chkpt_dir, arch)
    model_.to(device)

    img_list = os.listdir(data_path)
    
    name_list = []
    feature_list = []
    model_.eval()
    
    finished_num = 0
    for i in img_list:
        finished_num+=1
        if (finished_num%1000 == 0):
            print(str(finished_num)+"finished")
        
        img = Image.open(os.path.join(data_path, i))
        img = img.resize((224, 224))
        img = np.array(img) / 255.
        img[...,0] = (img[...,0] - img[...,0].mean())/img[...,0].std()
        img[...,1] = (img[...,1] - img[...,1].mean())/img[...,1].std()
        img[...,2] = (img[...,2] - img[...,2].mean())/img[...,2].std()
        assert img.shape == (224, 224, 3)
        
        latent_feature = run_one_image(img, model_,arch)
        
        name_list.append(i)
        feature_list.append(latent_feature.detach().cpu().numpy())
        
    return [name_list,feature_list]



In [None]:
chkpt_dir = hf_hub_download(repo_id="YukunZhou/RETFound_dinov2_meh", filename="RETFound_dinov2_meh.pth")
data_path = 'DATA_PATH'
device = torch.device('cuda')
arch='dinov2_large'

In [None]:
[name_list,feature]=get_feature(data_path,
                chkpt_dir,
                device,
                arch=arch)

In [23]:
#save the feature
df_feature = pd.DataFrame(feature)
df_imgname = pd.DataFrame(name_list)
df_visualization = pd.concat([df_imgname,df_feature], axis=1)
column_name_list = []

for i in range(1024):
    column_name_list.append("feature_{}".format(i))
df_visualization.columns = ["name"] + column_name_list
df_visualization.to_csv("Feature.csv",index=False)