In [2]:
import config
import numpy as np
import torch
import time
import pandas as pd
from glob import glob
from model import unet, deeplabv3plus, deeplabv3plus_mobilev2, hrnet
from utils.metrics import oa_binary, miou_binary
from dataloader.preprocess import read_normalize
from dataloader.loader import patch_tensor_dset, scene_dset
from dataloader.parallel_loader import threads_scene_dset
from utils.plot_dset_one import plot_dset_one


In [3]:
num_epoch = 200
model_name_save = 'unet_trained_1'


In [4]:
device = torch.device('cuda:0')
torch.manual_seed(999)            #### make the trianing replicable
model = unet(num_bands=6, num_classes=2).to(device)


In [5]:
## Data paths 
### The whole dataset.
paths_scene = sorted(glob(config.dir_s2 + '/scene/*'))
paths_truth = [path_scene.replace('/scene/', '/truth/').replace('_6Bands', '').split('.')[0] + '_truth.tif' for path_scene in paths_scene]
### Select training part from the dataset.
id_scene = [i for i in range(len(paths_scene))]
id_tra_scene = list(set(id_scene) - set(config.i_valset))
paths_tra_scene, paths_tra_truth = [paths_scene[i] for i in id_tra_scene], [paths_truth[i] for i in id_tra_scene]
len(paths_tra_scene)


75

In [6]:
### Validation part of the dataset (patch format)
paths_patch_val = sorted(glob(config.dir_patch_val_s2+'/*'))   ## Validation patches
'''--------- 1. Data loading --------'''
'''----- 1.1 training data loading (from scenes path) '''
tra_scenes, tra_truths = read_normalize(paths_img=paths_tra_scene, \
                            paths_truth=paths_tra_truth, max_bands=config.bands_max, min_bands=config.bands_min)


In [7]:
''' ----- 1.2. Training data loading and auto augmentation'''
time_start = time.time()
# tra_dset = threads_scene_dset(scene_list = tra_scenes, \
#                               truth_list = tra_truths, 
#                               transforms=config.transforms_tra, 
#                               num_thread=1)          ###  num_thread(30) patches per scene.

tra_dset = scene_dset(scene_list = tra_scenes, 
                             truth_list = tra_truths,
                             transforms = config.transforms_tra, 
                             patch_size = [256, 256])
print('size of training data:  ', len(tra_dset))
print('time comsuming:  ', time.time()-time_start)



size of training data:   75
time comsuming:   0.00028705596923828125


In [8]:
''' ----- 1.3. validation data loading (validation patches) ------ '''
patch_list_val = [torch.load(path) for path in paths_patch_val]
val_dset = patch_tensor_dset(patch_pair_list = patch_list_val)
print('size of validation data:', val_dset.__len__())
tra_loader = torch.utils.data.DataLoader(tra_dset, batch_size=config.batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=16)


size of validation data: 1000


In [19]:
'''------train step------'''
def train_step(model, loss_fn, optimizer, x, y):
    optimizer.zero_grad()
    pred = model(x)
    loss = loss_fn(pred, y.float())
    loss.backward()
    optimizer.step()
    miou = miou_binary(pred=pred, truth=y)
    oa = oa_binary(pred=pred, truth=y)
    return loss, miou, oa

'''------validation step------'''
def val_step(model, loss_fn, x, y):
    model.eval()
    with torch.no_grad():
        pred = model(x)
        loss = loss_fn(pred, y.float())
    miou = miou_binary(pred=pred, truth=y)
    oa = oa_binary(pred=pred, truth=y)
    return loss, miou, oa

'''------ train loops ------'''
def train_loops(model, loss_fn, optimizer, tra_loader, val_loader, epoches, lr_scheduler=None):
    size_tra_loader = len(tra_loader)
    size_val_loader = len(val_loader)
    tra_loss_loops, tra_miou_loops = [], []
    val_loss_loops, val_miou_loops = [], []
    for epoch in range(epoches):
        start = time.time()
        tra_loss, val_loss = 0, 0
        tra_miou, val_miou = 0, 0
        tra_oa, val_oa = 0, 0

        '''----- 1. train the model -----'''
        for x_batch, y_batch in tra_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            y_batch = config.label_smooth(y_batch) 
            loss, miou, oa = train_step(model=model, loss_fn=loss_fn, \
                                            optimizer=optimizer, x=x_batch, y=y_batch)
            tra_loss += loss.item()
            tra_miou += miou.item()
            tra_oa += oa.item()
        if lr_scheduler:
          lr_scheduler.step(tra_loss)         # using learning rate scheduler

        '''----- 2. validate the model -----'''
        for x_batch, y_batch in val_loader:
            x_batch, y_batch = x_batch.to(device).to(dtype=torch.float32), y_batch.to(device)    
            loss, miou, oa = val_step(model=model, loss_fn=loss_fn, x=x_batch, y=y_batch)
            val_loss += loss.item()
            val_miou += miou.item()
            val_oa += oa.item()

        '''------ 3. print mean accuracy ------'''
        tra_loss = tra_loss/size_tra_loader
        val_loss = val_loss/size_val_loader
        tra_miou = tra_miou/size_tra_loader
        val_miou = val_miou/size_val_loader
        tra_oa = tra_oa/size_tra_loader
        val_oa = val_oa/size_val_loader
        tra_loss_loops.append(tra_loss), tra_miou_loops.append(tra_miou)
        val_loss_loops.append(val_loss), val_miou_loops.append(val_miou)
        format = 'Ep{}: Tra-> Loss:{:.3f}, Oa:{:.3f}, Miou:{:.3f}, Val-> Loss:{:.3f}, Oa:{:.3f}, Miou:{:.3f}, Time:{:.1f}s'
        print(format.format(epoch+1, tra_loss, tra_oa, tra_miou, val_loss, val_oa, val_miou, time.time()-start))
    metrics = {'tra_loss':tra_loss_loops, 'tra_miou':tra_miou_loops, 'val_loss': val_loss_loops, 'val_miou': val_miou_loops}
    return metrics


In [20]:
''' -------- 2. Model loading and training strategy ------- '''
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.6, patience=20)

''' -------- 3. Model training for loops ------- '''
metrics = train_loops(model=model,  
                    loss_fn=config.loss_bce, 
                    optimizer=optimizer,  
                    tra_loader=tra_loader,  
                    val_loader=val_loader,  
                    epoches=200,  
                    lr_scheduler=lr_scheduler,
                    )


Ep1: Tra-> Loss:0.366, Oa:0.959, Miou:0.913, Val-> Loss:0.157, Oa:0.966, Miou:0.934, Time:3.0s


In [13]:
''' -------- 4. trained model and accuracy metric saving  ------- '''
# model saving
path_weights = 'model/trained_model/' + model_name_save + '_weights.pth'
torch.save(model.state_dict(), path_weights)
print('Model weights are saved to --> ', path_weights)
## metrics saving
path_metrics = 'model/trained_model/' + model_name_save + '_metrics.csv'
metrics_df = pd.DataFrame(metrics)
metrics_df.to_csv(path_metrics, index=False, sep=',')
metrics_df = pd.read_csv(path_metrics)
print('Training metrics are saved to --> ', path_metrics)


Model weights are saved to -->  model/trained_model/unet_trained_1_weights.pth
Training metrics are saved to -->  model/trained_model/unet_trained_1_metrics.csv
