In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import math, time, copy

import utils, parameters, Unet_models
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
dir_name        = "20220901_bgerr" + str(parameters.background_err)
print(dir_name)

20220828_bgerr0.0


In [2]:
training_set_npz = np.load('dataset/N' + str(parameters.sigNoise) + '_training_set.npz')
x_train_obs = training_set_npz['x_train_obs']
x_train = training_set_npz['x_train']
mask_train = training_set_npz['mask_train']

x_val_obs = training_set_npz['x_val_obs']
x_val = training_set_npz['x_val']
mask_val = training_set_npz['mask_val']

stdTr = training_set_npz['std']
meanTr = training_set_npz['mean']

In [3]:
batchsize = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

training_dataset  = torch.utils.data.TensorDataset(torch.Tensor(x_train_obs), torch.Tensor(x_train), torch.Tensor(mask_train))
val_dataset       = torch.utils.data.TensorDataset(torch.Tensor(x_val_obs),  torch.Tensor(x_val), torch.Tensor(mask_val)) 

dataloaders = {
    'train': torch.utils.data.DataLoader(training_dataset, batch_size=batchsize, shuffle=True, num_workers=4, pin_memory=True),
    'val': torch.utils.data.DataLoader(val_dataset, batch_size=batchsize, shuffle=True, num_workers=4, pin_memory=True),
}

dataset_sizes = {'train': len(training_dataset), 'val': len(val_dataset)}

In [4]:
model_head  = Unet_models.L96_UnetConvRec_head().to(device)
model_dyn   = Unet_models.L96_UnetConvRec_dyn().to(device)
model_sup   = Unet_models.L96_UnetConvRec_sup().to(device)

model_head.load_state_dict(torch.load("ckpts/" + dir_name + "/pretrain_head_epoch20"))
model_dyn.load_state_dict(torch.load("ckpts/" + dir_name + "/pretrain_dyn_epoch200"))

best_model_dyn_wts  = copy.deepcopy(model_dyn.state_dict())

In [5]:
num_epochs_sup  = [   10,   10,   10]
num_epochs_dyn  = [   10,   10,   10]
lr_head         = [ 1e-4,    0,    0]
# lr_head         = [    0,    0,    0]
lr_sup          = [ 1e-4, 1e-4, 1e-4]
lr_dyn          = [ 1e-4, 1e-4, 1e-4]
num_epochs_list = [num_epochs_sup, num_epochs_dyn]
mod_name        = ["sup", "dyn"]

if not os.path.exists("ckpts/" + dir_name):
    os.makedirs("ckpts/" + dir_name)
    print("creating dir: ckpts/" + dir_name )
    
if not os.path.exists("train_loss/" + dir_name):
    os.makedirs("train_loss/" + dir_name)
    print("creating dir: train_loss/" + dir_name )

for step in range(1, 4):

    print("step: ", step)
    model = Unet_models.L96_DARNN(model_head, model_dyn, model_sup, step)
    
    for status in range(2): # status=0 -> train sup, status=1 -> train dyn
        
        if status == 0:
            print("finetuning model_sup...")
        else:
            print("finetuning model_dyn...")

        since = time.time()
        best_loss_rec = 1e10

        train_loss_rec_list = []
        val_loss_rec_list = []
        train_loss_dyn_list = []
        val_loss_dyn_list = []
        train_loss_dynbg_list = []
        val_loss_dynbg_list = []
        train_loss_R_list = []
        val_loss_R_list = []
        train_loss_I_list = []
        val_loss_I_list = []
        
        num_epochs = num_epochs_list[status][step-1]
        
        if status == 0:
            model.model_dyn.eval()
            optimizer_model_sup  = optim.Adam(model.model_sup.parameters(),  lr=lr_sup[step-1])
            optimizer_model_head = optim.Adam(model.model_head.parameters(), lr=lr_head[step-1])
            print("optimizing model_sup, lr = ",  lr_sup[step-1])
            print("optimizing model_head, lr = ", lr_head[step-1])
            
        else:
            model.model_sup.eval()
            model.model_head.eval()
            optimizer_model_dyn  = optim.Adam(model.model_dyn.parameters(),  lr=lr_dyn[step-1])
            print("optimizing model_dyn, lr = ",  lr_dyn[step-1])

        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    if status == 0:
                        model.model_sup.train()
                        model.model_head.train()
                    else:
                        model.model_dyn.train()
                else:
                    if status == 0:
                        model.model_sup.eval()
                        model.model_head.eval()
                    else:
                        model.model_dyn.eval()

                running_loss_rec    = 0.0
                running_loss_dyn    = 0.0
                running_loss_dyn_bg = 0.0
                running_loss_R      = 0.0
                running_loss_I      = 0.0
                num_loss            = 0

                # Iterate over data.
                for inputs, targets, mask, in dataloaders[phase]:
                    mask        = mask.to(device)
                    targets     = targets.to(device)
                    inputs      = inputs.to(device)
                    
                    if status == 0:
                        optimizer_model_sup.zero_grad()
                        optimizer_model_head.zero_grad()
                    else:
                        optimizer_model_dyn.zero_grad()

                    with torch.set_grad_enabled(True): 
                        outputs    = model(inputs, mask)
                        
                        dyn_list   = torch.zeros(step + 1)
                        dynbg_list = torch.zeros(step + 1)
                        for ii, elem in enumerate(outputs):
                            dyn_list[ii]   = utils.dynamic_loss(elem, 1, meanTr, stdTr, 1)
                            dynbg_list[ii] = utils.dynamic_loss(elem, 1, meanTr, stdTr, 3)
                            
                        output      = outputs[-1]
                        loss_rec    = torch.mean((output - targets)**2)
                        loss_dyn    = dyn_list[-1]
                        loss_dyn_bg = dynbg_list[-1]
                        loss_R      = torch.sum((output - targets)**2 * mask) / torch.sum(mask)
                        loss_I      = torch.sum((output - targets)**2 * (1 - mask)) / torch.sum(1 - mask)
                        
                        if status == 0:
                            loss  = loss_rec
                        else:
                            loss  = loss_rec + 100 * torch.mean(dynbg_list)

                        if phase == 'train':
                            loss.backward()
                            if status == 0:
                                optimizer_model_sup.step()
                                optimizer_model_head.step()
                            else:
                                optimizer_model_dyn.step()

                    running_loss_rec         += loss_rec.item()    * inputs.size(0) * stdTr**2
                    running_loss_dyn         += loss_dyn.item()    * inputs.size(0) * stdTr**2
                    running_loss_dyn_bg      += loss_dyn_bg.item() * inputs.size(0) * stdTr**2
                    running_loss_R           += loss_R.item()      * inputs.size(0) * stdTr**2
                    running_loss_I           += loss_I.item()      * inputs.size(0) * stdTr**2
                    num_loss                 += inputs.size(0)

                epoch_loss_rec       = running_loss_rec    / num_loss
                epoch_loss_dyn       = running_loss_dyn    / num_loss
                epoch_loss_dyn_bg    = running_loss_dyn_bg / num_loss
                epoch_loss_R         = running_loss_R      / num_loss
                epoch_loss_I         = running_loss_I      / num_loss

                print('{} rec loss: {:.4e} dyn loss: {:.4e} dyn loss(bg): {:.4e} loss_R: {:.4e} loss_I: {:.4e}'.format(
                    phase, epoch_loss_rec, epoch_loss_dyn, epoch_loss_dyn_bg, epoch_loss_R, epoch_loss_I))

                if phase == 'train':
                    train_loss_rec_list.append(epoch_loss_rec)
                    train_loss_dyn_list.append(epoch_loss_dyn)
                    train_loss_dynbg_list.append(epoch_loss_dyn_bg)
                    train_loss_R_list.append(epoch_loss_R)
                    train_loss_I_list.append(epoch_loss_I)
                else:
                    val_loss_rec_list.append(epoch_loss_rec)
                    val_loss_dyn_list.append(epoch_loss_dyn)
                    val_loss_dynbg_list.append(epoch_loss_dyn_bg)
                    val_loss_R_list.append(epoch_loss_R)
                    val_loss_I_list.append(epoch_loss_I)

                if phase == 'val' and epoch_loss_rec < best_loss_rec:
                    best_loss_rec = epoch_loss_rec
                    best_model_head_wts = copy.deepcopy(model.model_head.state_dict())
                    best_model_sup_wts  = copy.deepcopy(model.model_sup.state_dict())
                    best_model_dyn_wts  = copy.deepcopy(model.model_dyn.state_dict())

            if epoch_loss_dyn_bg < parameters.relative_err:
                break
                
            print()

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val reconstruction loss: {:4e}'.format(best_loss_rec))

        save_dir_model_head = "ckpts/" + dir_name + "/finetune" + str(step*2 + status) + "_head_epoch"  + str(num_epochs)
        save_dir_model_sup  = "ckpts/" + dir_name + "/finetune" + str(step*2 + status) + "_sup_epoch"  + str(num_epochs)
        save_dir_model_dyn  = "ckpts/" + dir_name + "/finetune" + str(step*2 + status) + "_dyn_epoch" + str(num_epochs)
        print("saving model_head at " + save_dir_model_head)
        print("saving model_sup at "  + save_dir_model_sup)
        print("saving model_dyn at "  + save_dir_model_dyn)
        torch.save(best_model_head_wts, save_dir_model_head)
        torch.save(best_model_sup_wts,  save_dir_model_sup)
        torch.save(best_model_dyn_wts,  save_dir_model_dyn)

        save_dir_loss  = "train_loss/" + dir_name + "/finetune" + str(step*2 + status) + "_epoch" + str(num_epochs)
        print("saving loss at " + save_dir_loss)
        np.savez(save_dir_loss,
                 train_loss_rec   = train_loss_rec_list,   val_loss_rec   = val_loss_rec_list, 
                 train_loss_dyn   = train_loss_dyn_list,   val_loss_dyn   = val_loss_dyn_list,
                 train_loss_dynbg = train_loss_dynbg_list, val_loss_dynbg = val_loss_dynbg_list,
                 train_loss_R     = train_loss_R_list,     val_loss_R     = val_loss_R_list, 
                 train_loss_I     = train_loss_I_list,     val_loss_I     = val_loss_I_list,
                 time = time_elapsed)
        
        torch.cuda.empty_cache()
        print()
        

step:  1
finetuning model_sup...
optimizing model_sup, lr =  0.0001
optimizing model_head, lr =  0.0001
Epoch 0/9
----------


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


train rec loss: 2.0331e+00 dyn loss: 9.8560e-03 dyn loss(bg): 9.8560e-03 loss_R: 1.9318e+00 loss_I: 2.0476e+00
val rec loss: 5.8775e-01 dyn loss: 1.5539e-03 dyn loss(bg): 1.5539e-03 loss_R: 4.6658e-01 loss_I: 6.0506e-01

Epoch 1/9
----------
train rec loss: 4.5310e-01 dyn loss: 1.1352e-03 dyn loss(bg): 1.1352e-03 loss_R: 3.8010e-01 loss_I: 4.6353e-01
val rec loss: 5.4351e-01 dyn loss: 1.0971e-03 dyn loss(bg): 1.0971e-03 loss_R: 4.3119e-01 loss_I: 5.5956e-01

Epoch 2/9
----------
train rec loss: 4.2101e-01 dyn loss: 8.8788e-04 dyn loss(bg): 8.8788e-04 loss_R: 3.5451e-01 loss_I: 4.3051e-01
val rec loss: 5.2691e-01 dyn loss: 9.5209e-04 dyn loss(bg): 9.5209e-04 loss_R: 4.1829e-01 loss_I: 5.4243e-01

Epoch 3/9
----------
train rec loss: 4.0195e-01 dyn loss: 7.7288e-04 dyn loss(bg): 7.7288e-04 loss_R: 3.3915e-01 loss_I: 4.1092e-01
val rec loss: 5.1151e-01 dyn loss: 8.4729e-04 dyn loss(bg): 8.4729e-04 loss_R: 4.0582e-01 loss_I: 5.2661e-01

Epoch 4/9
----------
train rec loss: 3.8011e-01 dyn l

saving loss at train_loss/20220828_bgerr0.0/finetune4_epoch10

finetuning model_dyn...
optimizing model_dyn, lr =  0.0001
Epoch 0/9
----------
train rec loss: 2.8428e-01 dyn loss: 2.8770e-04 dyn loss(bg): 2.8770e-04 loss_R: 2.4852e-01 loss_I: 2.8938e-01
val rec loss: 4.1228e-01 dyn loss: 3.3485e-04 dyn loss(bg): 3.3485e-04 loss_R: 3.3486e-01 loss_I: 4.2334e-01

Epoch 1/9
----------
train rec loss: 2.7320e-01 dyn loss: 2.4493e-04 dyn loss(bg): 2.4493e-04 loss_R: 2.3832e-01 loss_I: 2.7818e-01
val rec loss: 4.1133e-01 dyn loss: 3.2800e-04 dyn loss(bg): 3.2800e-04 loss_R: 3.3477e-01 loss_I: 4.2227e-01

Epoch 2/9
----------
train rec loss: 2.7254e-01 dyn loss: 2.4476e-04 dyn loss(bg): 2.4476e-04 loss_R: 2.3783e-01 loss_I: 2.7750e-01
val rec loss: 4.1141e-01 dyn loss: 3.3172e-04 dyn loss(bg): 3.3172e-04 loss_R: 3.3443e-01 loss_I: 4.2241e-01

Epoch 3/9
----------
train rec loss: 2.7045e-01 dyn loss: 2.4584e-04 dyn loss(bg): 2.4584e-04 loss_R: 2.3593e-01 loss_I: 2.7538e-01
val rec loss: 4.0895

saving loss at train_loss/20220828_bgerr0.0/finetune7_epoch10

