In [None]:
!pip install wandb
import wandb

In [None]:
import os
import random
import time
import json
import warnings 
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from utils import label_accuracy_score, add_hist
import cv2

import numpy as np
import pandas as pd
from tqdm import tqdm

from pycocotools.coco import COCO
import torchvision
import torchvision.transforms as transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

print('pytorch version: {}'.format(torch.__version__))
print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))

print(torch.cuda.get_device_name(0))
print(torch.cuda.device_count())

# GPU 사용 가능 여부에 따라 device 정보 저장
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from easydict import EasyDict as eDict

def getArg():
	arg = eDict()

	arg.batch = 32
	arg.epoch = 20
	arg.lr = 1e-4
	arg.seed = 21
	arg.save_capacity = 5
	
	arg.train_image_root = "../input/train2014"
	arg.train_mask_root = "../input/train_mask"
	arg.val_image_root = "../input/val2014"
	arg.val_mask_root = "../input/val_mask"
	arg.output_path = "../output"

	arg.train_worker = 8
	arg.valid_worker = 4
	arg.test_worker = 4

	arg.wandb = True
	arg.wandb_project = "alchera"
	arg.wandb_entity = "cv4"

	arg.custom_name = "sample"
	
	arg.TTA = True
	arg.test_batch = 1
	arg.csv_size = 256

	return arg

In [None]:
# seed 고정
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

# Dataset

In [None]:
from torch.utils.data import Dataset
import cv2
import os


category_names = [
    'Background', 'Body', 'RightHand', 'LeftHand', 'LeftFeet', 'RightFeet', 
    'RightThigh', 'LeftThigh', 'RightCalf', 'LeftCalf', 'LeftArm', 'RightArm', 
    'LeftForeArm', 'RightForeArm','Head'
    ]

def get_classname(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return "None"

class CustomDataset(Dataset):
    """COCO format"""
    image_names = []
    num_classes = 15
    def __init__(self, image_root, mask_root=None, mode='train', transform=None):
        super().__init__()
        self.mode = mode
        self.image_root = image_root # original images
        self.mask_root = mask_root
        self.transform = transform
        self.setup()

    def setup(self):
        """
        saves path of each images
        """
        self.image_names = os.listdir(self.image_root)

    def __getitem__(self, index: int):
        
        images = cv2.imread(os.path.join(self.image_root, self.image_names[index]))
        images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) #FIXME albu issue

        if (self.mode in ('train', 'val')):
            # imagename에서 확장자 떼고 .png 붙이기
            file_name = os.path.splitext(self.image_names[index])[0]
            if self.mode == 'train':
                file_name += ".png"
            if self.mode == 'val':
                file_name += ".grayscale.png"
            masks = cv2.imread(os.path.join(self.mask_root, file_name), cv2.IMREAD_GRAYSCALE)

            # transform -> albumentations 라이브러리 활용
            if self.transform is not None:
                transformed = self.transform(image=images, mask=masks)
                images = transformed["image"]
                masks = transformed["mask"]
            return images, masks
        
        if self.mode == 'test':
            # transform -> albumentations 라이브러리 활용
            if self.transform is not None:
                transformed = self.transform(image=images)
                images = transformed["image"]
            image_name = os.path.splitext(self.image_names[index])[0]
            return images, image_name
    
    def __len__(self) -> int:
        # 전체 dataset의 size를 return
        return len(self.image_names)


# Augmentation

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import ttach as tta
# https://albumentations-demo.herokuapp.com/

def getTransform():

  train_transform = A.Compose([
                              A.Resize(512, 512, p=1.0),
                              A.Flip(),
                              A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                              ToTensorV2(),
                              ])

  val_transform = A.Compose([
                              A.Resize(512, 512, p=1.0),
                              A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                            ToTensorV2(),
                            ])

  return train_transform, val_transform


def getInferenceTransform():

  test_transform = A.Compose([
                            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                           ToTensorV2(),
                           ])

  tta_transform = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        
    ]
)

                        
  return test_transform, tta_transform

# DataLoader

In [None]:
from torch.utils.data.dataloader import DataLoader

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

def getDataloader(trainDataset, validDataset, batch, trainWorker, validWorker):

	trainDataloader = DataLoader(trainDataset, batch_size=batch, shuffle=True, pin_memory=True, num_workers=trainWorker, collate_fn=collate_fn)
	validDataloader = DataLoader(validDataset, batch_size=batch, shuffle=False,pin_memory=True, num_workers=validWorker, collate_fn=collate_fn)

	return trainDataloader, validDataloader

def getInferenceDataloader(dataset, batch, num_worker):

	return DataLoader(dataset, batch_size=1, num_workers=num_worker, shuffle= False,collate_fn=collate_fn)

# Model

In [None]:
import segmentation_models_pytorch as smp

# https://smp.readthedocs.io/en/latest/index.html
# https://smp.readthedocs.io/en/latest/encoders_timm.html
# timm encoder 쓸땐 encoder name에 tu- 붙이기

def getModel():
	
	model = smp.Unet(
			encoder_name="resnet18",
			encoder_weights="imagenet",
			in_channels=3,
			classes=15
		)
	
	return model

## Loss, Optimizer, Schedular

In [None]:
from torch import nn

def getLoss():
	criterion = nn.CrossEntropyLoss()

	return criterion

In [None]:
import torch

#https://sanghyu.tistory.com/113
#https://gaussian37.github.io/dl-pytorch-lr_scheduler/

def getOptAndScheduler(model, lr):

	optimizer = torch.optim.Adam(params = model.parameters(), lr=lr, weight_decay=1e-5)
	scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=10,eta_min=1e-5)

	return optimizer, scheduler	

# Start Train

In [None]:
import torch
import numpy as np
from utils.utils import add_hist, label_accuracy_score
from utils.wandb_method import WandBMethod
from utils.tqdm import TQDM
from utils.save_helper import SaveHelper
from torch.cuda.amp import GradScaler, autocast

def train(num_epochs, model, train_loader, val_loader, criterion, optimizer, scheduler, saved_dir, save_capacity, device, doWandb):
    n_class = 15
    scaler = GradScaler(enabled=True)

    saveHelper = SaveHelper(save_capacity, saved_dir)
    mainPbar = TQDM.makeMainProcessBar(num_epochs)

    for epoch in mainPbar:
        model.train()
        pbar = TQDM.makePbar(train_loader, epoch, True)

        hist = np.zeros((n_class, n_class))
        for step, (images, masks) in enumerate(pbar):
            images = torch.stack(images)       
            masks = torch.stack(masks).long() 
            
            # gpu 연산을 위해 device 할당
            images, masks = images.to(device), masks.to(device)
            
            # device 할당
            model = model.to(device)
            
            optimizer.zero_grad()

            with autocast(True):
                outputs = model(images)
                loss = criterion(outputs, masks)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # scheduler.step() 

            # auxiliary head가 포함됐을 때 mask만 추출
            if isinstance(outputs,tuple):
                outputs = outputs[0]    
            outputs = torch.argmax(outputs, dim=1).detach().cpu().numpy()
            
            masks = masks.detach().cpu().numpy()
            
            hist = add_hist(hist, masks, outputs, n_class=n_class)
            acc, acc_cls, acc_clsmean, mIoU, fwavacc, IoU = label_accuracy_score(hist)

            TQDM.setPbarPostInStep(pbar, acc,acc_clsmean,loss,mIoU)

            if doWandb:
                WandBMethod.trainLog(loss, acc, scheduler.get_last_lr())

        avrg_loss ,mIoU = validation(epoch, model, val_loader, criterion, device, doWandb)
        TQDM.setMainPbarPostInValid(mainPbar,avrg_loss)

        if saveHelper.checkBestIoU(mIoU, epoch):
            TQDM.setMainPbarDescInSaved(mainPbar,epoch,mIoU)
            saveHelper.removeModel()
            saveHelper.saveModel(epoch,model,optimizer,scheduler)

        # Scheduler는 epoch당 step
        scheduler.step() 

def validation(epoch, model, valid_loader, criterion, device, doWandb):
    model.eval()
    with torch.no_grad():
        n_class = 15
        total_loss = 0
        total_mIoU = 0
        cnt = 0
        
        hist = np.zeros((n_class, n_class))

        pbar = TQDM.makePbar(valid_loader,epoch,False)


        # len(pbar) 중 랜덤하게 정수 뽑고 해당 step일 때 이미지 그룹 찝어다가 이미지를 doWandb에 넣어줘
        targetStep = WandBMethod.pickImageStep(len(pbar))
        targetImages, targetOutputs, targetMasks = None, None, None

        for step, (images, masks) in enumerate(pbar):
            
            images = torch.stack(images)       
            masks = torch.stack(masks).long()  

            images, masks = images.to(device), masks.to(device)            
            
            model = model.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # auxiliary head가 포함됐을 때 mask만 추출
            if isinstance(outputs,tuple):
                outputs = outputs[0]    
            
            outputs = torch.argmax(outputs, dim=1).detach().cpu().numpy()
            masks = masks.detach().cpu().numpy()
            
            hist = add_hist(hist, masks, outputs, n_class=n_class)
            acc, acc_cls, acc_clsmean, mIoU, fwavacc, IoU = label_accuracy_score(hist)


            total_loss += loss
            total_mIoU += mIoU
            cnt += 1

            TQDM.setPbarPostInStep(pbar,acc,acc_clsmean,loss,total_mIoU/cnt)

            # 여러개의 epoch 중 랜덤으로 뽑아 wandb에 전송하는 용도
            if step==targetStep:
                targetImages, targetOutputs, targetMasks = images.detach().cpu().numpy(), outputs, masks

        avrg_loss = total_loss / cnt
        avrg_mIoU = total_mIoU / cnt

        if doWandb:
            WandBMethod.validLog(IoU, acc_cls, acc_clsmean, acc, avrg_mIoU, targetImages, targetOutputs, targetMasks)
      
        
    return avrg_loss, avrg_mIoU


In [None]:
arg = getArg()

train_transform, val_transform = getTransform()

train_dataset = CustomDataset(image_root=arg.train_image_root, mask_root=arg.train_mask_root, mode='train', transform=train_transform)
val_dataset = CustomDataset(image_root=arg.val_image_root, mask_root=arg.val_mask_root, mode='val', transform=val_transform)

trainLoader, valLoader = getDataloader(train_dataset, val_dataset, arg.batch, arg.train_worker, arg.valid_worker)

model = getModel()
criterion = getLoss()
optimizer, scheduler = getOptAndScheduler(model, arg.lr)

outputPath = os.path.join(arg.output_path, arg.custom_name)

# output Path 내 설정 저장
i = 2
while os.path.exists(outputPath):
    outputPath = outputPath + "_" + str(i)
    i += 1

shutil.copytree(f"custom/{custom_dir}",outputPath)
os.makedirs(outputPath+"/models")

# wandb
if arg.wandb:
    from utils.wandb_method import WandBMethod
    WandBMethod.login(arg, model, criterion)

train(arg.epoch, model, trainLoader, valLoader, criterion, optimizer,scheduler, outputPath, arg.save_capacity, device, arg.wandb)