In [None]:
import deepxde as dde
import numpy as np
import matplotlib.animation as ani
import matplotlib.pyplot as plt

In [None]:
# constants
LENGTH = 10
DURATION = 15
ALPHA = 1
BASE_TEMP = 0

In [None]:
# domains
time_domain = dde.geometry.TimeDomain(0, DURATION)
space_domain = dde.geometry.Interval(0, LENGTH)

domain = dde.geometry.GeometryXTime(space_domain, time_domain)

In [None]:
# partial differential equation
# x is 0, t is 1
def diff_EQ(inputs, u):
    u_t = dde.grad.jacobian(u, inputs, i=0, j=1) # j = 1 specifies the column where t resides
    u_xx = dde.grad.hessian(u, inputs, i=0, j=0) # j = 0 specifies the column where x resides
    return [u_t - (ALPHA * u_xx)]

In [None]:
def is_boundary(inp, on_boundary):
    return on_boundary

def is_initial(inp, on_initial):
    return on_initial

bc = dde.icbc.DirichletBC(domain, lambda t: BASE_TEMP, is_boundary)
ic = dde.icbc.IC(domain, lambda inp: 6 * np.sin((np.pi * inp[:, 0:1]) / LENGTH), is_initial)

conditions = [bc, ic]

In [None]:
train_sample = 1000
bound_sample = 100
initial_sample = 100
test_sample = 300

def sol_function(inputs):
    return 6 * np.sin((np.pi * inputs[:, 0:1]) / LENGTH) * np.exp(-ALPHA * (np.square(np.pi/LENGTH)) * inputs[:, 1:2])

data = dde.data.TimePDE(domain, diff_EQ, conditions, num_domain=train_sample, num_boundary=bound_sample, num_initial=initial_sample, solution=sol_function, num_test=test_sample)

In [None]:
layer_size = [2] + [16] * 3 + [1]

activation = "tanh"
initializer = "Glorot uniform"
dropout = .5

network = dde.nn.FNN(layer_size, activation, initializer, dropout_rate=dropout)

In [None]:
model = dde.Model(data, network)

model.compile("adam", lr=.001)
model.train(iterations=50000, display_every=500)

model.compile("L-BFGS-B")
losshistory, train_state = model.train()


dde.saveplot(losshistory, train_state, issave=True, isplot=True)

In [None]:
plt.ion()

xty = [[train_state.X_test[i][0], train_state.X_test[i][1], train_state.best_y[i][0]] for i in range(len(train_state.best_y))]

xty = sorted(xty, key=lambda x: x[:][0])
xty = sorted(xty, key=lambda x: x[:][1])



y = [xty[i][2] for i in range(len(xty))]
t = [xty[i][1] for i in range(len(xty))]
x = [xty[i][0] for i in range(len(xty))]


fig, ax = plt.subplots()

line = ax.plot(x[0], y[0], c="b")[0]
ax.set(xlim=[0, LENGTH], ylim=[0, 6], xlabel="bar", ylabel="temp")


frames = []
[frames.append(item) for item in t if item not in frames]





def update(frame):

    print("yo")

    loc = [x[i] for i in range(len(x)) if t[i] == frame]
    temp = [y[i] for i in range(len(t)) if t[i] == frame]

    line.set_xdata(loc)
    line.set_ydata(temp)

    return line


animation = ani.FuncAnimation(fig=fig, func=update, frames=frames)
# plt.show()
