# Paper

## Intialization

In [None]:
# !pip install --upgrade "jax[cpu]"
# !pip install jraph matplotlib numpy optax equinox
# !pip install git+https://github.com/ravinderbhattoo/shadow
# !pip install pint
%matplotlib inline
import os
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=5"

from multiprocessing import Process
import sys; sys.path.append("..")
import src
from src import tt, pt, utils, dyn, data_gen, models, train, units
from src import notebook_cells as nc
from src.examples import E1 as eg

def reset_config():
    pt.plt.rcParams["figure.dpi"] = 100
    pt.plt.rcParams["figure.figsize"] = (6, 6)

plt = pt.plt

import importlib
def reload():
    for i in [src, tt, pt, utils, dyn, data_gen, models, train, units, nc, eg]:
        importlib.reload(i)
    reset_config()

reload()

In [None]:
import jax.numpy as jnp
import numpy as np
import jax
from jax import config
config.update("jax_enable_x64", True)

from jax import nn, jit, vmap, grad
from jraph import GraphsTuple, GraphNetwork
jax.config


import optax

In [None]:
# CONSTANTS
U = units.U

U_Acceleration = U.meter / U.second / U.second
U_Velocity = U.meter / U.second
U_Force = U.kg * U_Acceleration
U_Area = U.meter ** 2
U_Volume = U.meter ** 3
U_Density = U.kg / U_Volume
U_Length = U.meter
U_Time = U.second

E0 = 70*1.0e9 * U_Force / U_Area
ρ = 2770 * U_Density
A0 = 0.01 * U_Area
width = 1.52 * U_Length
height = 1.52 * U_Length

In [None]:
rate = 1.0e2 / U_Time
F0 = 1000 * U.kg * 10 * U_Acceleration

items = (E0, ρ, A0, F0, width, height, rate)
p_units = [U.m, U.kg, U.s]
items = units.ustrip(*units.convert(p_units, *items))

E0, ρ, A0, F0, width, height, rate = items
EA0 = E0*A0

items

In [None]:
reload()
from PIL import Image

dmembers = [1, 12]
dvalues = [10, 20]

truss = tt.BoxTruss(n = 6, width=width, height=height)
def truss_plot(dmembers=[], dvalues=None, **kwargs):
    return truss.plot(scatter_kwargs={"s": 0},
               damage_plot_kwargs={"c": "r",
                                   "ls": "-",
                                   "lw": 2,
                                   "members": dmembers,
                                   "values": dvalues,
                                  },
               plot_kwargs={"c": "k", "lw": 2},
              **kwargs,
                        )


truss_plot(dmembers=dmembers, dvalues=dvalues)

In [None]:
f_nodes = [3]

In [None]:
globals = utils.GLOBAL(PE = jnp.array([0.0]),
                 KE =  jnp.array([0.0]))

np.random.seed(24)

damage = 20.0 + np.random.randn(truss.nedges) * 10
damage -= damage.min()

damage[damage < 5] = 0.0

_ch = np.random.choice(len(damage), int(len(damage) * 0.8), replace=False)
damage[_ch] = 0.0

nodes, edges, full_dict = nc.get_nodes_edges(truss, E0, A0, ρ, damage=damage)
damage

In [None]:
key = jax.random.PRNGKey(0)
graph_truss = GraphsTuple(**full_dict, globals=globals)


In [None]:
reload()

gn = GraphNetwork(nc.update_edge_fn, nc.update_node_fn, nc.update_global_fn)
graph_truss = gn(graph_truss)

N, DIM = nodes.position.shape
tmodel = dyn.DYN(gn, N=N, DIM=DIM)

In [None]:
R = jnp.array(truss.nodes["position"])
tmodel.PE(R*1.1, graph_truss)

In [None]:
V = jnp.array(truss.nodes["position"])*0+1.0
print(tmodel.KE(V, graph_truss))
V = V*0.0

In [None]:
tmodel.LAG(R, V, graph_truss)

In [None]:
Totaltime = 0.1 # s
sfreq = 1.0e3 # 1 kHz
DT = 1 / sfreq
STRIDE = 10

RUNS = int(Totaltime / DT)

dt = DT/STRIDE
print(f"Runs: {RUNS}, dt: {dt} s, total time: {RUNS*DT} s, \n\
sampling: {DT} s, sfreq: {1/DT} Hz")

In [None]:
reload()

F_factor = 1

freq_load = 10

F02 = F0 * F_factor

print(f"freq_load: {freq_load} Hz")

w = 2*np.pi*freq_load
func = "sin"
ftype = "sin"
f_args = (w, )

constraints = eg.get_constraints(width)
external_force = eg.get_eF(jnp.array(f_nodes),
                           -F02,
                           nodes.position.shape,
                           func=func,
                           f_args=f_args,
                          )

time_ = np.arange(0, Totaltime, dt)
y = vmap(lambda x: external_force(None, None, x, None))(time_).reshape(-1, N, DIM)[:, f_nodes[0], 1]

fig, ax = plt.subplots(figsize=(6, 6))
plt.plot(time_, -y/1000, lw=2,
         label=f"${F02/1000}sin({2*freq_load} \pi t), f = {freq_load}Hz$")
plt.xlabel("Time (s)")
plt.ylabel("Force (kN)")
# plt.legend(frameon=True, loc=1, bbox_to_anchor=(1, 1), facecolor="w", framealpha=1)
plt.legend(loc=3)

plt.savefig("../output/results/figures/sin.png", dpi=300)

In [None]:
R = jnp.array(truss.nodes["position"])
V = jnp.array(truss.nodes["position"])*0
x_vec, v_vec = R.flatten(), V.flatten()

M = tmodel.d2L_dv2(x_vec, v_vec, graph_truss)
M_1 =  jnp.linalg.pinv(M)


In [None]:
acc = dyn.acceleration(tmodel,
                       external_force=external_force,
                       constraints=constraints,
                       use_dissipative_forces=False,
                       constant_mass_inv=M_1
                      )

F02 / graph_truss.nodes.nodal_mass, acc(R, 0*R, 0, graph_truss)

In [None]:
t = 0.0

A = acc(R, 0*V, t, graph_truss)
state = (R, 0*V, A, graph_truss, t, dt)

apply = utils.get_apply(acc)
traj = utils.solve_dynamics(state, apply, runs=RUNS, stride=STRIDE)

print(f"Total time: {dt*RUNS*STRIDE} s")


In [None]:
reload()
nc.plot_traj_energy([traj], tmodel, graph_truss)


In [None]:
trajs = [traj]
_t = traj[-2]
for traj in trajs:
    X = traj[0] - traj[0][0, :, :]
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle('Position')
    axs[0].plot(_t, X[:, :, 0], label="x")
    axs[0].set_xlabel("Time")
    # axs[0].legend()
    axs[1].plot(_t, X[:, :, 1], label="y")
    axs[1].set_xlabel("Time")
    # axs[1].legend()
    plt.show()

In [None]:
for traj in trajs:
    _t = traj[-2]
    X = traj[2]
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    fig.suptitle('Acceleration', fontsize=16)
    axs[0].plot(_t, X[:, :, 0], label="x")
    # axs[0].legend()
    axs[1].plot(_t, X[:, :, 1], label="y")
    # axs[1].legend()
    plt.show()

### Data generation

In [None]:
generate_traj = data_gen.get_gen_traj(tmodel, state,
                                      constraints=constraints,
                                      constant_mass_inv=M_1,
                                      use_dissipative_forces=False)

def generate_data(*args, **kwargs):
    traj = generate_traj(*args, **kwargs)
    return traj[2].reshape(-1, N*DIM)


In [None]:
reload()

dmembers = np.argwhere(damage>0.0).ravel()
fig, ax = plt.subplots(figsize=(8, 8), dpi=300)

x = 0.05
ax.plot(np.array([1*width, 3*width, 5*width, 2*width, 4*width])-0.1, np.array([0+x, 0+x, 0+x, 1*height+x, 1*height+x]), "*", ms=10, color="green", zorder=12)

truss_plot(ax=ax, dmembers=dmembers, dvalues=damage[dmembers],
           text_kwargs2={"fontsize": 7},
           text_kwargs={"fontsize": 7},
          )

plt.savefig("../output/results/figures/dtruss.png")


In [None]:
# index start at 0
Fss = [
        {"y": [3]},
        {"y": [1]},
        {"y": [5]},
        {"x": [7]},
        {"x": [13]}
    ]

EXTERNAL_FORCE = data_gen.getF(Fss, -F02, nodes.position.shape)
print(EXTERNAL_FORCE.shape)

for i, x in enumerate(EXTERNAL_FORCE.T):
    plt.bar(range(len(x)), x, label=f"Force case {i+1}")
plt.legend()

In [None]:
training_RUNS = RUNS
TSTRIDE = STRIDE * 1

ACCELERATION = vmap(lambda x: generate_data(x,
                                            runs=training_RUNS,
                                            stride=TSTRIDE,
                                            f_args=f_args,
                                            Ffunc=ftype))(EXTERNAL_FORCE.T)

print(f"Total training time: {training_RUNS*dt*TSTRIDE}")
TIME = np.array(range(0, training_RUNS)) * dt*TSTRIDE

In [None]:
cnodes = [3]
fig, axs = plt.subplots(len(Fss), DIM, figsize=(6*DIM, len(Fss)*3), sharex=True, sharey=True,
                        gridspec_kw = dict(wspace=0.02, hspace=0.04))

for j in range(len(Fss)):
    for k in range(DIM):
        ax = axs[j, k]
        if k==0:
            ax.set_ylabel(f"A ({U_Length} / {U_Time}$^2$)")
        if j==len(Fss)-1:
            ax.set_xlabel(f"Time ({U_Time})")

        for i in cnodes:
            ax.text(0.5, 0.8, f"Case: {j+1}, $A_{i+1}^{'x' if k==0 else 'y'}$",
                    transform=ax.transAxes, ha="center")
            ax.plot(TIME, ACCELERATION[j].reshape(-1, N, DIM)[:, i, k])


plt.savefig("../output/results/figures/acc1.png")

### Adding noise to data

In [None]:
ACCELERATION_noise = utils.add_noise(0.2, ACCELERATION)
ACCELERATION_noise.shape

In [None]:
cnodes = [3]
fig, axs = plt.subplots(len(Fss), DIM, figsize=(6*DIM, len(Fss)*2.5), sharex=True, sharey=True,
                        gridspec_kw = dict(wspace=0.12, hspace=0.04))
for j in range(len(Fss)):
    for k in range(DIM):
        ax = axs[j, k]
        if k==0:
            ax.set_ylabel(f"A ({U_Length} / {U_Time}$^2$)")
        if j==len(Fss)-1:
            ax.set_xlabel(f"Time ({U_Time})")

        for i in cnodes:
            ax.text(0.8, 0.8, f"Case: {j+1}, $A_{i+1}^{'x' if k==0 else 'y'}$",
                    transform=ax.transAxes, ha="center", bbox=dict(color="w", alpha=0.7))
            ax.plot(TIME, ACCELERATION[j].reshape(-1, N, DIM)[:, i, k], lw=2,
                   label="Acceleration",)
            ax.scatter(TIME, ACCELERATION_noise[j].reshape(-1, N, DIM)[:, i, k],
                       marker="o",
                       s=20,
                       fc="none",
                       ec="r",
                       label="Acceleration + noise",
                      )

axs[0, 0].legend(ncol=2, bbox_to_anchor=(0, 1), loc=3)
plt.savefig("../output/results/figures/acc1.png", dpi=300)


## ML Model Traning

### ML Model intialization

In [None]:
reload()

N = len(nodes.position)
Ne = edges.type.shape[-1]
EA_fn = lambda d: EA0 * (1 - d)

model = models.TrussGraphModel(key, N=N, Ne=Ne, dim=DIM,
                               runs=training_RUNS, stride=TSTRIDE,
                               constraints=constraints,
                               use_dissipative_forces=False,
                               EA_fn=EA_fn,
                               NODAL_MASS0=0.0,
                               M_1=M_1,
                               activation=jax.nn.leaky_relu,
                               hidden_dim=8,
                              )


### Prediction from untrained model

In [None]:
STATE = (1*R, 0*R, 0*R, graph_truss, 0.0, dt)

ACCELERATION = jnp.array([model((STATE, EXTERNAL_FORCE[:, i], (f_args, ftype)), train=False)[2] for i in 
                      range(EXTERNAL_FORCE.shape[1])])

ACCELERATION_noise = utils.add_noise(0.2, ACCELERATION.reshape(5, -1, 28))
ACCELERATION_noise = ACCELERATION_noise

In [None]:
reload()

def plot_acc(model, name="untrained"):
    F_case = 0
    
    model = models.copy_model(model, runs=training_RUNS, override=False)
    
    # exact model1
    traj_data = generate_traj(EXTERNAL_FORCE[:, F_case], runs=training_RUNS,
                              stride=TSTRIDE, f_args=f_args, Ffunc=ftype)
    
    #exact model2 through train False
    traj_exact_model = model((STATE, EXTERNAL_FORCE[:, F_case], (f_args, ftype)), train=False)
    
    #ML model 
    traj_ml = model((STATE, EXTERNAL_FORCE[:, F_case], (f_args, ftype)))
    
    
    #training data
    As_data = ACCELERATION_noise[F_case][:training_RUNS].reshape(-1, N, DIM)
    _t = traj_data[-2]
    
    print(f"Force case: {F_case}")
    
    cnodes = sorted([1, 3, 5, 9, 11])
    with plt.rc_context():
        _n = len(cnodes)
        fig, axs = plt.subplots(_n, 1, figsize=(12, 3*_n), sharex=True, sharey=True)
        axs = np.array(axs)
        for ax, i in zip(axs.ravel(), cnodes):
            c1 = pt.colors[0]
            c2 = pt.colors[1]
    
            # ax.scatter(_t, traj_data[2][:, i, 1], c=c1, marker="x", s=20, label="Actual model")
            ax.plot(_t, traj_exact_model[2][:, i, 1], "-", c=c1, lw=2, label="Exact model")
    
            ax.plot(_t, traj_ml[2][:, i, 1], "--", c="r", lw=2, label="ML model")
            ax.scatter(_t, As_data[:, i, 1],
                       marker="o",
                       fc="none",
                       ec="r",
                       lw=1,
                       s=20, label="Training data (model + noise)")
    
            ax.set_ylabel("$A^y_{" + str(i+1) + "}$" + f"({U_Length / U_Time}$^2$)")
        ax.set_xlabel(f"Time ({U_Time})")
        plt.sca(axs[0])
        plt.title(f"Excitation case: {F_case}")
        plt.legend(bbox_to_anchor=(0, 1.1), loc=3, ncol=3)

    plt.savefig(f"../output/results/figures/acc_{name}.png", dpi=300)
    return fig, (traj_exact_model, As_data, traj_ml, _t)

    # with plt.rc_context():
    #     _n = len(cnodes)
    #     fig, axs = plt.subplots(_n, 1, figsize=(12, 3*_n), sharex=True)
    #     axs = np.array(axs)
    #     for ax, i in zip(axs.ravel(), cnodes):
    #         c1 = pt.colors[0]
    #         c2 = pt.colors[1]
    
    #         # ax.scatter(_t, traj_data[2][:, i, 1], c=c1, marker="x", s=20, label="Actual model")
    #         ax.plot(_t, traj_exact_model[0][:, i, 1], "-", c=c1, lw=2, label="Exact model")
    #         ax.plot(_t, traj_untrained_ml[0][:, i, 1], "--", c="r", lw=2, label="Untrained ML model")
    
    #         ax.set_ylabel("$X^y_{" + str(i+1) + "}$" + f"({U_Length / U_Time}$^2$)")
    #     ax.set_xlabel(f"Time ({U_Time})")
    #     plt.sca(axs[0])
    #     plt.legend(bbox_to_anchor=(0, 1.0), loc=3, ncol=2)

fig, untrained_data = plot_acc(model, name="untrained")

### Model training

In [None]:
data = [((STATE, EXTERNAL_FORCE[:, i], (f_args, ftype)), ACCELERATION_noise[i]) for i in range(len(ACCELERATION_noise))]

In [None]:
import optax

snodes = jnp.array([2, 10, 4, 12, 6]) - 1
subset_sensors = jnp.array(list(snodes * 2 + 1) + list(snodes * 2))
lr = 1.0e-3
Epochs = 1000
print_every = 100

optim = optax.adam(lr)

trN = 10
def loss_fn(model, data):
    return src.train.loss(model, data, M=None, trN=trN, sensors=subset_sensors)
_train = src.train.train(loss=loss_fn)

updated_model = models.copy_model(model, runs=trN)
updated_model10, extras = _train(updated_model, data, [], optim, Epochs, print_every)


In [None]:
plt.semilogy(range(len(extras[0])), extras[0])
plt.xlabel("Epochs")
plt.ylabel("Loss value")

plt.savefig("../output/results/figures/loss.png", dpi=300)

In [None]:
reload()

EA = np.array([E0*A0]*len(edges.EA))
EAd = EA * (1 - dd)

(fig, ax), (EAd, EAe, eps) = pt.compareEA(updated_model10, graph_truss, EA_fn, EAd, epsilon=0.05)
EAe10 = 1*EAe

In [None]:
fig, trained10_data = plot_acc(updated_model10, name="trained10")


In [None]:
trN = 20
def loss_fn(model, data):
    return src.train.loss(model, data, M=None, trN=trN, sensors=subset_sensors)
_train = src.train.train(loss=loss_fn)

optim = optax.adam(lr)
updated_model10 = models.copy_model(updated_model10, runs=trN)
updated_model20, extras = _train(updated_model10, data, [], optim, Epochs, print_every)



In [None]:
plt.semilogy(range(Epochs), extras[0])
plt.xlabel("Epochs")
plt.ylabel("Loss value")

In [None]:
EA = np.array([E0*A0]*len(edges.EA))
EAd = EA * (1 - dd)

(fig, ax), (EAd, EAe, eps) = pt.compareEA(updated_model20, graph_truss, EA_fn, EAd, epsilon=0.05)
EAe20 = 1*EAe

In [None]:
fig, trained20_data = plot_acc(updated_model20, name="trained20")




In [None]:
trN = 40
def loss_fn(model, data):
    return src.train.loss(model, data, M=None, trN=trN, sensors=subset_sensors)
_train = src.train.train(loss=loss_fn)

optim = optax.adam(lr)
updated_model20 = models.copy_model(updated_model20, runs=trN)
updated_model40, extras = _train(updated_model20, data, [], optim, Epochs, print_every)


In [None]:
plt.semilogy(range(Epochs), extras[0])
plt.xlabel("Epochs")
plt.ylabel("Loss value")


In [None]:
EA = np.array([E0*A0]*len(edges.EA))
EAd = EA * (1 - dd)

(fig, ax), (EAd, EAe, eps) = pt.compareEA(updated_model40, graph_truss, EA_fn, EAd, epsilon=0.05)
EAe40 = 1*EAe

In [None]:
fig, trained40_data = plot_acc(updated_model40, name="trained40")


In [None]:
trN = 60
def loss_fn(model, data):
    return src.train.loss(model, data, M=None, trN=trN, sensors=subset_sensors)
_train = src.train.train(loss=loss_fn)

optim = optax.adam(lr)
updated_model40 = models.copy_model(updated_model40, runs=trN)
updated_model60, extras = _train(updated_model40, data, [], optim, Epochs, print_every)

plt.semilogy(range(Epochs), extras[0])
plt.xlabel("Epochs")
plt.ylabel("Loss value")
plt.show()

EA = np.array([E0*A0]*len(edges.EA))
EAd = EA * (1 - dd)
(fig, ax), (EAd, EAe, eps) = pt.compareEA(updated_model60, graph_truss, EA_fn, EAd, epsilon=0.05)
EAe60 = 1*EAe
plt.show()



In [None]:
fig, trained60_data = plot_acc(updated_model60, name="trained60")
plt.show()


In [None]:
trN = 80
def loss_fn(model, data):
    return src.train.loss(model, data, M=None, trN=trN, sensors=subset_sensors)
_train = src.train.train(loss=loss_fn)

optim = optax.adam(lr)
updated_model60 = models.copy_model(updated_model60, runs=trN)
updated_model80, extras = _train(updated_model60, data, [], optim, Epochs, print_every)


plt.semilogy(range(Epochs), extras[0])
plt.xlabel("Epochs")
plt.ylabel("Loss value")
plt.show()

EA = np.array([E0*A0]*len(edges.EA))
EAd = EA * (1 - dd)
(fig, ax), (EAd, EAe, eps) = pt.compareEA(updated_model80, graph_truss, EA_fn, EAd, epsilon=0.05)
EAe80 = 1*EAe
plt.show()



In [None]:

fig, trained80_data = plot_acc(updated_model60, name="trained80")
plt.show()


In [None]:
trN = 100

Epochs = 1000

def loss_fn(model, data):
    return src.train.loss(model, data, M=None, trN=trN, sensors=subset_sensors)
_train = src.train.train(loss=loss_fn)

optim = optax.adam(lr)
updated_model80 = models.copy_model(updated_model80, runs=trN)
updated_model100, extras = _train(updated_model80, data, [], optim, Epochs, print_every)

plt.semilogy(range(Epochs), extras[0])
plt.xlabel("Epochs")
plt.ylabel("Loss value")


In [None]:
EA = np.array([E0*A0]*len(edges.EA))
EAd = EA * (1 - dd)

(fig, ax), (EAd, EAe, eps) = pt.compareEA(updated_model100, graph_truss, EA_fn, EAd, epsilon=0.05)
EAe100 = 1*EAe



In [None]:
fig, trained100_data = plot_acc(updated_model100, name="trained100")


### Trained model analysis

In [None]:
from functools import partial

def fn(EAe, which="mean"):
    a_ = 100 * (EAe - EAd) / EAd
    a = np.abs(a_)
    if which=="mean":
        return a.mean()
    if which == "max":
        return a.max()

fig, axs = plt.subplots(1, 2, figsize=(12, 6))

XX = [10, 20, 40, 60, 80, 100]
YY = [EAe10, EAe20, EAe40, EAe60, EAe80, EAe100]

plt.sca(axs[0])
plt.plot(XX, list(map(fn, YY)))
plt.scatter(XX, list(map(fn, YY)))
plt.xlabel("$N_{training}$")
plt.ylabel("$Mean_{Error (\%)}$")
plt.xlim(0, trN*1.1)

plt.sca(axs[1])
plt.plot(XX, list(map(partial(fn, which="max"), YY)))
plt.scatter(XX, list(map(partial(fn, which="max"), YY)))
plt.xlabel("$N_{training}$")
plt.ylabel("$Max_{Error (\%)}$")
plt.xlim(0, trN*1.1)

plt.savefig("../output/results/figures/error_mean_max.png", dpi=300)




In [None]:
DATA__ = [untrained_data, trained10_data, trained20_data, trained40_data, trained60_data, trained80_data, trained100_data]

with plt.rc_context():
    _n = len(DATA__)
    fig, axs = plt.subplots(_n, 1, figsize=(12, 2*_n), sharex=True)
    axs = np.array(axs)
    i = 3
    for ax, d_, trn in zip(axs.ravel(), DATA__, ["Untrained", 10, 20, 40, 60, 80, 100]):
        traj_exact_model, As_data, traj_ml, _t = d_
        c1 = pt.colors[0]
        c2 = pt.colors[1]

        # ax.scatter(_t, traj_data[2][:, i, 1], c=c1, marker="x", s=20, label="Actual model")

        ax.plot(_t, traj_exact_model[2][:, i, 1], "-", c=c1, lw=2, label="Exact model")
        ax.plot(_t, traj_ml[2][:, i, 1], "--", c="r", lw=2, label="ML model")
        ax.scatter(_t, As_data[:, i, 1],
                   marker="o",
                   fc="none",
                   ec="r",
                   lw=1,
                   s=20, label="Training data (with noise)")

        ax.set_ylabel("$A^y_{" + str(i+1) + "}$" + f"({U_Length / U_Time}$^2$)")
        if isinstance(trn, str):
            ax.text(0.025, 0.8, "Untrained", transform=ax.transAxes)
        else:
            ax.text(0.025, 0.8, "$N_{training}$" +f" = {trn}", transform=ax.transAxes)
            
    ax.set_xlabel(f"Time ({U_Time})")
    plt.sca(axs[0])
    plt.legend(bbox_to_anchor=(0, 1.1), loc=3, ncol=3)
    plt.title(f"Excitation case: {F_case+1}")
    
    plt.savefig("../output/results/figures/comp_train.png", dpi=300)


In [None]:
n_ = 2
ttN = n_ * trN

F_case = 0

traj_exact = generate_traj(EXTERNAL_FORCE[:, F_case], runs=ttN, stride=TSTRIDE, f_args=f_args, Ffunc=ftype)
traj_trained_ml = models.copy_model(updated_model100, runs=ttN, stride=TSTRIDE)((STATE, EXTERNAL_FORCE[:, F_case], (f_args, ftype)))

traj_exact2 = generate_traj(EXTERNAL_FORCE[:, F_case], runs=5*ttN, stride=TSTRIDE, f_args=f_args, Ffunc=ftype)
traj_trained_ml2 = models.copy_model(updated_model100, runs=5*ttN, stride=TSTRIDE)((STATE, EXTERNAL_FORCE[:, F_case], (f_args, ftype)))

_t = traj_exact[-2]
_t2 = traj_exact2[-2]

# _t22 = traj_data[-2]
# nn2_ = np.argmax(_t22 > _t[-1])

As = ACCELERATION_noise[F_case].reshape(-1, truss.nnodes, 2)

cnodes = sorted([1, 3, 5, 9, 11])
fig, axs_xy = plt.subplots(len(cnodes), 2, figsize=(12, 2*len(cnodes)), sharey=True)
fig.subplots_adjust(wspace=0.02)
axs = axs_xy[:, 0]
axs2 = axs_xy[:, 1]

for ax, ax2, i in zip(axs.ravel(), axs2.ravel(), cnodes):
    ax.plot(_t, traj_exact[2][:, i, 1], lw=1, c="k", alpha=1.0, label="Exact Model")
    ax.plot(_t, traj_trained_ml[2][:, i, 1], "--", c="r", label="Trained ML")
    # ax.scatter(_t22[:nn2_], As[:nn2_, i, 1], marker="x", fc="k", s=20, label="Training data (Model + noise)")

    ax.axvspan(0, _t[-1]/n_, alpha=0.2, color="y", label="Training region")
    ax.axvspan(_t[-1]/n_, _t[-1], alpha=0.2, color="g", label="Extrapolation")
    ax.set_ylabel(r"$A^y_{" + f"{i+1}" + "}$" + f"({U_Length / U_Time}$^2$)")


    nn_ = np.argmax(_t2 > _t[-1])

    ax2.plot(_t2[nn_:], traj_exact2[2][nn_:, i, 1], lw=1, c="k", alpha=1.0, label="Exact Model")
    ax2.plot(_t2[nn_:], traj_trained_ml2[2][nn_:, i, 1], "--", c="r", label="Trained ML")
    ax2.axvspan(_t2[nn_], _t2[-1], alpha=0.2, color="g", label="Extrapolation")

    ax.spines.right.set_visible(False)
    ax.tick_params(which='both', right=False)  # don't put tick labels at the top
    ax2.spines.left.set_visible(False)
    ax2.tick_params(which='both', left=False)  # don't put tick labels at the top


    d = 1  # proportion of vertical to horizontal extent of the slanted line
    kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
                  linestyle="none", color='k', mec='k', mew=1, clip_on=False)
    ax.plot([1, 1], [0, 1], transform=ax.transAxes, **kwargs)
    ax2.plot([0, 0], [0, 1], transform=ax2.transAxes, **kwargs)

ax.text(-0.15, -0.4, "Time (s)", transform=ax2.transAxes)
plt.sca(axs[0])
plt.legend(bbox_to_anchor=(-0.05, 1.1), loc=3, ncol=4)
plt.text(1, 1.1, f"Excitation case: {F_case+1}", transform=axs[0].transAxes, ha="center")

plt.savefig("../output/results/figures/extrapolation.png", dpi=300)


In [None]:
reload()

EA = np.array([E0*A0]*len(edges.EA))
EAd = EA * (1 - dd)

pt.compareEA(updated_model100, graph_truss, EA_fn, EAd, epsilon=0.05)

plt.savefig("../output/results/figures/error_final.png", dpi=300)

## Ablation studies

In [None]:
def get_damage(nedges, dnumber=0.8, dlevel=20.0):
    damage = dlevel + np.random.randn(nedges) * 10.0
    damage -= damage.min()
    
    damage[damage < 5] = 0.0
    
    _ch = np.random.choice(len(damage), int(len(damage) * dnumber), replace=False)
    damage[_ch] = 0.0

    return damage

In [None]:
Totaltime = 0.1 # s
sfreq = 1.0e3 # 1 kHz
DT = 1 / sfreq
STRIDE = 10

RUNS = int(Totaltime / DT)

dt = DT/STRIDE
print(f"Runs: {RUNS}, dt: {dt} s, total time: {RUNS*DT} s, \n\
sampling: {DT} s, sfreq: {1/DT} Hz")

f_nodes = [2, 4]

Fss = [
        {"y": [3]},
        {"y": [1]},
        {"y": [5]},
        {"x": [7]},
        {"x": [13]}
    ]

In [None]:
reload()

def train_an_example(
        truss, E0, A0, ρ, width, F0,
        Fss,
        snodes,
        dnumber,
        dlevel,
        seed=42,
        RUNS = 100,
        STRIDE = 10,
        DT = 0.1,
        rate=1.0,
        training_RUNS = 50,
        trNs = [100],
        noise_fs = 0.2,
        lr = 1.0e-3,
        Epochs = 400,
        print_every = 100,
        example = "examplenone"
        ):

    IDD = f"seed_{seed}_dnumber_{dnumber}_dlevel_{dlevel}"
    fdirname = f"../output/results/{example}"
    import os
    os.makedirs( fdirname, exist_ok=True)

    globals = utils.GLOBAL(PE = jnp.array([0.0]),
                 KE =  jnp.array([0.0]))

    key = jax.random.PRNGKey(0)

    damage = get_damage(truss.nedges, dnumber=dnumber, dlevel=dlevel)
    nodes, edges, full_dict = nc.get_nodes_edges(truss, E0, A0, ρ, damage=damage)
    graph_truss = GraphsTuple(**full_dict, globals=globals)
    gn = GraphNetwork(nc.update_edge_fn, nc.update_node_fn, nc.update_global_fn)
    graph_truss = gn(graph_truss)

    N, DIM = nodes.position.shape
    tmodel = dyn.DYN(gn, N=N, DIM=DIM)

    F_factor = 1
    freq_load = 10
    F02 = F0 * F_factor
    print(f"freq_load: {freq_load} Hz")
    w = 2*np.pi*freq_load
    func = "sin"
    ftype = "sin"
    f_args = (w, )

    constraints = eg.get_constraints(width)
    external_force = eg.get_eF(jnp.array(f_nodes),
                               -F02,
                               nodes.position.shape,
                               func=ftype,
                               f_args=f_args,
                              )

    R = jnp.array(truss.nodes["position"])
    V = jnp.array(truss.nodes["position"])*0
    x_vec, v_vec = R.flatten(), V.flatten()
    M = tmodel.d2L_dv2(x_vec, v_vec, graph_truss)
    M_1 =  jnp.linalg.pinv(M)

    acc = dyn.acceleration(tmodel,
                           external_force=external_force,
                           constraints=constraints,
                           use_dissipative_forces=False,
                           constant_mass_inv=M_1
                          )

    dt = DT/STRIDE

    t = 0.0

    print(f"Time step: {dt}, Total time: {dt * training_RUNS}, Sampling every: {DT}, Frequecy: {1/DT}")

    A = acc(R, 0*V, t, graph_truss)
    state = (R, 0*V, A, graph_truss, t, dt)

    apply = utils.get_apply(acc)
    traj = utils.solve_dynamics(state, apply, runs=RUNS, stride=STRIDE)
    nc.plot_traj_energy([traj], tmodel, graph_truss)

    
    generate_traj = data_gen.get_gen_traj(tmodel, state,
                                          constraints=constraints,
                                          constant_mass_inv=M_1,
                                          use_dissipative_forces=False)

    def generate_data(*args, **kwargs):
        traj = generate_traj(*args, **kwargs)
        return traj[2].reshape(-1, N*DIM)

    
    #############################
    # Plot truss
    #############################
    dmembers = np.argwhere(damage>0.0).ravel()
    fig, ax = plt.subplots(figsize=(8, 8), dpi=300)
    truss_plot(ax=ax, dmembers=dmembers, dvalues=damage[dmembers],
               text_kwargs2={"fontsize": 7},
               text_kwargs={"fontsize": 7},
              )
    plt.savefig(f"{fdirname}/dtruss_{IDD}.png")

    #############################
    #############################
    
    EXTERNAL_FORCE = data_gen.getF(Fss, -F02, nodes.position.shape)
    ACCELERATION = vmap(lambda x: generate_data(x, runs=training_RUNS, stride=STRIDE, 
                                        Ffunc=ftype, f_args=f_args))(EXTERNAL_FORCE.T)
    print(f"Total training time: {training_RUNS*dt*STRIDE}")
  
    TIME = np.array(range(0, training_RUNS)) * dt*STRIDE

    IDD2 = f"{IDD}"
    
    for noise_f in noise_fs:    

        IDD = f"{IDD2}_noise_{noise_f}"
        
        ACCELERATION_noise = utils.add_noise(noise_f, ACCELERATION)
    
        STATE = (1*R, 0*R, 0*R, graph_truss, 0.0, dt)
        data = [((STATE, EXTERNAL_FORCE[:, i], (f_args, ftype)), ACCELERATION_noise[i]) for i in range(len(ACCELERATION_noise))]
    
        #############################
        # Plot acc with noise
        #############################
        cnodes = [3]
        fig, axs = plt.subplots(len(Fss), DIM, figsize=(6*DIM, len(Fss)*3), sharex=True, sharey=False,
                                gridspec_kw = dict(wspace=0.12, hspace=0.04))
        for j in range(len(Fss)):
            for k in range(DIM):
                ax = axs[j, k]
                if k==0:
                    ax.set_ylabel(f"A ({U_Length} / {U_Time}$^2$)")
                if j==len(Fss)-1:
                    ax.set_xlabel(f"Time ({U_Time})")
        
                for i in cnodes:
                    ax.text(0.8, 0.8, f"Case: {j+1}, $A_{i+1}^{'x' if k==0 else 'y'}$",
                            transform=ax.transAxes, ha="center", bbox=dict(color="w", alpha=0.7))
                    ax.plot(TIME, ACCELERATION[j].reshape(-1, N, DIM)[:, i, k], lw=2,
                           label="Acceleration",)
                    ax.scatter(TIME, ACCELERATION_noise[j].reshape(-1, N, DIM)[:, i, k],
                               marker="o",
                               s=20,
                               fc="none",
                               ec="r",
                               label="Acceleration + noise",
                              )
        
        axs[0, 0].legend(ncol=2, bbox_to_anchor=(0, 1), loc=3)
        plt.savefig(f"{fdirname}/acc_withnoise_{IDD}.png", dpi=300)
        
    
        ##############################
        ##############################
        
        N = len(nodes.position)
        Ne = edges.type.shape[-1]
        EA_fn = lambda d: EA0 * (1 - d)
    
        model = models.TrussGraphModel(key, N=N, Ne=Ne, dim=DIM,
                                       runs=training_RUNS, stride=STRIDE,
                                       constraints=constraints,
                                       use_dissipative_forces=False,
                                       EA_fn=EA_fn,
                                       NODAL_MASS0=0.0,
                                       M_1=M_1,
                                       activation=jax.nn.leaky_relu,
                                       hidden_dim=8,
                                      )
    
    
        subset_sensors = jnp.array(list(snodes * 2 + 1) + list(snodes * 2))    
        updated_model = models.copy_model(model, runs=training_RUNS)
    
        _mean_e = []
        _max_e = []
        
        
        for trN in trNs:
            print(f"Training for {trN} runs.") 
        
            updated_model = models.copy_model(updated_model, runs=trN)
    
            def loss_fn(model, data):
                return src.train.loss(model, data, M=None, trN=trN, sensors=subset_sensors)
        
            _train = src.train.train(loss=loss_fn)
        
            optim = optax.adam(lr)
            updated_model, extras = _train(updated_model, data, [], optim, Epochs, print_every)
    
            fig, ax = plt.subplots()
            plt.semilogy(range(len(extras[0])), extras[0])
            plt.xlabel("Epochs")
            plt.ylabel("Loss value")
            
            plt.savefig(f"{fdirname}/loss_{IDD}_trN_{trN}.png", dpi=300)
        
            EA = np.array([E0*A0]*len(edges.EA))
            dd = damage/100
            EAd = EA * (1 - dd)
            
            (fig, ax), (EAd, EAe, eps) = pt.compareEA(updated_model, graph_truss, 
                                                      EA_fn, EAd, epsilon=0.05)
            plt.savefig(f"{fdirname}/damage_es_pr_{IDD}_trN_{trN}.png", dpi=300)
        
            pt.compareDamage(updated_model, graph_truss, dd, epsilon=0.02)
            plt.savefig(f"{fdirname}/damage_final_{IDD}_trN_{trN}.png", dpi=300)
            
        
            # fig, ax = plt.subplots()
            # plt.semilogy(range(Epochs), extras[0])
            # plt.xlabel("Epochs")
            # plt.ylabel("Loss value")
            # plt.savefig(f"{fdirname}/loss.png", dpi=300)
            # plt.show()
        
            # EA = np.array([E0*A0]*len(edges.EA))
            # dd = damage/100
            # EAd = EA * (1 - dd)
            # pt.compareEA(updated_model, graph_truss, EA_fn, EAd, epsilon=0.05)
            # plt.savefig(f"{fdirname}/damage_es_pr.png", dpi=300)
            # plt.show()
        
            # return updated_model, extras, dict(graph_truss=graph_truss, full_d=full_dict, EA_fn=EA_fn)
        
            def fn(EAe, which="mean"):
                a_ = 100 * (EAe - EAd) / EAd
                a = np.abs(a_)
                if which=="mean":
                    return a.mean()
                if which == "max":
                    return a.max()
            
            _mean_e = fn(EAe, which="mean")
            _max_e = fn(EAe, which="max") 
    
            plt.savefig(f"{fdirname}/error_{IDD}_trN_{trN}_mean_{_mean_e: .2f}_max_{_max_e: .2f}.png", dpi=10)

    print("Done.")
        

### Noise level

In [None]:
%matplotlib agg
import time

lr = 1.0e-3
Epochs = 1000

dnumber = 0.8
dlevel = 20.0
snodes = jnp.array([2, 10, 4, 12, 6]) - 1

noise_levels = [0.0, 0.1, 0.2, 0.3]
trNs = [10, 20, 40, 60, 80, 100]

seeds = [[102, 435, 860, 270], [106, 71, 700,  20], [614, 121, 123, 567]]

def fn(seed):
    print(seed)
    np.random.seed(seed)
    time.sleep(1)
    
    train_an_example(
        truss, E0, A0, ρ, width, F0,
        Fss,
        snodes,
        dnumber,
        dlevel,
        seed=seed,
        RUNS = RUNS,
        STRIDE = STRIDE,
        DT = DT,
        training_RUNS = 100,
        trNs = trNs,
        noise_fs = noise_levels,
        lr = lr,
        Epochs = Epochs,
        print_every = 100,
        example = "example1"
    )

    print(seed, "Done")
    return seed

In [None]:
import threading

for seed in seeds:
    print(seed)
    threads = [threading.Thread(target=fn, args=(s,), name='t1') for s in seed]

    for thread in threads:
        thread.start()
    
    for thread in threads:
        thread.join()

print("Done")


### Damage level

In [None]:
%matplotlib agg
import time

lr = 1.0e-3
Epochs = 1000

dnumber = 0.8
dlevels = [10.0, 20.0, 30.0, 40.0]
snodes = jnp.array([2, 10, 4, 12, 6]) - 1

noise_levels = [0.1]
trNs = [10, 20, 40, 60, 80, 100]

seeds = [[102, 435, 860, 270, 124]]

def fn(seed):
    print(seed)
    np.random.seed(seed)
    time.sleep(1)
    for dlevel in dlevels:
        train_an_example(
            truss, E0, A0, ρ, width, F0,
            Fss,
            snodes,
            dnumber,
            dlevel,
            seed=seed,
            RUNS = RUNS,
            STRIDE = STRIDE,
            DT = DT,
            training_RUNS = 100,
            trNs = trNs,
            noise_fs = noise_levels,
            lr = lr,
            Epochs = Epochs,
            print_every = 100,
            example = "example1_dlevel"
        )

    print(seed, "Done")
    return seed


import threading


for seed in seeds:
    print(seed)
    threads = [threading.Thread(target=fn, args=(s,), name='t1') for s in seed]

    for thread in threads:
        thread.start()
    
    for thread in threads:
        thread.join()

print("Done")




### Number of damages members

In [None]:
%matplotlib agg
import time

lr = 1.0e-3
Epochs = 1000

dnumbers = [0.6, 0.7, 0.8, 0.9]
dlevel = 20.0
snodes = jnp.array([2, 10, 4, 12, 6]) - 1

noise_levels = [0.1]
trNs = [10, 20, 40, 60, 80, 100]

seeds = [[102, 435, 860, 270, 124]]

def fn(seed):
    print(seed)
    np.random.seed(seed)
    time.sleep(1)
    for dnumber in dnumbers:
        train_an_example(
            truss, E0, A0, ρ, width, F0,
            Fss,
            snodes,
            dnumber,
            dlevel,
            seed=seed,
            RUNS = RUNS,
            STRIDE = STRIDE,
            DT = DT,
            training_RUNS = 100,
            trNs = trNs,
            noise_fs = noise_levels,
            lr = lr,
            Epochs = Epochs,
            print_every = 100,
            example = "example1_dnumber"
        )

    print(seed, "Done")
    return seed


import threading


for seed in seeds:
    print(seed)
    threads = [threading.Thread(target=fn, args=(s,), name='t1') for s in seed]

    for thread in threads:
        thread.start()
    
    for thread in threads:
        thread.join()

print("Done")


