$$
\begin{align*}
\text{1D Heat Diffusion: }
\frac{du(x,t)}{dt} &= \alpha \frac{d^2u(x,t)}{dx^2} + f(x,t)
\end{align*}
$$


In [1]:
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import GPy
from GPy.kern import RBF
from jax import grad

In [2]:
np.random.seed(123)

In [3]:
# Define parameters
L = 1.0 
N = 50  
alpha = 0.01  
T_final = 10.0  
dt = 0.1

# # Generate mesh grid
x = np.linspace(0, L, N) 


In [4]:
# Generate source terms 
kernel = GPy.kern.RBF(input_dim=1, variance=1.0, lengthscale=1.0) #variance and lengthscale = 1, could implement parameter tuning 
gp = GPy.models.GPRegression(np.zeros((1,1)), np.zeros((1,1)), kernel)

f_source = gp.posterior_samples_f(x.reshape(-1, 1), full_cov=True, size=1) # generate only one source function 
f_source = f_source[:, 0, 0]

In [5]:
# Define initial condition 40th - 60th percentile = 1 
def initial_condition(N):
    """
    Define the initial temperature distribution along the rod.
    """
    u_init = jnp.zeros(N)
    start_index = int(0.4 * N)
    end_index = int(0.6 * N)
    u_init = u_init.at[start_index:end_index].set(1.0) 
    return u_init

u_init = initial_condition(N)

In [6]:
def heat_equation(u, f):

    d2u_dx2 = jnp.zeros_like(u)
    for i in range(len(u)):
        d2u_dx2 = d2u_dx2.at[i].set(grad(grad(lambda x: x))(u[i]))
    
        du_dt = alpha * d2u_dx2 + f

    return du_dt

In [7]:
# Perform time integration to solve the 1D heat diffusion equation using forward Euler's method 

def integrate(u):
    t = 0.0
    while t < T_final:
        # Compute the temperature gradient
        du_dt = heat_equation(u, f_source)
        # Update the temperature using forward Euler's method
        u += dt * du_dt 
        # Increment time
        t += dt
    return u


In [8]:
u_final = integrate(u_init)
u_final

Array([ 1.8455838,  2.0621839,  2.2865326,  2.518291 ,  2.7570882,
        3.0025237,  3.2541919,  3.5116503,  3.7744207,  4.0420423,
        4.3139997,  4.589776 ,  4.868838 ,  5.1506267,  5.4345856,
        5.720142 ,  6.006725 ,  6.2937055,  6.5805025,  6.866494 ,
        8.151109 ,  8.433676 ,  8.713594 ,  8.990275 ,  9.263065 ,
        9.531361 ,  9.794561 , 10.052056 , 10.303208 , 10.547472 ,
        9.784231 , 10.012904 , 10.232933 , 10.443748 , 10.64483  ,
       10.835609 , 11.01557  , 11.184221 , 11.341041 , 11.485586 ,
       11.617358 , 11.735931 , 11.840857 , 11.931769 , 12.008191 ,
       12.069829 , 12.116299 , 12.14724  , 12.16234  , 12.161308 ],      dtype=float32)

In [9]:
f_source

array([0.18273087, 0.20417639, 0.22638939, 0.24933574, 0.27297891,
       0.29727996, 0.32219759, 0.34768796, 0.37370544, 0.40020216,
       0.42712864, 0.45443309, 0.48206271, 0.50996303, 0.53807828,
       0.56635137, 0.59472465, 0.62313887, 0.65153467, 0.67985173,
       0.70802958, 0.73600698, 0.76372274, 0.79111601, 0.81812536,
       0.84468965, 0.8707487 , 0.89624217, 0.92111028, 0.94529395,
       0.96873528, 0.99137681, 1.01316183, 1.03403553, 1.05394313,
       1.07283211, 1.09065047, 1.1073478 , 1.12287536, 1.13718592,
       1.15023307, 1.16197285, 1.17236282, 1.18136191, 1.18893102,
       1.19503308, 1.19963252, 1.20269574, 1.20419124, 1.20408939])

In [10]:
training_data = np.concatenate((u_final[:, np.newaxis], f_source[:, np.newaxis]), axis=1)
training_data

array([[ 1.8455838 ,  0.18273087],
       [ 2.06218386,  0.20417639],
       [ 2.28653264,  0.22638939],
       [ 2.518291  ,  0.24933574],
       [ 2.75708818,  0.27297891],
       [ 3.00252366,  0.29727996],
       [ 3.25419188,  0.32219759],
       [ 3.51165032,  0.34768796],
       [ 3.77442074,  0.37370544],
       [ 4.04204226,  0.40020216],
       [ 4.31399965,  0.42712864],
       [ 4.58977604,  0.45443309],
       [ 4.86883783,  0.48206271],
       [ 5.15062666,  0.50996303],
       [ 5.43458557,  0.53807828],
       [ 5.72014189,  0.56635137],
       [ 6.00672483,  0.59472465],
       [ 6.29370546,  0.62313887],
       [ 6.58050251,  0.65153467],
       [ 6.86649418,  0.67985173],
       [ 8.15110874,  0.70802958],
       [ 8.43367577,  0.73600698],
       [ 8.71359444,  0.76372274],
       [ 8.99027538,  0.79111601],
       [ 9.26306534,  0.81812536],
       [ 9.53136063,  0.84468965],
       [ 9.79456139,  0.8707487 ],
       [10.05205631,  0.89624217],
       [10.30320835,