<a href="https://colab.research.google.com/github/vmattey/SeparablePINN_AC_Codes/blob/main/SPINN_AC_DeepRitz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Code for Separable PINN in PyTorch
# Solving 1D Allen Cahn Equation

import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import torch.optim as optim
import torch.func as ft
import numpy as np
import time
import torch.jit as jit
import scipy
import matplotlib.pyplot as plt
import torch.optim.lr_scheduler as lr_scheduler
from torch.autograd import Variable
# Seed for randomization
SEED = 444

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
##############################################################
# Neural Network definitions

class NeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, activation='gelu'):
        super(NeuralNetwork, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size

        # Create a list of hidden layers based on user-defined sizes
        self.hidden_layers = nn.ModuleList()
        prev_size = input_size  # Initialize the input size
        for size in hidden_sizes:
            self.hidden_layers.append(nn.Linear(prev_size, size))
            prev_size = size

        if activation == 'tanh':
            self.act_fun = nn.Tanh()
        else:
            self.act_fun = nn.LeakyReLU()

        self.output_layer = nn.Linear(hidden_sizes[-1], output_size)



    def forward(self, X):
        for layer in self.hidden_layers:
            X = self.act_fun(layer(X))
        X = self.output_layer(X)
        return X

class Combined(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, activation):
        super(Combined, self).__init__()
        self.model1 = NeuralNetwork(input_size, hidden_sizes, output_size, activation)
        self.model2 = NeuralNetwork(input_size, hidden_sizes, output_size, activation)

    def forward(self, x, t):
        scaled_x = x
        scaled_t = t
        model1_output = self.model1(scaled_x)
        model2_output = self.model2(scaled_t)
        u = torch.matmul(model1_output, model2_output.T)
        return u


In [None]:
##############################################################
# Auxillary Functions
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)


def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: ft.jvp(f, (primals,), (tangents,))[1]
    primals_out, tangents_out = ft.jvp(g, (primals,), (tangents,))
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out

In [None]:
##############################################################
# Loss Functions

def spinn_loss(apply_fn, ad_fn, tau, train_data, train_data_ic):

    def residual_loss(x,t):
        # calculate u
        u = apply_fn(x,t)
        # tangent vector dx/dx
        # assumes t, x, y have same shape (very important)
        v = torch.ones(x.shape)
        # 2nd derivatives of u
        ux,uxx = hvp_fwdfwd(lambda x: ad_fn(x,t), x, v,return_primals=True)
        ut = ft.jvp(lambda t: ad_fn(x,t), (t,), (v,))[1]
        #return torch.mean((ut-0.0001*uxx+5*(u**3-u))**2) - 1e-6*(torch.mean(torch.log10(u**2 + ux**2 + uxx**2)) + torch.mean(torch.log10((u-1)**2 + ux**2 + uxx**2))  + torch.mean(torch.log10((u+1)**2 + ux**2 + uxx**2)))
        return torch.mean((ut-0.0001*uxx+5*(u**3-u))**2)


    def initial_loss(x,t,u):
        return torch.mean((apply_fn(x,t) - u)**2)

    def moving_loss(x,u,tau,h):
        return torch.sum(h*(apply_fn(x) - u)**2)/(2*tau)

    def boundary_loss(x):

        loss_u = torch.mean((apply_fn(x[0]) - apply_fn(x[1]))**2)
        v = torch.ones(x[0].shape)
        ux_lb =  ft.jvp(lambda x: ad_fn(x), (x[0],), (v,))[1]
        ux_ub =  ft.jvp(lambda x: ad_fn(x), (x[1],), (v,))[1]
        loss_ux = torch.mean((ux_lb - ux_ub)**2)
        return loss_u + loss_ux

    def energy_loss(x, h):
        u = apply_fn(x)
        v = torch.ones(x.shape)
        ux =  ft.jvp(lambda x: ad_fn(x), (x,), (v,))[1]
        f = h*(0.5*ux**2 + 12500*(u**2 - 1)**2)
        #f = h*(0.5*u**2)
        return torch.sum(f)

    # unpack data
    xg, xb, h, _ = train_data
    xi, ui = train_data_ic
    ngpt = xg.size()[0]
    # Computing the loss value
    #res_loss = residual_loss(xc,tc)
    moving_loss = moving_loss(xi, ui, tau, h)
    ener_loss = energy_loss(xg, h)
    bound_loss = boundary_loss(xb)
    loss = ener_loss + ngpt*bound_loss + moving_loss

    return loss.to(device), ener_loss, bound_loss, moving_loss

def icgl_loss(apply_fn, train_data_icgl):
    x, u = train_data_icgl
    loss = torch.mean((apply_fn(x) - u)**2)
    return loss.to(device)

In [None]:
##############################################################
# Generating the training data

def spinn_train_generator_AC1D(nc, keys):

    # collocation points
    xc = -1 + 2*torch.rand((nc, 1), generator=keys[1])
    tc = 0.5*torch.rand((nc, 1), generator=keys[0])
    xc_mesh, tc_mesh = torch.meshgrid(xc.ravel(), tc.ravel(), indexing='ij')

    # initial points
    ti = torch.zeros((1, 1))
    xi = xc
    ui = (xi**2)*torch.cos(torch.pi*xi)


    # boundary points (hard-coded)
    xb = [-1*torch.ones(1,1).to(device)]  + [1*torch.ones(1,1).to(device)]
    tb = 0.5*torch.rand(nc,1)

    return xc.to(device), tc.to(device), xi.to(device), ti.to(device), ui.to(device), xb, tb.to(device)

def gausspts_train_generator_AC1D(nc):

    # Domain Points
    xd = torch.linspace(-1, 1, nc)

    xleft = xd[:-1]
    xright = xd[1:]
    xg = torch.zeros(((2*(nc-1)),1))
    xg1 = -(1/3**0.5)*(xright - xleft)/2 + (xright + xleft)/2
    xg2 = (1/3**0.5)*(xright - xleft)/2 + (xright + xleft)/2
    xg[0:-1:2,0] = xg1
    xg[1::2,0] = xg2

    xb = [-1*torch.ones(1).to(device)]  + [1*torch.ones(1).to(device)]
    h = (xd[2] - xd[1])/2
    return xg.to(device), xb, h.to(device), xd.to(device)

def icgl_train_generator_AC1D(nc):

    xd  = torch.linspace(-1, 1, nc)
    xleft = xd[:-1]
    xright = xd[1:]
    xg = torch.zeros(((2*(nc-1)),1))
    xg1 = -(1/3**0.5)*(xright - xleft)/2 + (xright + xleft)/2
    xg2 = (1/3**0.5)*(xright - xleft)/2 + (xright + xleft)/2
    xg[0:-1:2,0] = xg1
    xg[1::2,0] = xg2
    xi = xg

    ui = (xi**2)*torch.cos(torch.pi*xi)*torch.exp(-xi**2)
    #ui = torch.sin(xi*torch.pi)

    return xi.to(device), ui.to(device)

In [None]:
##############################################################
# Training functions and utilities

def train_step(loss_fn,optimizer,epoch, lossVal, sol_list, tau, train_data_gauss, train_data_icgl):
    # clear the gradients
    optimizer.zero_grad()
    # Loss
    loss_spinn, ener_loss, bound_loss, moving_loss = loss_fn(spinn, spinn, tau, train_data_gauss, train_data_icgl)
    loss_value = loss_spinn.detach().cpu().numpy()
    lossVal.append(loss_value)
    _, _, _, xd = train_data_gauss
    sol = spinn(xd.reshape(xd.shape[0],1))
    sol_list.append(sol)
    end_time = time.time()
    if epoch%1000 == 0:
        print('Energy Loss:',ener_loss.detach().cpu().numpy(),', Bound Loss:',bound_loss.detach().cpu().numpy(),', Moving Loss:',moving_loss.detach().cpu().numpy(),', Total Loss:',loss_value, ', iter:', epoch)
    loss_spinn.backward()
    # Update model weights
    optimizer.step()
    return loss_spinn

def train_step_icgl(loss_fn,optimizer,epoch,train_data_icgl):
    # clear the gradients
    optimizer.zero_grad()
    # Loss
    loss_ic = loss_fn(spinn, train_data_icgl)
    loss_value = loss_ic.detach().cpu().numpy()
    end_time = time.time()
    if epoch%1000 == 0:
        print(' Total Loss:',loss_value, ', iter:', epoch)
    loss_ic.backward()
    # Update model weights
    optimizer.step()
    return loss_ic


def closure():
    # Zero gradients
    lbfgs.zero_grad()

    # Compute loss
    loss, ener_loss, bound_loss, init_loss = spinn_loss(spinn, spinn, *train_data_gauss)

    # Backward pass
    loss.backward()

    return loss, ener_loss, bound_loss, init_loss

def repackage_hidden(h):
    """Wraps hidden states in new Variables, to detach them from their history."""
    if type(h) == Variable:
        return Variable(h.data)
    else:
        return tuple(repackage_hidden(v) for v in h)


In [None]:
##############################################################
# Main Code

# random key
g_cpu = torch.Generator()
keys =  [g_cpu.manual_seed(SEED),g_cpu.manual_seed(SEED),g_cpu.manual_seed(SEED)]


# dataset

nc = 512 # user input
train_data = spinn_train_generator_AC1D(nc, keys)
train_data_gauss = gausspts_train_generator_AC1D(nc)
train_data_icgl = icgl_train_generator_AC1D(nc)


# User Input for Size of Neural Network
input_size = 1  # You can change this to the desired number of input features
hidden_sizes = [64, 64, 64, 64, 64, 64]  # You can specify the number of neurons in each hidden layer
output_size = 1
epochs_icgl = 1001
epochs_pinn = 1001
tau = 0.00001/2
activation = 'gelu' # Choose either tanh or gelu
N = 512
xgrid = torch.linspace(-1, 1, N).to(device)
t = 0

# Create an instance of the neural network
spinn = NeuralNetwork(input_size,hidden_sizes,output_size, activation).to(device)
# spinn = torch.jit.script(spinn).to(device)
# spinn.apply(init_weights)


# Define an optimizer
adam = optim.Adam(spinn.parameters(),lr=0.001)  # You can adjust the learning rate (lr) as needed
scheduler = lr_scheduler.LinearLR(adam,start_factor=1,end_factor=0.1,total_iters=epochs_pinn)
lbfgs = optim.LBFGS(spinn.parameters(), history_size=4, max_iter=10)
lossVal = []
sol_list = []
upred = []


start = time.time()
# Training with ADAM for Initial Condition
for epoch in range(epochs_icgl):
        start_time = time.time()
        loss_fn = icgl_loss
        train_step_icgl(loss_fn,adam,epoch,train_data_icgl)

upred.append(spinn(xgrid.reshape(N,1)))
# Training with ADAM
for epoch in range(epochs_pinn):
        start_time = time.time()
        loss_fn = spinn_loss
        train_step(loss_fn,adam,epoch, lossVal, sol_list,tau,train_data_gauss, train_data_icgl)
        scheduler.step()

upred.append(spinn(xgrid.reshape(N,1)))
t += tau
print('Sim Time: ', t)

for i in range(20):
    xi, ui = train_data_icgl
    ui = spinn(xi)
    ui.detach_()
    train_data_icgl = xi,ui

    # Training with ADAM
    for epoch in range(epochs_pinn):
            start_time = time.time()
            loss_fn = spinn_loss
            train_step(loss_fn,adam,epoch, lossVal, sol_list,tau,train_data_gauss, train_data_icgl)
            scheduler.step()


    upred.append(spinn(xgrid.reshape(N,1)))
    t += tau
    print('Sim Time: ', t)

end = time.time()
print('Total time taken for training: ',end-start)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [None]:
# LBFGS(Optional)
# Training with LBFGS
for epoch in range(10000):
        running_loss = 0.0
        # Update weights
        lbfgs.step(closure)
        # Update the running loss
        loss, ener_loss, bound_loss, init_loss = closure()
        running_loss += loss.item()
        if epoch%100 == 0:
            print('Energy Loss:',ener_loss.detach().numpy(),', Bound Loss:',bound_loss.detach().numpy(),', Initial Loss:',init_loss.detach().numpy(),', Total Loss:',loss, ', iter:', epoch)
        if running_loss <= 1e-5:
            exit

## Model Saving Utilities

In [None]:
# Saving the Model
filename = 'AC_withConstraintGELU_Fulltime.pt'
checkpt = {'model_params':spinn.state_dict(),
                    'optimizer':adam.state_dict(),
                    'net': spinn
                    }
torch.save(checkpt,filename)

In [None]:
# Saving the solution data to mat file
u_pred = []
for u in upred:
    u_pred.append(u.detach().numpy())

umat = {'u': u_pred}
scipy.io.savemat('AC_DeepRitz1D_oldeqn.mat',umat)

## Plotting Utilities

In [None]:
#%% Plotting the results
def plot_ac(xtrain,uexact,upred,t_val):
  plt.figure()
  plt.scatter(xtrain,uexact,label='exact')
  plt.scatter(xtrain,upred.detach().numpy(),label='predicted')
  plt.legend(loc=1)
  plt.title(f"t = {t_val}")
  plt.show()

def plot_ac_cmap(x,t,uexact,upred):
    tmesh,xmesh = np.meshgrid(t,x)
    upred = upred.detach().numpy()
    plt.figure()
    plt.pcolor(tmesh,xmesh,upred,cmap='coolwarm')
    plt.title('Predicted')
    plt.figure()
    plt.pcolor(tmesh,xmesh,uexact,cmap='coolwarm')
    plt.title('Exact')


t_val = 0.5 #change this variable based on the time snapshot

# Chebfun Solution
data = scipy.io.loadmat('AC_R_1.mat')
x = data['x']
t = data['tt']
u = data['uu']
uexact = u[:,int(t_val*200)]


# Neural Network Prediction
xtest = torch.tensor(x.T,dtype=torch.float32)
ttest = t_val*torch.ones(1,1)
upred = spinn(xtest,ttest)

plot_ac(x,uexact,upred,t_val)

# For 2D surface plot
ttest = torch.tensor(t.T[:int(t_val*200)],dtype=torch.float32)
upred = spinn(xtest,ttest)
plot_ac_cmap(x,t.T[:int(t_val*200)],u[:,:int(t_val*200)],upred)


In [None]:
# Plotting results for Deep Ritz Testing
step = 20
N = 512
x = torch.linspace(-1, 1, N).resize(N,1)
ytrue = x**2*torch.cos(torch.pi*x)
ytrue = ytrue.detach().numpy()
ypred = upred[step]

rmse = np.sqrt(np.sum((ypred.detach().numpy() - ytrue)**2)/N)

plt.figure()
plt.scatter(x.detach().numpy(), ytrue, label ='IC')
plt.scatter(x.detach().numpy(), ypred.detach().numpy(), label ='Predicted')
plt.legend(loc = 1)
plt.title('rmse = %.6f' %rmse)