In [14]:
import sys
sys.path.append('../')

In [15]:
import os, glob, random, cv2
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import albumentations as A
import segmentation_models_pytorch as smp
import model.metric as module_metric

from data_loader.dataloader import get_dataloader 
from utils.data import get_datasize
from utils.visual import *
from albumentations.pytorch import transforms
from model.loss import *
from train import *
from pathlib import Path

In [16]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

lr = 1e-3
batch_size = 4
num_epoch = 100

train_dir = './dataset/refined/crushed/train/'
val_dir = './dataset/refined/crushed/val/'

In [17]:
transform_train = A.Compose([
    A.HorizontalFlip(),
    A.Rotate((-30, 30), p=0.5, border_mode=cv2.BORDER_REFLECT,),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
    A.Resize(256, 256),
    A.Normalize(mean = 0.5, std=0.5),
    transforms.ToTensorV2(transpose_mask=True)
])

transform_val = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean = 0.5, std=0.5),
    transforms.ToTensorV2(transpose_mask=True)
])

In [18]:
train_dataloader = get_dataloader(train_dir, transform_train, batch_size)
val_dataloader = get_dataloader(val_dir, transform_val, batch_size)

In [19]:
model = smp.Unet(encoder_name='efficientnet-b0', encoder_weights='imagenet', in_channels=3, classes=1, activation=None)
model = model.to(device)

In [20]:
criterion = DiceBCELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6)

In [21]:
train_config = {}

In [22]:
train_config['Batch size'] = batch_size
train_config['Learning Rate'] = lr
train_config['Epochs'] = num_epoch

train_config['Loss fn'] = 'DiceBCE'
train_config['Optimizer'] = 'Adam'
train_config['LR Scheduler'] = 'ReduceLROnPlateau'
train_config['Metric'] = [metric for metric in ['IOUscore', 'PixelAccuracy']]

In [23]:
wandb.init(project='Segmentation', name='UNet', config=train_config)

In [None]:
metrics = [getattr(module_metric, met) for met in ['IOUscore', 'PixelAccuracy']]

In [None]:
trainer = Trainer(model, criterion, metrics, optimizer, device, num_epoch, Path('./saved/Efficient_B0'),
                                  data_loader=train_dataloader, valid_data_loader=val_dataloader,
                                  lr_scheduler=scheduler)

In [None]:
trainer.train()


Epoch : 0 | Train Loss : 1.15082 | Train P.A : 93.08% | Train IOU : 0.31803 | (4, 256, 256) (4, 256, 256)


ValueError: axes don't match array