import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

# Neural Network
class PINN(nn.Module):
    def __init__(self):
        super(PINN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 50), nn.Tanh(),
            nn.Linear(50, 50), nn.Tanh(),
            nn.Linear(50, 1)
        )

    def forward(self, x):
        return self.net(x)

# Assume unit disk domain, rotation-invariant BC: u=0 on boundary
def loss_fn(model, X_int, X_bc, u_bc):
    # Physics loss: Poisson residual ∇²u = -1
    X_int.requires_grad_(True)
    u = model(X_int)
    u_x = torch.autograd.grad(u, X_int, torch.ones_like(u), create_graph=True)[0][:, 0]
    u_y = torch.autograd.grad(u, X_int, torch.ones_like(u), create_graph=True)[0][:, 1]
    u_xx = torch.autograd.grad(u_x, X_int, torch.ones_like(u_x), create_graph=True)[0][:, 0]
    u_yy = torch.autograd.grad(u_y, X_int, torch.ones_like(u_y), create_graph=True)[0][:, 1]
    residual = u_xx + u_yy + 1
    phys_loss = torch.mean(residual**2)

    # BC loss
    u_pred_bc = model(X_bc)
    bc_loss = torch.mean((u_pred_bc - u_bc)**2)

    return phys_loss + bc_loss  # Adjust weights if needed

# Training setup (example)
model = PINN()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Sample points (unit disk)
N_int = 1000
r = torch.rand(N_int)
theta = torch.rand(N_int) * 2 * np.pi
X_int = torch.stack([r * torch.cos(theta), r * torch.sin(theta)], dim=1)

N_bc = 200
theta_bc = torch.linspace(0, 2*np.pi, N_bc)
X_bc = torch.stack([torch.cos(theta_bc), torch.sin(theta_bc)], dim=1)
u_bc = torch.zeros(N_bc, 1)  # Example BC

# Train loop
for epoch in range(10000):
    optimizer.zero_grad()
    loss = loss_fn(model, X_int, X_bc, u_bc)
    loss.backward()
    optimizer.step()

# Plot contours
n_grid = 100
x = np.linspace(-1, 1, n_grid)
y = np.linspace(-1, 1, n_grid)
X, Y = np.meshgrid(x, y)
mask = X**2 + Y**2 <= 1
points = np.stack((X.flatten(), Y.flatten()), axis=1)
with torch.no_grad():
    Z_flat = model(torch.tensor(points, dtype=torch.float)).numpy().flatten()
Z = np.full_like(X, np.nan)
Z[mask] = Z_flat[mask.flatten()]
plt.contourf(X, Y, Z, levels=50, cmap=cm.jet)
plt.colorbar()
plt.show()