In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from scipy.optimize import minimize
import math



In [None]:
class PINN(nn.Module):
    def __init__(self,num_layers = 5, hidden_size = 60, output_size = 1, input_size = 1):
        
        super().__init__()
        h = hidden_size
        assert num_layers >=2
        self.fc_u = nn.ModuleList()
        self.ln_u = nn.ModuleList()
#         self.fc_lambda = nn.ModuleList()
#         self.ln_lambda = nn.ModuleList()
        self.fc_u.append(nn.Linear(input_size, h))
        self.ln_u.append(nn.Tanh())
        for _ in range(num_layers - 2):
            self.fc_u.append(nn.Linear(h, h))
            self.ln_u.append(nn.Tanh())
        self.fc_u.append(nn.Linear(h, output_size))
#         self.fc_lambda.append(nn.Linear(input_size, h))
#         self.ln_lambda.append(nn.LayerNorm(h))
        
        
#         for _ in range(num_layers - 2):
#             self.fc_lambda.append(nn.Linear(h, h))
#             self.ln_lambda.append(nn.Tanh())
#         self.fc_lambda.append(nn.Linear(h, output_size))
        
#         self.fc_lambda.apply(lambda_weights_init)
#         self.fc_u.apply(u_weights_init)

    def forward(self, x):
        u = self.u_net(x)
#         lambda_val = self.lambda_net(x)
        return u #, lambda_val

    def u_net(self, x):
        for i in range(len(self.fc_u) - 1):
            layer = self.fc_u[i]
            tanh = self.ln_u[i]
            x = layer(x)
            x = tanh(x)
        x = self.fc_u[-1](x)
        return x

#     def lambda_net(self, x):
#         for i in range(len(self.fc_lambda) - 1):
#             layer = self.fc_lambda[i]
#             tanh = self.ln_lambda[i]
#             x = layer(x)
#             x = tanh(x)
            
#         x = self.fc_lambda[-1](x)
#         x = -torch.exp(x)
#         return x
    
# def lambda_weights_init(m):
#     if isinstance(m, nn.Linear):
#         torch.nn.init.normal_(m.weight, mean=-1 , std=0.4)  # Initialize weights with a normal distribution centered at -0.5
#         m.bias.data.fill_(0.01)
        
# def u_weights_init(m):
#     if isinstance(m, nn.Linear):
#         torch.nn.init.normal_(m.weight, mean=2, std=1)  # Initialize weights with a normal distribution centered at 2
#         m.bias.data.fill_(0.01)

def cal_u_exact(x):
    # x = np.linspace(0,1,100)
    u = torch.zeros_like(x)
    cond1 = (x>=0) * (x< 1/(2*math.sqrt(2)))
    cond2 = (x>= 1/(2*math.sqrt(2))) * (x< 0.5)
    cond3 = (x>=0.5) * (x<1- 1/(2*math.sqrt(2)))
    cond4 = (x>= ( 1 - 1/(2*math.sqrt(2)))) * (x<=1)
    x1 = x[cond1]
    x2 = x[cond2]
    x3 = x[cond3]
    x4 = x[cond4]
    u[cond1] = (100 - 50 * math.sqrt(2)) *x1
    u[cond2] = 100 * x2 * (1 - x2) - 12.5
    u[cond3] = 100 * x3 * (1 - x3) - 12.5
    u[cond4] = (100 - 50 * math.sqrt(2)) * (1 - x4)
    # if config['visual'] == True:
    # plt.plot(x,u);plt.show()
    return u
        
def g(x):
    # x = np.linspace(0,1,100)
    out = torch.zeros_like(x)
    cond1 = (x>=0) * (x<0.25)
    cond2 = (x>= 0.25) * (x< 0.5)
    cond3 = (x>=0.5) * (x < 0.75)
    cond4 = (x>=0.75) * (x <= 1.0)
    x1 = x[cond1]
    x2 = x[cond2]
    x3 = x[cond3]
    x4 = x[cond4]
    out[cond1] = 100 * x1**2
    out[cond2] = 100 * x2 * (1 - x2) - 12.5
    out[cond3] = 100 * x3 * (1 - x3) - 12.5
    out[cond4] = 100 * (1 - x4)**2
    # if config['visual'] == True:
    #     plt.plot(x,g);plt.show()
    return out


def differentiable_heaviside(x, epsilon=1e-6):
    return (torch.tanh(x / epsilon) + 1) / 2



def loss_function(u, x, g, gamma):
    u_prime = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0]
    u_doubleprime = torch.autograd.grad(u_prime, x, grad_outputs=torch.ones_like(u_prime), create_graph=True, retain_graph=True)[0]
    
#     loss = (10**(-3)) * torch.sum(torch.square(torch.min((- u_doubleprime + lambda_val), torch.zeros_like(- u_doubleprime - lambda_val))))
#     loss = torch.sum(torch.square(torch.mul(torch.heaviside((u - g), values = torch.tensor([1.0])), u_doubleprime)))
    
    heaviside_approx = differentiable_heaviside(u - g)  # Use the differentiable Heaviside function
    
    loss = torch.sum(torch.square(torch.mul(heaviside_approx, u_doubleprime)))
    
    obstacle_term = torch.relu(g - u) #torch.min((u - g), torch.zeros_like(u - g))
    loss += torch.sum(torch.square(obstacle_term))    
#     loss += torch.sum(torch.square(lambda_val + (1/gamma) * torch.max(torch.zeros_like(u), g - gamma * lambda_val - u)))  # lambda = -(1/gamma) * max(0, u - gamma * lambda - g)
    loss += (10**(3)) * torch.sum(torch.square(u[0]))
    loss += (10**(3)) * torch.sum(torch.square(u[-1]))    

    return torch.sum(loss)

def train(model, x, g, gamma, epochs=6000, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        optimizer.zero_grad()
        u = model(x)
        
        energy = loss_function(u, x, g, gamma)
        loss = energy 
        loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        if epoch % 100 == 0:
            
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")
            x_np = x.detach().numpy()
            u_np = model.u_net(x).detach().numpy()
            u_exact_np = cal_u_exact(x).detach().numpy()

            plt.figure(figsize=(6, 6))
            plt.plot(x_np, u_np, label='Approximate Solution')
            plt.plot(x_np, u_exact_np, label='Exact u(x)')
#             plt.plot(x_np, g.detach().numpy(), label='Obstacle g(x)')
#             plt.plot(x_np, lambda_np, label='Lambda Function')
            plt.xlabel('x')
            plt.ylabel('u(x)')
#             plt.title('Approximate Solution; '+f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")
            plt.grid(True)
            plt.legend()
            plt.show()
            
#         if loss < 3000:
#             break


x = torch.linspace(0, 1, 500).reshape(-1, 1)
x.requires_grad_(True)

gx = g(x)  # obstacle function
gamma = 10

model = PINN()



train(model, x, gx, gamma)
