##### prompt

In [None]:
import json
import os
import glob

os.makedirs('jsons/eval/ksdd2', exist_ok=True)
norm_prompts = "This is a non-defective image of a product surface used for visual inspection. Minor dirt or black stains are not considered defects this time.",
abnorm_prompts = "This is a defective image of a product surface used for visual inspection. Minor dirt or black stains are not considered defects this time.",
prompts = "This is an image of a product surface used for visual inspection. Minor dirt or black stains are not considered defects this time.",

# final_prompt = "\n<image>\nThen, does this image have any defects? If yes, please provide the bounding box coordinate of the region where the defect is located. If no, please say None."
final_prompt = " Then, does this image have any defects? If yes, please provide the bounding box coordinate of the region where the defect is located. If no, please say None."


normal_data = glob.glob('./KolektorSDD2/good/*.png')
abnormal_data = glob.glob('./KolektorSDD2/anomaly/*.png')
data = normal_data + abnormal_data
new_dict = {}
for d in data:
    # if d.split('/')[-2] == 'good':
    #     prompt = norm_prompts[0]
    # else:
    #     prompt = abnorm_prompts[0]
    prompt = prompts[0]
    prompt += final_prompt
    new_dict[d] = prompt
    # json.dump(new_dict, open('jsons/eval/ksdd2/prompt_each.json', 'w'), indent=4)
    json.dump(new_dict, open('jsons/eval/ksdd2/prompt_single.json', 'w'), indent=4)

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import glob
import json
import tqdm

encoder = models.resnet50(pretrained=True)
encoder = nn.Sequential(*list(encoder.children())[:-1])
encoder.eval()

normal_data = glob.glob('./KolektorSDD2/good/*.png')
abnormal_data = glob.glob('./KolektorSDD2/anomaly/*.png')
query_paths = normal_data + abnormal_data
new_dict = {}
category = None

def nearest_first(unlabeled_embeddings, labeled_embedding, n):
    unlabeled_embeddings = unlabeled_embeddings.to('cpu')
    labeled_embedding = labeled_embedding.unsqueeze(0).to('cpu')
    dist_ctr = torch.cdist(unlabeled_embeddings, labeled_embedding, p=2)
    min_dist = torch.min(dist_ctr, dim=1)[0]
    idxs = []
    selected_indices = set()
    for _ in range(n):
        idx = torch.argmin(min_dist)
        while idx.item() in selected_indices:
            min_dist[idx] = float('inf')
            idx = torch.argmin(min_dist)
        selected_indices.add(idx.item())
        idxs.append(idx.item())
        dist_new_ctr = torch.cdist(unlabeled_embeddings, unlabeled_embeddings[[idx], :])
        min_dist = torch.minimum(min_dist, dist_new_ctr[:, 0])
    return idxs

def get_transform(image_mode):
    if image_mode == 'L':
        return transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
    else:
        return transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

for idx, query_path in tqdm.tqdm(enumerate(query_paths), total=len(query_paths)):
    query_img = Image.open(query_path)
    transform = get_transform(query_img.mode)
    query_tensor = transform(query_img).unsqueeze(0)
    
    with torch.no_grad():
        query_embedding = encoder(query_tensor).squeeze().cpu()
    if category != 'test':
        category = 'test'
        support_set_paths = query_paths
        
        support_set_embeddings = []
        for support_img_path in support_set_paths:
            support_img = Image.open(support_img_path)
            support_tensor = transform(support_img).unsqueeze(0)
            
            with torch.no_grad():
                output = encoder(support_tensor)
            support_set_embeddings.append(output.squeeze().cpu())
        
        support_set_embeddings = torch.stack(support_set_embeddings)
    
    n = 1
    selected_indices = nearest_first(support_set_embeddings, query_embedding, n)
    
    new_dict[query_path] = support_set_paths[selected_indices[1]]

with open('jsons/eval/ksdd2/nearest_first.json', 'w') as f:
    json.dump(new_dict, f, indent=4)


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import glob
import json
import tqdm

encoder = models.resnet50(pretrained=True)
encoder = nn.Sequential(*list(encoder.children())[:-1])
encoder.eval()

normal_data = glob.glob('./KolektorSDD2/good/*.png')
abnormal_data = glob.glob('./KolektorSDD2/anomaly/*.png')
query_paths = normal_data + abnormal_data
new_dict = {}
category = None

def nearest_first(unlabeled_embeddings, labeled_embedding, n):
    unlabeled_embeddings = unlabeled_embeddings.to('cpu')
    labeled_embedding = labeled_embedding.unsqueeze(0).to('cpu')
    similarity = torch.nn.functional.cosine_similarity(unlabeled_embeddings, labeled_embedding, dim=1)
    max_similarity = similarity.clone()
    idxs = []
    selected_indices = set()
    for _ in range(n):
        idx = torch.argmax(max_similarity)
        while idx.item() in selected_indices:
            max_similarity[idx] = float('-inf')
            idx = torch.argmax(max_similarity)
        selected_indices.add(idx.item())
        idxs.append(idx.item())
        similarity_new = torch.nn.functional.cosine_similarity(unlabeled_embeddings, unlabeled_embeddings[idx].unsqueeze(0), dim=1)
        max_similarity = torch.maximum(max_similarity, similarity_new)
    return idxs

def get_transform(image_mode):
    if image_mode == 'L':
        return transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
    else:
        return transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

for idx, query_path in tqdm.tqdm(enumerate(query_paths), total=len(query_paths)):
    query_img = Image.open(query_path)
    transform = get_transform(query_img.mode)
    query_tensor = transform(query_img).unsqueeze(0)
    
    with torch.no_grad():
        query_embedding = encoder(query_tensor).squeeze().cpu()
    if category != 'test':
        category = 'test'
        support_set_paths = query_paths
        
        support_set_embeddings = []
        for support_img_path in support_set_paths:
            support_img = Image.open(support_img_path)
            support_tensor = transform(support_img).unsqueeze(0)
            
            with torch.no_grad():
                output = encoder(support_tensor)
            support_set_embeddings.append(output.squeeze().cpu())
        
        support_set_embeddings = torch.stack(support_set_embeddings)
    
    n = 5
    selected_indices = nearest_first(support_set_embeddings, query_embedding, n)
    
    new_dict[query_path] = support_set_paths[selected_indices[1]]

with open('jsons/eval/ksdd2/rices_one.json', 'w') as f:
    json.dump(new_dict, f, indent=4)
