In [1]:
from dataset import CustomDataset
from runner.model import TimmModel
from config import CFG

from utils_.set_seed import seed_everything
from utils_.set_path import *




In [2]:
import random
import pandas as pd
import numpy as np
import os
import re
import glob
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torchvision.models as models

from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report
from tqdm.auto import tqdm
import timm

import warnings
warnings.filterwarnings(action='ignore') 

import wandb

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [4]:
df = pd.read_csv('./data/train.csv')
df['filename'] = df['filename'].apply(lambda x : os.path.join('./Data/Training_whole/PNG',x))

In [5]:
df.head()

Unnamed: 0,id,filename,label
0,0,./Data/Training_whole/PNG/울음_12.png,울음
1,1,./Data/Training_whole/PNG/울음_11.png,울음
2,2,./Data/Training_whole/PNG/울음_16.png,울음
3,3,./Data/Training_whole/PNG/울음_17.png,울음
4,4,./Data/Training_whole/PNG/울음_21.png,울음


In [6]:
train, val, _, _ = train_test_split(df, df['label'], test_size=0.3, stratify=df['label'], random_state=CFG['SEED'])

In [7]:
indexes = val['label'].value_counts().index
val['label'].value_counts()

훼손         422
오염         179
걸레받이수정      92
꼬임          63
터짐          49
곰팡이         43
오타공         43
몰딩수정        39
면불량         30
석고수정        17
들뜸          16
피스          15
창틀,문틀수정      8
울음           7
이음부불량        5
가구수정         4
녹오염          4
반점           1
틈새과다         1
Name: label, dtype: int64

In [8]:
le = preprocessing.LabelEncoder()
train['label'] = le.fit_transform(train['label'])
val['label'] = le.transform(val['label'])

In [9]:
val['label'].value_counts().values

array([422, 179,  92,  63,  49,  43,  43,  39,  30,  17,  16,  15,   8,
         7,   5,   4,   4,   1,   1])

# Dataset

In [10]:
train_transform = A.Compose([
                            A.Resize(CFG['IMG_SIZE'],CFG['IMG_SIZE']),
                            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, always_apply=False, p=1.0),
                            ToTensorV2()
                            ])

test_transform = A.Compose([
                            A.Resize(CFG['IMG_SIZE'],CFG['IMG_SIZE']),
                            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, always_apply=False, p=1.0),
                            ToTensorV2()
                            ])

In [14]:
val_dataset = CustomDataset(val['filename'].values, val['label'].values, transforms = False, CFG=CFG)
val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False, num_workers=4)

# Model load

In [16]:
import timm
from pprint import pprint
model_names = timm.list_models('vit_base_patch16*')
(model_names)

['vit_base_patch16_18x2_224',
 'vit_base_patch16_224',
 'vit_base_patch16_224_dino',
 'vit_base_patch16_224_in21k',
 'vit_base_patch16_224_miil',
 'vit_base_patch16_224_miil_in21k',
 'vit_base_patch16_224_sam',
 'vit_base_patch16_384',
 'vit_base_patch16_plus_240',
 'vit_base_patch16_rpn_224']

In [17]:
model = TimmModel('vit_base_patch16_224', num_classes = 19, pretrained= True)

In [18]:
model.state_dict()

OrderedDict([('model.cls_token',
              tensor([[[ 3.2389e-01,  1.4362e-02, -4.3632e-01, -2.9668e-02,  4.9954e-01,
                         3.4431e-01,  6.9270e-02,  1.5145e-02,  1.5875e-01,  6.6091e-03,
                         2.9595e-02,  2.6433e-02,  4.8807e-02,  1.5438e-01,  7.6776e-02,
                        -5.6193e-02,  2.3481e+00, -3.7596e-02, -7.9648e-02, -3.2865e-02,
                         4.0753e-02,  1.0625e-01,  1.2637e-02,  1.2623e-01, -5.1822e-03,
                        -2.3552e-01,  3.1639e-02,  6.2501e-02,  3.5712e-02, -1.7925e-02,
                        -2.7995e-02,  8.2622e-01,  5.4036e-02,  3.6365e-02,  4.2601e-02,
                         9.5680e-02, -4.4193e-02,  5.1107e-02,  1.9187e-01,  3.5678e-01,
                         6.6560e-02,  8.5484e-03,  2.0514e-02,  2.7039e-02,  6.3104e-02,
                         2.8506e-02, -9.6596e-02,  9.6870e-03,  1.0163e-01,  1.4900e-01,
                         1.1239e-01,  3.5970e-02,  9.4504e-02,  3.3720e-02,  

In [20]:
new = dict()
tmp = (torch.load('/opt/ml/Wallpaper-defect_classification/models/[vit_base_patch16_224]_[score0.8078]_[loss0.6596].pt'))
model.load_state_dict(tmp)
model.eval()
model.to(device)

TimmModel(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (drop1): Dropout(p=0.0, inplace=False)
          (fc2): Linear(in_features=3072, out_features=768, bi

In [21]:
# for key, value in tmp.items():
#     if 'backbone' in key:
#         new[key.replace('backbone', 'model')] = value 
#     # else:
#     #     new[key.replace('fc', 'module.fc')] = value
# model.load_state_dict(new)
# model.eval()
# model.to(device)
# new

# 예측 결과 해석

In [22]:
targets_img = None
preds, targets = torch.Tensor().to(device), torch.Tensor().to(device)
prediction = []
with torch.no_grad():
    for X, y in tqdm(val_loader):
        X, y = X.to(device), y.to(device)
        
        pred = model(X)
        pred = pred.argmax(dim=1)
        indices = (pred != y).nonzero().squeeze()

        if indices.nelement() != 0:
            target_img = X[indices]
            if target_img.dim() == 3:
                target_img = target_img.unsqueeze(dim=0)
            
            if targets_img is None:
                targets_img = target_img
            else:
                targets_img = torch.cat((targets_img, target_img), dim=0)

            ps = pred[indices]
            ys = y[indices]


            if ys.dim() == 0: # 배치에사 하나만 잘 못 판단하면 이 코드 넣어주어야함
                # print(ps, ys)
                # print(ys.dim())
                # break
                ys = ys.unsqueeze(dim=0)
                ps = ps.unsqueeze(dim=0)

            preds = torch.cat((preds, ps))
            targets = torch.cat((targets, ys))
            # print(preds, targets)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=65.0), HTML(value='')))




# 총 틀린 prdes, targets 개수 

In [23]:
worng_preds = pd.Series(list(map(int, preds.cpu()))).value_counts()
worng_targets = pd.Series(list(map(int, targets.cpu()))).value_counts()


In [24]:
worng_preds

18    82
10    32
3     11
11     9
2      8
1      8
15     7
7      7
6      5
17     4
5      4
14     2
13     1
4      1
dtype: int64

In [25]:
sorted(pd.Series(list(map(int,targets))).unique())

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]

In [26]:
total_cnt = val['label'].value_counts()
indexes, val['label'].value_counts()

(Index(['훼손', '오염', '걸레받이수정', '꼬임', '터짐', '곰팡이', '오타공', '몰딩수정', '면불량', '석고수정',
        '들뜸', '피스', '창틀,문틀수정', '울음', '이음부불량', '가구수정', '녹오염', '반점', '틈새과다'],
       dtype='object'),
 18    422
 10    179
 1      92
 3      63
 15     49
 2      43
 11     43
 7      39
 6      30
 9      17
 5      16
 17     15
 14      8
 12      7
 13      5
 4       4
 0       4
 8       1
 16      1
 Name: label, dtype: int64)

In [27]:
from collections import defaultdict

In [28]:
worng_pred = defaultdict(int)
worng_target = defaultdict(int)
for key in list(worng_preds.keys()):
    worng_pred[key] = worng_preds[key] / total_cnt[key]
for key in list(worng_targets.keys()):
    worng_target[key] = worng_targets[key] / total_cnt[key]    

In [29]:
worng_pred, worng_target
# 잘못 예측한 preds, 잘못 예측한 true label

(defaultdict(int,
             {18: 0.1943127962085308,
              10: 0.1787709497206704,
              3: 0.1746031746031746,
              11: 0.20930232558139536,
              2: 0.18604651162790697,
              1: 0.08695652173913043,
              15: 0.14285714285714285,
              7: 0.1794871794871795,
              6: 0.16666666666666666,
              17: 0.26666666666666666,
              5: 0.25,
              14: 0.25,
              13: 0.2,
              4: 0.25}),
 defaultdict(int,
             {18: 0.0947867298578199,
              10: 0.22346368715083798,
              6: 0.6333333333333333,
              9: 0.7058823529411765,
              14: 1.0,
              15: 0.16326530612244897,
              2: 0.18604651162790697,
              7: 0.1794871794871795,
              1: 0.07608695652173914,
              12: 0.7142857142857143,
              17: 0.3333333333333333,
              13: 0.8,
              11: 0.09302325581395349,
              0: 1.0,
  

In [88]:
misprediction_count = {
    (18, 10) : 0,
    (18, 1) : 0,
    (1, 0) : 0,
    (1, 2) : 0,
    (2, 0) : 0,
    (2, 1) : 0
}
misprediction_count = defaultdict

misprediction_count = { (18, i) : 0 for i in range(18)}
print(misprediction_count)

{(18, 0): 0, (18, 1): 0, (18, 2): 0, (18, 3): 0, (18, 4): 0, (18, 5): 0, (18, 6): 0, (18, 7): 0, (18, 8): 0, (18, 9): 0, (18, 10): 0, (18, 11): 0, (18, 12): 0, (18, 13): 0, (18, 14): 0, (18, 15): 0, (18, 16): 0, (18, 17): 0}


In [92]:
for pred, label in tqdm(zip(preds, targets)):
    if (pred.item(), label.item()) in misprediction_count.keys():
        misprediction_count[(pred.item(), label.item())] += 1

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




In [93]:
misprediction_count

{(18, 0): 2,
 (18, 1): 8,
 (18, 2): 10,
 (18, 3): 8,
 (18, 4): 4,
 (18, 5): 2,
 (18, 6): 26,
 (18, 7): 8,
 (18, 8): 0,
 (18, 9): 18,
 (18, 10): 113,
 (18, 11): 14,
 (18, 12): 2,
 (18, 13): 2,
 (18, 14): 4,
 (18, 15): 14,
 (18, 16): 0,
 (18, 17): 8}