## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from tqdm.notebook import tqdm

import matplotlib
import matplotlib.pyplot as plt
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
# @title Styling
# @markdown Making things pretty! (This is meant to work with Colab's dark theme **(Settings > Site > Theme > "Dark")**.)
# import google
# is_dark = google.colab.output.eval_js(
#     'document.documentElement.matches("[theme=dark]")'
# )

matplotlib.rcParams["figure.dpi"] = 100
plt.rcParams["hatch.color"] = "white"

# if is_dark:
# load style sheet for matplotlib, a plotting library we use for 2D visualizations
plt.style.use(
    "https://github.com/dhaitz/matplotlib-stylesheets/raw/master/pitayasmoothie-dark.mplstyle"
)
plt.style.use("dark_background")
plt.rcParams.update(
    {
        "figure.facecolor": (0.22, 0.22, 0.22, 1.0),
        "axes.facecolor": (0.22, 0.22, 0.22, 1.0),
        "savefig.facecolor": (0.22, 0.22, 0.22, 1.0),
        "grid.color": (0.4, 0.4, 0.4, 1.0),
    }
)

plotly_template = pio.templates["plotly_dark"]
pio.templates["draft"] = go.layout.Template(
    layout=dict(
        plot_bgcolor="rgba(56,56,56,0)",
        paper_bgcolor="rgba(56,56,56,0)",
    )
)
pio.templates.default = "plotly_dark+draft"
    

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'device'
print(f"Using {device}")

## Net

In [None]:
from mlp import MLP
from heat import PINN_1D, PINN_2D

## Heat 1D

### Init Network

In [None]:
# d_fn = lambda pts: torch.sigmoid((pts[:,1:2])*10)*8 + 0.01 + torch.sigmoid((pts[:,1:2])*10 + 80)*8 + 5
d_fn = lambda pts: torch.sigmoid(torch.sign(pts[:,1:2])*8)*20 + 0.1
# d_fn = lambda pts: pts[:,1:2]*0 + 1

# ic_fn = lambda pts: pts[:,1:2]*0
# ic_fn = lambda pts: torch.sigmoid(torch.sign(pts[:,1:2]*10))*10
ic_fn = lambda pts: 10*(1/(3*np.sqrt(2*np.pi)))*torch.exp(-(pts[:,1:2])**2 / (2*3**2))

# bc_fn = lambda pts: -(torch.sin(0.5*pts[:,0:1])*20)*(torch.sign(pts[:,1:2])+1)/2
# bc_fn = lambda pts: torch.sin(pts[:,0:1]*2)*0.5 + ic_fn(pts)
# ic_fn = lambda pts: 10*(1/(3*np.sqrt(2*np.pi)))*torch.exp(-(pts[:,1:2])**2 / (2*3**2))
# bc_fn = lambda pts: -(4)*(torch.sign(pts[:,1:2])+1)/2 
# bc_fn = lambda pts: -(4)*(torch.sign(pts[:,1:2])+1)/2 + (1)*pts[:,0:1]*(-torch.sign(pts[:,1:2])+1)/2 
bc_fn = ic_fn

pinn = PINN_1D(
    net=MLP(hidden_layer_ct=4,hidden_dim=256, act=F.tanh, learnable_act="SINGLE"), 
    initial_ct=250, 
    collocation_ct=1000, 
    d_fn=d_fn,
    ic_fn=ic_fn,
    bc_fn=bc_fn,
    boundary_type=["dirchelet","dirchelet"],
    t_bounds=[0,4*np.pi], 
    space_bounds=[-4*np.pi, 4*np.pi],
    lr=1e-3
)
pinn.plot_config_and_res()

### Train

In [None]:
for i in range(15):
    print(pinn.net.a)
    if i == 0:
        pinn.adam.param_groups[0]["lr"] = 1e-4
    # if i == 0:
    #     pinn.adam.param_groups[0]["lr"] = 5e-5
    if i == 3:
        pinn.adam.param_groups[0]["lr"] = 5e-5
    if i == 5:
        pinn.adam.param_groups[0]["lr"] = 1e-5
    if i == 7:
        pinn.adam.param_groups[0]["lr"] = 5e-6
    pinn.train(n_epochs=1000, reporting_frequency=100, mode="Adam", phys_weight=1, res_weight=0.01, bc_weight=4, ic_weight=4)
    pinn.plot_3d()
    pinn.plot_config_and_res()
pinn.plot_config_and_res()

##

## Heat 2D

### Init Network

In [None]:
# d_fn = lambda pts: torch.sigmoid((pts[:,1:2])*10)*8 + 0.01 + torch.sigmoid((pts[:,1:2])*10 + 80)*8 + 5
# d_fn = lambda pts: torch.sigmoid(torch.sign(pts[:,1:2])*8)*20 + 0.1
d_fn = lambda pts: pts[:,1:2]*0 + 4

# ic_fn = lambda pts: pts[:,1:2]*0
# ic_fn = lambda pts: torch.sigmoid(torch.sign(pts[:,1:2]*10))*10
ic_fn = lambda pts: 15*(1/(5*np.sqrt(2*np.pi)))*torch.exp(-(((pts[:,1:2])**2 + (pts[:,2:3])**2)) / (2*5**2))
# ic_fn = lambda pts: pts[:,1:2]*0 

# bc_fn = lambda pts: -(torch.sin(0.5*pts[:,0:1])*20)*(torch.sign(pts[:,1:2])+1)/2
# bc_fn = lambda pts: torch.sin(pts[:,0:1]*2)*0.5 + ic_fn(pts)
# ic_fn = lambda pts: 10*(1/(3*np.sqrt(2*np.pi)))*torch.exp(-(pts[:,1:2])**2 / (2*3**2))
# bc_fn = lambda pts: -(4)*(torch.sign(pts[:,1:2])+1)/2 
# bc_fn = lambda pts: -(4)*(torch.sign(pts[:,1:2])+1)/2 + (1)*pts[:,0:1]*(-torch.sign(pts[:,1:2])+1)/2 
bc_fn = lambda pts: pts[:,1:2]*0
# bc_fn = ic_fn

pinn = PINN_2D(
    net=MLP(input_dim=3, hidden_layer_ct=10,hidden_dim=256, act=F.tanh, learnable_act="SINGLE"), 
    initial_ct=1000, 
    boundary_ct=1000,
    collocation_ct=2500, 
    d_fn=d_fn,
    ic_fn=ic_fn,
    bc_fn=bc_fn,
    boundary_type=["neumann", "neumann", "neumann", "neumann"],
    t_bounds=[0,8*np.pi], 
    space_bounds=[-4*np.pi, 4*np.pi],
    lr=1e-3,
    adaptive_resample=True
)
# pinn.plot_bcs()
# pinn.plot_frames(n=5)
# pinn.plot_3d()
pinn.plot_ics()

### Train

In [6]:
for i in range(15):
    if i == 0:
        pinn.adam.param_groups[0]["lr"] = 1e-3
    if i == 1:
        pinn.adam.param_groups[0]["lr"] = 1e-4
    if i == 2:
        pinn.adam.param_groups[0]["lr"] = 5e-5
    if i == 3:
        pinn.adam.param_groups[0]["lr"] = 1e-5
    if i == 4:
        pinn.adam.param_groups[0]["lr"] = 5e-6
    pinn.train(n_epochs=1000, reporting_frequency=100, mode="Adam", phys_weight=1, res_weight=0.01, bc_weight=8, ic_weight=4)
    # pinn.plot_frames()
    torch.save(pinn.net.state_dict(), f"./models/2d-simple-neumann2d-{i}-heated-with-wind-adaptive.pth")

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 4.381638526916504
Epoch 00100 Loss: 0.6931319832801819
Epoch 00200 Loss: 0.46909621357917786
Epoch 00300 Loss: 0.2870258688926697
Epoch 00400 Loss: 0.22368653118610382
Epoch 00500 Loss: 0.14275917410850525
Epoch 00600 Loss: 0.07241560518741608
Epoch 00700 Loss: 0.03652005270123482
Epoch 00800 Loss: 0.0320710726082325
Epoch 00900 Loss: 0.04372389614582062


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 0.023601055145263672
Epoch 00100 Loss: 0.016271183267235756
Epoch 00200 Loss: 0.01593012921512127
Epoch 00300 Loss: 0.014745509251952171
Epoch 00400 Loss: 0.016473976895213127
Epoch 00500 Loss: 0.016047408804297447
Epoch 00600 Loss: 0.01677597127854824
Epoch 00700 Loss: 0.012741058133542538
Epoch 00800 Loss: 0.014271609485149384
Epoch 00900 Loss: 0.01205350086092949


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 0.011332088150084019
Epoch 00100 Loss: 0.012753859162330627
Epoch 00200 Loss: 0.011069725267589092
Epoch 00300 Loss: 0.011033800430595875
Epoch 00400 Loss: 0.010414921678602695
Epoch 00500 Loss: 0.01146179623901844
Epoch 00600 Loss: 0.01145590003579855
Epoch 00700 Loss: 0.010397358797490597
Epoch 00800 Loss: 0.009236177429556847
Epoch 00900 Loss: 0.0113290436565876


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 0.01048200111836195
Epoch 00100 Loss: 0.010543321259319782
Epoch 00200 Loss: 0.00871056318283081
Epoch 00300 Loss: 0.008770403452217579
Epoch 00400 Loss: 0.009366831742227077
Epoch 00500 Loss: 0.008287688717246056
Epoch 00600 Loss: 0.009144569747149944
Epoch 00700 Loss: 0.00835158210247755
Epoch 00800 Loss: 0.00899388175457716
Epoch 00900 Loss: 0.007972817867994308


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 0.009529300965368748
Epoch 00100 Loss: 0.007628492079675198
Epoch 00200 Loss: 0.008181105367839336
Epoch 00300 Loss: 0.009754357859492302
Epoch 00400 Loss: 0.0074717761017382145
Epoch 00500 Loss: 0.00888130534440279
Epoch 00600 Loss: 0.008361006155610085
Epoch 00700 Loss: 0.00811695959419012
Epoch 00800 Loss: 0.009527618065476418
Epoch 00900 Loss: 0.0072561874985694885


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 0.00875427108258009
Epoch 00100 Loss: 0.008503681048750877
Epoch 00200 Loss: 0.007588635664433241
Epoch 00300 Loss: 0.007911596447229385
Epoch 00400 Loss: 0.007880032062530518
Epoch 00500 Loss: 0.007422515656799078
Epoch 00600 Loss: 0.007729413919150829
Epoch 00700 Loss: 0.007309338077902794
Epoch 00800 Loss: 0.007210635580122471
Epoch 00900 Loss: 0.007551664020866156


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 0.008034484460949898
Epoch 00100 Loss: 0.007867430336773396
Epoch 00200 Loss: 0.01041320525109768
Epoch 00300 Loss: 0.006907113362103701
Epoch 00400 Loss: 0.006819537840783596
Epoch 00500 Loss: 0.007210506591945887
Epoch 00600 Loss: 0.006845904514193535
Epoch 00700 Loss: 0.006988021079450846
Epoch 00800 Loss: 0.007609926164150238
Epoch 00900 Loss: 0.006910546217113733


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 0.006518734153360128
Epoch 00100 Loss: 0.0076891896314918995
Epoch 00200 Loss: 0.009758475236594677
Epoch 00300 Loss: 0.00722165871411562
Epoch 00400 Loss: 0.00948956236243248
Epoch 00500 Loss: 0.0064216176979243755
Epoch 00600 Loss: 0.0075678289867937565
Epoch 00700 Loss: 0.007337545044720173
Epoch 00800 Loss: 0.009251110255718231
Epoch 00900 Loss: 0.007763986010104418


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00000 Loss: 0.007784481160342693
Epoch 00100 Loss: 0.006770794745534658
Epoch 00200 Loss: 0.0063942186534404755
