In [2]:
import os
os.chdir('..')
import matplotlib.pyplot as plt
import numpy as np
import torch
from notebooks import config
import torch.nn as nn
import time
import random
import glob
from dataloader.loader import patch_tensor_dset
from dataloader.parallel_loader import threads_scene_dset
from dataloader.preprocess import read_normalize
from utils.metric import oa_binary, miou_binary
from utils.geotif_io import readTiff
from model.seg_model.unet import unet
from model.seg_model.deeplabv3_plus import deeplabv3plus, deeplabv3plus_imp


In [3]:
# ------------device---------------- #
device = torch.device('cuda:0')
# device = torch.device('cpu')

# ----------Data paths-------------- #
# ---s1 path---
paths_as = sorted(glob.glob(config.dir_as+'/*'))
paths_des = sorted(glob.glob(config.dir_des+'/*'))
paths_truth = sorted(glob.glob(config.dir_truth+'/*'))
# # ---patch path---
# paths_patch_tra = sorted(glob.glob(config.root+'/data/tra_patches/*'))
paths_patch_val = sorted(glob.glob(config.dir_patch_val+'/*'))

#----------Training parameter------- #
epochs = 200
lr = 0.005
torch.manual_seed(999)   # make the trianing replicable
random.seed(999)         # make the data augmentation replicable

In [5]:
## train dataset
scene_list, truth_list = read_normalize(paths_as=paths_as, \
                            paths_des=paths_des, paths_truth=paths_truth)
tra_dset = threads_scene_dset(scene_list[0:15], \
                            truth_list[0:15], \
                            transforms=config.transforms_tra, \
                            num_thread=20)
## validation dataset
patch_list_val = [torch.load(path) for path in paths_patch_val]
val_dset = patch_tensor_dset(patch_pair_list=patch_list_val)


In [6]:
print(tra_dset.__len__())
print(val_dset.__len__())


300
250


In [7]:
# tra_loader = torch.utils.data.DataLoader(tra_dset, batch_size=8, shuffle=True)
tra_loader = torch.utils.data.DataLoader(tra_dset, batch_size=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=4)


In [8]:
### Configuration
# model = unet(num_bands=4, num_classes=2).to(device)
model = deeplabv3plus(num_bands=4, num_classes=2).to(device)
# model = deeplabv3plus_xception65(img_channels=6, num_classes=2, output_stride=32).to(device)
# model = deeplabv3plus_mobilev2_imp(img_channels=6, num_classes=2).to(device)
# summary(model, input_size=(6, 512, 512))
# loss_ce = nn.CrossEntropyLoss()
loss_bce = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)


In [9]:
'''------train step------'''
def train_step(model, loss_fn, optimizer, x, y):
    pred = model(x)
    loss = loss_fn(pred, y.float())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    miou = miou_binary(pred=pred, truth=y)
    oa = oa_binary(pred=pred, truth=y)
    return loss, miou, oa

'''------test step------'''
def test_step(model, loss_fn, x, y):
    model.eval()
    with torch.no_grad():
        pred = model(x)
        loss = loss_fn(pred, y.float())
    miou = miou_binary(pred=pred, truth=y)
    oa = oa_binary(pred=pred, truth=y)
    return loss, miou, oa

'''------train loops------'''
def train_loops(model, loss_fn, optimizer, tra_loader, test_loader, epoches, lr_scheduler):
    size_tra_loader = len(tra_loader)
    size_test_loader = len(test_loader)
    for epoch in range(epoches):
        # !free -m   # print the memory to be used
        start_epoch = time.time()
        traload_start = time.time()
        tra_loss, test_loss = 0, 0
        tra_miou, test_miou = 0, 0
        tra_oa, test_oa = 0, 0
        '''-----train the model-----'''
        time_traload = []
        time_testload = []
        time_train = []
        time_test = []
        for x_batch, y_batch in tra_loader:
            x_batch, y_batch = [batch.to(device) for batch in x_batch], y_batch.to(device)
            x_batch = x_batch[2]    #!!!note: x_batch[2] for single-scale model
            batch_loadend = time.time()  # 
            time_traload.append(batch_loadend-traload_start) # test time
            loss, miou, oa = train_step(model=model, loss_fn=loss_fn, 
                                    optimizer=optimizer, x=x_batch, y=y_batch)
            time_train.append(time.time()-batch_loadend)
            tra_loss += loss.item()
            tra_miou += miou.item()
            tra_oa += oa.item()
            traload_start = time.time()
        lr_scheduler.step()  # dynamic adjust learning rate
        '''-----test the model-----'''
        testload_start = time.time()
        for x_batch, y_batch in val_loader:
            x_batch, y_batch = [batch.to(device) for batch in x_batch], y_batch.to(device)            
            x_batch = x_batch[2]   #!!!note: x_batch[2] for single-scale model
            batch_loadend = time.time()  # for test time
            time_testload.append(batch_loadend-testload_start) # test time
            loss, miou, oa = test_step(model=model, loss_fn=loss_fn, 
                                                    x=x_batch, y=y_batch)
            time_test.append(time.time()-batch_loadend)
            test_loss += loss.item()
            test_miou += miou.item()
            test_oa += oa.item()
            testload_start = time.time()

        format = 'ep: {}, tradata load: {:.2f}, testdata load: {:.2f}, model train: {:.2f}, model test: {:.2f}, all time: {:.2f}'
        print(format.format(epoch, np.sum(time_traload), np.sum(time_testload), np.sum(time_train), np.sum(time_test), time.time()-start_epoch))


In [10]:
train_loops(model=model,
            loss_fn=loss_bce,
            optimizer=optimizer,
            tra_loader=tra_loader,
            test_loader=val_loader,
            epoches=100,
            lr_scheduler=lr_scheduler)


ep: 0, tradata load: 1.96, testdata load: 0.22, model train: 5.51, model test: 0.86, all time: 8.96
ep: 1, tradata load: 1.96, testdata load: 0.20, model train: 4.83, model test: 0.84, all time: 8.25
ep: 2, tradata load: 1.98, testdata load: 0.21, model train: 4.83, model test: 0.87, all time: 8.30
