## Setup

In [363]:
import torch
import torch.nn as nn
import numpy as np

from tqdm.notebook import tqdm

import matplotlib
import matplotlib.pyplot as plt
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [364]:
# @title Styling
# @markdown Making things pretty! (This is meant to work with Colab's dark theme **(Settings > Site > Theme > "Dark")**.)
# import google
# is_dark = google.colab.output.eval_js(
#     'document.documentElement.matches("[theme=dark]")'
# )

matplotlib.rcParams["figure.dpi"] = 100
plt.rcParams["hatch.color"] = "white"

# if is_dark:
# load style sheet for matplotlib, a plotting library we use for 2D visualizations
plt.style.use(
    "https://github.com/dhaitz/matplotlib-stylesheets/raw/master/pitayasmoothie-dark.mplstyle"
)
plt.style.use("dark_background")
plt.rcParams.update(
    {
        "figure.facecolor": (0.22, 0.22, 0.22, 1.0),
        "axes.facecolor": (0.22, 0.22, 0.22, 1.0),
        "savefig.facecolor": (0.22, 0.22, 0.22, 1.0),
        "grid.color": (0.4, 0.4, 0.4, 1.0),
    }
)

plotly_template = pio.templates["plotly_dark"]
pio.templates["draft"] = go.layout.Template(
    layout=dict(
        plot_bgcolor="rgba(56,56,56,0)",
        paper_bgcolor="rgba(56,56,56,0)",
    )
)
pio.templates.default = "plotly_dark+draft"
    

In [365]:
device = 'cuda' if torch.cuda.is_available() else 'device'
print(f"Using {device}")

Using cuda


## 1D Heat Diffusion (Homogenous)

In [407]:
class MLP(nn.Module):
    def __init__(self, input_dim=2, hidden_layer_ct=2, hidden_dim=128, output_dim=1, activation=nn.Tanh, *args, **kwargs) -> None:
        super(MLP,self).__init__(*args, **kwargs)
        self.input_dim = input_dim
        self.input_layer = nn.Sequential(nn.Linear(input_dim,hidden_dim), activation())
        hidden_layers = []
        for _ in range(hidden_layer_ct):
            hidden_layers.append(nn.Linear(hidden_dim,hidden_dim))
            hidden_layers.append(activation())
        self.hidden_layers = nn.Sequential(*hidden_layers)
        self.output_layer = nn.Sequential(nn.Linear(hidden_dim,output_dim))
    
    def forward(self, x):
        x = self.input_layer(x)
        x = self.hidden_layers(x)
        x = self.output_layer(x)
        return x

In [432]:
class PINN_1D:
    # TODOs: test with fixed/dynamic/non-periodic BCs, varying diffusivity field, etc
    def __init__(self, net, lr=1e-4, collocation_ct=500, t_bounds=[0,4], space_bounds=[-2*np.pi,2*np.pi], diffusivity=0.2) -> None:
        self.net = net.to(device)
        self.t_bounds = t_bounds
        self.space_bounds = space_bounds
        self.collocation_ct = collocation_ct
        self.boundary_ct = 100
        self.initial_ct = 100
        self.dim = 1 + 1
        
        self.loss_fn = nn.MSELoss()
        self.adam = torch.optim.Adam(self.net.parameters(), lr=lr)
        self.lbfgs = torch.optim.LBFGS(
            self.net.parameters(),
            history_size=50,
            tolerance_grad=1e-7, 
            tolerance_change=1.0 * np.finfo(float).eps,
            line_search_fn="strong_wolfe",   # better numerical stability
        )

        self.D_heat = diffusivity

        self.runs = []
        self.loss_history = []

        self.u_range = 50
        self.u_mid = 0
        self.u_sigma = 0.6

        self.phys_weight = 0.1
        self.grad_weight = 0.05

    def normalize_pts(self, pts):
      pts_norm = torch.empty_like(pts)
      pts_norm[:, 0] = (pts[:,0] - self.t_bounds[0]) / (self.t_bounds[1] - self.t_bounds[0]) 
      pts_norm[:, 1] = (pts_norm[:,1] - self.space_bounds[0]) / (self.space_bounds[1]- self.space_bounds[0]) 
      return pts
      # return pts_norm*2-1
      

    def unnormalize_pts(self, pts_norm):
      pts = torch.empty_like(pts_norm)
      pts[:, 0] = pts_norm[:,0] * (self.t_bounds[1] - self.t_bounds[0]) + self.t_bounds[0]
      pts[:, 1] = pts_norm[:,1] * (self.space_bounds[1]- self.space_bounds[0]) + self.space_bounds[0]
      return pts_norm
      # return (pts+1)/2
    
    def sample_ics(self):
        # torch.manual_seed(0)
        pts = torch.rand((self.initial_ct, self.dim), device=device, requires_grad=True) * (self.space_bounds[1]-self.space_bounds[0]) + self.space_bounds[0]
        pts[:,0] = 0 # time at start

        # Gaussian start
        # u = self.u_range*(1/(self.u_sigma*np.sqrt(2*np.pi)))*torch.exp(-(pts[:,1])**2 / (2*self.u_sigma**2)).reshape(-1,1) + self.u_mid
        
        # sinusoidal start
        u = self.u_range/2*(torch.cos(pts[:,1])).reshape(-1,1) + self.u_mid

        # parabolic start
        # u = (-0.5*(pts[:,1]-self.space_bounds[0])*(pts[:,1]+self.space_bounds[1])).reshape(-1,1)

        return pts, u
    
    def sample_bcs(self):
        # torch.manual_seed(1)
        half = int(self.boundary_ct/2)
        pts_left = torch.ones((half, self.dim), device=device, requires_grad=True) * self.space_bounds[0]
        pts_right = torch.ones((half, self.dim), device=device, requires_grad=True) * self.space_bounds[1]

        # stack = torch.vstack([pts_left, pts_right])
        # stack[:,0] = torch.rand(2*half) * (self.t_bounds[1] - self.t_bounds[0]) + self.t_bounds[0]
        # pts = stack.to(device).requires_grad_(True)

        # u = torch.zeros_like(pts).to(device)
        # return pts, u

        # u = self.u_range*(1/(self.u_sigma*np.sqrt(2*np.pi)))*torch.exp(-(pts[:,1])**2 / (2*self.u_sigma**2)).reshape(-1,1) + self.u_mid
        # u = torch.maximum(-torch.abs(4*pts[:,1])+20, torch.zeros_like(u).flatten()).reshape(-1,1)
        # return pts, u

        t = torch.rand(half) * (self.t_bounds[1] - self.t_bounds[0]) + self.t_bounds[0]
        pts_left[:,0] = t
        pts_right[:,0] = t
        pts = torch.vstack([pts_left, pts_right])
        return pts, pts_left, pts_right # for periodic bcs

    def sample_collocation(self):
        # torch.manual_seed(2)
        pts_rand = torch.rand((self.collocation_ct, self.dim), device=device).requires_grad_(True)
        pts = torch.empty_like(pts_rand)
        limiter = 0.995
        pts[:,0] = pts_rand[:,0] * (self.t_bounds[1]- self.t_bounds[0])*limiter + self.t_bounds[0] + (1-limiter) * (self.t_bounds[1]- self.t_bounds[0]) 
        pts[:,1] = pts_rand[:,1] * (self.space_bounds[1]- self.space_bounds[0])*limiter + self.space_bounds[0] + (1-limiter)/2 * (self.space_bounds[1] - self.space_bounds[0])
        return pts

    def train_step(self):

        """Sample pts"""
        pts_ic, u_ic = self.sample_ics()
        pts_col = self.sample_collocation()
        # pts_bc, u_bc = self.sample_bcs()
        pts_bc, pts_bc_l, pts_bc_r = self.sample_bcs()
        

        """Reset grads"""
        self.adam.zero_grad()
        self.lbfgs.zero_grad()

        """Initial Conditions"""
        u_ic_pred = self.net(self.normalize_pts(pts_ic))
        loss_u_ic = self.loss_fn(u_ic_pred,u_ic)

        """Boundary Pts"""
        # u_bc_pred = self.net(self.normalize_pts(pts_bc))
        # loss_u_bc = self.loss_fn(u_bc_pred,u_bc)
        u_bc_pred_l = self.net(self.normalize_pts(pts_bc_l))
        u_bc_pred_r = self.net(self.normalize_pts(pts_bc_r))
        loss_u_bc = self.loss_fn(u_bc_pred_l,u_bc_pred_r) # periodic bcs

        """Collocation pts"""
        pts_col_norm = self.normalize_pts(pts_col)
        u_col_pred = self.net(pts_col_norm)

        u_prime_pred = torch.autograd.grad(
            inputs=pts_col_norm,
            outputs= u_col_pred,
            grad_outputs=torch.ones_like(u_col_pred).to(device),
            create_graph=True
        )[0]


        u_pprime_pred = torch.autograd.grad(
            inputs=pts_col_norm,
            outputs=u_prime_pred,
            grad_outputs=torch.ones_like(u_prime_pred).to(device),
            create_graph=True
        )[0]
        
        u_x = torch.sum(u_prime_pred[:,1:],axis=1)
        u_xx = torch.sum(u_pprime_pred[:,1:],axis=1)
        u_t = u_prime_pred[:,0].reshape(-1,1)
 
        residual_diffusion = u_t - self.D_heat * u_xx # Heat Equation
        residual_regularizer = u_t + self.D_heat * u_x

        r2 = residual_regularizer**2 + residual_diffusion**2
        loss_physics = torch.mean(r2)
        loss_physics_grad_penalty = torch.maximum(torch.max(r2),torch.Tensor([10000]).to(device)) #torch.max(r2)

        """Known Data"""
        # known_pts = torch.linspace(*self.t_bounds,7).to(device).requires_grad_(True).reshape(-1,1)
        # u_known_true = self.oscillator_true(known_pts)
        # u_known_pred = self.net(self.normalize_pts(nown_pts))
        # loss_known = self.loss_fn(u_known_pred, u_known_true)

        """Loss"""
        loss = loss_u_ic + self.phys_weight*loss_physics + loss_u_bc #+ self.grad_weight*loss_physics_grad_penalty
        loss.backward()
        return loss
    
    def train(self, n_epochs=500, mode='Adam', reporting_frequency=500, phys_weight=None, grad_weight=None):
        if phys_weight is not None:
          self.phys_weight = phys_weight
        if grad_weight is not None:
          self.grad_weight = grad_weight
        results = {
            "loss": []
        }
        for it in tqdm(range(n_epochs)):
            loss = self.adam.step(self.train_step) if mode == "Adam" else self.lbfgs.step(self.train_step)
            results["loss"].append(loss.item())
            self.loss_history.append(loss.item())
            if it % reporting_frequency == 0:
              print(f"Loss: {loss.item()}")
        self.runs.append(results)
    
    def plot_ics_and_bcs(self):
        with torch.no_grad():
            pts_bc, _, _ = self.sample_bcs()
            u_bc = self.net(self.normalize_pts(pts_bc))
            # pts_bc, u_bc = self.sample_bcs()

            pts_ic, u_ic = self.sample_ics()
            pts_bc = pts_bc.detach().cpu()
            pts_ic = pts_ic.detach().cpu()
            plt.scatter(pts_bc[:,0].flatten(), pts_bc[:,1].flatten() ,cmap="plasma", s=1, c=u_bc.detach().cpu().flatten(), vmin=self.u_mid-self.u_range/2, vmax=self.u_mid+self.u_range/2)
            plt.scatter(pts_ic[:,0].flatten(), pts_ic[:,1].flatten() ,cmap="plasma", s=1, c=u_ic.detach().cpu().flatten(), vmin=self.u_mid-self.u_range/2, vmax=self.u_mid+self.u_range/2)

    def plot_pts(self):
        with torch.no_grad():
            n_old = self.collocation_ct
            self.collocation_ct = 10000
            pts = self.sample_collocation()
            self.collocation_ct = n_old
            u = self.net(self.normalize_pts(pts))
            pts = pts.detach().cpu()
            plt.scatter(pts[:,0].flatten(), pts[:,1].flatten() ,cmap="plasma", s=1, c=u.detach().cpu().flatten(), vmin=self.u_mid-self.u_range/2, vmax=self.u_mid+self.u_range/2)

    def plot_all_pts(self):
        with torch.no_grad():
            """Col pts"""
            n_old = self.collocation_ct
            self.collocation_ct = 10000
            pts_col = self.sample_collocation()
            self.collocation_ct = n_old
            u_col = self.net(self.normalize_pts(pts_col))

            """BC pts"""
            pts_bc, _, _ = self.sample_bcs()
            u_bc = self.net(self.normalize_pts(pts_bc))
            # pts_bc, u_bc = self.sample_bcs()

            """IC pts"""
            pts_ic, u_ic = self.sample_ics()

            """Collect"""
            pts = torch.vstack([pts_col,pts_bc, pts_ic])
            u   = torch.vstack([u_col, u_bc, u_ic])
            pts = pts.detach().cpu()
            u = u.detach().cpu()

            """Plot"""
            plt.scatter(pts[:,0].flatten(), pts[:,1].flatten() ,cmap="plasma", s=1, c=u.detach().cpu().flatten())
    
    def plot_3d(self,res=200):
      with torch.no_grad():
        fig = go.Figure()
        t = torch.linspace(*self.t_bounds,res).to(device)
        x = torch.linspace(*self.space_bounds,res).to(device)
        xx, tt = torch.meshgrid(x,t, indexing="xy")
        X = torch.stack([tt,xx]).T.reshape(-1,2)
        y = self.net(X).reshape(res,res).cpu()
        scatter_plot = go.Surface(
          x=tt.cpu(),
          y=xx.cpu(),
          z=y.T,
        )
        fig.add_trace(scatter_plot)
        fig.update_layout(
          coloraxis=dict(colorscale='plasma'), 
          showlegend=False,
        )

        fig.update_scenes(
          xaxis_title_text='t [s]',  
          yaxis_title_text='x [m]',  
          zaxis_title_text='T [deg C]',
          xaxis_showbackground=False,
          yaxis_showbackground=False,
          zaxis_showbackground=False,
        )
        fig.show()
              

In [433]:
diffusivities = [
    8,5,2,0.5
]
t_maxes = [
    4,4,12,12
]

pinns = []
for d,t in zip(diffusivities,t_maxes):
  pinn = PINN_1D(net=MLP(hidden_layer_ct=4,hidden_dim=256), collocation_ct=1000, diffusivity=d, t_bounds=[0,t], lr=1e-4)
  pinn.train(n_epochs=5000, reporting_frequency=1000, mode="Adam", phys_weight=0.1, grad_weight=0.1)
  pinns.append(pinn)

  0%|          | 0/5000 [00:00<?, ?it/s]

Loss: 302.5671691894531
Loss: 196.86676025390625
Loss: 182.23684692382812
Loss: 191.5963592529297
Loss: 213.74244689941406


  0%|          | 0/5000 [00:00<?, ?it/s]

Loss: 293.625244140625
Loss: 113.26795196533203
Loss: 120.15091705322266
Loss: 122.09503173828125
Loss: 125.6242904663086


  0%|          | 0/5000 [00:00<?, ?it/s]

Loss: 313.7002258300781
Loss: 13.70343017578125
Loss: 9.903806686401367
Loss: 12.149616241455078
Loss: 15.083488464355469


  0%|          | 0/5000 [00:00<?, ?it/s]

Loss: 253.45472717285156
Loss: 5.4922709465026855
Loss: 3.0515682697296143
Loss: 2.9329214096069336
Loss: 2.64601469039917


In [434]:
for pinn in pinns:
  pinn.plot_3d()

Output hidden; open in https://colab.research.google.com to view.