In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from se3cnn.util.plot import spherical_harmonics_coeff_to_sphere

In [None]:
import plotly.graph_objects as go

In [None]:
import vistools
import otp
import cgae.cgae

In [None]:
def traces_lines(ar, color='red'):
    def trace_two(a):
        x, y, z = a.T
        trace = go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='lines',
            line=dict(color=color, width=4)
        )
        return trace
    
    pairs = [np.asarray(i) for i in combinations(ar, 2)]
    traces = [trace_two(i) for i in pairs]
    return traces

# It all starts here

In [None]:
def load_batched_xyz(pkl):
    args = pkl['args']
    args.device = 'cpu'
    xyz, forces, features = otp.data(pkl['args'])
    n_batches, xyz, forces, features = otp.batch(xyz, forces, features, pkl['args'])
    return n_batches, xyz, forces, features

In [None]:
ls -lht *.pkl

In [None]:
# Select save data here
PICKLE = 'se_long.pkl'

pkl = torch.load(PICKLE, map_location='cpu')
_, xyz, _, features = load_batched_xyz(pkl)

In [None]:
steps = len(pkl['dynamics'])
epochs = pkl['dynamics'][-1]['epoch']
print(f"There are {steps} steps.")
print(f"Corresponding to {epochs} epochs.")
print(f"i.e. about {steps/epochs} steps/epoch")

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True)
loss = [i['loss'] for i in pkl['dynamics']]
loss_ae = [i['loss_ae'] for i in pkl['dynamics']]
loss_fm = [i['loss_fm'] for i in pkl['dynamics']]

ax[0].plot(loss, label='total loss')
ax[0].plot(loss_fm, label='force match')
ax[0].legend()

ax[1].plot(loss_ae, label='autoenc')
ax[1].legend()

fig.tight_layout()
plt.show()
plt.savefig('loss.png')

# Visualize frame

In [None]:
frame = 777
example = 0

In [None]:
temp = pkl['summaries'][frame]['temp']
batch = pkl['summaries'][frame]['batch']
geo = xyz[batch, example].detach().numpy()
gumble = pkl['summaries'][frame]['gumble'][example]
st_gumble = pkl['summaries'][frame]['st_gumble'][example]

cg_xyz = pkl['summaries'][frame]['cg_xyz'][example].detach().numpy()

print(temp, batch)

In [None]:
plt.imshow(gumble.detach().cpu().numpy().T, aspect=4)
# plt.xticks(np.arange(32))
# plt.yticks(np.arange(3))
# plt.yticks(np.arange(pkl['args'].ncg), ["CG" + str(i+1) for i in range(N_cg)])
plt.show()

In [None]:
def ylms_to_surface(ylms, center):
    xyz_signal = vistools.sh_coeff_to_xyz_signal(ylms, angular_resolution=100, r_scale=0.5)
    surface = vistools.xyz_signal_to_surface(xyz_signal, center)
    return surface

colors = ['red', 'green', 'blue']
colormap = {i: c for i, c in enumerate(colors)}

In [None]:
cg_sph = pkl['summaries'][frame]['pred_sph'][example].detach().numpy()
cg_sph_0 = pkl['summaries'][frame]['pred_sph'][example, :, :36].detach().numpy()
cg_sph_1 = pkl['summaries'][frame]['pred_sph'][example, :, 36:].detach().numpy()

In [None]:
# Hydrogen signal
data = [ylms_to_surface(ylms=cg_sph_0[i], center=cg_xyz[i]) for i in range(3)]

for color in colors:
    mask = vistools.assignment_to_color(st_gumble, colormap) == color
    data += [vistools.trace_pts(geo[mask.flatten()], color=color)]

fig = go.Figure(data=data)
fig.show()

In [None]:
# Carbon signal
data = [ylms_to_surface(ylms=cg_sph_1[i], center=cg_xyz[i]) for i in range(3)]

for color in colors:
    mask = vistools.assignment_to_color(st_gumble, colormap) == color
    data += [vistools.trace_pts(geo[mask.flatten()], color=color)]

fig = go.Figure(data=data)
fig.show()