In [None]:
# !pip3 install faiss-gpu
# !conda install -c pytorch faiss-gpu
!pip3 install faiss-cpu

In [None]:
import torch
from datasets import load_dataset
from transformers import ViTImageProcessor, ViTModel

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTModel.from_pretrained(model_name)

inputs = processor(image, return_tensors="pt")
with torch.no_grad():
    outputs = model(inputs["pixel_values"])

print("마지막 특징 맵의 형태 :", outputs["last_hidden_state"].shape)
print("특징 벡터의 차원 수 :", outputs["last_hidden_state"][:, 0, :].shape)
print("특징 벡터 :", outputs["last_hidden_state"][:, 0, :])

In [None]:
import torch
import numpy as np
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel

dataset = load_dataset("sasha/dog-food")
images = dataset["test"]["image"][:100]

model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

vectors = []
with torch.no_grad():
    for image in images:
        inputs = processor(images=image, padding=True, return_tensors="pt")
        outputs = model.get_image_features(**inputs)
        vectors.append(outputs.cpu().numpy())

vectors = np.vstack(vectors)
print("이미지 벡터의 shape :", vectors.shape)

In [None]:
import faiss

dimension = vectors.shape[-1]
index = faiss.IndexFlatL2(dimension)
if torch.cuda.is_available():
    res = faiss.StandardGpuResources()
    index = faiss.index_cpu_to_gpu(res, 0, index)

index.add(vectors)

In [None]:
import matplotlib.pyplot as plt

search_vector = vectors[0].reshape(1, -1)
num_neighbors = 5
distances, indices = index.search(x=search_vector, k=num_neighbors)

fig, axes = plt.subplots(1, num_neighbors + 1, figsize=(15, 5))

axes[0].imshow(images[0])
axes[0].set_title("Input Image")
axes[0].axis("off")

for i, idx in enumerate(indices[0]):
    axes[i + 1].imshow(images[idx])
    axes[i + 1].set_title(f"Match {i + 1}\nIndex: {idx}\nDist: {distances[0][i]:.2f}")
    axes[i + 1].axis("off")

print("유사한 벡터의 인덱스 번호:", indices)
print("유사도 계산 결과:", distances)