### Deep learning model training.


In [5]:
import time
import torch
import random
import pandas as pd
import torch.nn as nn
from glob import glob
from model import u2net 
import geopandas as gpd 
from notebooks import config
import torch.nn.functional as F
from utils.imgShow import imsShow
from utils.utils import read_scenes
from utils.metrics import oa_binary, miou_binary
from utils.dataloader import SceneArraySet, PatchPathSet, PatchPathSet_2
from model import unet, deeplabv3plus, deeplabv3plus_mobilev2, u2net
from model import unet_att, u2net_att, u2net_att_3


In [6]:
patch_size = 512  ## patch size setting
patch_resize = None  ## patch resize setting
### traset
paths_scene_tra, paths_truth_tra = config.paths_scene_tra, config.paths_truth_tra
paths_dem_tra = config.paths_dem_tra
# paths_dem_tra = config.paths_dem_adjust_tra
print(f'train scenes: {len(paths_scene_tra)}')
### valset
paths_valset = sorted(glob(f'data/dset/valset/patch_{patch_size}/*'))  ## for model prediction 
# paths_patch_valset = sorted(glob(f'data/dset/valset/patch_{patch_size}_dem_adjust/*'))
print(f'vali patch {patch_size}: {len(paths_valset)}')



train scenes: 48
vali patch 512: 293


### dataset loading

In [7]:
scenes_arr, truths_arr = read_scenes(paths_scene_tra, 
                                     paths_truth_tra, 
                                     paths_dem_tra) 


In [8]:
## Create dataset instances
tra_data = SceneArraySet(scenes_arr=scenes_arr, 
                         truths_arr=truths_arr, 
                         patch_size=patch_size,
                         patch_resize=patch_resize)
val_data = PatchPathSet_2(paths_valset=paths_valset, 
                          patch_resize=patch_resize)
tra_loader = torch.utils.data.DataLoader(tra_data, 
                                         batch_size=4, 
                                         shuffle=True, 
                                         num_workers=10)
val_loader = torch.utils.data.DataLoader(val_data, 
                                         batch_size=4, 
                                         num_workers=10)


#### Model training

In [3]:
### check model
# model = unet(num_bands=7)
# model = unet_att(num_bands=7)
# model = u2net(num_bands_b1=6, num_bands_b2=1)
model = u2net_att_3(num_bands_b1=6, num_bands_b2=1)
# model = deeplabv3plus(num_bands=7)
# model = deeplabv3plus_mobilev2(num_bands=7)


In [4]:
tensor = torch.randn(2, 7, 512, 512)  
output = model(tensor) 
print(output.shape) 


torch.Size([2, 1, 512, 512])


In [9]:
### create loss and optimizer
loss_bce = nn.BCELoss()
loss_mse = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)  
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \
                                      mode='min', factor=0.6, patience=20)


In [10]:
'''------train step------'''
def train_step(model, loss_fn, optimizer, x, y):
    optimizer.zero_grad()
    pred = model(x)
    loss = loss_fn(pred, y.float())
    loss.backward()
    optimizer.step()
    miou = miou_binary(pred=pred, truth=y, device=x.device)
    oa = oa_binary(pred=pred, truth=y, device=x.device)
    return loss, miou, oa

'''------validation step------'''
def val_step(model, loss_fn, x, y, patch_size, patch_resize):
    model.eval()
    with torch.no_grad():
        pred = model(x.float())
        if patch_size > 256:
            if patch_resize == 256:
                pred = F.interpolate(pred, 
                                size=(patch_size, patch_size), 
                                mode='bilinear',
                                align_corners=True)
            crop_start = (patch_size-256)//2
            pred = pred[:, :, crop_start:crop_start+256, crop_start:crop_start+256]  ## crop to inner 256x256 patch
        loss = loss_fn(pred, y.float())
    miou = miou_binary(pred=pred, truth=y, device=x.device)
    oa = oa_binary(pred=pred, truth=y, device=x.device)
    return loss, miou, oa

'''------train loops------'''
def train_loops(model, loss_fn, optimizer, 
                    tra_loader, val_loader, epoches, 
                    device, patch_size, patch_resize,
                    lr_scheduler=None):
    tra_loss_loops, tra_miou_loops, tra_oa_loops = [], [], []
    val_loss_loops, val_miou_loops, val_oa_loops = [], [], []
    model = model.to(device)
    size_tra_loader = len(tra_loader)
    size_val_loader = len(val_loader)
    for epoch in range(epoches):
        start = time.time()
        tra_loss, val_loss = 0, 0
        tra_miou, val_miou = 0, 0
        tra_oa, val_oa = 0, 0
        '''-----train the model-----'''
        for x_batch, y_batch in tra_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            loss, miou, oa = train_step(model=model, loss_fn=loss_fn, 
                                    optimizer=optimizer, x=x_batch, y=y_batch)
            tra_loss += loss.item()
            tra_miou += miou.item()
            tra_oa += oa.item()
        if lr_scheduler:
          lr_scheduler.step(tra_loss)    # if using ReduceLROnPlateau
        '''----- validation the model: time consuming -----'''
        for x_batch, y_batch in val_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            loss, miou, oa = val_step(model=model, 
                                        loss_fn=loss_fn, 
                                        x=x_batch, 
                                        y=y_batch, 
                                        patch_size=patch_size, 
                                        patch_resize=patch_resize)
            val_loss += loss.item()
            val_miou += miou.item()
            val_oa += oa.item()
        ## Accuracy
        tra_loss = tra_loss/size_tra_loader
        val_loss = val_loss/size_val_loader
        tra_miou = tra_miou/size_tra_loader
        val_miou = val_miou/size_val_loader
        tra_oa = tra_oa/size_tra_loader
        val_oa = val_oa/size_val_loader
        tra_loss_loops.append(tra_loss); tra_miou_loops.append(tra_miou); tra_oa_loops.append(tra_oa)
        val_loss_loops.append(val_loss); val_miou_loops.append(val_miou); val_oa_loops.append(val_oa)
        print(f'Ep{epoch+1}: tra-> Loss:{tra_loss:.3f},Oa:{tra_oa:.3f},Miou:{tra_miou:.3f}, '
                f'val-> Loss:{val_loss:.3f},Oa:{val_oa:.3f}, Miou:{val_miou:.3f},time:{time.time()-start:.1f}s')
        ## show the result
        if (epoch+1)%20 == 0:
            model.eval()
            sam_index = random.randrange(len(val_data))
            patch, truth = val_data[sam_index]
            patch, truth = torch.unsqueeze(patch.float(), 0).to(device), truth.to(device)
            pred = model(patch)
            if patch_size > 256:
                while patch_resize == 256:
                    patch = F.interpolate(patch, 
                                    size=(patch_size, patch_size), 
                                    mode='bilinear',
                                    align_corners=True)
                    pred = F.interpolate(pred, 
                                    size=(patch_size, patch_size), 
                                    mode='bilinear',
                                    align_corners=True)
                    break
                crop_start = (patch_size-256)//2                
                patch_val = patch[:, :, crop_start:crop_start+256, crop_start:crop_start+256]  ## crop to match the size
                pred_val = pred[:, :, crop_start:crop_start+256, crop_start:crop_start+256]  ## crop to match the size
            ## convert to numpy and plot
            patch = patch[0].to('cpu').detach().numpy().transpose(1,2,0)
            pdem = patch[:,:, -1]
            patch_val = patch_val[0].to('cpu').detach().numpy().transpose(1,2,0)
            pred_val = pred_val[0].to('cpu').detach().numpy()
            truth_val = truth.to('cpu').detach().numpy()
            imsShow([patch, pdem, patch_val, pred_val, truth_val], 
                    clip_list = (2,2,2,0,0),
                    img_name_list=['input_patch', 'pdem', 'patch_val', 'pred_val', 'truth_val'],                     
                    figsize=(15, 3))
    metrics = {'tra_loss':tra_loss_loops, 'tra_miou':tra_miou_loops, 'tra_oa': tra_oa_loops, 
                'val_loss': val_loss_loops, 'val_miou': val_miou_loops, 'val_oa': val_oa_loops}
    return metrics 


In [13]:
device = torch.device('cuda:1') 
metrics = train_loops(model=model, 
                epoches=200,  
                loss_fn=loss_mse,  
                optimizer=optimizer,  
                tra_loader=tra_loader,  
                val_loader=val_loader,  
                patch_size=patch_size,  
                patch_resize=None,  
                lr_scheduler=lr_scheduler,  
                device=device)  


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 GiB. GPU 1 has a total capacity of 23.53 GiB of which 7.86 GiB is free. Process 1968237 has 4.49 GiB memory in use. Including non-PyTorch memory, this process has 11.14 GiB memory in use. Of the allocated memory 10.65 GiB is allocated by PyTorch, and 45.37 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# # model saving
# model_name = 'u2net'
# # model_name = 'unet_att'
# # net_name = 'deeplabv3plus'
# # # net_name = 'deeplabv3plus_mb2'
# # path_save = f'model/trained/patch_{patch_size}/{model_name}_weights_1.pth'
# path_save = f'model/trained/{model_name}_{patch_size}/{model_name}_weights_2.pth'
# torch.save(model.state_dict(), path_save)     ## save weights of the trained model 
# # # model.load_state_dict(torch.load(path_save, weights_only=True))  ## load the weights of the trained model
# # # ## metrics saving
# path_metrics = f'model/trained/{model_name}_{patch_size}/{model_name}_metrics_2.csv'    
# # path_metrics = f'model/trained/patch_{patch_size}/{model_name}_metrics_1.csv'    
# metrics_df = pd.DataFrame(metrics)
# metrics_df.to_csv(path_metrics, index=False, sep=',')
