In [2]:
import cv2
import torch
import os
from typing import Any
import torch
from torch import Tensor
from torch import nn
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import time
import copy
import tqdm
import torchvision.transforms as transforms
from torchmetrics.classification import MultilabelAccuracy, Accuracy
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchsummary import summary

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#######################################데이터관련 로드, 검증, 클래스정의, 데이터로더 #####################
#######################################데이터관련 로드, 검증, 클래스정의, 데이터로더 #####################
#######################################데이터관련 로드, 검증, 클래스정의, 데이터로더 #####################
#image broken check
def check_jpeg_eoi(file_path):
    with open(file_path, 'rb') as f:
        f.seek(-2, 2) # 파일의 끝에서 두 바이트 전으로 이동합니다.
        return f.read() == b'\xff\xd9'
    

def is_image_valid(image_path):
    try:
        img = Image.open(image_path) # 이미지를 열어봅니다.
        img.verify() # verify() 메소드는 파일이 손상되었는지 확인합니다.
        return True
    except (IOError, SyntaxError) as e:
        print('Invalid image: ', image_path, '\n'+ e) # 손상된 이미지에 대한 에러 메시지를 출력합니다.
        return False

#image validation(exist and broken file)
def validate_dataset(df, img_dir):
    count = 0
    df_bar = tqdm.tqdm(df.itertuples(), desc="validating all images", total=len(df))
    for rows in df_bar:
        if os.path.isfile(img_dir+'/'+ rows.id):
            if is_image_valid(img_dir+'/'+ rows.id) and check_jpeg_eoi(img_dir+'/'+ rows.id):
                continue
            else:
                count += 1
                df.drop(df[df['id'] == rows.id].index, inplace=True)
        else:
            count += 1
            df.drop(df[df['id'] == rows.id].index, inplace=True)
        print("Not founded images (Num) : ",count)
    return df

#csv에서 데이터 가져옴
def get_data_from_csv(csv_path, train_ratio, img_dir, randoms_state=42):
    ###### columns example : ['id', 'good', 'b_edge', 'burr', 'borken', 'b_bubble', 'etc', 'no_lens']
    

    df = pd.read_csv(csv_path)
    df = validate_dataset(df=df,img_dir=img_dir)
    train_df , temp_df = train_test_split(df, test_size=1-train_ratio, random_state=randoms_state)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=randoms_state)


    # 'good' 열을 데이터 프레임에서 제거 및 클래스 재정렬
    cls_list = ['no_lens', 'etc', 'burr', 'borken', 'b_edge', 'b_bubble']
    train_df = train_df.drop(columns=['good'])
    train_df = train_df[['id'] + cls_list]
    val_df = val_df.drop(columns=['good'])
    val_df = val_df[['id'] + cls_list]
    test_df = test_df.drop(columns=['good'])
    test_df = test_df[['id'] + cls_list]


    print('num of train_df',len(train_df))
    print('num of val_df',len(val_df))
    print('num of test_df',len(test_df))

    num_cls = len(train_df.columns) - 1  # because, it is multi-label

    print('number of class: ', num_cls)
    # cls_list = list(train_df.columns)
    # cls_list.remove('id')

    print(cls_list)
    
    return train_df, val_df, test_df, num_cls, cls_list

#데이터셋 클래스 정의
class CustomDataset(Dataset):

    def __init__(self, dataframe, image_dir, num_classes, class_list, transforms=None, img_resize = False, img_dsize = (640,640)):
        super().__init__()
        
        self.image_ids = dataframe['id'].unique() # 이미지 고유 ID
        self.df = dataframe
        self.image_dir = image_dir
        self.transforms = transforms
        self.img_resize = img_resize
        self.img_dsize = img_dsize
        self.class_list = class_list
        self.num_classes = num_classes

    #데이터 길이 검증
    def validate_data_records(self):
        for idx, image_id in enumerate(self.image_ids):
            records = self.df[self.df['id'] == image_id]
            target = np.array(records[self.class_list].values).astype(np.float32)
            if target.shape[1] != len(self.class_list):
                print(f"Index {idx} with image_id {image_id} has mismatched target size. Expected {len(self.class_list)}, but got {target.shape[1]}")


    def __getitem__(self, index: int):
        # 이미지 index로 아이템 불러오기

        image_id = self.image_ids[index]
        records = self.df[self.df['id'] == image_id]
        
        image = cv2.imread(f'{self.image_dir}/{image_id}', cv2.IMREAD_COLOR)
            
        # OpenCV가 컬러를 저장하는 방식인 BGR을 RGB로 변환
        if self.img_resize:
            image = cv2.resize(image, self.img_dsize)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0 # 0 ~ 1로 스케일링

        target = np.array(records[self.class_list].values).astype(np.float32)
        target = target.reshape(-1)
        if self.transforms is not None:
            image = self.transforms(image)

        return image, target

    def __len__(self) -> int:
        return self.image_ids.shape[0]

def collate_fn(batch):
    images, targets = zip(*batch)
    images = torch.stack(images, 0)
    
    # Find the maximum target length
    max_len = max([len(t) for t in targets])
    
    # Pad each target to the maximum length
    targets_padded = [torch.cat([torch.tensor(t), torch.zeros(max_len - len(t))]) for t in targets]
    
    targets = torch.stack(targets_padded, 0)
    return images, targets

In [3]:
#####################기타####################
def create_directory(save_path):
    i = 1
    while True:
        dir_name = os.path.join('models/'+save_path+ str(i) +'/')
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)
            os.makedirs(dir_name+'/result')
            return dir_name
            break
        i += 1

In [4]:

####################################### 모델구조정의 ##########################################
####################################### 모델구조정의 ##########################################
####################################### 모델구조정의 ##########################################
class TH_InceptionV4(nn.Module):

    def __init__(self, k=192, l=224, m=256, n=384, num_classes=7):
        super(TH_InceptionV4, self).__init__()
        
        self.stem = InceptionV4Stem(3)
        
        self.inceptionA1 = InceptionA(384)
        self.inceptionA2 = InceptionA(384)
        self.inceptionA3 = InceptionA(384)
        self.inceptionA4 = InceptionA(384)
        self.no_etc_output_linear = nn.Linear(384, 2)
        
        self.reductionA = ReductionA(384, k, l, m, n)
        
        self.inceptionB1 = InceptionB(1024)
        self.inceptionB2 = InceptionB(1024)
        self.inceptionB3 = InceptionB(1024)
        self.inceptionB4 = InceptionB(1024)
        self.inceptionB5 = InceptionB(1024)
        self.inceptionB6 = InceptionB(1024)
        self.inceptionB7 = InceptionB(1024)

        self.burr_broken_output_linear = nn.Linear(1024, 2)
        
        self.reductionB = ReductionB(1024)
        
        self.inceptionC1 = InceptionC(1536)
        self.inceptionC2 = InceptionC(1536)
        self.inceptionC3 = InceptionC(1536)
        
        self.global_average_pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(1536, 1)

        # Initialize neural network weights
        self._initialize_weights()

    def forward(self, x):
        outputs = []
        x = self.stem(x)
        
        x1 = self.inceptionA1(x)
        x2 = self.inceptionA2(x1)
        x3 = self.inceptionA3(x2)
        x4 = self.inceptionA4(x3)

        #no_lens와 etc 클래스 output
        no_etc_output = self.global_average_pooling(x4)
        no_etc_output = torch.flatten(no_etc_output, 1)
        no_etc_output = self.no_etc_output_linear(no_etc_output)
        outputs.append(no_etc_output)

        x_redA = self.reductionA(x4)
        
        xB1 = self.inceptionB1(x_redA)
        xB2 = self.inceptionB2(xB1)
        xB3 = self.inceptionB3(xB2)
        xB4 = self.inceptionB4(xB3)
        xB5 = self.inceptionB5(xB4)
        xB6 = self.inceptionB6(xB5)
        xB7 = self.inceptionB7(xB6)

        burr_broken_output = self.global_average_pooling(xB7)
        burr_broken_output = torch.flatten(burr_broken_output, 1)
        burr_broken_output = self.burr_broken_output_linear(burr_broken_output)
        outputs.append(burr_broken_output)
        
        #b_edge분기 - broken에서 받음
        x_redB = self.reductionB(xB7)
        
        xC1 = self.inceptionC1(x_redB)
        xC2 = self.inceptionC2(xC1)
        xC3 = self.inceptionC3(xC2)

        b_edge_output = self.global_average_pooling(xC3)
        b_edge_output = torch.flatten(b_edge_output, 1)
        b_edge_output = self.linear(b_edge_output)
        outputs.append(b_edge_output)

        #b_bubble분기 - 완전히 따로 내려옴
        x_redA_2 = self.reductionA(x4)
        
        xB1_2 = self.inceptionB1(x_redA_2)
        xB2_2 = self.inceptionB2(xB1_2)
        xB3_2 = self.inceptionB3(xB2_2)
        xB4_2 = self.inceptionB4(xB3_2)
        xB5_2 = self.inceptionB5(xB4_2)
        xB6_2 = self.inceptionB6(xB5_2)
        xB7_2 = self.inceptionB7(xB6_2)

        x_redB_2 = self.reductionB(xB7_2)
        
        xC1_2 = self.inceptionC1(x_redB_2)
        xC2_2 = self.inceptionC2(xC1_2)
        xC3_2 = self.inceptionC3(xC2_2)
        
        b_bubble_output = self.global_average_pooling(xC3_2)
        b_bubble_output = torch.flatten(b_bubble_output, 1)
        b_bubble_output = self.linear(b_bubble_output)
        outputs.append(b_bubble_output)

        final_outputs = torch.cat(outputs, dim=1)
        return final_outputs

    def _initialize_weights(self) -> None:
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                stddev = float(module.stddev) if hasattr(module, "stddev") else 0.1
                torch.nn.init.trunc_normal_(module.weight, mean=0.0, std=stddev, a=-2, b=2)
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

class BasicConv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU(True)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)

        return out

class InceptionV4Stem(nn.Module):
    def __init__(
            self,
            in_channels: int,
    ) -> None:
        super(InceptionV4Stem, self).__init__()
        self.conv2d_1a_3x3 = BasicConv2d(in_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0))

        self.conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0))
        self.conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

        self.mixed_3a_branch_0 = nn.MaxPool2d((3, 3), (2, 2))
        self.mixed_3a_branch_1 = BasicConv2d(64, 96, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0))

        self.mixed_4a_branch_0 = nn.Sequential(
            BasicConv2d(160, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
        )
        self.mixed_4a_branch_1 = nn.Sequential(
            BasicConv2d(160, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(64, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3)),
            BasicConv2d(64, 64, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0)),
            BasicConv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0))
        )

        self.mixed_5a_branch_0 = BasicConv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0))
        self.mixed_5a_branch_1 = nn.MaxPool2d((3, 3), (2, 2))

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv2d_1a_3x3(x)
        out = self.conv2d_2a_3x3(out)
        out = self.conv2d_2b_3x3(out)

        mixed_3a_branch_0 = self.mixed_3a_branch_0(out)
        mixed_3a_branch_1 = self.mixed_3a_branch_1(out)
        mixed_3a_out = torch.cat([mixed_3a_branch_0, mixed_3a_branch_1], 1)

        mixed_4a_branch_0 = self.mixed_4a_branch_0(mixed_3a_out)
        mixed_4a_branch_1 = self.mixed_4a_branch_1(mixed_3a_out)
        mixed_4a_out = torch.cat([mixed_4a_branch_0, mixed_4a_branch_1], 1)

        mixed_5a_branch_0 = self.mixed_5a_branch_0(mixed_4a_out)
        mixed_5a_branch_1 = self.mixed_5a_branch_1(mixed_4a_out)
        mixed_5a_out = torch.cat([mixed_5a_branch_0, mixed_5a_branch_1], 1)

        return mixed_5a_out

class InceptionV4ResNetStem(nn.Module):
    def __init__(
            self,
            in_channels: int,
    ) -> None:
        super(InceptionV4ResNetStem, self).__init__()
        self.features = nn.Sequential(
            BasicConv2d(in_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
            BasicConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.MaxPool2d((3, 3), (2, 2)),
            BasicConv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
            nn.MaxPool2d((3, 3), (2, 2)),
        )
        self.branch_0 = BasicConv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        self.branch_1 = nn.Sequential(
            BasicConv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
        )
        self.branch_2 = nn.Sequential(
            BasicConv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            BasicConv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        )
        self.branch_3 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
            BasicConv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
        )

    def forward(self, x):
        features = self.features(x)
        branch_0 = self.branch_0(features)
        branch_1 = self.branch_1(features)
        branch_2 = self.branch_2(features)
        branch_3 = self.branch_3(features)

        out = torch.cat([branch_0, branch_1, branch_2, branch_3], 1)

        return out

class InceptionA(nn.Module):
    def __init__(
            self,
            in_channels: int,
    ) -> None:
        super(InceptionA, self).__init__()
        self.branch_0 = BasicConv2d(in_channels, 96, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        )
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            BasicConv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        )
        self.brance_3 = nn.Sequential(
            nn.AvgPool2d((3, 3), (1, 1), (1, 1), count_include_pad=False),
            BasicConv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        )

    def forward(self, x: Tensor) -> Tensor:
        branch_0 = self.branch_0(x)
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)
        brance_3 = self.brance_3(x)

        out = torch.cat([branch_0, branch_1, branch_2, brance_3], 1)

        return out

class ReductionA(nn.Module):
    def __init__(
            self,
            in_channels: int,
            k: int,
            l: int,
            m: int,
            n: int,
    ) -> None:
        super(ReductionA, self).__init__()
        self.branch_0 = BasicConv2d(in_channels, n, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0))
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, k, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(k, l, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            BasicConv2d(l, m, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
        )
        self.branch_2 = nn.MaxPool2d((3, 3), (2, 2))

    def forward(self, x: Tensor) -> Tensor:
        branch_0 = self.branch_0(x)
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)

        out = torch.cat([branch_0, branch_1, branch_2], 1)

        return out

class InceptionB(nn.Module):
    def __init__(
            self,
            in_channels: int,
    ) -> None:
        super(InceptionB, self).__init__()
        self.branch_0 = BasicConv2d(in_channels, 384, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, 192, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(192, 224, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3)),
            BasicConv2d(224, 256, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0)),
        )
        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, 192, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0)),
            BasicConv2d(192, 224, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3)),
            BasicConv2d(224, 224, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0)),
            BasicConv2d(224, 256, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3)),
        )
        self.branch_3 = nn.Sequential(
            nn.AvgPool2d((3, 3), (1, 1), (1, 1), count_include_pad=False),
            BasicConv2d(in_channels, 128, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
        )

    def forward(self, x: Tensor) -> Tensor:
        branch_0 = self.branch_0(x)
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)
        branch_3 = self.branch_3(x)

        out = torch.cat([branch_0, branch_1, branch_2, branch_3], 1)

        return out

class ReductionB(nn.Module):
    def __init__(
            self,
            in_channels: int,
    ) -> None:
        super(ReductionB, self).__init__()
        self.branch_0 = nn.Sequential(
            BasicConv2d(in_channels, 192, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
        )
        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channels, 256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(256, 256, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3)),
            BasicConv2d(256, 320, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0)),
            BasicConv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
        )
        self.branch_2 = nn.MaxPool2d((3, 3), (2, 2))

    def forward(self, x: Tensor) -> Tensor:
        branch_0 = self.branch_0(x)
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)

        out = torch.cat([branch_0, branch_1, branch_2], 1)

        return out

class InceptionC(nn.Module):
    def __init__(
            self,
            in_channels: int,
    ) -> None:
        super(InceptionC, self).__init__()
        self.branch_0 = BasicConv2d(in_channels, 256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))

        self.branch_1 = BasicConv2d(in_channels, 384, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        self.branch_1_1 = BasicConv2d(384, 256, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
        self.branch_1_2 = BasicConv2d(384, 256, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))

        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channels, 384, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            BasicConv2d(384, 448, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0)),
            BasicConv2d(448, 512, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
        )
        self.branch_2_1 = BasicConv2d(512, 256, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
        self.branch_2_2 = BasicConv2d(512, 256, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))

        self.branch_3 = nn.Sequential(
            nn.AvgPool2d((3, 3), (1, 1), (1, 1)),
            BasicConv2d(in_channels, 256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        )

    def forward(self, x: Tensor) -> Tensor:
        branch_0 = self.branch_0(x)
        branch_1 = self.branch_1(x)

        branch_1_1 = self.branch_1_1(branch_1)
        branch_1_2 = self.branch_1_2(branch_1)
        x1 = torch.cat([branch_1_1, branch_1_2], 1)

        branch_2 = self.branch_2(x)
        branch_2_1 = self.branch_2_1(branch_2)
        branch_2_2 = self.branch_2_2(branch_2)
        x2 = torch.cat([branch_2_1, branch_2_2], 1)

        x3 = self.branch_3(x)

        out = torch.cat([branch_0, x1, x2, x3], 1)

        return out


In [5]:
########################################## 학습 매커니즘 설정 #####################################
########################################## 학습 매커니즘 설정 #####################################
########################################## 학습 매커니즘 설정 #####################################

# get current lr
def get_lr(opt):
    for param_group in opt.param_groups:
        return param_group['lr']


# function to start training
def train_val(model, device, params):
    num_epochs=params['num_epochs']
    loss_func=params['loss_func']
    opt=params['optimizer']
    train_dl=params['train_dl']
    val_dl=params['val_dl']
    sanity_check=params['sanity_check']
    lr_scheduler=params['lr_scheduler']
    path2weights=params['path2weights']

    loss_history = {'train': [], 'val': []}
    metric_history = {'train': [], 'val': []}
    metric_cls_history = {'train': [], 'val': []}

    best_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    

    for epoch in range(num_epochs):
        start_time = time.time()
        current_lr = get_lr(opt)

        print(f"Epoch {epoch}/{num_epochs-1}")

        model.train()
        train_loss, train_metric,train_cls_metric = loss_epoch_multi_output(model, device, loss_func, train_dl, sanity_check, opt)
        
        loss_history['train'].append(train_loss)
        metric_history['train'].append(train_metric.item())
        
        metric_cls_history['train'].append(train_cls_metric)

        model.eval()
        with torch.no_grad():
            val_loss, val_metric,val_cls_metric = loss_epoch_multi_output(model, device, loss_func, val_dl, sanity_check)
        
        loss_history['val'].append(val_loss)
        metric_history['val'].append(val_metric.item())
        
        metric_cls_history['val'].append(val_cls_metric)

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())

        if isinstance(model, torch.nn.DataParallel):
        # model.module is the original model before DataParallel
            torch.save(model.module.state_dict(), path2weights + f'{epoch}_weight.pt')
        else:
            torch.save(model.state_dict(), path2weights + f'{epoch}_weight.pt')

        # torch.save(model.module.state_dict(), path2weights + f'{epoch}_weight.pt')

        lr_scheduler.step(val_loss)
        if current_lr != get_lr(opt):
            print('Loading best model weights!')
            model.load_state_dict(best_model_wts)
        
        print(f'train loss: {train_loss:.6f}, val loss: {val_loss:.6f}, accuracy: {val_metric:.2f},cls_acc : {val_cls_metric}, time: {(time.time()-start_time)/60:.4f} min')
        lossdf = pd.DataFrame(loss_history)
        accdf = pd.DataFrame(metric_history)
        acc_clsdf = pd.DataFrame(metric_cls_history)

        lossdf.to_csv(path2weights + 'result/loss.csv')
        accdf.to_csv(path2weights + 'result/acc.csv')
        acc_clsdf.to_csv(path2weights + 'result/cls_acc.csv')
        

    # model.load_state_dict(best_model_wts)
    return model, loss_history, metric_history, metric_cls_history

def metric_batch_multi_output(output, target, device):
    # output: [batch_size, num_classes], target: [batch_size, num_classes]
    
    pred = output.sigmoid() >= 0.5
    
    num_classes = target.shape[1]
    mla_ova = MultilabelAccuracy(num_labels=num_classes).to(device=device)
    mla = MultilabelAccuracy(num_labels=num_classes, average=None).to(device=device)
    
    class_accuracies = mla(pred, target)
    overall_accuracy = mla_ova(pred, target)
    
    return class_accuracies, overall_accuracy


def loss_batch_multi_output(loss_func, output, target, device, opt=None):
    # output: [batch_size, num_classes], target: [batch_size, num_classes]
    loss_b = loss_func(output, target)
    class_metric_b , metric_b = metric_batch_multi_output(output, target, device)

    if opt is not None:
        opt.zero_grad()
        loss_b.backward()
        opt.step()

    return loss_b.item(), metric_b, class_metric_b

def loss_epoch_multi_output(model, device, loss_func, dataset_dl, sanity_check=False, opt=None):
    running_loss = 0.0
    running_metric = 0.0
    running_class_metrics = torch.zeros(dataset_dl.dataset.num_classes).to(device)
    len_data = len(dataset_dl.dataset)
    num_classes = dataset_dl.dataset.num_classes
    b_count = 0
    with tqdm.tqdm(dataset_dl, unit="batch") as tepoch:
        for xb, yb in tepoch:
            b_count+=1
            xb = xb.to(device)
            yb = yb.to(device)
            output = model(xb)

            loss_b, metric_b, class_metric_b = loss_batch_multi_output(loss_func, output, yb, device, opt)

            running_loss += loss_b

            if metric_b is not None:
                running_metric += metric_b
            
            if class_metric_b is not None:
                running_class_metrics += class_metric_b

            if sanity_check is True:
                break

    loss = running_loss / b_count
    metric = running_metric / b_count # 수정된 부분
    class_metrics = {f'class_{i+1}': (running_class_metrics[i] / b_count).item() for i in range(num_classes)}
    return loss, metric, class_metrics

# check the directory to save weights.pt
def createFolder(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except os.OSerror:
        print('Error')
createFolder('./models')

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.reset_max_memory_allocated(device=None)
torch.cuda.empty_cache()
print(device)

cuda


In [7]:
############################# 여기는 전체 데이터셋에서 샘플만 추출하는 과정임 #################################

# data_df = pd.read_csv('dataset.csv')
# data_df.sum()

# categories = ['good', 'b_edge', 'burr', 'borken', 'b_bubble', 'etc', 'no_lens']

# resampled_dfs = []
# used_indices = set()  # 이미 사용된 인덱스를 추적합니다.

# for category in categories:
#     # 이미 선택된 샘플을 제외한 데이터프레임을 생성합니다.
#     available_data = data_df.drop(index=used_indices)
    
#     # 각 카테고리별로 데이터프레임을 필터링합니다.
#     category_df = available_data[available_data[category] == 1] # 카테고리별로 적절한 필터링 조건을 적용해야 합니다.

#     # 해당 카테고리에서 사용 가능한 샘플 수가 900개를 초과하는지 확인합니다.
#     if len(category_df) > 900:
#         category_df = category_df.sample(n=400, random_state=42) # 무작위 샘플 선택
#         used_indices.update(category_df.index)  # 선택된 인덱스를 사용된 인덱스 집합에 추가합니다.
#     else:
#         used_indices.update(category_df.index)  # 남은 모든 샘플 사용
        
#     resampled_dfs.append(category_df)

# # 모든 카테고리의 데이터프레임을 하나로 병합합니다.
# balanced_df = pd.concat(resampled_dfs, ignore_index=True)

# # 결과를 확인합니다.
# print(balanced_df.sum())

# balanced_df.to_csv('TH_dataset.csv', index=False)

In [7]:
balanced_df = pd.read_csv('TH_dataset.csv')
print('전체 이미지 수 : ',len(balanced_df))
print('####################### 라벨링 벨런스 ######################')
print(balanced_df.sum())


전체 이미지 수 :  2800
####################### 라벨링 벨런스 ######################
id          None_20230404100000_3_R_1.jpgNone_202304131430...
good                                                      400
b_edge                                                    483
burr                                                      411
borken                                                    431
b_bubble                                                  427
etc                                                       423
no_lens                                                   400
dtype: object


: 

In [8]:
model = TH_InceptionV4()

################## gpu사용처리 ######################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.reset_max_memory_allocated(device=None)
torch.cuda.empty_cache()
num_device = torch.cuda.device_count()
print(device)
device_idx = []
for i in range(num_device):
    if torch.cuda.get_device_name(i) == "NVIDIA DGX Display":
        print(f"Device is not using : {torch.cuda.get_device_name(i)}")
    else:
        device_idx.append(i)

if torch.cuda.device_count() > 1:
    print("Let's use",num_device, "GPUs!")
    if torch.cuda.device_count() > 4: #for GCT
        model=model.to('cuda:0')
        model = nn.DataParallel(model, device_ids=device_idx)
    else:
        model = model.to(device=device)
        model = nn.DataParallel(model)
else:
    model = model.to(device=device)

cuda


In [9]:
csv_path = 'TH_dataset.csv'
img_dir = './data/images2/'
train_ratio = 0.6
IMG_SIZE = 640
BATCH_SIZE = 6
EPOCH = 200
train_name = create_directory('TH_gnet4_')
loss_func = nn.MultiLabelSoftMarginLoss()
opt = optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=5)

In [10]:
train_df, val_df, test_df, NUM_CLS, cls_list = get_data_from_csv(csv_path=csv_path,img_dir=img_dir, train_ratio=train_ratio, randoms_state=42)

validating all images: 100%|██████████| 2800/2800 [00:00<00:00, 3014.00it/s]

num of train_df 1680
num of val_df 560
num of test_df 560
number of class:  6
['no_lens', 'etc', 'burr', 'borken', 'b_edge', 'b_bubble']





In [11]:
transformation = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize(IMG_SIZE)
])

train_set = CustomDataset(train_df,num_classes=NUM_CLS, image_dir=img_dir, class_list= cls_list ,img_resize=True, img_dsize=(IMG_SIZE,IMG_SIZE))
train_set.transforms = transformation

val_set = CustomDataset(val_df,num_classes=NUM_CLS, image_dir=img_dir, class_list= cls_list, img_resize=True, img_dsize=(IMG_SIZE,IMG_SIZE))
val_set.transforms = transformation

test_set = CustomDataset(test_df,num_classes=NUM_CLS, image_dir=img_dir, class_list= cls_list, img_resize=True, img_dsize=(IMG_SIZE,IMG_SIZE))
test_set.transforms = transformation

In [12]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [13]:
params_train = {
    'num_epochs':EPOCH,
    'optimizer':opt,
    'loss_func':loss_func,
    'train_dl':train_loader,
    'val_dl':val_loader,
    'sanity_check':False,
    'lr_scheduler':lr_scheduler,
    'path2weights':train_name,
}

In [14]:
summary(model, (3, IMG_SIZE, IMG_SIZE), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 319, 319]             864
       BatchNorm2d-2         [-1, 32, 319, 319]              64
              ReLU-3         [-1, 32, 319, 319]               0
       BasicConv2d-4         [-1, 32, 319, 319]               0
            Conv2d-5         [-1, 32, 317, 317]           9,216
       BatchNorm2d-6         [-1, 32, 317, 317]              64
              ReLU-7         [-1, 32, 317, 317]               0
       BasicConv2d-8         [-1, 32, 317, 317]               0
            Conv2d-9         [-1, 64, 317, 317]          18,432
      BatchNorm2d-10         [-1, 64, 317, 317]             128
             ReLU-11         [-1, 64, 317, 317]               0
      BasicConv2d-12         [-1, 64, 317, 317]               0
        MaxPool2d-13         [-1, 64, 158, 158]               0
           Conv2d-14         [-1, 96, 1

In [15]:
traind_model, loss_hist, metric_hist, metric_cls_hist = train_val(model, device, params_train)

Epoch 0/199


100%|██████████| 280/280 [35:57<00:00,  7.71s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.53batch/s]


train loss: 0.349484, val loss: 0.277743, accuracy: 0.90,cls_acc : {'class_1': 0.9911347031593323, 'class_2': 0.9361703395843506, 'class_3': 0.9503546953201294, 'class_4': 0.8439715504646301, 'class_5': 0.8351064920425415, 'class_6': 0.8421986103057861}, time: 37.0049 min
Epoch 1/199


100%|██████████| 280/280 [32:15<00:00,  6.91s/batch]
100%|██████████| 94/94 [02:09<00:00,  1.37s/batch]


train loss: 0.262005, val loss: 0.215143, accuracy: 0.91,cls_acc : {'class_1': 0.9982268810272217, 'class_2': 0.9450355172157288, 'class_3': 0.964539110660553, 'class_4': 0.8918437361717224, 'class_5': 0.838652491569519, 'class_6': 0.8386525511741638}, time: 34.4250 min
Epoch 2/199


100%|██████████| 280/280 [29:13<00:00,  6.26s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.54batch/s]


train loss: 0.237439, val loss: 0.198912, accuracy: 0.92,cls_acc : {'class_1': 0.9982268810272217, 'class_2': 0.9503547549247742, 'class_3': 0.9716312885284424, 'class_4': 0.9042550325393677, 'class_5': 0.8439717292785645, 'class_6': 0.8297871351242065}, time: 30.2575 min
Epoch 3/199


100%|██████████| 280/280 [32:13<00:00,  6.91s/batch]
100%|██████████| 94/94 [02:09<00:00,  1.37s/batch]


train loss: 0.217102, val loss: 0.200654, accuracy: 0.91,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9468086361885071, 'class_3': 0.9698582887649536, 'class_4': 0.8936168551445007, 'class_5': 0.8351064920425415, 'class_6': 0.8421986103057861}, time: 34.3896 min
Epoch 4/199


100%|██████████| 280/280 [32:13<00:00,  6.90s/batch]
100%|██████████| 94/94 [02:08<00:00,  1.37s/batch]


train loss: 0.206239, val loss: 0.197927, accuracy: 0.92,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9539008140563965, 'class_3': 0.9751773476600647, 'class_4': 0.9007090926170349, 'class_5': 0.8528369069099426, 'class_6': 0.8244681358337402}, time: 34.3769 min
Epoch 5/199


100%|██████████| 280/280 [28:47<00:00,  6.17s/batch]
100%|██████████| 94/94 [01:00<00:00,  1.55batch/s]


train loss: 0.203081, val loss: 0.197723, accuracy: 0.91,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9503546953201294, 'class_3': 0.9680852293968201, 'class_4': 0.8829786777496338, 'class_5': 0.8546099066734314, 'class_6': 0.8120567202568054}, time: 29.8197 min
Epoch 6/199


100%|██████████| 280/280 [32:12<00:00,  6.90s/batch]
100%|██████████| 94/94 [02:09<00:00,  1.37s/batch]


train loss: 0.190163, val loss: 0.200250, accuracy: 0.92,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9485816359519958, 'class_3': 0.9787234663963318, 'class_4': 0.923758864402771, 'class_5': 0.8351064920425415, 'class_6': 0.8421986103057861}, time: 34.3646 min
Epoch 7/199


100%|██████████| 280/280 [32:13<00:00,  6.91s/batch]
100%|██████████| 94/94 [02:09<00:00,  1.38s/batch]


train loss: 0.191329, val loss: 0.181754, accuracy: 0.92,cls_acc : {'class_1': 0.9982269406318665, 'class_2': 0.9468086361885071, 'class_3': 0.9751774072647095, 'class_4': 0.9131205677986145, 'class_5': 0.8351064920425415, 'class_6': 0.8421986103057861}, time: 34.3919 min
Epoch 8/199


100%|██████████| 280/280 [28:47<00:00,  6.17s/batch]
100%|██████████| 94/94 [01:00<00:00,  1.54batch/s]


train loss: 0.181093, val loss: 0.231046, accuracy: 0.91,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9414895176887512, 'class_3': 0.9734042882919312, 'class_4': 0.8705673813819885, 'class_5': 0.8351064920425415, 'class_6': 0.8421986103057861}, time: 29.8117 min
Epoch 9/199


100%|██████████| 280/280 [28:47<00:00,  6.17s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.54batch/s]


train loss: 0.178381, val loss: 0.192037, accuracy: 0.92,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9521278738975525, 'class_3': 0.9716312885284424, 'class_4': 0.9343969821929932, 'class_5': 0.8209220170974731, 'class_6': 0.8528369069099426}, time: 29.8113 min
Epoch 10/199


100%|██████████| 280/280 [28:47<00:00,  6.17s/batch]
100%|██████████| 94/94 [01:00<00:00,  1.54batch/s]


train loss: 0.171551, val loss: 0.301250, accuracy: 0.89,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.804964542388916, 'class_3': 0.9627659916877747, 'class_4': 0.9042552709579468, 'class_5': 0.8368796110153198, 'class_6': 0.8404256701469421}, time: 29.8198 min
Epoch 11/199


100%|██████████| 280/280 [28:46<00:00,  6.17s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.53batch/s]


train loss: 0.171522, val loss: 0.243418, accuracy: 0.91,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9095746278762817, 'class_3': 0.9574468731880188, 'class_4': 0.9078013300895691, 'class_5': 0.8351064920425415, 'class_6': 0.8421986103057861}, time: 29.8130 min
Epoch 12/199


100%|██████████| 280/280 [28:47<00:00,  6.17s/batch]
100%|██████████| 94/94 [01:00<00:00,  1.54batch/s]


train loss: 0.166558, val loss: 0.234783, accuracy: 0.91,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.8936170935630798, 'class_3': 0.9698581695556641, 'class_4': 0.9131203293800354, 'class_5': 0.8351064920425415, 'class_6': 0.8421986103057861}, time: 29.8109 min
Epoch 13/199


100%|██████████| 280/280 [28:47<00:00,  6.17s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.54batch/s]


Loading best model weights!
train loss: 0.167054, val loss: 0.220676, accuracy: 0.92,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9539007544517517, 'class_3': 0.966312050819397, 'class_4': 0.9237588047981262, 'class_5': 0.831560492515564, 'class_6': 0.8351064920425415}, time: 29.8169 min
Epoch 14/199


100%|██████████| 280/280 [28:47<00:00,  6.17s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.53batch/s]


train loss: 0.153743, val loss: 0.155126, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9592200517654419, 'class_3': 0.98758864402771, 'class_4': 0.9326242208480835, 'class_5': 0.8546099066734314, 'class_6': 0.8191490769386292}, time: 29.8208 min
Epoch 15/199


100%|██████████| 280/280 [32:11<00:00,  6.90s/batch]
100%|██████████| 94/94 [02:09<00:00,  1.37s/batch]


train loss: 0.134922, val loss: 0.152303, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9627661108970642, 'class_3': 0.98758864402771, 'class_4': 0.9361702799797058, 'class_5': 0.8439717292785645, 'class_6': 0.8297873735427856}, time: 34.3509 min
Epoch 16/199


100%|██████████| 280/280 [30:40<00:00,  6.57s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.53batch/s]


train loss: 0.121921, val loss: 0.158071, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9609931707382202, 'class_3': 0.98758864402771, 'class_4': 0.9379432797431946, 'class_5': 0.8421986699104309, 'class_6': 0.831560492515564}, time: 31.7077 min
Epoch 17/199


100%|██████████| 280/280 [30:41<00:00,  6.58s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.54batch/s]


train loss: 0.111626, val loss: 0.159317, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9680852293968201, 'class_3': 0.9840425848960876, 'class_4': 0.9397163987159729, 'class_5': 0.8404256105422974, 'class_6': 0.836879551410675}, time: 31.7090 min
Epoch 18/199


100%|██████████| 280/280 [30:40<00:00,  6.57s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.54batch/s]


train loss: 0.106341, val loss: 0.165417, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9609930515289307, 'class_3': 0.9840425848960876, 'class_4': 0.9361701011657715, 'class_5': 0.8404256105422974, 'class_6': 0.836879551410675}, time: 31.6998 min
Epoch 19/199


100%|██████████| 280/280 [30:41<00:00,  6.58s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.54batch/s]


train loss: 0.102404, val loss: 0.176862, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9698582887649536, 'class_3': 0.9804965853691101, 'class_4': 0.9414891600608826, 'class_5': 0.8510637879371643, 'class_6': 0.8191490173339844}, time: 31.7092 min
Epoch 20/199


100%|██████████| 280/280 [30:41<00:00,  6.58s/batch]
100%|██████████| 94/94 [01:00<00:00,  1.54batch/s]


train loss: 0.097780, val loss: 0.184322, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9592200517654419, 'class_3': 0.9822695851325989, 'class_4': 0.9432623386383057, 'class_5': 0.8351064324378967, 'class_6': 0.8386525511741638}, time: 31.7181 min
Epoch 21/199


100%|██████████| 280/280 [30:42<00:00,  6.58s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.53batch/s]


Loading best model weights!
train loss: 0.093715, val loss: 0.176511, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9627661108970642, 'class_3': 0.9822695851325989, 'class_4': 0.9414893388748169, 'class_5': 0.8421985507011414, 'class_6': 0.8280143737792969}, time: 31.7393 min
Epoch 22/199


100%|██████████| 280/280 [30:43<00:00,  6.58s/batch]
100%|██████████| 94/94 [01:01<00:00,  1.54batch/s]


train loss: 0.119901, val loss: 0.150029, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9592199921607971, 'class_3': 0.98758864402771, 'class_4': 0.9414893388748169, 'class_5': 0.8297869563102722, 'class_6': 0.836879312992096}, time: 31.7425 min
Epoch 23/199


100%|██████████| 280/280 [32:13<00:00,  6.90s/batch]
100%|██████████| 94/94 [02:09<00:00,  1.38s/batch]


train loss: 0.115937, val loss: 0.150345, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9556737542152405, 'class_3': 0.98758864402771, 'class_4': 0.9432623386383057, 'class_5': 0.8333331942558289, 'class_6': 0.8333333134651184}, time: 34.3811 min
Epoch 24/199


100%|██████████| 280/280 [32:13<00:00,  6.91s/batch]
100%|██████████| 94/94 [02:09<00:00,  1.38s/batch]


train loss: 0.113423, val loss: 0.150665, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9592198729515076, 'class_3': 0.98758864402771, 'class_4': 0.9450353980064392, 'class_5': 0.8333331942558289, 'class_6': 0.8333332538604736}, time: 34.3877 min
Epoch 25/199


100%|██████████| 280/280 [32:14<00:00,  6.91s/batch]
100%|██████████| 94/94 [02:09<00:00,  1.38s/batch]


train loss: 0.111042, val loss: 0.151087, accuracy: 0.93,cls_acc : {'class_1': 0.9999999403953552, 'class_2': 0.9592198729515076, 'class_3': 0.98758864402771, 'class_4': 0.9450353980064392, 'class_5': 0.8351061940193176, 'class_6': 0.8315602540969849}, time: 34.4038 min
Epoch 26/199


 45%|████▌     | 127/280 [14:40<17:40,  6.93s/batch]


KeyboardInterrupt: 