In [None]:
def evaluate_model(sindy, test_data, params, encoder):
    sindy.eval()
    x_test = torch.from_numpy(test_data["x"]).to(device)
    dx_test = torch.from_numpy(test_data["dx"]).to(device)

    with torch.no_grad():
        # print(f"shape x_test, dx_test: {np.shape(x_test)}, {np.shape(dx_test)}")
        x, dx, dz_predict, dz, x_decode, dx_decode, sindy_coeffs = sindy(x_test, dx_test)
        z = encoder.net(torch.cat((x, dx)))[:x.shape[0]]  # get z from encoder

    def sindy_simulate(z0, t, coeffs, poly_order, include_sine):
        def sindy_rhs(z, t):
            z_tensor = torch.tensor(z[None, :], dtype=torch.float64).to(device)
            Theta = sindy_library_pt(z_tensor, z.shape[0], poly_order, include_sine).cpu().numpy()
            return (Theta @ coeffs).flatten()

        return odeint(sindy_rhs, z0, t)

    z0 = z[0].detach().cpu().numpy()
    t = test_data["t"][:, 0]
    #coeffs = (sindy_coeffs * sindy.coefficient_mask).detach().cpu().numpy()
    coeffs = (sindy_coeffs).detach().cpu().numpy()
    z_sim = sindy_simulate(z0, t, coeffs, params["poly_order"], params["include_sine"])

    # Plot latent dynamics
    z_np = z.detach().cpu().numpy()
    plt.figure(figsize=(3,2))
    plt.subplot(2,1,1)
    plt.plot(z_np[:,0], color='#888888', linewidth=2)
    plt.plot(z_sim[:,0], '--', linewidth=2)
    plt.axis('off')

    plt.subplot(2,1,2)
    plt.plot(z_np[:,1], color='#888888', linewidth=2)
    plt.plot(z_sim[:,1], '--', linewidth=2)
    plt.axis('off')
    plt.show()

    plt.figure(figsize=(3,3))
    plt.plot(z_sim[:,0], z_sim[:,1], linewidth=2)
    plt.axis('equal')
    plt.axis('off')
    plt.show()

    # Compute relative errors
    x_np = x.detach().cpu().numpy()
    x_decode_np = x_decode.detach().cpu().numpy()
    dx_np = dx.detach().cpu().numpy()
    dx_decode_np = dx_decode.detach().cpu().numpy()
    dz_np = dz.detach().cpu().numpy()
    dz_predict_np = dz_predict.detach().cpu().numpy()

    decoder_x_error = np.mean((x_np - x_decode_np)**2) / np.mean(x_np**2)
    decoder_dx_error = np.mean((dx_np - dx_decode_np)**2) / np.mean(dx_np**2)
    sindy_dz_error = np.mean((dz_np - dz_predict_np)**2) / np.mean(dz_np**2)

    print(f'Decoder relative error: {decoder_x_error:.6f}')
    print(f'Decoder relative SINDy error: {decoder_dx_error:.6f}')
    print(f'SINDy relative error, z: {sindy_dz_error:.6f}')

In [None]:
params = json.load(open("params.json"))
encoder = AutoEncoder(params, "encoder")
decoder = AutoEncoder(params, "decoder")
sindy = SINDy(encoder, decoder, device=device, params=params).to(device)



sindy.load_state_dict(torch.load("sindy_model.pt"))
sindy.eval()  # Optional: evaluation mode

training_data_rd, validation_data_rd, test_data_rd = get_rd_data(random=True)
evaluate_model(sindy, test_data_rd, params, encoder)