In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# Set random seeds for reproducibility
torch.manual_seed(1234)
np.random.seed(1234)

In [3]:
# --- 1. Define the Neural Network Architecture ---
class PINN(nn.Module):
    def __init__(self, input_dim=2, output_dim=1, num_hidden_layers=4, num_neurons_per_layer=50):
        super(PINN, self).__init__()
        
        layers = [nn.Linear(input_dim, num_neurons_per_layer), nn.Tanh()]
        for _ in range(num_hidden_layers - 1):
            layers.append(nn.Linear(num_neurons_per_layer, num_neurons_per_layer))
            layers.append(nn.Tanh())
        layers.append(nn.Linear(num_neurons_per_layer, output_dim))
        
        self.net = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.net(x)

In [15]:
# --- 2. Define the PINN Solver ---
class PINN_WaveEquation_Implementation:
    def __init__(self, c=1.0, num_hidden_layers=4, num_neurons_per_layer=50, learning_rate=0.001):
        self.c = c
        self.model = PINN(num_hidden_layers=num_hidden_layers, num_neurons_per_layer=num_neurons_per_layer)
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.loss_fn = nn.MSELoss() # Mean Squared Error Loss

        # Device configuration
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        print(f"Using device: {self.device}")
    
        # Helper method to compute derivatives using torch.autograd.grad
    def compute_derivatives(self, u, x, t):
        # First derivatives
        # We need create_graph=True for the first derivative to compute second derivative later
        u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
        u_t = torch.autograd.grad(u, t, torch.ones_like(u), create_graph=True)[0]

        # Second derivatives
        u_xx = torch.autograd.grad(u_x, x, torch.ones_like(u_x))[0]
        u_tt = torch.autograd.grad(u_t, t, torch.ones_like(u_t))[0]
        
        return u_x, u_t, u_xx, u_tt

    def compute_loss(self, X_collocation, X_initial, X_boundary):
        # Ensure inputs require gradients
        x_c = X_collocation[:, 0:1].clone().detach().requires_grad_(True)
        t_c = X_collocation[:, 1:2].clone().detach().requires_grad_(True)
        
        # Predict u at collocation points
        u_pred_c = self.model(torch.cat([x_c, t_c], dim=1))

        # Compute derivatives for PDE loss
        u_x, u_t, u_xx, u_tt = self.compute_derivatives(u_pred_c, x_c, t_c)
        
        # PDE residual (f_u = u_tt - c^2 * u_xx = 0)
        pde_residual = u_tt - self.c**2 * u_xx
        loss_pde = self.loss_fn(pde_residual, torch.zeros_like(pde_residual))

        # --- Initial Conditions Loss ---
        x_ic = X_initial[:, 0:1].clone().detach().requires_grad_(True)
        t_ic = X_initial[:, 1:2].clone().detach().requires_grad_(True) # Should be all zeros
        
        u_pred_ic = self.model(torch.cat([x_ic, t_ic], dim=1))
        # Compute u_t at IC points
        u_t_pred_ic = torch.autograd.grad(u_pred_ic, t_ic, torch.ones_like(u_pred_ic), create_graph=True)[0]
        
        # u(x, 0) = sin(pi * x)
        u_initial_true = torch.sin(torch.pi * x_ic)
        loss_ic_u = self.loss_fn(u_pred_ic, u_initial_true)

        # du/dt(x, 0) = 0
        loss_ic_ut = self.loss_fn(u_t_pred_ic, torch.zeros_like(u_t_pred_ic))

        # --- Boundary Conditions Loss ---
        x_bc = X_boundary[:, 0:1].clone().detach().requires_grad_(True)
        t_bc = X_boundary[:, 1:2].clone().detach().requires_grad_(True)
        
        u_pred_bc = self.model(torch.cat([x_bc, t_bc], dim=1))

        # u(0, t) = 0 and u(1, t) = 0
        loss_bc = self.loss_fn(u_pred_bc, torch.zeros_like(u_pred_bc))

        # Total Loss (can add weights)
        total_loss = loss_pde + loss_ic_u + loss_ic_ut + loss_bc
        return total_loss, loss_pde, loss_ic_u, loss_ic_ut, loss_bc

    def train(self, num_epochs, num_collocation_points, num_initial_points, num_boundary_points):
        # Generate training points on CPU, then move to device
        # Collocation points
        X_collocation_np = np.random.rand(num_collocation_points, 2)
        X_collocation_np[:, 0] = X_collocation_np[:, 0] * 1.0 # x in [0, 1]
        X_collocation_np[:, 1] = X_collocation_np[:, 1] * 1.0 # t in [0, 1]
        X_collocation = torch.from_numpy(X_collocation_np).float().to(self.device)

        # Initial condition points (t=0)
        X_initial_np = np.random.rand(num_initial_points, 2)
        X_initial_np[:, 0] = X_initial_np[:, 0] * 1.0 # x in [0, 1]
        X_initial_np[:, 1] = 0.0                     # t = 0
        X_initial = torch.from_numpy(X_initial_np).float().to(self.device)

        # Boundary condition points (x=0 or x=1)
        X_boundary_left_np = np.random.rand(num_boundary_points // 2, 2)
        X_boundary_left_np[:, 0] = 0.0               # x = 0
        X_boundary_left_np[:, 1] = X_boundary_left_np[:, 1] * 1.0 # t in [0, 1]

        X_boundary_right_np = np.random.rand(num_boundary_points // 2, 2)
        X_boundary_right_np[:, 0] = 1.0              # x = 1
        X_boundary_right_np[:, 1] = X_boundary_right_np[:, 1] * 1.0 # t in [0, 1]
        X_boundary = torch.from_numpy(np.vstack((X_boundary_left_np, X_boundary_right_np))).float().to(self.device)
        
        history = {'total_loss': [], 'pde_loss': [], 'ic_u_loss': [], 'ic_ut_loss': [], 'bc_loss': []}

        for epoch in range(num_epochs):
            self.optimizer.zero_grad()
            total_loss, loss_pde, loss_ic_u, loss_ic_ut, loss_bc = self.compute_loss(
                X_collocation, X_initial, X_boundary
            )
            
            total_loss.backward() # Backpropagation
            self.optimizer.step() # Update model parameters

            history['total_loss'].append(total_loss.item())
            history['pde_loss'].append(loss_pde.item())
            history['ic_u_loss'].append(loss_ic_u.item())
            history['ic_ut_loss'].append(loss_ic_ut.item())
            history['bc_loss'].append(loss_bc.item())

            if epoch % 1000 == 0:
                print(f"Epoch {epoch}, Total Loss: {total_loss.item():.4e}, "
                        f"PDE Loss: {loss_pde.item():.4e}, "
                        f"IC_u Loss: {loss_ic_u.item():.4e}, "
                        f"IC_ut Loss: {loss_ic_ut.item():.4e}, "
                        f"BC Loss: {loss_bc.item():.4e}")
        return history

In [19]:
# --- 3. Run the PINN ---
if __name__ == "__main__":
    # Parameters
    wave_speed = 1.0
    num_epochs = 10000 # Increase for better accuracy, but takes longer
    num_collocation_points = 10000
    num_initial_points = 500
    num_boundary_points = 500

    # Initialize and train the PINN
    pinn = PINN_WaveEquation_Implementation(c=wave_speed)
    print("Starting PINN training...")
    history = pinn.train(num_epochs, num_collocation_points, num_initial_points, num_boundary_points)
    print("Training finished.")

Using device: cpu
Starting PINN training...


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [18]:
# --- 4. Visualize Results ---

# Plot loss history
plt.figure(figsize=(12, 6))
plt.plot(history['total_loss'], label='Total Loss')
plt.plot(history['pde_loss'], label='PDE Loss')
plt.plot(history['ic_u_loss'], label='IC (u) Loss')
plt.plot(history['ic_ut_loss'], label='IC (du/dt) Loss')
plt.plot(history['bc_loss'], label='BC Loss')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.title('PINN Training Loss History (PyTorch)')
plt.legend()
plt.grid(True)
plt.show()

# Generate a grid for prediction
x_grid = np.linspace(0, 1, 100)
t_grid = np.linspace(0, 1, 100)
X_plot, T_plot = np.meshgrid(x_grid, t_grid)
XT_flat_np = np.vstack([X_plot.ravel(), T_plot.ravel()]).T
XT_tensor = torch.from_numpy(XT_flat_np).float().to(pinn.device)

# Predict u(x,t) using the trained PINN
pinn.model.eval() # Set model to evaluation mode
with torch.no_grad(): # Disable gradient calculations for inference
    u_pred_flat = pinn.model(XT_tensor).cpu().numpy()
u_pred = u_pred_flat.reshape(X_plot.shape)

# Analytical solution for comparison
u_analytical = np.sin(np.pi * X_plot) * np.cos(np.pi * T_plot)

# Plotting 3D surface
fig = plt.figure(figsize=(15, 7))

ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(X_plot, T_plot, u_pred, cmap=cm.viridis)
ax1.set_xlabel('x')
ax1.set_ylabel('t')
ax1.set_zlabel('u(x,t) (PINN)')
ax1.set_title('PINN Solution of Wave Equation (PyTorch)')

ax2 = fig.add_subplot(122, projection='3d')
ax2.plot_surface(X_plot, T_plot, u_analytical, cmap=cm.viridis)
ax2.set_xlabel('x')
ax2.set_ylabel('t')
ax2.set_zlabel('u(x,t) (Analytical)')
ax2.set_title('Analytical Solution of Wave Equation')

plt.tight_layout()
plt.show()

# Plotting error
error = np.abs(u_pred - u_analytical)
plt.figure(figsize=(8, 6))
plt.imshow(error.T, extent=[0, 1, 1, 0], origin='lower', cmap='hot', aspect='auto')
plt.colorbar(label='Absolute Error')
plt.xlabel('x')
plt.ylabel('t')
plt.title(f'Absolute Error |u_pred - u_analytical| (Max Error: {np.max(error):.4e})')
plt.show()

# Plotting snapshots in time
plt.figure(figsize=(10, 8))
times_to_plot = [0.0, 0.25, 0.5, 0.75, 1.0]
for i, t_val in enumerate(times_to_plot):
    idx_t = np.argmin(np.abs(t_grid - t_val))
    plt.subplot(len(times_to_plot), 1, i + 1)
    plt.plot(x_grid, u_pred[idx_t, :], label=f'PINN at t={t_val:.2f}')
    plt.plot(x_grid, u_analytical[idx_t, :], '--', label=f'Analytical at t={t_val:.2f}')
    plt.title(f'u(x, {t_val:.2f})')
    plt.legend()
    plt.ylim([-1.1, 1.1])
    if i == len(times_to_plot) - 1:
        plt.xlabel('x')
    else:
        plt.xticks([]) # Hide x-ticks for intermediate plots
    plt.grid(True)
plt.tight_layout()
plt.suptitle('Wave Snapshots Over Time', y=1.02)
plt.show()

NameError: name 'history' is not defined

<Figure size 1200x600 with 0 Axes>