In [1]:
import sys
import config
sys.path.append(config.root)
sys.path.append(config.root+'/model')
sys.path.append(config.root+'/trainer')
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import random
import glob
from dataloader.loader import patch_tensor_dset
from dataloader.parallel_loader import parallel_load_dset
from dataloader.parallel_loader import parallel_read
from dataloader.read_preprocess import read_preprocess
from utils.metric import oa_binary, miou_binary
from utils.tiff_io import readTiff
from utils.transforms import toTensor, crop_scales, normalize
from utils.img_aug import rotate, flip, noise, missing
from utils.plot_dset_one import plot_dset_one
from utils.imgShow import imgShow
from seg_model.unet import unet
import matplotlib.pyplot as plt

In [2]:
# ------------device---------------- #
device = torch.device('cuda:0')
# device = torch.device('cpu')

# ----------Data paths-------------- #
# ---s1 path---
paths_as = sorted(glob.glob(config.root+'/data/s1_ascend/*'))
paths_des = sorted(glob.glob(config.root+'/data/s1_descend/*'))
paths_truth = sorted(glob.glob(config.root+'/data/s1_truth/*'))
# ---patch path---
paths_patch_tra = sorted(glob.glob(config.root+'/data/tra_patches/*'))
paths_patch_test = sorted(glob.glob(config.root+'/data/test_patches/*'))

#----------Training parameter------- #
epochs = 200
lr = 0.005
torch.manual_seed(999)   # make the trianing replicable
random.seed(999)         # make the data augmentation replicable

In [3]:
# patch_list_tra = [torch.load(path) for path in paths_patch_tra]
# patch_list_test = [torch.load(path) for path in paths_patch_test]


In [4]:
scene_list, truth_list = read_preprocess(paths_as=paths_as, \
                        paths_des=paths_des, paths_truth=paths_truth)
tra_dset = parallel_load_dset(scene_list[0:15], \
                        truth_list[0:15], num_thread=30)
patch_list_test = [torch.load(path) for path in paths_patch_test]
test_dset = patch_tensor_dset(patch_pair_list=patch_list_test)


In [5]:
tra_dset.__len__()


450

In [6]:
# tra_loader = torch.utils.data.DataLoader(tra_dset, batch_size=8, shuffle=True)
tra_loader = torch.utils.data.DataLoader(tra_dset, batch_size=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=8)


In [7]:
### Configuration
model = unet(num_bands=4, num_classes=2).to(device)
# model = deeplabv3plus_xception65(img_channels=6, num_classes=2, output_stride=32).to(device)
# model = deeplabv3plus_mobilev2_imp(img_channels=6, num_classes=2).to(device)
# summary(model, input_size=(6, 512, 512))
# loss_ce = nn.CrossEntropyLoss()
loss_bce = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)


In [8]:
'''------train step------'''
def train_step(model, loss_fn, optimizer, x, y):
    pred = model(x)
    loss = loss_fn(pred, y.float())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    miou = miou_binary(pred=pred, truth=y)
    oa = oa_binary(pred=pred, truth=y)
    return loss, miou, oa

'''------test step------'''
def test_step(model, loss_fn, x, y):
    model.eval()
    with torch.no_grad():
        pred = model(x)
        loss = loss_fn(pred, y.float())
    miou = miou_binary(pred=pred, truth=y)
    oa = oa_binary(pred=pred, truth=y)
    return loss, miou, oa

'''------train loops------'''
def train_loops(model, loss_fn, optimizer, tra_loader, test_loader, epoches, lr_scheduler):
    size_tra_loader = len(tra_loader)
    size_test_loader = len(test_loader)
    for epoch in range(epoches):
        !free -m
        start = time.time()
        read_start = start
        tra_loss, test_loss = 0, 0
        tra_miou, test_miou = 0, 0
        tra_oa, test_oa = 0, 0
        '''-----train the model-----'''
        time_dataread = []
        time_train = []
        for x_batch, y_batch in tra_loader:
            x_batch, y_batch = [batch.to(device) for batch in x_batch], y_batch.to(device)
            x_batch = x_batch[2]    #!!!note: x_batch[2] for single-scale model
            batch_readend = time.time()  # for test time
            time_dataread.append(batch_readend-read_start) # test time
            loss, miou, oa = train_step(model=model, loss_fn=loss_fn, 
                                    optimizer=optimizer, x=x_batch, y=y_batch)
            time_train.append(time.time()-batch_readend)
            read_start = time.time()
            tra_loss += loss.item()
            tra_miou += miou.item()
            tra_oa += oa.item()
        lr_scheduler.step()  # dynamic adjust learning rate
        format = 'ep: {}, data read: {:.2f}, model train: {:.2f}, all time: {:.2f}'
        print(format.format(epoch, np.sum(time_dataread), np.sum(time_train), time.time()-start))


In [9]:
train_loops(model=model,
            loss_fn=loss_bce,
            optimizer=optimizer,
            tra_loader=tra_loader,
            test_loader=test_loader,
            epoches=100,
            lr_scheduler=lr_scheduler)


              total        used        free      shared  buff/cache   available
Mem:          64301       46164         466          10       17670       17466
Swap:          2047        2047           0
ep: 0, data read: 7.09, model train: 2.47, all time: 9.57
              total        used        free      shared  buff/cache   available
Mem:          64301       48309        1137          15       14854       15316
Swap:          2047        2047           0
ep: 1, data read: 6.50, model train: 1.82, all time: 8.32
              total        used        free      shared  buff/cache   available
Mem:          64301       47638        1845          15       14818       15987
Swap:          2047        2047           0
ep: 2, data read: 6.54, model train: 1.80, all time: 8.34
              total        used        free      shared  buff/cache   available
Mem:          64301       48772        1189          15       14339       14852
Swap:          2047        2047           0
ep: 3, dat