## Segmentation 학습에서 dfl backbone을 적용


#### 1. 라이브러리 호출

In [None]:
import sys
sys.path.append('../')

In [None]:
import os, glob, random, cv2
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import albumentations as A
import segmentation_models_pytorch as smp
import segmentation.model.metric as module_metric
import segmentation.model.model as model
import dfl_model.dfl_cnn as DFL

from segmentation.data_loader.dataloader import get_dataloader
from utils.data import get_datasize
from utils.visual import *
from albumentations.pytorch import transforms
from segmentation.model.loss import *
from segmentation.train import *
from pathlib import Path


#### 2. 시드고정

In [None]:
SEED = 201
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

#### 3. 하이퍼 파라미터 설정

###### lr : 1e-2
###### batch : 4
###### epochs : 200
###### loss : dice

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

lr = 1e-2
batch_size = 4
num_epoch = 200
damage = 'spacing'


train_dir = f'/aiffel/aiffel/final_project/dataset/accida_segmentation_dataset_v1/{damage}/train/'
val_dir = f'/aiffel/aiffel/final_project/dataset/accida_segmentation_dataset_v1/{damage}/valid/'
test_dir = f'/aiffel/aiffel/final_project/dataset/accida_segmentation_dataset_v1/{damage}/test/'


In [None]:
image_size = 512

mean= (0.5, 0.5, 0.5)

std= (0.5, 0.5, 0.5)

In [None]:
transform_train = A.Compose([
    A.Resize(image_size, image_size),
    A.HorizontalFlip(),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    A.Rotate((-10, 10), p=0.5, border_mode=cv2.BORDER_REFLECT,),
    A.Normalize(mean, std),
    transforms.ToTensorV2(transpose_mask=True)
])

transform_val = A.Compose([
    A.Resize(image_size, image_size),
    A.Normalize(mean, std),
    transforms.ToTensorV2(transpose_mask=True)
])

transform_test = A.Compose([
    A.Resize(image_size, image_size),
    A.Normalize(mean, std),
    transforms.ToTensorV2(transpose_mask=True)
])

In [None]:
train_dataloader = get_dataloader(train_dir, transform_train, batch_size, True)
val_dataloader = get_dataloader(val_dir, transform_val, batch_size, False)
test_dataloader = get_dataloader(test_dir, transform_test, batch_size, False)

## DFL 불러오기

In [None]:
# vgg16 기반

model_dfl_v = DFL.DFL_VGG16()
model_dfl_v

In [None]:
# # resnet50 기반

# model_resnet = DFL.DFL_RESNET50()
# model_resnet

In [None]:
# 불러온 DFL의 weight 확인

for name, param in model_dfl_v.named_parameters():
    print(param)

In [None]:
# 저장된 weight 가져오기

model_dfl_v.load_state_dict(torch.load('./DFL_Model.pth'))

In [None]:
# 저장된 weight 불러온 뒤 모델의 weight 확인하기(위에와 비교해보기)

for name, param in model_dfl_v.named_parameters():
    print(param)

## 학습할 segmentation 모델 

In [None]:
# segmentation 학습할 모델 불러오기

model_unet = smp.Unet(encoder_name='vgg16_bn') # 이때 dfl에서 불러온 모델을 맞춰줘야한다
model_unet

In [None]:
# segmentation 모델의 파라미터 확인하기

for name, param in model_unet.named_parameters():
    print(param)

## segmentation의 weight 바꿔주기

In [None]:
for i in range(0,33):
    try:
        model_dfl_v.conv1_conv4[i].weight
        model_unet.encoder.features[i].weight
    except :
        print(f'{i} : ', model_dfl_v.conv1_conv4[i])
        print(f'{i} : ', model_unet.encoder.features[i])
    else :
        model_unet.encoder.features[i].weight = model_dfl_v.conv1_conv4[i].weight
        model_unet.encoder.features[i].bias = model_dfl_v.conv1_conv4[i].bias
print("완료!")

In [None]:
for i in range(1,model_dfl_v.conv5.__len__()):
    try:
        model_dfl_v.conv5[i].weight
        model_unet.encoder.features[i+33].weight
    except :
        print(f'{i} : ', model_dfl_v.conv5[i])
        print(f'{i} : ', model_unet.encoder.features[i+33])
    else :
        model_unet.encoder.features[i+33].weight = model_dfl_v.conv5[i].weight 
        model_unet.encoder.features[i+33].bias = model_dfl_v.conv5[i].bias 


print("완료!")

## segmentation 모델에 weight 바뀌었는지 확인하기

In [None]:
for name, param in model_unet.named_parameters():
#     print(name)
    print(param)

In [None]:
model_unet.to(device)

In [None]:
# criterion = nn.BCEWithLogitsLoss().to(device)
criterion = DiceLoss().to(device)

optimizer = optim.SGD( model_unet.parameters(), 
                        momentum=0.9, 
                        lr=lr,
                        weight_decay = 5 * 1e-4 )

# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=20, T_mult=2, eta_min=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, 
                                                 factor=0.5, 
                                                 mode='min', 
                                                 patience=5, 
                                                 min_lr=1e-6 )

metrics = [getattr(module_metric, met) for met in ['IOUscore', 'PixelAccuracy']]

#### 4. wandb config 생성

In [None]:
train_config = {}
train_config['Batch size'] = batch_size
train_config['Learning Rate'] = lr
train_config['Epochs'] = num_epoch
train_config['Image size'] = image_size

train_config['Loss fn'] = criterion.__class__.__name__
train_config['Optimizer'] = optimizer.__class__.__name__
train_config['LR Scheduler'] = scheduler.__class__.__name__
train_config['Metric'] = {str(idx+1) : metric for idx, metric in enumerate([metrics[i].__name__ for i in range(len(metrics))])}

In [None]:
save_dir = f"./saved/U-Net/{'_'.join({model_unet.__class__.__name__})}_{model_dfl_v.__class__.__name__}_{damage}_ver0/"
print(save_dir)

In [None]:
trainer = Trainer( model_unet, criterion, metrics, optimizer, device, num_epoch, save_dir, mean, std,
                  data_loader=train_dataloader, valid_data_loader=val_dataloader, test_data_loader=test_dataloader,
                  lr_scheduler=scheduler )

In [None]:
trainer.early_stop = 30

train_config['Early stop'] = trainer.early_stop

In [None]:
wandb.init( project=f"[FINAL]_Fin_{trainer.dir.split('/')[1]}", 
            name=f"{model_dfl_v.__class__.__name__}_unet-weight(vgg)", 
            config=train_config )

#### 5. 학습 시작하기

In [None]:
trainer.train()