In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import math, time, copy

import utils, parameters, Unet_models
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
dir_name        = "20220829_bgerr" + str(parameters.background_err)
print(dir_name)

20220829_bgerr4.0


In [2]:
training_set_npz = np.load('dataset/N' + str(parameters.sigNoise) + '_training_set.npz')
x_train_obs = training_set_npz['x_train_obs']
x_train = training_set_npz['x_train']
mask_train = training_set_npz['mask_train']

x_val_obs = training_set_npz['x_val_obs']
x_val = training_set_npz['x_val']
mask_val = training_set_npz['mask_val']

stdTr = training_set_npz['std']
meanTr = training_set_npz['mean']

In [3]:
batchsize = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

training_dataset  = torch.utils.data.TensorDataset(torch.Tensor(x_train_obs), torch.Tensor(x_train), torch.Tensor(mask_train))
val_dataset       = torch.utils.data.TensorDataset(torch.Tensor(x_val_obs),  torch.Tensor(x_val), torch.Tensor(mask_val)) 

dataloaders = {
    'train': torch.utils.data.DataLoader(training_dataset, batch_size=batchsize, shuffle=True, num_workers=4, pin_memory=True),
    'val': torch.utils.data.DataLoader(val_dataset, batch_size=batchsize, shuffle=True, num_workers=4, pin_memory=True),
}

dataset_sizes = {'train': len(training_dataset), 'val': len(val_dataset)}

In [4]:
model_head  = Unet_models.L63_UnetConvRec_head().to(device)
model_dyn   = Unet_models.L63_UnetConvRec_dyn().to(device)
model_sup   = Unet_models.L63_UnetConvRec_sup().to(device)

model_head.load_state_dict(torch.load("ckpts/" + dir_name + "/pretrain_head_epoch20"))
model_dyn.load_state_dict(torch.load("ckpts/" + dir_name + "/pretrain_dyn_epoch2"))

best_model_head_wts  = copy.deepcopy(model_head.state_dict())
best_model_dyn_wts   = copy.deepcopy(model_dyn.state_dict())

In [5]:
num_epochs_sup  = [   10,   10,   10]
num_epochs_dyn  = [   10,   10,   10]
lr_head         = [ 1e-4,    0,    0]
# lr_head         = [    0,    0,    0]
lr_sup          = [ 1e-4, 1e-4, 1e-4]
lr_dyn          = [ 1e-4, 1e-4, 1e-4]
num_epochs_list = [num_epochs_sup, num_epochs_dyn]
mod_name        = ["sup", "dyn"]

if not os.path.exists("ckpts/" + dir_name):
    os.makedirs("ckpts/" + dir_name)
    print("creating dir: ckpts/" + dir_name )
    
if not os.path.exists("train_loss/" + dir_name):
    os.makedirs("train_loss/" + dir_name)
    print("creating dir: train_loss/" + dir_name )

for step in range(1, 4):

    print("step: ", step)
    model = Unet_models.L63_DARNN(model_head, model_dyn, model_sup, step)
    
    for status in range(2): # status=0 -> train sup, status=1 -> train dyn
        
        if status == 0:
            print("finetuning model_sup...")
        else:
            print("finetuning model_dyn...")

        since = time.time()
        best_loss_rec = 1e10

        train_loss_rec_list = []
        val_loss_rec_list = []
        train_loss_dyn_list = []
        val_loss_dyn_list = []
        train_loss_dynbg_list = []
        val_loss_dynbg_list = []
        train_loss_R_list = []
        val_loss_R_list = []
        train_loss_I_list = []
        val_loss_I_list = []
        
        num_epochs = num_epochs_list[status][step-1]
        
        if status == 0:
            model.model_dyn.eval()
            optimizer_model_sup  = optim.Adam(model.model_sup.parameters(),  lr=lr_sup[step-1])
            optimizer_model_head = optim.Adam(model.model_head.parameters(), lr=lr_head[step-1])
            print("optimizing model_sup, lr = ",  lr_sup[step-1])
            print("optimizing model_head, lr = ", lr_head[step-1])
            
        else:
            model.model_sup.eval()
            model.model_head.eval()
            optimizer_model_dyn  = optim.Adam(model.model_dyn.parameters(),  lr=lr_dyn[step-1])
            print("optimizing model_dyn, lr = ",  lr_dyn[step-1])

        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    if status == 0:
                        model.model_sup.train()
                        model.model_head.train()
                    else:
                        model.model_dyn.train()
                else:
                    if status == 0:
                        model.model_sup.eval()
                        model.model_head.eval()
                    else:
                        model.model_dyn.eval()

                running_loss_rec    = 0.0
                running_loss_dyn    = 0.0
                running_loss_dyn_bg = 0.0
                running_loss_R      = 0.0
                running_loss_I      = 0.0
                num_loss            = 0

                # Iterate over data.
                for inputs, targets, mask, in dataloaders[phase]:
                    mask        = mask.to(device)
                    targets     = targets.to(device)
                    inputs      = inputs.to(device)
                    
                    if status == 0:
                        optimizer_model_sup.zero_grad()
                        optimizer_model_head.zero_grad()
                    else:
                        optimizer_model_dyn.zero_grad()

                    with torch.set_grad_enabled(True): 
                        outputs    = model(inputs[:,0:1,:], mask[:,0:1,:])
                        
                        dyn_list   = torch.zeros(step + 1)
                        dynbg_list = torch.zeros(step + 1)
                        for ii, elem in enumerate(outputs):
                            dyn_list[ii]   = utils.dynamic_loss(elem, 1, meanTr, stdTr, 1)
                            dynbg_list[ii] = utils.dynamic_loss(elem, 1, meanTr, stdTr, 3)
                            
                        output      = outputs[-1]
                        loss_rec    = torch.mean((output - targets)**2)
                        loss_dyn    = dyn_list[-1]
                        loss_dyn_bg = dynbg_list[-1]
                        loss_R      = torch.sum((output - targets)**2 * mask) / torch.sum(mask)
                        loss_I      = torch.sum((output - targets)**2 * (1 - mask)) / torch.sum(1 - mask)
                        
                        if status == 0:
                            loss  = loss_rec
                        else:
                            loss  = loss_rec + 100 * torch.mean(dynbg_list)

                        if phase == 'train':
                            loss.backward()
                            if status == 0:
                                optimizer_model_sup.step()
                                optimizer_model_head.step()
                            else:
                                optimizer_model_dyn.step()

                    running_loss_rec         += loss_rec.item()    * inputs.size(0) * stdTr**2
                    running_loss_dyn         += loss_dyn.item()    * inputs.size(0) * stdTr**2
                    running_loss_dyn_bg      += loss_dyn_bg.item() * inputs.size(0) * stdTr**2
                    running_loss_R           += loss_R.item()      * inputs.size(0) * stdTr**2
                    running_loss_I           += loss_I.item()      * inputs.size(0) * stdTr**2
                    num_loss                 += inputs.size(0)

                epoch_loss_rec       = running_loss_rec    / num_loss
                epoch_loss_dyn       = running_loss_dyn    / num_loss
                epoch_loss_dyn_bg    = running_loss_dyn_bg / num_loss
                epoch_loss_R         = running_loss_R      / num_loss
                epoch_loss_I         = running_loss_I      / num_loss

                print('{} rec loss: {:.4e} dyn loss: {:.4e} dyn loss(bg): {:.4e} loss_R: {:.4e} loss_I: {:.4e}'.format(
                    phase, epoch_loss_rec, epoch_loss_dyn, epoch_loss_dyn_bg, epoch_loss_R, epoch_loss_I))

                if phase == 'train':
                    train_loss_rec_list.append(epoch_loss_rec)
                    train_loss_dyn_list.append(epoch_loss_dyn)
                    train_loss_dynbg_list.append(epoch_loss_dyn_bg)
                    train_loss_R_list.append(epoch_loss_R)
                    train_loss_I_list.append(epoch_loss_I)
                else:
                    val_loss_rec_list.append(epoch_loss_rec)
                    val_loss_dyn_list.append(epoch_loss_dyn)
                    val_loss_dynbg_list.append(epoch_loss_dyn_bg)
                    val_loss_R_list.append(epoch_loss_R)
                    val_loss_I_list.append(epoch_loss_I)

                if phase == 'val' and epoch_loss_rec < best_loss_rec:
                    best_loss_rec = epoch_loss_rec
                    best_model_head_wts = copy.deepcopy(model.model_head.state_dict())
                    best_model_sup_wts  = copy.deepcopy(model.model_sup.state_dict())
                    best_model_dyn_wts  = copy.deepcopy(model.model_dyn.state_dict())
                    
            if epoch_loss_dyn_bg < parameters.relative_err:
                break

            print()

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val reconstruction loss: {:4e}'.format(best_loss_rec))

        save_dir_model_head = "ckpts/" + dir_name + "/finetune" + str(step*2 + status) + "_head_epoch"+ str(num_epochs)
        save_dir_model_sup  = "ckpts/" + dir_name + "/finetune" + str(step*2 + status) + "_sup_epoch" + str(num_epochs)
        save_dir_model_dyn  = "ckpts/" + dir_name + "/finetune" + str(step*2 + status) + "_dyn_epoch" + str(num_epochs)
        print("saving model_head at " + save_dir_model_head)
        print("saving model_sup at "  + save_dir_model_sup)
        print("saving model_dyn at "  + save_dir_model_dyn)
        torch.save(best_model_head_wts, save_dir_model_head)
        torch.save(best_model_sup_wts,  save_dir_model_sup)
        torch.save(best_model_dyn_wts,  save_dir_model_dyn)

        save_dir_loss  = "train_loss/" + dir_name + "/finetune" + str(step*2 + status) + "_epoch" + str(num_epochs)
        print("saving loss at " + save_dir_loss)
        np.savez(save_dir_loss,
                 train_loss_rec   = train_loss_rec_list,   val_loss_rec   = val_loss_rec_list, 
                 train_loss_dyn   = train_loss_dyn_list,   val_loss_dyn   = val_loss_dyn_list,
                 train_loss_dynbg = train_loss_dynbg_list, val_loss_dynbg = val_loss_dynbg_list,
                 train_loss_R     = train_loss_R_list,     val_loss_R     = val_loss_R_list, 
                 train_loss_I     = train_loss_I_list,     val_loss_I     = val_loss_I_list,
                 time = time_elapsed)
            
        torch.cuda.empty_cache()
        print()
        

step:  1
finetuning model_sup...
optimizing model_sup, lr =  0.0001
optimizing model_head, lr =  0.0001
Epoch 0/9
----------


  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


train rec loss: 1.7033e+01 dyn loss: 1.2604e-01 dyn loss(bg): 1.2301e-01 loss_R: 8.2304e+00 loss_I: 1.7368e+01
val rec loss: 1.9518e+00 dyn loss: 7.5354e-02 dyn loss(bg): 7.8525e-02 loss_R: 5.0198e-01 loss_I: 2.0070e+00

Epoch 1/9
----------
train rec loss: 1.6941e+00 dyn loss: 6.1024e-02 dyn loss(bg): 6.5899e-02 loss_R: 4.2372e-01 loss_I: 1.7425e+00
val rec loss: 1.6465e+00 dyn loss: 5.1069e-02 dyn loss(bg): 5.7934e-02 loss_R: 4.2282e-01 loss_I: 1.6931e+00

Epoch 2/9
----------
train rec loss: 1.5141e+00 dyn loss: 4.1057e-02 dyn loss(bg): 4.7236e-02 loss_R: 3.9727e-01 loss_I: 1.5566e+00
val rec loss: 1.5210e+00 dyn loss: 3.5374e-02 dyn loss(bg): 4.2562e-02 loss_R: 3.9686e-01 loss_I: 1.5637e+00

Epoch 3/9
----------
train rec loss: 1.4364e+00 dyn loss: 2.9290e-02 dyn loss(bg): 3.6070e-02 loss_R: 3.9086e-01 loss_I: 1.4762e+00
val rec loss: 1.4738e+00 dyn loss: 2.6273e-02 dyn loss(bg): 3.4498e-02 loss_R: 3.8881e-01 loss_I: 1.5151e+00

Epoch 4/9
----------
train rec loss: 1.3997e+00 dyn l

val rec loss: 1.3152e+00 dyn loss: 6.8588e-03 dyn loss(bg): 1.0134e-02 loss_R: 3.9696e-01 loss_I: 1.3501e+00

Epoch 3/9
----------
train rec loss: 1.1724e+00 dyn loss: 6.2446e-03 dyn loss(bg): 8.9736e-03 loss_R: 3.9751e-01 loss_I: 1.2019e+00
val rec loss: 1.3070e+00 dyn loss: 7.3196e-03 dyn loss(bg): 1.0909e-02 loss_R: 3.8662e-01 loss_I: 1.3420e+00

Epoch 4/9
----------
train rec loss: 1.1736e+00 dyn loss: 6.3382e-03 dyn loss(bg): 9.2162e-03 loss_R: 3.9549e-01 loss_I: 1.2032e+00
val rec loss: 1.2810e+00 dyn loss: 6.3396e-03 dyn loss(bg): 9.5218e-03 loss_R: 4.0076e-01 loss_I: 1.3145e+00

Epoch 5/9
----------
train rec loss: 1.1719e+00 dyn loss: 6.4874e-03 dyn loss(bg): 9.3791e-03 loss_R: 3.9556e-01 loss_I: 1.2014e+00
val rec loss: 1.2883e+00 dyn loss: 6.6447e-03 dyn loss(bg): 9.6022e-03 loss_R: 4.1723e-01 loss_I: 1.3215e+00

Epoch 6/9
----------
train rec loss: 1.1561e+00 dyn loss: 6.8971e-03 dyn loss(bg): 9.9607e-03 loss_R: 3.8622e-01 loss_I: 1.1854e+00
val rec loss: 1.3739e+00 dyn los