# $L^2_\mu$ and $H^1_\mu$ training of DeepONet

In [None]:
# MIT License
# Copyright (c) 2025
#
# This is part of the dino_tutorial package
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
# For additional questions contact Thomas O'Leary-Roseberry

import os, sys
import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.append('../../')

from dinotorch_lite import *


## Load the Data

In [None]:
data_dir = 'data/full_state/'

mq_data_dict = np.load(data_dir+'mq_data.npz')

q_data = mq_data_dict['q_data']
m_data = mq_data_dict['m_data']


# fno_metadata = np.load(data_dir+'fno_metadata.npz')

# d2v = fno_metadata['d2v_param']
# v2d = fno_metadata['v2d_param']
# nx = fno_metadata['nx']
# ny = fno_metadata['ny']

n_data, dQ = q_data.shape
n_data, dM = m_data.shape

print('dQ = ',dQ,', dM = ',dM)

m_train = torch.Tensor(m_data[:-800])
q_train = torch.Tensor(q_data[:-800])

m_test = torch.Tensor(m_data[-200:])
q_test = torch.Tensor(q_data[-200:])



# Set up datasets and loaders
l2train = L2Dataset(m_train,q_train)
l2test = L2Dataset(m_test,q_test)
batch_size = 32

train_loader = DataLoader(l2train, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(l2test, batch_size=batch_size, shuffle=True)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

## $L^2_\mu$ training

In [None]:
# model_settings = fno2d_settings(modes1=4, modes2=4, width=64, n_layers=4, d_out=2)
# model = VectorFNO2D(v2d=[d2v, d2v], d2v=[v2d, v2d], nx=nx, ny=ny, dim=2, settings=model_settings).to(device) 

rQ = 100

model = DeepONetNodal(dM,dQ,rQ)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

n_epochs = 100

loss_func = normalized_f_mse
# from scipy.sparse import csr_matrix, save_npz, load_npz
# M_output = load_npz(data_dir+'M_output_csr.npz')
# M_torch = scipy_csr_to_torch_csr(M_output).to(torch.float32) 
# M_torch.to(device)
# loss_func = weighted_l2_norm(M_torch)

lr_scheduler = None

optimizer = torch.optim.Adam(model.parameters())

network, history = l2_training(model,loss_func,train_loader, validation_loader,\
                     optimizer,lr_scheduler=lr_scheduler,n_epochs = n_epochs,verbose = True)

rel_error = evaluate_l2_error(model,validation_loader,error_func = loss_func)

print('L2 relative error = ', rel_error)

torch.save(model.state_dict(), data_dir+'l2_model_don.pth')

# $H^1_\mu$ training of DON

In [None]:
# Additional data
rQ = 100

J_data_dict = np.load(data_dir+'JstarPhi_data.npz',allow_pickle=True)
J_data = J_data_dict['JstarPhi_data'].transpose((0,2,1))[:,:rQ,:]
POD_encoder = np.load(data_dir+'POD/POD_encoder.npy')[:,:rQ]
# POD_encoder = J_data_dict['MPhi'][:,:rQ]
# POD_encoder.shape
POD_encoder = torch.Tensor(POD_encoder).to(torch.float32)
J_train = torch.Tensor(J_data[:-800])
J_test = torch.Tensor(J_data[-200:])

# Set up datasets and loaders
dinotrain = DINODataset(m_train,q_train, J_train)
dinotest = DINODataset(m_test,q_test, J_test)
batch_size = 32

dino_train_loader = DataLoader(dinotrain,  batch_size=batch_size, shuffle=True)
dino_validation_loader = DataLoader(dinotest, batch_size=batch_size, shuffle=True)




In [None]:
model = DeepONetNodal(dM,dQ,rQ)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

n_epochs = 100

loss_func = normalized_f_mse
loss_func_jac = normalized_f_mse

lr_scheduler = None

optimizer = torch.optim.Adam(model.parameters())

network, history = h1_training(model,loss_func,loss_func_jac, dino_train_loader, dino_validation_loader,\
                     optimizer,lr_scheduler=lr_scheduler,n_epochs = n_epochs,verbose = True,\
                               output_projector = POD_encoder)

rel_error = evaluate_l2_error(model,validation_loader,error_func = loss_func)

print('L2 relative error = ', rel_error)

torch.save(model.state_dict(), data_dir+'h1_model_don.pth')