# 라이브러리

In [1]:
#*---------- basic ------------*
import os, cv2, random, time
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import math
import copy                        # 가중치 복사

#*--------- torch ------------*
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

#*--------- torch vision -------*
import torchvision
import torchsummary as summary
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

#*--------- Learning_rate Scheduler ---------*
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, _LRScheduler

#*-------- warnings ------------*
import warnings
warnings.filterwarnings(action='ignore')

# Custom 라이브러리 

In [2]:
#*------------- for Unet ++ -----------*
from collections import OrderedDict
from torch.optim import lr_scheduler
from model import Unet_block, UNet, Nested_UNet
from init_weights import init_weights
from utils import BCEDiceLoss, AverageMeter, count_params, iou_score

# GPU 설정, 모델 관려 configuration 설정, SEED 고정

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

CFG = {
    'IMG_SIZE' : 224,
    'BATCH_SIZE' : 8,
    'SEED' : 1010,
    'NUM_EPOCHS' : 100,
    'NUM_CLASS' : 4           #background, No info, Fault, Hole
}

#*---------- Set up random seed ---------*
def set_seed(seed):
    np.random.seed(seed)
    os.environ['PYTHONSHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
set_seed(CFG['SEED']) # randomSeed 고정

# get RGB statistics for normalization for Custom Dataset

In [None]:
def get_RGB_statis(imgs_paths):  
    
    images = []
    for img_path in tqdm(imgs_paths):
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.
        images.append(img)
    images = np.array(images)
    
    R_mean = [image[0].mean() for image in images]
    G_mean = [image[1].mean() for image in images]
    B_mean = [image[2].mean() for image in images]
    
    R_std = [image[0].std() for image in images]
    G_std = [image[1].std() for image in images]
    B_std = [image[2].std() for image in images]
    
    R_mean = np.array(R_mean); G_mean = np.array(G_mean); B_mean = np.array(B_mean)
    R_std = np.array(R_std); G_std = np.array(G_std); B_std = np.array(B_std)

    return [R_mean.mean(), G_mean.mean(), B_mean.mean()], [R_std.mean(), G_std.mean(), B_std.mean()]

In [None]:
WorkSpace_path = r'C:\Users\SDML\1_Continue\4. SmartFactoryCapstoneDesign\Code\unet++\[Final]'
os.chdir(WorkSpace_path)
os.getcwd()

In [None]:
true_path = './true_v3/images'
false_path = './false_v7/images'

true_paths = glob(os.path.join(true_path, '*.PNG'))
false_paths = glob(os.path.join(false_path, '*.PNG'))

#concat
true_paths = np.array(true_paths); false_paths = np.array(false_paths)

imgs_paths = np.concatenate([true_paths, false_paths], axis = 0)
print(true_paths.shape, false_paths.shape)
print(imgs_paths.shape)

RGB_mean, RGB_std = get_RGB_statis(imgs_paths)
RGB_mean, RGB_std

# set Transform for Custom Dataset

In [None]:
# Rigid Transformation
train_image_Rigid_transform = A.Compose([
                            A.Resize(CFG['IMG_SIZE'],CFG['IMG_SIZE']),
#                           A.OneOf([
#                             A.HorizontalFlip(p=1.0),
#                             A.VerticalFlip(p=1.0),             
#                           ], p=1.0),
#                           A.OneOf([
#                               A.RandomRotate90(p=1),
#                              A.Rotate(limit=60, p=1)], p=1.0)
])

# Intensify Transformation
train_image_intensity_transform = A.Compose([
                            A.GaussNoise(p=0.5),
                            A.OneOf([
                                A.RandomBrightnessContrast(p=1.0),
                                A.RandomContrast(p=1.0),
                                A.RandomGamma(p=1.0),
                                A.JpegCompression(p=1.0),
                                A.CLAHE(p=1.0)
                            ], p=1.0)                                
                            ], p = 0.5)
# Normalization
#train_image_Normalization_trainsform = A.Compose([
#                            A.Normalize(mean= (RGB_mean[0], RGB_mean[1], RGB_mean[2]),
#                                        std= (RGB_std[0], RGB_std[1], RGB_std[2]), max_pixel_value=255.0, always_apply=False, p=1.0),
#                            ToTensorV2()
#])

# Normalization
train_image_Normalization_trainsform = A.Compose([
                            A.Normalize(mean= (RGB_mean[0]),
                                        std= (RGB_std[0]), max_pixel_value=255.0, always_apply=False, p=1.0),
                            ToTensorV2()])

#Test
#test_transform = A.Compose([
#                            A.Resize(CFG['IMG_SIZE'],CFG['IMG_SIZE']),
#                            A.Normalize(mean= (RGB_mean[0], RGB_mean[1], RGB_mean[2]),
#                                        std= (RGB_std[0], RGB_std[1], RGB_std[2]), max_pixel_value=255.0, always_apply=False, p=1.0),
#                            ToTensorV2()
#                            ])
test_transform = A.Compose([
                            A.Resize(CFG['IMG_SIZE'],CFG['IMG_SIZE']),
                            A.Normalize(mean= (RGB_mean[0]),
                                        std= (RGB_std[0]), max_pixel_value=255.0, always_apply=False, p=1.0),
                            ToTensorV2()])

# Custom Dataset

In [None]:
import natsort
#true -> images, mask_background, mask_Noinfo, mask_fault, mask_hole
#false -> images, mask_background, mask_Noinfo, mask_fault, mask_hole
#Split ratio -> 0.1 perspectively
#have to concat

class CustomDataset(Dataset):
    def __init__(self, true_folder_path, false_folder_path,
                 train_image_Rigid_transform=None,
                 train_image_intensity_transform = None, 
                 train_image_Normalization_trainsform = None,
                 test_transform = None,
                 train = bool):
        
        #image
        self.true_folder_path  = true_folder_path       # 정상 이미지 : 128
        self.false_folder_path = false_folder_path     # 비정상 이미지 : 91
        
        #mask-true
        self.true_mask_background_folder_path  = os.path.join(true_folder_path + '/masks', 'mask_background')
        self.true_mask_Noinfo_folder_path      = os.path.join(true_folder_path + '/masks', 'mask_NoInfo')
        self.true_mask_fault_folder_path       = os.path.join(true_folder_path + '/masks', 'mask_fault')
        self.true_mask_hole_folder_path        = os.path.join(true_folder_path + '/masks', 'mask_hole')
        
        #mask-false
        self.false_mask_background_folder_path  = os.path.join(false_folder_path + '/masks', 'mask_background')
        self.false_mask_Noinfo_folder_path      = os.path.join(false_folder_path + '/masks', 'mask_NoInfo')
        self.false_mask_fault_folder_path       = os.path.join(false_folder_path + '/masks', 'mask_fault')
        self.false_mask_hole_folder_path        = os.path.join(false_folder_path + '/masks', 'mask_hole')

        #transform
        self.train_image_Rigid_transform          = train_image_Rigid_transform
        self.train_image_intensity_transform      = train_image_intensity_transform
        self.train_image_Normalization_trainsform = train_image_Normalization_trainsform
        self.test_transform = test_transform
        
        #train mode
        self.train = train
        
        #ext
        self.extension = '*.PNG'
        
        #true set
        self.true_image_paths           = natsort.natsorted(glob(os.path.join(self.true_folder_path + '/images',  self.extension)))
        self.true_mask_background_paths = natsort.natsorted(glob(os.path.join(self.true_mask_background_folder_path, self.extension)))
        self.true_mask_Noinfo_paths     = natsort.natsorted(glob(os.path.join(self.true_mask_Noinfo_folder_path, self.extension)))
        self.true_mask_fault_paths      = natsort.natsorted(glob(os.path.join(self.true_mask_fault_folder_path, self.extension)))
        self.true_mask_hole_paths       = natsort.natsorted(glob(os.path.join(self.true_mask_hole_folder_path, self.extension)))  

        #false set
        self.false_image_paths           = natsort.natsorted(glob(os.path.join(self.false_folder_path + '/images', self.extension)))
        self.false_mask_background_paths = natsort.natsorted(glob(os.path.join(self.false_mask_background_folder_path, self.extension)))
        self.false_mask_Noinfo_paths     = natsort.natsorted(glob(os.path.join(self.false_mask_Noinfo_folder_path, self.extension)))
        self.false_mask_fault_paths      = natsort.natsorted(glob(os.path.join(self.false_mask_fault_folder_path, self.extension)))
        self.false_mask_hole_paths       = natsort.natsorted(glob(os.path.join(self.false_mask_hole_folder_path, self.extension)))
        
        #organized
        self.true_set = list(zip(self.true_image_paths, self.true_mask_background_paths,
                                 self.true_mask_Noinfo_paths, self.true_mask_fault_paths, self.true_mask_hole_paths))
        self.false_set = list(zip(self.false_image_paths, self.false_mask_background_paths,
                                 self.false_mask_Noinfo_paths, self.false_mask_fault_paths, self.false_mask_hole_paths))
        
        #split
        self.true_indices = np.random.permutation(len(self.true_set))   #shuffle
        self.false_indices = np.random.permutation(len(self.false_set)) #shuffle
        
        self.true_total_length = len(self.true_set)
        self.false_total_length = len(self.false_set)
        
        self.train_ratio = 0.9
        
        self.true_train_length = int(self.true_total_length*self.train_ratio)
        self.false_train_length = int(self.false_total_length*self.train_ratio)     
        
        if train:
            self.true_train_indices = self.true_indices[:self.true_train_length]
            self.false_train_indices = self.false_indices[:self.false_train_length]
            self.true_result_set = np.array(self.true_set)[self.true_train_indices]
            self.false_result_set = np.array(self.false_set)[self.false_train_indices]
            self.result_set = np.concatenate([self.true_result_set, self.false_result_set], axis = 0)
        else:
            self.true_valid_indices = self.true_indices[self.true_train_length:]
            self.false_valid_indices = self.false_indices[self.false_train_length:]
            
            self.true_result_set = np.array(self.true_set)[self.true_valid_indices]
            self.false_result_set = np.array(self.false_set)[self.false_valid_indices]
            self.result_set = np.concatenate([self.true_result_set, self.false_result_set], axis = 0)
        
    def __getitem__(self, idx):        
        image_path, mask_background_path, mask_Noinfo_path, mask_fault_path, mask_hole_path = self.result_set[idx]
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask_background = self.get_mask(mask_background_path, CFG['IMG_SIZE'], CFG['IMG_SIZE'])
        mask_Noinfo     = self.get_mask(mask_Noinfo_path, CFG['IMG_SIZE'], CFG['IMG_SIZE'])
        mask_fault      = self.get_mask(mask_fault_path, CFG['IMG_SIZE'], CFG['IMG_SIZE'])
        mask_hole       = self.get_mask(mask_hole_path, CFG['IMG_SIZE'], CFG['IMG_SIZE'])
        
        mask_background = np.expand_dims(mask_background, axis=-1)
        mask_Noinfo     = np.expand_dims(mask_Noinfo, axis=-1)
        mask_fault      = np.expand_dims(mask_fault, axis=-1)
        mask_hole       = np.expand_dims(mask_hole, axis=-1)
        
        if self.train:
            #Rigid Tranformation : images, masks
            if self.train_image_Rigid_transform is not None:
                Rigid_transform                      = self.train_image_Rigid_transform(image = image,
                                                                                        mask_background = mask_background,
                                                                                        mask_Noinfo = mask_Noinfo,
                                                                                        mask_fault = mask_fault,
                                                                                        mask_hole = mask_hole)

                image           = Rigid_transform['image']
                mask_background = Rigid_transform['mask_background']
                mask_Noinfo     = Rigid_transform['mask_Noinfo']
                mask_fault      = Rigid_transform['mask_fault']
                mask_hole       = Rigid_transform['mask_hole']
            
            #intensity Tranformation : images
            if self.train_image_intensity_transform is not None:
                train_image_intensity_transform      = self.train_image_intensity_transform(image = image)
                image = train_image_intensity_transform['image']
            
            #Normalization Tranformation : images
            train_image_Normalization_trainsform = self.train_image_Normalization_trainsform(image = image)
            image = train_image_Normalization_trainsform['image']
                            
            mask_background = torch.Tensor(mask_background)
            mask_Noinfo     = torch.Tensor(mask_Noinfo)
            mask_fault      = torch.Tensor(mask_fault)
            mask_hole       = torch.Tensor(mask_hole)

            mask_background = torch.permute(mask_background, (2,0,1))
            mask_Noinfo = torch.permute(mask_Noinfo, (2,0,1))
            mask_fault = torch.permute(mask_fault, (2,0,1))
            mask_hole = torch.permute(mask_hole, (2,0,1))
            
            result_mask = np.concatenate([mask_background, mask_Noinfo, mask_fault, mask_hole], axis = 0)
            
            return image, result_mask
        else:
            test_transform = self.test_transform(image = image)
            image = test_transform['image']
            
            mask_background = torch.Tensor(mask_background)
            mask_Noinfo     = torch.Tensor(mask_Noinfo)
            mask_fault      = torch.Tensor(mask_fault)
            mask_hole       = torch.Tensor(mask_hole)
            
            mask_background = torch.permute(mask_background, (2,0,1))
            mask_Noinfo = torch.permute(mask_Noinfo, (2,0,1))
            mask_fault = torch.permute(mask_fault, (2,0,1))
            mask_hole = torch.permute(mask_hole, (2,0,1))
            result_mask = np.concatenate([mask_background, mask_Noinfo, mask_fault, mask_hole], axis = 0)
            return image, result_mask
        
    def __len__(self):
        return len(self.result_set)

    
    def get_mask(self, mask_path, IMG_HEIGHT, IMG_WIDTH):
        mask = np.zeros((IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)   
        mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask_ = cv2.resize(mask_, (IMG_HEIGHT, IMG_WIDTH))
        mask += mask_  
        # 들어오는 마스크의 픽셀은 0과 1로 이루어짐. sigmoid를 위한 처리는 필요없음.
        # mask =  mask/ 255.
        return mask

# create DataLoader base-on CustomDataset

In [None]:
os.chdir(r'C:\Users\SDML\1_Continue\4. SmartFactoryCapstoneDesign\Code\unet++\[Final]')
os.getcwd()

In [None]:
true_folder_path = './true_v3'
false_folder_path = './false_v7'

trainDataset = CustomDataset(true_folder_path, false_folder_path,
                             train_image_Rigid_transform, train_image_intensity_transform, train_image_Normalization_trainsform,
                             train = True)

validDataset = CustomDataset(true_folder_path, false_folder_path,
                             train_image_Rigid_transform, train_image_intensity_transform, train_image_Normalization_trainsform,
                             test_transform, train = False)

Train_Dataloader = DataLoader(trainDataset, batch_size = CFG['BATCH_SIZE'], shuffle = True)
Valid_Dataloader = DataLoader(validDataset, batch_size = CFG['BATCH_SIZE'], shuffle = False)

len(trainDataset), len(validDataset)

In [None]:
# TrainDataset
# 확인
iter_mode = iter(Train_Dataloader)
images, masks = next(iter_mode)
images.size(), masks.size()

In [None]:
# 이미지 확인
idx = 2
image = images[idx].permute(1,2,0).numpy()
image = (image - np.min(image))/ (np.max(image) - np.min(image))
mask = masks[idx].permute(1,2,0).numpy()

mask_background = mask[:, :, 0]
mask_NoInfo     = mask[:, :, 1]
mask_fault      = mask[:, :, 2]
mask_hole       = mask[:, :, 3]

figs, axes = plt.subplots(1,5, figsize = (15,15))

axes[0].imshow(image, cmap = 'gray')
axes[0].set_title('image')

axes[1].imshow(mask_background, cmap = 'gray')
axes[1].set_title('background')

axes[2].imshow(mask_NoInfo, cmap = 'gray')
axes[2].set_title('NO info')

axes[3].imshow(mask_fault, cmap = 'gray')
axes[3].set_title('Fault')

axes[4].imshow(mask_hole, cmap = 'gray')
axes[4].set_title('hole')

In [None]:
# ValidDataset
# 확인
iter_mode = iter(Valid_Dataloader)
images, masks = next(iter_mode)
images.size(), masks.size()

In [None]:
# 이미지 확인
idx = 3
image = images[idx].permute(1,2,0).numpy()
image = (image - np.min(image))/ (np.max(image) - np.min(image))
mask = masks[idx].permute(1,2,0).numpy()

mask_background = mask[:, :, 0]
mask_NoInfo     = mask[:, :, 1]
mask_fault      = mask[:, :, 2]
mask_hole       = mask[:, :, 3]

figs, axes = plt.subplots(1,5, figsize = (15,15))

axes[0].imshow(image, cmap = 'gray')
axes[0].set_title('image')

axes[1].imshow(mask_background, cmap = 'gray')
axes[1].set_title('background')

axes[2].imshow(mask_NoInfo, cmap = 'gray')
axes[2].set_title('NO info')

axes[3].imshow(mask_fault, cmap = 'gray')
axes[3].set_title('Fault')

axes[4].imshow(mask_hole, cmap = 'gray')
axes[4].set_title('hole')

# CosineAnnealingWarmUpRestart

In [None]:
class CosineAnnealingWarmUpRestarts(_LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        if T_up < 0 or not isinstance(T_up, int):
            raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
        self.T_0 = T_0
        self.T_mult = T_mult
        self.base_eta_max = eta_max
        self.eta_max = eta_max
        self.T_up = T_up
        self.T_i = T_0
        self.gamma = gamma
        self.cycle = 0
        self.T_cur = last_epoch
        super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.T_cur == -1:
            return self.base_lrs
        elif self.T_cur < self.T_up:
            return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.cycle += 1
                self.T_cur = self.T_cur - self.T_i
                self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
        else:
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                    self.cycle = epoch // self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.cycle = n
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
                
        self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

# model + Grid-Search

In [None]:
#set up
def train(train_loader, model, criterion, optimizer):
    avg_meters = {'loss':AverageMeter(),
                  'iou' :AverageMeter()}

    model.train()
    # pbar = tqdm(total=len(train_loader))

    for inputs, labels in train_loader:
        
        inputs = torch.tensor(inputs, device=device, dtype=torch.float32)
        labels = torch.tensor(labels, device=device, dtype=torch.float32)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # log
        iou = iou_score(outputs, labels, threshold=0.8)
        avg_meters['loss'].update(loss.item(), n=inputs.size(0))
        avg_meters['iou'].update(iou, n=inputs.size(0))

        log = OrderedDict([
                    ('loss', avg_meters['loss'].avg),
                    ('iou', avg_meters['iou'].avg),
                ])
    return log, model

In [None]:
#set-up
def validation(val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter()}
    
    model.eval()
    with torch.no_grad():
#        pbar = tqdm(total=len(val_loader))
        for inputs, labels in val_loader:
            inputs = torch.tensor(inputs, device=device, dtype=torch.float32)
            labels = torch.tensor(labels, device=device, dtype=torch.float32)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            iou = iou_score(outputs, labels, threshold=0.8)

            avg_meters['loss'].update(loss.item(), n=inputs.size(0))
            avg_meters['iou'].update(iou, n=inputs.size(0))

            log = OrderedDict([
                        ('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                    ])
    return log

# Train - No Deep SuperVision

In [None]:
# train
# learningRates = [0.1, 0.01, 0.001, 0.0001, 0.00001]
learningRates = [0.1]
epochs=100
best_iou = 0
Early_stopping_cnt = 0
earlyStopping = []

for lr in learningRates:
    model = Nested_UNet(num_classes = 4, input_channels = 1, deep_supervision=False).to(device)
    criterion = BCEDiceLoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    lr_scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=100, T_mult=2, eta_max=1e-5,  T_up=10, gamma=0.5)
    
    for epoch in range(1, epochs+1):
        train_log, model = train(Train_Dataloader, model, criterion, optimizer)
        val_log =  validation(Valid_Dataloader, model, criterion)
        print(f'{epoch}Epoch')
        print(f'train loss:{train_log["loss"]:.3f} |train iou:{train_log["iou"]:.3f}')
        print(f'val loss:{val_log["loss"]:.3f} |val iou:{val_log["iou"]:.3f}\n')
        valid_iou = val_log['iou']
        lr_scheduler.step()
        if best_iou < valid_iou:
            best_iou = valid_iou
            best_epoch = epoch
            best_weights = copy.deepcopy(model.state_dict())
        print(f'Best iou : {best_iou:.3f}')
        print('-'*30)
        
        ######  
        #EarlyStopping
        earlyStopping.append(valid_iou)
        
        # buffer
        if len(earlyStopping) == 20:
            del earlyStopping[0]
        
        if valid_iou < np.array(earlyStopping).mean():
            cnt = cnt + 1
            if cnt == 20:
                break
        else:
            cnt = 0

In [None]:
Best_model = Nested_UNet(num_classes = 4, input_channels = 3, deep_supervision=False).to(device)
Best_model.load_state_dict(best_weights)
val_log =  validation(Valid_Dataloader, Best_model, criterion)

print(f'val loss:{val_log["loss"]:.3f} |val iou:{val_log["iou"]:.3f}\n')

In [None]:
path = r'D:\3_프로젝트\4_SmartFactoryCapstoneDesign\SmartFactory\weights'
filename = 'No_DeepSupervision_weights_gray.pt'
save_path = os.path.join(path, filename)
torch.save(best_weights, save_path)

# Train Result Visualization

In [None]:
path = r'D:\3_프로젝트\4_SmartFactoryCapstoneDesign\SmartFactory\weights'
filename = 'No_DeepSupervision_weights_gray.pt'
save_path = os.path.join(path, filename)
check_point = torch.load(save_path)

Best_model = Nested_UNet(num_classes = 4, input_channels = 1, deep_supervision=False).to(device)
Best_model.load_state_dict(check_point)

In [None]:
false_folder_path = './false_v7'

validDataset = CustomDataset(true_folder_path, false_folder_path,
                             train_image_Rigid_transform, train_image_intensity_transform, train_image_Normalization_trainsform,
                             test_transform, train = False)

Valid_Dataloader = DataLoader(trainDataset, batch_size=1, shuffle=False)
val_iter = iter(Valid_Dataloader)

In [None]:
image.size()

In [None]:
for idx, (image, masks) in enumerate(Valid_Dataloader):    
    image = image.to(device);
    
    #mask prediction
    Best_model.eval()
    seg_image = Best_model(image)
    seg_image = torch.sigmoid(seg_image)

    seg_image[seg_image<0.5]=0
    seg_image[seg_image>=0.5]=1
    seg_image = seg_image.squeeze(dim = 0)

    prediction_mask_background = seg_image.permute(1, 2, 0)[:, :, 0].detach().cpu().numpy()
    prediction_mask_NoInfo     = seg_image.permute(1, 2, 0)[:, :, 1].detach().cpu().numpy()
    prediction_mask_fault      = seg_image.permute(1, 2, 0)[:, :, 2].detach().cpu().numpy()
    prediction_mask_hole       = seg_image.permute(1, 2, 0)[:, :, 3].detach().cpu().numpy()

    #origin image
    image =image.detach().cpu().squeeze()
    image = image.numpy()
    image = (image - np.min(image))/(np.max(image) - np.min(image))
    #GT
    masks = masks.squeeze()
    masks = masks.permute(1,2,0).numpy()

    mask_background = masks[:, :, 0]
    mask_NoInfo     = masks[:, :, 1]
    mask_fault      = masks[:, :, 2]
    mask_hole       = masks[:, :, 3]

    figs, axes = plt.subplots(2,5, figsize = (15,15))

    axes[0, 0].imshow(image)
    axes[0, 0].set_title('image')

    axes[0, 1].imshow(mask_background, cmap = 'gray')
    axes[0, 1].set_title('GT - background')

    axes[0, 2].imshow(mask_NoInfo, cmap = 'gray')
    axes[0, 2].set_title('GT - NO info')

    axes[0, 3].imshow(mask_fault, cmap = 'gray')
    axes[0, 3].set_title('GT - Fault')

    axes[0, 4].imshow(mask_hole, cmap = 'gray')
    axes[0, 4].set_title('GT - hole')

    axes[1, 0].imshow(image)
    axes[1, 0].set_title('image')

    axes[1, 1].imshow(prediction_mask_background, cmap = 'gray')
    axes[1, 1].set_title('Prediction - background')

    axes[1, 2].imshow(prediction_mask_NoInfo, cmap = 'gray')
    axes[1, 2].set_title('Prediction - NO info')

    axes[1, 3].imshow(prediction_mask_fault, cmap = 'gray')
    axes[1, 3].set_title('Prediction - Fault')

    axes[1, 4].imshow(prediction_mask_hole, cmap = 'gray')
    axes[1, 4].set_title('Prediction - hole')
    figs.savefig(r'C:\Users\SDML\Desktop\SmartFactory\gray\%s.jpg'%(str(idx)))

# Deep Supervision

In [None]:
# train
# learningRates = [0.1, 0.01, 0.001, 0.0001, 0.00001]
learningRates = [0.1, 0.00001]
epochs=100
best_iou = 0
Early_stopping_cnt = 0
earlyStopping = []

for lr in learningRates:
    model = Nested_UNet(num_classes = 4, input_channels = 3, deep_supervision=True).to(device)
    criterion = BCEDiceLoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    lr_scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=100, T_mult=2, eta_max=1e-5,  T_up=10, gamma=0.5)
    
    for epoch in range(1, epochs+1):
        train_log, model = train(Train_Dataloader, model, criterion, optimizer)
        val_log =  validation(Valid_Dataloader, model, criterion)
        print(f'{epoch}Epoch')
        print(f'train loss:{train_log["loss"]:.3f} |train iou:{train_log["iou"]:.3f}')
        print(f'val loss:{val_log["loss"]:.3f} |val iou:{val_log["iou"]:.3f}\n')
        valid_iou = val_log['iou']
        lr_scheduler.step()
        if best_iou < valid_iou:
            best_iou = valid_iou
            best_epoch = epoch
            best_weights = copy.deepcopy(model.state_dict())
        print(f'Best iou : {best_iou:.3f}')
        print('-'*30)
        
        ######  
        #EarlyStopping
        earlyStopping.append(valid_iou)
        
        # buffer
        if len(earlyStopping) == 20:
            del earlyStopping[0]
        
        if valid_iou < np.array(earlyStopping).mean():
            cnt = cnt + 1
            if cnt == 20:
                break
        else:
            cnt = 0