In [13]:
import cv2
import pandas as pd
from PIL import Image
from models.modeling import VisionTransformer, CONFIGS
import torch
from torchvision import transforms

In [2]:
# 이미지 변환을 위한 transform 설정
transform = transforms.Compose([
    transforms.Resize((600, 600), Image.BILINEAR),
    transforms.CenterCrop((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 설정 변수들
img_size = 448
pretrained_model_path = "output/best_model.bin"

# 모델 설정
config = CONFIGS["ViT-B_16"]
config.split = 'overlap'
config.slide_step = 12
num_classes = pd.read_csv('label_encoding.csv')['label'].nunique()

# 모델 초기화 및 사전 학습된 가중치 로드
model = VisionTransformer(config, img_size, num_classes, zero_head=True)
model.load_state_dict(torch.load(pretrained_model_path, map_location='cpu')['model'])

# 장치 설정 및 모델을 평가 모드로 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

  "Argument interpolation should be of type InterpolationMode instead of int. "


VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(12, 12))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0): Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (out): Linear(in_features=768, out_features=768, bias=True)
     

In [6]:
# 이미지 읽기 및 변환 
img_path = 'demo/examples/C-220720_13_CR15_02_A1619.jpg'
img = cv2.imread(img_path)
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 
img.show()
img = transform(img).unsqueeze(0).to(device)

# 레이블 딕셔너리 생성
label_dict = pd.read_csv('label_encoding.csv').drop_duplicates('label').set_index('label').to_dict()['label_']

# 모델을 사용하여 예측 수행
part_logits = model(img)
probs = torch.nn.functional.softmax(part_logits, dim=-1)
predicted_label = torch.argmax(probs, dim=-1).item()
predicted_name = label_dict[predicted_label]

# 예측된 클래스 이름 출력
print("Predicted class name:", predicted_name)

Predicted class name: 지프#랭글러_JL(2018)
