In [1]:
import os

import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import models
from torchvision import transforms

from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

### Load data

In [5]:
images_w_boxes = pd.read_csv('./public_test/images_w_boxes.csv')

In [9]:
submission_list = pd.read_csv('./public_test/submission_list.csv')

In [6]:
class CustomDataset(Dataset):
    
    def __init__(self, df, path):
        self.images = df['img_path'].values
        self.bbox_x1 = df['bbox_x1'].values
        self.bbox_y1 = df['bbox_y1'].values
        self.bbox_x2 = df['bbox_x2'].values
        self.bbox_y2 = df['bbox_y2'].values
        self.path = path
        
    def __getitem__(self, index):
        image_path = self.images[index]
        try:
            image = Image.open(self.path + image_path)
        except:
            image = Image.open(self.images[0])
            print('OPEN FILE ERR', image_path)
        x1 = self.bbox_x1[index]
        y1 = self.bbox_y1[index]
        x2 = self.bbox_x2[index]
        y2 = self.bbox_y2[index]
        try:
            image = image.crop((int(x1), int(y1), int(x2), int(y2)))
        except:
            print('BB ERRR', image_path)
        image = image.resize((224, 224))
        image = transforms.ToTensor()(image)
        try:
            image = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(image)
        except:
            image = torch.zeros([3, 224, 224])
            print('IMAGE FORMAT ERRR', image_path)
        if image.shape != torch.Size([3, 224, 224]):
            image = torch.zeros([3, 224, 224])
            print('IMG ERR', image_path)
        return image
        
    def __len__ (self):
        return len(self.images)

In [7]:
dataset = CustomDataset(images_w_boxes, './public_test/images/')

In [8]:
idx2img = {images_w_boxes.loc[i]['img_path'] : images_w_boxes.loc[i].name for i in range(len(images_w_boxes))}

In [10]:
from sklearn.preprocessing import normalize

def get_embedding(model, dataset, img):
    idx = idx2img[img]
    img = dataset[idx]
    emb = normalize(model(torch.unsqueeze(img, 0)).data.cpu().numpy()[0].reshape(1, -1))
    return emb

### Load model

In [2]:
model = models.resnet18(pretrained=False)

In [3]:
model.fc = torch.nn.Identity()

In [4]:
model.load_state_dict(torch.load('./models/model_mc.pth', map_location=torch.device('cpu')))

<All keys matched successfully>

### Prediction

In [11]:
sim_list = []
for i in tqdm(submission_list['id'].values):
    ids, img_a, img_b = submission_list[submission_list['id'] == i].values[0]
    emb_a = get_embedding(model, dataset, img_a)
    emb_b = get_embedding(model, dataset, img_b)
    dist = (2 - np.linalg.norm(emb_a - emb_b)) / 2
    sim_list.append(dist)

  1%|▏         | 188/14960 [00:33<48:16,  5.10it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


  7%|▋         | 976/14960 [02:48<40:17,  5.78it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


  9%|▊         | 1299/14960 [03:44<39:41,  5.74it/s]

IMAGE FORMAT ERRR 86eae0a5c83f5435019efb3b7cc8a3d2.jpg


 13%|█▎        | 2012/14960 [05:53<38:37,  5.59it/s]

IMAGE FORMAT ERRR 8cc25f5c544625a9f54e03b081945e5c.jpg


 28%|██▊       | 4178/14960 [11:31<27:32,  6.52it/s]

IMAGE FORMAT ERRR 658a9061fda92bd65a0dbb4332b472b4.jpg


 31%|███       | 4568/14960 [12:32<27:18,  6.34it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 31%|███       | 4609/14960 [12:39<27:12,  6.34it/s]

IMAGE FORMAT ERRR 1675d8443dc4cefdf71e93ccf08d22e5.jpg


 32%|███▏      | 4788/14960 [13:07<26:19,  6.44it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 32%|███▏      | 4842/14960 [13:15<26:34,  6.34it/s]

IMAGE FORMAT ERRR 8cc25f5c544625a9f54e03b081945e5c.jpg


 34%|███▍      | 5055/14960 [13:49<25:52,  6.38it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 36%|███▌      | 5311/14960 [14:29<25:18,  6.35it/s]

IMAGE FORMAT ERRR 1675d8443dc4cefdf71e93ccf08d22e5.jpg


 37%|███▋      | 5543/14960 [15:06<24:45,  6.34it/s]

IMAGE FORMAT ERRR 1675d8443dc4cefdf71e93ccf08d22e5.jpg


 45%|████▍     | 6663/14960 [18:01<21:37,  6.39it/s]

IMAGE FORMAT ERRR 8cc25f5c544625a9f54e03b081945e5c.jpg


 49%|████▊     | 7265/14960 [19:36<19:52,  6.45it/s]

IMAGE FORMAT ERRR 1675d8443dc4cefdf71e93ccf08d22e5.jpg


 50%|█████     | 7506/14960 [20:14<19:28,  6.38it/s]

IMAGE FORMAT ERRR 658a9061fda92bd65a0dbb4332b472b4.jpg


 50%|█████     | 7549/14960 [20:21<19:37,  6.30it/s]

IMAGE FORMAT ERRR 86eae0a5c83f5435019efb3b7cc8a3d2.jpg


 53%|█████▎    | 7870/14960 [21:11<18:29,  6.39it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 54%|█████▍    | 8055/14960 [21:40<17:49,  6.45it/s]

IMAGE FORMAT ERRR 1675d8443dc4cefdf71e93ccf08d22e5.jpg


 58%|█████▊    | 8694/14960 [23:20<16:35,  6.29it/s]

IMAGE FORMAT ERRR 98eb14af0b772c77c60c09339d96562e.jpg


 61%|██████▏   | 9164/14960 [24:34<15:01,  6.43it/s]

IMAGE FORMAT ERRR 86eae0a5c83f5435019efb3b7cc8a3d2.jpg


 61%|██████▏   | 9166/14960 [24:34<15:17,  6.32it/s]

IMAGE FORMAT ERRR 86eae0a5c83f5435019efb3b7cc8a3d2.jpg


 66%|██████▌   | 9869/14960 [26:24<13:15,  6.40it/s]

IMAGE FORMAT ERRR 658a9061fda92bd65a0dbb4332b472b4.jpg


 68%|██████▊   | 10170/14960 [27:11<12:33,  6.35it/s]

IMAGE FORMAT ERRR 658a9061fda92bd65a0dbb4332b472b4.jpg


 73%|███████▎  | 10875/14960 [29:01<10:37,  6.41it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 76%|███████▌  | 11348/14960 [30:15<09:21,  6.44it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 77%|███████▋  | 11484/14960 [30:37<09:10,  6.31it/s]

IMAGE FORMAT ERRR 658a9061fda92bd65a0dbb4332b472b4.jpg


 79%|███████▉  | 11801/14960 [31:26<08:13,  6.40it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 79%|███████▉  | 11826/14960 [31:30<08:18,  6.29it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 79%|███████▉  | 11891/14960 [31:40<08:01,  6.38it/s]

IMAGE FORMAT ERRR 8cc25f5c544625a9f54e03b081945e5c.jpg


 83%|████████▎ | 12489/14960 [33:14<06:19,  6.50it/s]

IMAGE FORMAT ERRR b18bf4af16e58c45f0b44371d73ba6ed.jpg


 86%|████████▌ | 12830/14960 [34:07<05:29,  6.47it/s]

IMAGE FORMAT ERRR 8cc25f5c544625a9f54e03b081945e5c.jpg


 89%|████████▊ | 13260/14960 [35:14<04:27,  6.36it/s]

IMAGE FORMAT ERRR 658a9061fda92bd65a0dbb4332b472b4.jpg


 89%|████████▊ | 13275/14960 [35:16<04:21,  6.44it/s]

IMAGE FORMAT ERRR 8cc25f5c544625a9f54e03b081945e5c.jpg
IMAGE FORMAT ERRR 1675d8443dc4cefdf71e93ccf08d22e5.jpg


 89%|████████▉ | 13304/14960 [35:21<04:19,  6.37it/s]

IMAGE FORMAT ERRR 1675d8443dc4cefdf71e93ccf08d22e5.jpg


 89%|████████▉ | 13327/14960 [35:25<04:17,  6.34it/s]

IMAGE FORMAT ERRR 86eae0a5c83f5435019efb3b7cc8a3d2.jpg


 91%|█████████ | 13640/14960 [36:14<03:24,  6.47it/s]

IMAGE FORMAT ERRR 1675d8443dc4cefdf71e93ccf08d22e5.jpg


 93%|█████████▎| 13934/14960 [37:00<02:40,  6.39it/s]

IMAGE FORMAT ERRR 98eb14af0b772c77c60c09339d96562e.jpg


 94%|█████████▍| 14104/14960 [37:26<02:13,  6.43it/s]

IMAGE FORMAT ERRR 86eae0a5c83f5435019efb3b7cc8a3d2.jpg


100%|██████████| 14960/14960 [39:40<00:00,  6.28it/s]


In [12]:
submission = pd.DataFrame()
submission['id'] = submission_list['id'].values
submission['score'] = sim_list

In [13]:
submission.to_csv('mcs_prediction.csv', index=False)