In [1]:
import cv2
import os
import timm
import numpy as np
import pandas as pd
import albumentations as A

from glob import glob
from tqdm import tqdm
from easydict import EasyDict
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import f1_score

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

In [2]:
class CustomDataset(Dataset):
    def __init__(self, img_list, label_list=None, transforms=None, mode="train") :
        self.img_list = img_list
        
        if mode == "train" : 
            self.label_list = self.label_encoder(label_list)
            
        self.transforms = transforms
        self.mode = mode
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_path = self.img_list[idx]

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transforms:            
            img = self.transforms(image=img)['image']
        
        if self.mode == "train" :
            label = self.label_list[idx]
            return img, torch.tensor(label)
        
        elif self.mode == "test" :
            return img
    
    def label_encoder(self, label_list) :
        label_enc = {k : i for i, k in enumerate(sorted(list(set(label_list))))}
        return [label_enc[label] for label in label_list]
    

# Custom Swin Transformer

In [3]:
class BackBone(nn.Module) :
    def __init__(self, model_name, backbone_output) :
        super(BackBone, self).__init__()
        self.model = timm.create_model(model_name=model_name, num_classes=backbone_output, pretrained=True)
    
    def forward(self, x) :
        output = self.model(x)
        return output
    
class MLP(nn.Module) :
    def __init__(self, in_features, dropout_rate, num_state) :
        super(MLP, self).__init__()
        #forward_features 시 LayerNorm까지 통과한 결과임
        # 따라서 LayerNorm 와 AdaptiveAvgPool1d는 필요없음    

        self.linear_1 = nn.Linear(in_features, in_features//2, bias=True)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(p=dropout_rate, inplace=False)
        self.linear_2 = nn.Linear(in_features//2, num_state, bias=True)
        
    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.linear_2(x)
        return x
    
class CustomSwinTransformer(nn.Module) :
    def __init__(self, 
                 model_path, 
                 model_name, 
                 backbone_output, 
                 num_class, 
                 num_state,
#                  label_decoder,
                 dropout_rate=0.5) :
        super(CustomSwinTransformer, self).__init__()
#         self.label_decoder = label_decoder
        
        self.backbone = self.get_backbone(model_path,
                                         model_name,
                                         backbone_output)
        
        # num_state + 1을 해준 이유 = None Class를 추가할 예정이기 때문
        self.mlps = nn.ModuleList([MLP(in_features=1024, 
                         dropout_rate=dropout_rate, 
#                        num_state = num_state[i]) for i in range(num_class)])
                         num_state = num_state[i] + 1) for i in range(num_class)])
        
    def forward(self, x) :       
        
        preds = []
        feature_map = self.backbone.forward_features(x)
        for mlp in self.mlps :
            preds.append(mlp(feature_map))
        return preds
    
    def get_backbone(self, model_path, model_name, backbone_output) :
        checkpoint = torch.load(model_path)
        backbone = BackBone(model_name, backbone_output)
        backbone.load_state_dict(checkpoint["model_state_dict"])
        return backbone.model

# Weight Freeze

In [4]:
def WeightFreeze(model) :
    for i, child in enumerate(model.backbone.children()) :
        for param in child.parameters() :
            param.requires_grad = False
    return model

# MLP Label Split

In [5]:
def mlp_label_split(num_state, labels) :
    tmp = {i : torch.tensor([], dtype=torch.int32) for i in range(15)}

    for label in labels :

        if 0 <= label and label < num_state[0] :
            tmp[0] = torch.cat((tmp[0], torch.tensor([label])), dim=0)
        else :
            tmp[0] = torch.cat((tmp[0], torch.tensor([num_state[0]])))
            
        for i in range(1, 15):
            if sum(num_state[:i]) <= label and label < sum(num_state[:i+1]) :
                tmp[i] = torch.cat((tmp[i], torch.tensor([int(label - sum(num_state[:i]))])), dim=0)
            else :
                tmp[i] = torch.cat((tmp[i], torch.tensor([num_state[i]])), dim=0)
                
    return tmp

# Label Decoder

In [6]:
def label_decoder(labels) :
    print("==== each_dec ====")
    print({k:i for i, k in enumerate(labels)})
    return {k:i for i, k in enumerate(labels)}

def each_label_decoder(labels, num_state) :
    dec = label_decoder(labels)
    each_dec = []
    cnt = 0
    for idx, (k, v) in enumerate(dec.items()) :
        if sum(num_state[:cnt]) == idx :
            cnt += 1
            flag=True
        
        if flag :    
            each_dec.append({cnt : list(dec.keys())[i] for cnt, i in enumerate(range(idx, sum(num_state[:cnt])))})
            flag = False
            
    print("==== each_dec ====")
    print(each_dec)
    return each_dec
        

In [7]:
opt = {
    "test_df_path" : "../data/test_df.csv",
    "train_df_path" : "../data/train_df.csv",
    "submission_df_path" : "../data/sample_submission.csv",
    "img_path" : "../data/test",
    "save_path" : "../data/submission/custom_aug_v4_6E_0.0113_swin_base_patch4_window7_224_in22k.csv",
    "model_name" : "swin_base_patch4_window7_224_in22k",
    "model_path" : "../model/custom_swin_aug_v4_mixup/6E_0.0113_swin_base_patch4_window7_224_in22k.pt",
    "num_classes" : 88,
    'num_state' : [4, 9, 6, 6, 6, 5, 6, 5, 8, 6, 6, 2, 5, 6, 8],
    "resize" : 224,
    "device" : "cuda:0",
    "batch_size" : 64,
    "ensemble" : False
}

opt = EasyDict(opt)


model_opt = {    
    'model_path' : '../model/swin_aug_v4_CEL/30E_0.0114_swin_base_patch4_window7_224_in22k.pt',
    'model_name' : 'swin_base_patch4_window7_224_in22k',
    'backbone_output' : 88,
    'num_class' : 15,
    'num_state' : [4, 9, 6, 6, 6, 5, 6, 5, 8, 6, 6, 2, 5, 6, 8],
    'dropout_rate' : 0.5
}
model_opt = EasyDict(model_opt)


test_transforms = A.Compose([
    A.Normalize(),
    A.Resize(opt.resize, opt.resize),
    ToTensorV2()
])

train_df = pd.read_csv(opt.train_df_path)
label_list = list(sorted(train_df['label'].unique()))
label_dec = label_decoder(label_list)
each_label_dec = each_label_decoder(label_list, opt.num_state)

test_df = pd.read_csv(opt.test_df_path)
file_names = list(map(lambda y :os.path.join(opt.img_path, y), test_df['file_name']))

test_data = CustomDataset(file_names, transforms=test_transforms, mode="test")
test_loader = DataLoader(test_data, batch_size=opt.batch_size, shuffle=False)



if opt.ensemble :
    models = []
    for model_path in opt.models_list :
        model = CNN(opt.model_name, opt.num_classes).to(opt.device)
        model_data = torch.load(model_path)
        model.load_state_dict(model_data["model_state_dict"])
        models.append(model.eval())
        
        
else :
    custom_swin = CustomSwinTransformer(**model_opt).to(opt.device)
    model = WeightFreeze(custom_swin)
    model_data = torch.load(opt.model_path)
    model.load_state_dict(model_data["model_state_dict"])
    model.eval()

==== each_dec ====
{'bottle-broken_large': 0, 'bottle-broken_small': 1, 'bottle-contamination': 2, 'bottle-good': 3, 'cable-bent_wire': 4, 'cable-cable_swap': 5, 'cable-combined': 6, 'cable-cut_inner_insulation': 7, 'cable-cut_outer_insulation': 8, 'cable-good': 9, 'cable-missing_cable': 10, 'cable-missing_wire': 11, 'cable-poke_insulation': 12, 'capsule-crack': 13, 'capsule-faulty_imprint': 14, 'capsule-good': 15, 'capsule-poke': 16, 'capsule-scratch': 17, 'capsule-squeeze': 18, 'carpet-color': 19, 'carpet-cut': 20, 'carpet-good': 21, 'carpet-hole': 22, 'carpet-metal_contamination': 23, 'carpet-thread': 24, 'grid-bent': 25, 'grid-broken': 26, 'grid-glue': 27, 'grid-good': 28, 'grid-metal_contamination': 29, 'grid-thread': 30, 'hazelnut-crack': 31, 'hazelnut-cut': 32, 'hazelnut-good': 33, 'hazelnut-hole': 34, 'hazelnut-print': 35, 'leather-color': 36, 'leather-cut': 37, 'leather-fold': 38, 'leather-glue': 39, 'leather-good': 40, 'leather-poke': 41, 'metal_nut-bent': 42, 'metal_nut-colo

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [8]:
def merge_mlp_label(tmp_anws, num_state, batch_size) :
    answer = []
    for bn in range(batch_size) :
        score, index, mlp_state = 0, 0, 0
        
        # mlp 개수만큼 반복 = 15번
        for j, anws in enumerate(tmp_anws) :
#             print(j)
#             print("anws[j]")
#             print(anws[bn])
#             print("anws[1]")
#             print(anws[1].shape)
#             print(anws[1])
#             print(anws[1].item())
#             print("anws[0]")
#             print(anws[0].shape)
#             print(anws[0])
#             print(anws[0].item())
            
            if score < anws[1][bn][anws[0][bn]] and anws[0][bn] != num_state[j] :
                score = anws[1][bn][anws[0][bn]].item()
                index = anws[0][bn].item()
                mlp_state = j
                
#         print(each_label_dec[mlp_state][index])
#         print(type(each_label_dec[mlp_state][index]))
#         answer.append(label_dec[each_label_dec[mlp_state][index]])
        answer.append(each_label_dec[mlp_state][index])
    return answer
            

In [9]:
with torch.no_grad() :
    answers = []
    for img in tqdm(test_loader) :
        img = img.to(opt.device)
        if not opt.ensemble :
            outputs = model(img)
            tmp_anws = []
            for idx, output in enumerate(outputs) :
                prob = F.softmax(output.cpu())
                preds = torch.argmax(prob, dim=1)
                # [index, score]
                tmp_anws.append([preds, prob])                
                
#                 score, indice = 0, 0
#                 # batch size 만큼 반복
#                 for j, pred in enumerate(preds) :
#                     if score < output[j][preds]:
#                         indice = pred
#                         score = output[j][preds]

            labels = merge_mlp_label(tmp_anws, opt.num_state, img.size(0))
#             labels = list(map(lambda x : label_decoder[x.item()], preds))
            answers.extend(labels)
            
        elif opt.ensemble :
            predicts = torch.zeros(img.size(0), opt.num_classes)
            results = {}
            tmp_vals = [0] * opt.num_classes
            tmp_inds = [0] * opt.num_classes
            for model in models :
                output = model(img)        
                output = F.softmax(output.cpu())
                
                vals, indices = torch.max(output, 1)
                for i, (val, idx) in enumerate(zip(vals, indices)) :
                    if tmp_vals[i] < val :
                        tmp_vals[i] = val.item()
                        tmp_inds[i] = idx.item()
                predicts += output
                
            preds = torch.argmax(predicts.detach().cpu() / len(models), dim=1)
            labels = list(map(lambda x : label_decoder[x.item()], preds))
            answers.extend(labels)

  prob = F.softmax(output.cpu())
100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [01:04<00:00,  1.90s/it]


In [10]:
submission = pd.read_csv(opt.submission_df_path).set_index('index')
submission['label'] = answers
submission

Unnamed: 0_level_0,label
index,Unnamed: 1_level_1
0,tile-glue_strip
1,grid-good
2,transistor-good
3,tile-gray_stroke
4,tile-good
...,...
2149,tile-gray_stroke
2150,screw-good
2151,grid-good
2152,cable-good


In [11]:
submission.to_csv(opt.save_path)