In [1]:
import os
from PIL import Image
import re
import torch
import torch.nn as nn
import vision_former as vits
from vision_former import DINOHead
from torchvision import  transforms
import pandas as pd
from tqdm import tqdm

In [2]:
# prepare the model and the learned prototype space!!! 
device = torch.device("cuda")
model = vits.__dict__['vit_small'](patch_size=16, num_classes=0)
state_dict = torch.load("checkpoint_student_teacher.pth", map_location="cpu") 
#https://huggingface.co/shenxiaochen/SongCi/blob/main/checkpoint_student_teacher.pth 
head = DINOHead(384,
        65536,
        use_bn=True,
        norm_last_layer=False,
        predictor=True,)

class MultiCropWrapper(nn.Module):
    def __init__(self, backbone, head, in_dim=384, hidden_dim=2048, nlayers=2, output_dim=256, nmb_prototypes=1024, teacher=False):
        super(MultiCropWrapper, self).__init__()
        backbone.fc, backbone.head = nn.Identity(), nn.Identity()
        self.teacher = teacher
        self.backbone = backbone
        self.head = head
        self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False)
        layers = [nn.Linear(in_dim, hidden_dim)]
        layers.append(nn.BatchNorm1d(hidden_dim))
        layers.append(nn.GELU())
        for _ in range(nlayers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.projection = nn.Sequential(*layers)

    def forward(self, x):
        x = self.backbone(x)
        x = self.projection(x)
        x = nn.functional.normalize(x, dim=1, p=2)
        x = self.prototypes(x)
        return x

student = MultiCropWrapper(model,head)
state_dict_new = state_dict["student"]
state_dict_new = {k.replace("module.", ""): v for k, v in state_dict_new.items()}
student.load_state_dict(state_dict_new,strict=True)

for p in student.parameters():
    p.requires_grad = False


normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
trans = transforms.Compose([transforms.Resize((224,224)),normalize])

student = student.to(device)
student.eval()
with torch.no_grad():
    w = student.prototypes.weight.data.clone()
    w = nn.functional.normalize(w, dim=1, p=2)
    student.prototypes.weight.copy_(w)

In [3]:
# Retrieve each patch and determine its corresponding prototype position.
data = pd.DataFrame({"name": [], "x_axis": [], "y_axis": [], "pro_index": [], "sim_value": []})
data_dir = "/home/ssddata1/patch_all/patch2/a21-1884_brain0/" # an example!
for patch in os.listdir(data_dir):

    ll = re.split("-|_", patch)
    pil_image = Image.open(data_dir+patch)
    pil_image = trans(pil_image)
    img = torch.unsqueeze(pil_image, 0)

    with torch.no_grad():
        ind = student(img.to(device))
        v, pro_ind = ind.max(1)


    data.loc[len(data.index)] = ["a21-1884_brain0", int(ll[-3]), int(ll[-2]), pro_ind.item(), v.item()]
    #break


In [4]:
data

Unnamed: 0,name,x_axis,y_axis,pro_index,sim_value
0,a21-1884_brain0,55,39,786,0.993029
1,a21-1884_brain0,44,81,430,0.978069
2,a21-1884_brain0,71,30,306,0.994450
3,a21-1884_brain0,19,39,766,0.994393
4,a21-1884_brain0,42,28,960,0.976979
...,...,...,...,...,...
7077,a21-1884_brain0,16,52,242,0.993580
7078,a21-1884_brain0,54,5,535,0.994139
7079,a21-1884_brain0,86,31,786,0.990055
7080,a21-1884_brain0,34,5,882,0.981500


In [5]:
data["pro_index"].value_counts()

pro_index
830    1259
555     567
786     504
535     497
857     357
       ... 
401       1
768       1
433       1
976       1
556       1
Name: count, Length: 80, dtype: int64

In [6]:
len(data["pro_index"].value_counts())

80