In [None]:
import sys
sys.path.append('../util/')
sys.path.append('../datasets/')
sys.path.append('../models/')

import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from torch.utils.data import DataLoader
import numpy as np
from sklearn import model_selection
from scipy.optimize import fmin_cg
import scipy.signal
import random
import matplotlib.pyplot as plt
from typing import *

from engine import setup_seed,Namespace,train_vae_one_epoch_test
from datasets import build_dataset
from models import build_models
from monitor import Monitor


In [None]:
setup_seed(42)

In [None]:


args = Namespace()

args.dataset_type ='cylinder2d-p'
args.dataset_path =  r'../threeCylinder-grid-256-512-p.npy'
args.dataset_mask_path =  r'../.npy'
args.test_size = 0.33

args.model_type = "MLP"
args.mod_number = 3
args.pod_loss = None

args.lr = 1e-4
args.step_size = 100
args.gamma = 0.1
args.epochs=int(1000)

args.log_interval = 100
args.batch_size = 10
args.device='cuda:0'
args.word_dir='work_dir/'


args.monitorType = "random"
# args.monitorGridShape = (5,5)
args.random_num = 25

# log

In [None]:
localtime = time.asctime( time.localtime(time.time()) )
localtime = localtime.replace(' ','_')
localtime = localtime.replace(':','_')
print(localtime)
log_path = './log/'+localtime

os.makedirs(log_path)


    
with open(log_path+'/arg.txt','w',encoding='utf-8') as f:
    f.write(str(vars(args)))

# dataset

In [None]:
# (train_dataset.data - train_dataset.data.mean(axis=0)).shape

In [None]:
# build_dataset(dataset_path,dataset_type,test_size,mod_input_shape
build_dataset_res= build_dataset(args.dataset_path,args.dataset_type,args.test_size,)
train_dataset = build_dataset_res['train_dataset']
val_dataset  = build_dataset_res['val_dataset']
args.data_shape =  build_dataset_res['data_shape']

print("val_dataset.shape",val_dataset.shape)
print("train_dataset.shape",train_dataset.shape)

train_loader = DataLoader(train_dataset,batch_size=args.batch_size)
val_loader = DataLoader(val_dataset,batch_size=1)

# module

In [None]:
build_model = build_models(
    args.model_type,
    args.data_shape,
    args.mod_number,
    args.pod_loss
)
model = build_model['model']
args.mod_input_shape=build_model['mod_input_shape']
args.mod_output_shape=build_model['mod_output_shape']
args.code_shape=build_model['code_shape']


with open(log_path+'/model.txt','w',encoding='utf-8') as f:
    f.write(str(model))
    

# train pipeline

In [None]:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,mode='min', factor=args.gamma, patience=args.step_size, threshold=0.00001,)
loss = nn.MSELoss()

# train fun

In [None]:
def train_vae_one_epoch_test(
        model: torch.nn.Module,
        train_loader, 
        optimizer: torch.optim.Optimizer,
        device: torch.device, 
        epoch: int, 
        mod_input_shape,
        mod_output_shape,
        log_interval,
        ):
    
    model.train()
    for batch_idx, (data) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        B_size = len(data)
        B_mod_input_shape = [B_size,]+ list(mod_input_shape)
        B_mod_output_shape = [B_size,]+ list(mod_output_shape)
        output = model(data.reshape(B_mod_input_shape))
        loss_info = model.loss_function(data.reshape(B_mod_output_shape),output)
        loss = loss_info['loss']
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


# val fun

In [None]:
target = val_dataset[15]

"""
Monitor(self,monitorType ,data_shape,
        random_num=None, monitorGridShape=None
"""
meansure = Monitor(args.monitorType, args.data_shape,random_num = args.random_num ,mask = None  )
mask_map = meansure.grid2D()
mask_map = torch.Tensor(mask_map)

    
def val(targets : List ,m_s  : List ,model,code_shape,output_shape,device):
    con_m=[]
    for m in m_s:
        res = []
        for target in targets:
            model = model.to(device)
            mask_map = m.grid2D()
            mask_map = torch.Tensor(mask_map)

            target_re = target.reshape(output_shape).to(device)
            mask_re   = mask_map.reshape(output_shape).to(device)

            f = model.loss_decoder_helper(target_re,mask_re)
            fp= model.grad_loss_decoder_helper(target_re,mask_re)

            start_code = np.zeros(args.mod_number).astype(np.float32)+0.1
            fmin_code =fmin_cg(f,start_code,fprime=fp,disp=False )
            B_code_shape = [1,]+list(code_shape)
            y_per = model.decode( torch.Tensor(fmin_code.astype(np.float32)).to(device).reshape(B_code_shape))

            DS = (target.flatten().detach().cpu()-y_per.flatten().detach().cpu())
            D = torch.sqrt((DS**2).sum())
            T = torch.sqrt(((target.detach())**2).sum())
            L2 = D/T
            L2 = L2.cpu().numpy()
            
            res.append(L2)
        res = np.array(res ).mean()
        con_m.append(res)
    return con_m 


# train

In [None]:
model.to(args.device)

val_best = np.inf

for epoch in range(1, args.epochs+ 1):
    model = model.to(args.device)
    train_vae_one_epoch_test( 
        model,  
        train_loader, 
        optimizer, 
        args.device, 
        epoch ,
        args.mod_input_shape,
        args.mod_output_shape,
        args.log_interval
    )



torch.save(model,log_path+'/MLP3-last.pth')


# eval

In [None]:
def test(target  ,m  ,model,code_shape,output_shape,device):

    model = model.to(device)
    mask_map,meansure_x,meansure_y = m.grid2D(reqxy=True)
    mask_map = torch.Tensor(mask_map)

    target_re = (target.float()).reshape(output_shape).to(device)
    mask_re   = mask_map.reshape(output_shape).to(device)

    f = model.loss_decoder_helper(target_re,mask_re)
    fp= model.grad_loss_decoder_helper(target_re,mask_re)

    start_code = np.zeros(args.mod_number).astype(np.float32)+0.1
    fmin_code =fmin_cg(f,start_code,fprime=fp,disp=False )
    B_code_shape = [1,]+list(code_shape)
    y_per = model.decode( torch.Tensor(fmin_code.astype(np.float32)).to(device).reshape(B_code_shape))


    DS = (target.flatten().detach().cpu()-y_per.flatten().detach().cpu())
    D = torch.sqrt((DS**2).sum())
    T = torch.sqrt(((target.detach())**2).sum())
    L2 = D/T
    L2 = L2.cpu().numpy()

    y_per = y_per.flatten().detach().cpu().numpy()


    return y_per,L2,meansure_x,meansure_y

In [None]:
target = val_dataset[0]


for n in [10,20,30,40,50,60,70,80,90,100,150,200]:
    ll=[]
    for i in val_dataset[0:10]:
        for _ in range(10):
            mm = Monitor(args.monitorType, args.data_shape,random_num = n ,mask = np.load(args.dataset_mask_path))
            val_res,L2,meansure_x,meansure_y = test(i ,mm,model,args.code_shape,args.mod_output_shape,args.device)
            ll.append(L2)
            # print(L2)
    ll = np.array(ll)
    print( n, '\t',ll.mean(),ll.std())
#         print(n,L2)

# debug


In [None]:
val_dataset[10].device

In [None]:
test_data = val_dataset[10].to(args.device).flatten()

In [None]:
test_data_out =model(test_data)

In [None]:
test_data_out = test_data_out.reshape(args.data_shape)

In [None]:
out = test_data_out.cpu().detach().numpy().reshape(args.data_shape)
gt = test_data.cpu().detach().numpy().reshape(args.data_shape)
plt.imshow(out+TM)
plt.colorbar()

In [None]:
plt.imshow(gt)
plt.colorbar()

In [None]:
# plt.imshow( out -gt+TM)
# plt.colorbar()

In [None]:
args.data_shape