In [None]:
import numpy as np
import os
import torch
import yaml
import sys
sys.path.append("./utils/")
from utilities import *
from Adam import Adam
from dataload import *
from RealNVP3D_DeepONet import *
from fwd_inv_loss import *

In [None]:
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"Your GPU model: {device_name}")
else:
    print("No GPU available.")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Load MT data

In [None]:
config_file_name = "train_32_32.yml"
with open(config_file_name, 'r') as f:
    args = yaml.safe_load(f)

size = 1024

# Load the data
TRAIN_PATH = args["TRAIN_PATH"]
TEST_PATH = args["TEST_PATH"]
ntrain = args['ntrain']
ntest  = args['ntest']
s_train = args['s_train']
r_train = args['r_train']
s_test = args['s_test']
r_test = args['r_test']
batch_size = args['batch_size']
n_out = args['n_out']
################################################################
# load data and data normalization
################################################################
loc_train,loc_test,x_train,y_train,x_test,y_test,freq_base,obs_base, y_normalizer,x_normalizer = \
        get_batch_data(TRAIN_PATH, TEST_PATH, ntrain, ntest,\
                        r_train, s_train,r_test,s_test,batch_size,n_out)
################################################################

### train data

In [None]:
## train data
x_train = x_train.reshape(ntrain,-1)
y_train = y_train[:,:,:,0].reshape(ntrain,-1)

u_t2 = x_train.reshape(ntrain,1,32,32)
y_t2 = loc_train.unsqueeze(0)

y_t2 = y_t2.repeat(ntrain, 1, 1)
s_t2 = y_train.reshape(ntrain, 1024, 1)

print("u_t2 shape:", u_t2.shape)
print("y_t2 shape:", y_t2.shape)
print("s_t2 shape:", s_t2.shape)

### test data

In [None]:
## test data
x_test = x_test.reshape(ntest,-1)
y_test = y_test[:,:,:,0].reshape(ntest,-1)

u_t3 = x_test.reshape(ntest,1,32,32)
y_t3 = loc_test.unsqueeze(0)


y_t3 = y_t3.repeat(ntest, 1, 1)
s_t3 = y_test.reshape(ntest, 1024, 1)


print("u_t3 shape:", u_t3.shape)
print("y_t3 shape:", y_t3.shape)
print("s_t3 shape:", s_t3.shape)

In [None]:
# Create data set
op_batch_size = 64


operator_dataset = Onet_dataset(u_t2,y_t2,s_t2)
operator_dataset = DataLoader(operator_dataset,batch_size = op_batch_size,shuffle = True)

operator_dataset_test = Onet_dataset(u_t3,y_t3,s_t3)
operator_dataset_test = DataLoader(operator_dataset_test,batch_size = op_batch_size,shuffle = True)

## Model Define

In [None]:
input_dim = output_dim = 1024
query_dim = 2
layer_sizes_trunk= [query_dim]+[128]*5+[output_dim]

activation="leaky_relu"
kernel_initializer="Glorot normal"

Model = DeepONetCartesianProd(layer_sizes_trunk, activation, kernel_initializer ).to(device)

# Train

In [None]:
epochs = 400
learning_rate = 1e-3
step_size = 10
gamma = 0.95

optimizer = Adam(Model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
save_interval = 20  

save_dir = "Model/1"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)


with open('Model/1/1.txt', 'w') as file:
    record_interval = 1 
    for epoch in range(epochs):  
        train_loss = 0
        total_operator_loss = 0  
        total_INV_loss = 0  

        Model.train()  
        for (u_op, y_op, Guy_op) in operator_dataset:
            x_op = y_op[:,:,0]
            t_op = y_op[:,:,1]

            (u_op, x_op, t_op, Guy_op) = (u_op.to(device), x_op.to(device), t_op.to(device), Guy_op.to(device))

            OP_loss =  OP_residual_calculator(u_op, x_op, t_op, Guy_op, Model)
            INV_loss = loss_inv(Guy_op, x_op, t_op, u_op, Model)

            loss =   OP_loss  + INV_loss 

            train_loss += loss.item()
                 
            total_operator_loss += OP_loss.item()
            total_INV_loss += INV_loss.item()


            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        average_OP_loss = total_operator_loss / ntrain               
        average_INV_loss = total_INV_loss / ntrain                   


        Model.eval() 
        test_loss = 0
        total_test_OP_loss = 0
        total_test_INV_loss = 0

        with torch.no_grad(): 
            for (u_op_test, y_op_test, Guy_op_test) in operator_dataset_test:
                x_op_test = y_op_test[:,:,0]
                t_op_test = y_op_test[:,:,1]

                (u_op_test, x_op_test, t_op_test, Guy_op_test) = (u_op_test.to(device), x_op_test.to(device), t_op_test.to(device), Guy_op_test.to(device))

                test_OP_loss = OP_residual_calculator(u_op_test, x_op_test, t_op_test, Guy_op_test, Model)
                test_INV_loss = loss_inv(Guy_op_test, x_op_test, t_op_test, u_op_test, Model)

                test_loss += (test_OP_loss + test_INV_loss).item()
                total_test_OP_loss += test_OP_loss.item()
                total_test_INV_loss += test_INV_loss.item()        

        average_test_OP_loss = total_test_OP_loss / ntest
        average_test_INV_loss = total_test_INV_loss / ntest
    
        if epoch % record_interval == 0 or epoch == epochs - 1:                
            print(f"Epoch {epoch}: "
                  f"Train - Average Operator Loss = {average_OP_loss}, Average INV Loss = {average_INV_loss} | "
                  f"Test - Average Operator Loss = {average_test_OP_loss}, Average INV Loss = {average_test_INV_loss}")
            file.write(f"Epoch {epoch}: "
                       f"Train - Average Operator Loss = {average_OP_loss}, Average INV Loss = {average_INV_loss} | "
                       f"Test - Average Operator Loss = {average_test_OP_loss}, Average INV Loss = {average_test_INV_loss}\n")

    
        scheduler.step()

        if epoch % save_interval == 0 or epoch == epochs - 1:
            save_path = os.path.join(save_dir, f"model_epoch_{epoch}.pth")
            torch.save(Model, save_path)  
            print(f"Model saved at epoch {epoch} to {save_path}")