# Differential Equations for Neural Networks

This notebook contains PyTorch examples demonstrating differential equations concepts.

## Table of Contents
1. [Ordinary Differential Equations (ODEs)](#ordinary-differential-equations-odes)
2. [Partial Differential Equations (PDEs)](#partial-differential-equations-pdes)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

## Ordinary Differential Equations (ODEs)

**Formula:** $\frac{dy}{dt} = f(y, t)$

Connection between ResNets and Neural ODEs.

In [None]:
# Simple Neural ODE implementation
class ODEFunc(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim, 50),
            torch.nn.Tanh(),
            torch.nn.Linear(50, dim)
        )
    
    def forward(self, t, y):
        return self.net(y)

# Simple ODE solver (Euler method)
def ode_solve(func, y0, t_span, dt=0.1):
    t_eval = torch.arange(t_span[0], t_span[1] + dt, dt)
    y = y0.clone()
    trajectory = [y.clone()]
    
    for i in range(len(t_eval) - 1):
        dydt = func(t_eval[i], y)
        y = y + dt * dydt
        trajectory.append(y.clone())
    
    return torch.stack(trajectory)

# ResNet vs Neural ODE comparison
class ResBlock(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim, 50),
            torch.nn.ReLU(),
            torch.nn.Linear(50, dim)
        )
    
    def forward(self, x):
        return x + self.net(x)  # Residual connection

# Demonstrate connection
dim = 10
x0 = torch.randn(1, dim)

# ResNet step: x_{n+1} = x_n + f(x_n)
resblock = ResBlock(dim)
x_resnet = resblock(x0)

# Neural ODE step: dx/dt = f(x), x(t+dt) ≈ x(t) + dt*f(x(t))
ode_func = ODEFunc(dim)
x_ode = x0 + 0.1 * ode_func(0, x0)  # Small time step

print(f"ResNet output norm: {torch.norm(x_resnet - x0):.3f}")
print(f"Neural ODE output norm: {torch.norm(x_ode - x0):.3f}")
print("Both represent discrete/continuous versions of the same idea")

## Partial Differential Equations (PDEs)

**Formula:** $\frac{\partial u}{\partial t} = f\left(u, \frac{\partial u}{\partial x}, \frac{\partial^2 u}{\partial x^2}, \ldots\right)$

Physics-Informed Neural Networks (PINNs).

In [None]:
# Physics-Informed Neural Network for heat equation: ∂u/∂t = α ∂²u/∂x²
class PINN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(2, 50),  # Input: (x, t)
            torch.nn.Tanh(),
            torch.nn.Linear(50, 50),
            torch.nn.Tanh(),
            torch.nn.Linear(50, 1)   # Output: u(x, t)
        )
    
    def forward(self, x, t):
        inputs = torch.cat([x, t], dim=1)
        return self.net(inputs)

def compute_pde_loss(model, x, t, alpha=0.1):
    """Compute physics loss for heat equation"""
    x.requires_grad_(True)
    t.requires_grad_(True)
    
    u = model(x, t)
    
    # Compute derivatives
    u_t = torch.autograd.grad(u.sum(), t, create_graph=True)[0]
    u_x = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
    u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0]
    
    # PDE residual: ∂u/∂t - α ∂²u/∂x²
    pde_residual = u_t - alpha * u_xx
    
    return torch.mean(pde_residual**2)

# Training example
pinn = PINN()
optimizer = torch.optim.Adam(pinn.parameters(), lr=0.001)

# Sample points in domain
x_train = torch.linspace(0, 1, 100).unsqueeze(1)
t_train = torch.linspace(0, 1, 100).unsqueeze(1)

# Training loop (simplified)
for epoch in range(100):
    optimizer.zero_grad()
    
    # Physics loss
    physics_loss = compute_pde_loss(pinn, x_train, t_train)
    
    # Boundary conditions (u(0,t) = u(1,t) = 0)
    u_boundary = pinn(torch.tensor([[0.], [1.]]), t_train[:2])
    boundary_loss = torch.mean(u_boundary**2)
    
    # Total loss
    total_loss = physics_loss + boundary_loss
    total_loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}: Loss = {total_loss.item():.6f}")

print("PINN trained to solve heat equation while respecting physics")