# Notebook for training a PINN to simulate the double pendulum

### Load all libraries and modules

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import sys
import time
import itertools

import yaml
from configs.config_loader import load_config
from utils.plotting import Plotting, get_optimal_figsize, legend_without_duplicate_labels, init_plot_style

init_plot_style(use_tex=False)

# Standard PINN

#### Define configuration for simulation

In [None]:
# Initial condition
theta = 150*np.pi/180
y0 = [theta, theta, 0, 0] 

# Adapt config file
config_dict = { # Save results as json file
                'save_data': False,
                'log_name': ['default', 'weighted'], 
                'log_order': ['y0', 't', 'layers', 'epochs', 'loss'],
  
                # Initial condition, computational domain
                'y0': y0,
                'x_domain': [0,5],

                # Random seeds
                'seed_pinn': 6,

                # Loss weighting scheme
                'lambda_IC': 100,

                # Training parameters
                'n_epochs': 2000,

              }

config = load_config(Path('', 'configs', 'default.yaml'),
                     config_update = config_dict,
                     verbose = True) 

from pinn.neural_net import PhysicsInformedNN

#### PINN training 

In [None]:
##############################################################################
# Start training the PINN
##############################################################################

# Create an instance of the PINN model
pinn = PhysicsInformedNN(config, verbose=False)

# Train the model instance
pinn.train() 

# Return logs containing all training details and results
log = pinn.callback.log

### Plotting the solution trajectories and training losses

In [None]:
# time step from where the RK solution is evaluated with the corresponding PINN prediction
t_cont = 0

#####################################################################################

fig, axes = plt.subplots(1, 2, figsize=get_optimal_figsize(scale=1, height_factor=.3))
axes = axes.flatten()

Plot = Plotting(log, alpha=1)
Plot.plotting_continuation(axes[0:],t_cont=t_cont)
Plot.plotting_yt_sol(axes[0:])
IC_pred = log['y_pred'][0]

axes[0].set_title(r'$\overline{{\theta}}_1(t_0)={:.2f},\ \overline{{\theta}}_2(t_0)={:.2f}$'.format(IC_pred[0],IC_pred[1]), fontsize=10)
axes[1].set_title(r'$\overline{{\omega}}_1(t_0)={:.2f},\ \overline{{\omega}}_2(t_0)={:.2f}$'.format(IC_pred[2],IC_pred[3]), fontsize=10)

legend_without_duplicate_labels(axes[0], (-0.9, 0.5), idx = [-1, 0, 1, 2, 3, 4, 5], loc='center left');
legend_without_duplicate_labels(axes[1], (1, 0.5), idx = [-1, 0, 1, 2, 3, 4, 5], loc='center left');



In [None]:
fig, axes = plt.subplots(1, 2, figsize=get_optimal_figsize(scale=1, height_factor=.3))
axes = axes.flatten()
fig.subplots_adjust(wspace=0.3)

Plot = Plotting(log, alpha=1)
Plot.plotting_losses(axes[0:])

legend_without_duplicate_labels(axes[0], (-0.7, 0.5), idx = [-1, 2, 3, 4], loc='center left');
legend_without_duplicate_labels(axes[1], (1, 0.5), idx = [2,3], loc='center left');