In [1]:
%load_ext autoreload
%autoreload 2

## [Setup + Utils]

In [2]:
import py3Dmol

def traj_to_pdb(traj, downsample=1, tail=20):
    models = ""
    traj = traj[::downsample] + [traj[-1]] * tail
    for i, p in enumerate(tqdm(traj)):
        models += f"MODEL {i + 1}\n"
        models += p.to_pdb_str()
        models += "\nENDMDL\n"
    return models

def plot_py3dmol_traj(traj, window_size=(400, 400), duration=10000):
    # interval = duration / len(traj)
    v = py3Dmol.view(*window_size)
    models = traj_to_pdb(traj)
    v.addModelsAsFrames(models, 'pdb')
    v.setStyle({})
    v.addStyle({'atom': 'CA'}, {'sphere': {'radius': 0.5, 'color': 'darkgray'}})
    v.addStyle({'chain': 'A'}, {'stick': {}})
    v.setBackgroundColor("rgb(0,0,0)", 0)
    v.animate({'loop': 'forward'})
    return v

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"  # add this
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

In [4]:
import sys 
sys.path.append('../../')
import jax
import jax.numpy as jnp
import haiku as hk
import optax
from einops import repeat, rearrange
import functools
from collections import defaultdict
import plotly as plt
import e3nn_jax as e3nn
from tqdm import tqdm
from model.base.utils import inner_stack, inner_split
import functools
from moleculib.graphics.py3Dmol import plot_py3dmol_grid
import numpy as np
from kheiron.pipeline.utils import register_pytree
from moleculib.protein.datum import ProteinDatum
register_pytree(ProteinDatum)
import pickle



## Prepare Data

### -> Choose Dataset

#### 1. Fast Folding

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from moleculib.protein.dataset import FastFoldingDataset
import mdtraj as md
dataset = FastFoldingDataset(protein="chignolin", tau=0, buffer=100)

In [None]:
from torch.utils.data import DataLoader
x0 = dataset[0]
batch_size = 32
pad_size = len(x0)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=min(batch_size, 32), collate_fn=lambda x: x)

#### 2. ATLAS EIF4

In [None]:
from moleculib.traj.dataset import AtlasEIF4EDataset
dataset = AtlasEIF4EDataset(base_path='/mas/projects/molecularmachines/db/PREPROCESSED')

#### 3. AdK Equilibrium

In [None]:
from moleculib.traj.dataset import AdKEquilibriumDataset
dataset = AdKEquilibriumDataset(base_path='/mas/projects/molecularmachines/db/PREPROCESSED')

#### 4. Timewarp

In [5]:
from moleculib.traj.dataset import TimewarpDataset

dataset = TimewarpDataset(
    dataset="2AA-1-big",
    split="train",
    tau=0,
    max_files=1,
)
dataset.splits = { 'train': [dataset[i] for i in tqdm(range(1000))]}

100%|██████████| 1000/1000 [00:08<00:00, 115.59it/s]


### -> Preprocess

In [7]:
from kheiron.pipeline.trainer import batch_dict

x0 = dataset.splits['train'][1]
batch_size = 16 
pad_size = len(x0)
dataloader = [ batch_dict([dataset.splits['train'][i].to_dict() for i in range(j, j+batch_size) if i < len(dataset.splits['train'])]) for j in range(0, len(dataset.splits['train']), batch_size)]

In [8]:
dataset.splits['train'][0]

<moleculib.protein.datum.ProteinDatum at 0x7f6983af4df0>

In [10]:
from moleculib.graphics.py3Dmol import plot_py3dmol_grid
plot_py3dmol_grid([[dataset.splits['train'][1]]])

<py3Dmol.view at 0x7f6982508be0>

## Model Setup 

In [26]:

from functools import partial
from typing import List

from model.base.protein import protein_to_tensor_cloud, tensor_cloud_to_protein
# from model.base.denoiser import Denoiser
from model.generative.stochastic_interpolants import TensorCloudMirrorInterpolant
from configs.mirror_interpolant import ProteinMirrorInterpolant


EPS = jnp.array([1.0, 3.0])

import functools
import math
import haiku as hk
import jax.numpy as jnp
import e3nn_jax as e3nn
from typing import List
from model.base.spatial_convolution import CompleteSpatialConvolution, kNNSpatialConvolution
from model.base.self_interaction import SelfInteraction
from model.base.utils import TensorCloud
from model.base.layer_norm import EquivariantLayerNorm
import jax


model_params = dict(
    coord_layers=[16, 16, 16],
    feature_layers=[16, 16],
    k=8,
    k_seq=16,
    radial_cut=32.0,
    leading_shape=(pad_size,),
    var_features=0.5,
    var_coords=0.5,
)

model = hk.transform(lambda *a, **ka: ProteinMirrorInterpolant(**model_params)(*a, **ka))
sample = hk.transform(lambda *a, **ka: ProteinMirrorInterpolant(**model_params).sample(*a, **ka))
rng_seq = hk.PRNGSequence(42)
params = model.init(next(rng_seq), batch_dict([x0.to_dict()]))
# num_params = sum([jnp.prod(p.shape) for p in jax.tree_leaves(params)])
# print(f"Number of parameters: {num_params}")

In [39]:
def batch_dict(list_):
    keys = list_[0].keys()
    return {k: jnp.stack([d[k] for d in list_]) for k in keys if list_[0][k] is not None}


def unbatch_dict(d):
    keys = d.keys()
    return [{k: d[k][i] for k in keys} for i in range(d[list(keys)[0]].shape[0])]


@functools.partial(jax.jit, static_argnums=(2,))
def _simulate(key, prot, interp_steps=100, eps=jnp.array([1.0,1.0])):
    keys = jax.random.split(key, prot['atom_coord'].shape[0]) 
    return jax.vmap(lambda k, p: sample.apply(params, k, p, num_steps=1, interp_steps=interp_steps, eps=eps))(keys, prot)

def rollout(prot, num_steps, interp_steps, eps):
    traj = []
    prot = batch_dict(prot)
    for _ in tqdm(range(num_steps)):
        prot, _ = _simulate(next(rng_seq), prot, interp_steps=interp_steps, eps=eps)
        traj.append(unbatch_dict(prot))
    prot = unbatch_dict(prot)
    return prot, traj

_ = rollout([x0.to_dict()], num_steps=1, interp_steps=300, eps=jnp.array([1.0,1.0]))

(1, 2, 14, 3)


100%|██████████| 1/1 [00:22<00:00, 22.62s/it]


In [40]:
jax.config.update('jax_disable_jit', False)
jax.config.update('jax_debug_nans', False)

optimizer = optax.adam(1e-2, 0.9, 0.999)
opt_state = optimizer.init(params)

from model.losses import LossPipe
from model.losses import MirrorInterpolantLoss

loss_fn = LossPipe(
    loss_list=[
        MirrorInterpolantLoss(),
    ]
)

def loss(params, rng, data):
    def _apply_loss(key, datum):
        output = model.apply(params, key, datum)
        _, loss_, metrics = loss_fn(key, output, datum, 0)
        return loss_, {'loss': loss_, ** metrics }
    rng_keys = jax.random.split(rng, len(data.residue_token))
    losses, metrics = jax.vmap(_apply_loss)(rng_keys, data)
    metrics = { k: jnp.mean(v) for k, v in metrics.items() }
    return jnp.mean(losses), [metrics]

@jax.jit
def update(rng, params, opt_state, data) -> List:
    key, rng = jax.random.split(rng, 2)
    grads, [metrics] = jax.grad(
        loss,
        has_aux=True,
        argnums=0,
    )(params, key, data)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, metrics


## Train 

In [41]:
num_epochs = 10
metrics = defaultdict(list)

In [42]:
total_step = 0
import pickle
for epoch in range(num_epochs):
    # bar = tqdm(dataloader * 50)
    bar = tqdm(dataloader)
    for step, data in enumerate(bar):
        if len(data) != batch_size: continue
        data = batch_dict([batch_dict([d.to_dict()]) for d in data])
        params, opt_state, step_metrics = update(
            next(rng_seq), params, opt_state, data
        )
        for k, v in step_metrics.items(): metrics[k].append(v)
        bar.set_postfix({k: f'{float(v):.2e}' for k, v in step_metrics.items()})

        if (total_step % 1000) == 0:
            # out, traj = rollout(data[0], 1, interp_steps=100, eps=EPS)
            # plot_py3dmol_traj(traj).show()
            with open('p.pkl', 'wb') as f: pickle.dump(params, f)
        total_step += 1

100%|██████████| 63/63 [00:00<00:00, 355640.85it/s]
100%|██████████| 63/63 [00:00<00:00, 650840.28it/s]
100%|██████████| 63/63 [00:00<00:00, 636725.67it/s]
100%|██████████| 63/63 [00:00<00:00, 665594.84it/s]
100%|██████████| 63/63 [00:00<00:00, 431766.59it/s]
100%|██████████| 63/63 [00:00<00:00, 451694.28it/s]
100%|██████████| 63/63 [00:00<00:00, 526376.80it/s]
100%|██████████| 63/63 [00:00<00:00, 400972.92it/s]
100%|██████████| 63/63 [00:00<00:00, 416127.80it/s]
100%|██████████| 63/63 [00:00<00:00, 542589.63it/s]


In [None]:
import matplotlib.pyplot as plt
downsample = 5
window = 100

fig, axs = plt.subplots(len(metrics.items()), 1, figsize=(8, 6))
for i, (metric, array) in enumerate(metrics.items()):
    array = np.array(array)[::downsample]
    axs[i].plot(array - np.min(array))
    axs[i].set_ylabel(metric)
    axs[i].set_yscale('log')
fig.show()

## Simulate

In [None]:
out, traj = rollout(x0, 1, internal_traj=True)
# plot_py3dmol_traj(traj[::5], window_size=(400, 400))

In [None]:
out, traj = rollout(x0, 500, interp_steps=100, eps=jnp.array([1.0, 1.0]))

In [None]:
trajs = []
# for _ in range(10):
out, traj = rollout(out, 3000, interp_steps=300, eps=jnp.array([1.0, 1.0]))
    # trajs.append(traj)


In [None]:
# traj = [ x for t in trajs for x in t ]

In [None]:
# traj[0].residue_token
plot_py3dmol_traj(traj[-100:], window_size=(400, 400))

In [None]:
# out.to_pdb_str()
# 
# repeat(np.arange(len(out.residue_token)), "r -> r a", a=14)[out.atom_mask]
# expand it with standard np as a new dimension


In [None]:
len(out.residue_token)

In [None]:
v = plot_py3dmol_traj(traj[::4])
v.show()
html = v._make_html()

html_path = os.path.join(f"chignolin.html")
with open(html_path, "w") as f:
    f.write(html)

In [None]:
# real_traj = dataset.splits['train'][:len(traj)]
real_traj = dataset.splits['train']

## Benchmarks

In [None]:
import mdtraj as md
from tempfile import gettempdir
import os 

def protein_datum_to_md_traj(traj, downsample=1):
    pdbstr = traj_to_pdb(traj, tail=0)
    with open(f'{gettempdir()}/tmp.pdb', 'w') as f:
        f.write(pdbstr)
    mdtraj_datum = md.load(f'{gettempdir()}/tmp.pdb')
    os.remove(f'{gettempdir()}/tmp.pdb')
    return mdtraj_datum


md_traj = protein_datum_to_md_traj(traj)
real_md_traj = protein_datum_to_md_traj(dataset.splits['train'][:1000])

### Evaluate Chemistry

In [None]:
from moleculib.protein.metrics import StandardChemicalDeviation

def measure_chemistry(traj):
    metrics = defaultdict(list)
    chem = StandardChemicalDeviation()
    for step in tqdm(range(0, len(traj), 4)):
        metrics_ = chem(traj[step])
        for k, v in metrics_.items():
            metrics[k].append(v)
    return metrics

real_metrics = measure_chemistry(real_traj)
fake_metrics = measure_chemistry(traj)

In [None]:
import matplotlib.pyplot as plt
window = 100
fig, axs = plt.subplots(len(fake_metrics.items()), 1, figsize=(8, 10))
# for i, (metric, array) in enumerate(real_metrics.items()):
    # axs[i].plot(array)
for i, (metric, array) in enumerate(fake_metrics.items()):
    axs[i].set_title(metric)
    axs[i].plot(array)
fig.legend(['real', 'fake'])
fig.tight_layout()
fig.show()

In [None]:
# data_transform = plat.cfg['trainer']
from hydra_zen import instantiate


### Load Pretrained Model

In [None]:
from kheiron.pipeline.registry import Registry

registry_path = '/mas/projects/molecularmachines/experiments/generative/allanc3'
registry = Registry('ophiuchus', registry_path)
model_name = 'youthful-shape-1240'
from functools import reduce

plat = registry.get_platform(model_name, read_only=True)

data_transforms = instantiate(plat.cfg['trainer']['dataset']['transform'])
data_transform = lambda datum: reduce(lambda x, f: f.transform(x), data_transforms, datum)

params = plat.get_params(-1)
premodel = plat.instantiate_model()
transform = hk.transform(lambda *a, **ka: premodel().sample(*a, **ka))

@functools.partial(jax.jit, static_argnums=(1,))
def _simulate(prot, interp_steps=100):
    return transform.apply(params, next(rng_seq), prot, num_steps=1, interp_steps=interp_steps)

def rollout(prot, num_steps, interp_steps, internal_traj=False):
    traj = []
    for _ in tqdm(range(num_steps)):
        out, int_traj = _simulate(prot, interp_steps=interp_steps)
        if internal_traj:
            internals = [tensor_cloud_to_protein(tc, prot) for tc in inner_split(inner_split(int_traj)[0])]
            traj.extend(internals)
        else:
            traj.append(out)
        prot = out
    return out, traj


rng_seq = hk.PRNGSequence(42)
x0 = data_transform(dataset.splits['train'][1])

x1, traj = rollout(x0, 50, 1000)
plot_py3dmol_traj(traj, window_size=(400, 400))
# rename paths so that the first folder, of name "folder" gets a "_1" appended to iter

In [None]:
from typing import Dict, Any
def compute_rmsd(traj: md.Trajectory, ref: md.Trajectory) -> float:
    """Calculate the RMSD (in Angstroms)"""
    return 10 * md.rmsd(traj, ref)


def compute_rmsf(traj: md.Trajectory, ref: md.Trajectory, selector: str = 'name CA') -> float:
    """Calculate the RMSF on alpha-carbons (in Angstroms)"""
    # Select the alpha-carbon atoms
    indices = ref.topology.select(selector)
    return 10 * md.rmsf(traj.atom_slice(indices), ref.atom_slice(indices))


def compute_radius_of_gyration(traj: md.Trajectory) -> float:
    """Calculate the radius of gyration (in Angstroms)"""
    return 10 * md.compute_rg(traj)


def compute_SASA(traj: md.Trajectory) -> float:
    """Calculate the solvent accessible surface area (in Angstroms^2)"""
    return 100 * md.shrake_rupley(traj, mode='residue')




def compute_contact_maps(traj: md.Trajectory, contacts: str, scheme: str, distance_cutoff: float) -> np.ndarray:
    distances, residue_pairs = md.compute_contacts(traj, contacts=contacts, scheme=scheme, ignore_nonprotein=True)
    distances = distances * 10  # C,nvert to Angstroms
    contact_maps = md.geometry.squareform(distances, residue_pairs)
    contact_maps[contact_maps > distance_cutoff] = 0
    return contact_maps


def compute_KL_divergence(hist1: np.ndarray, hist2: np.ndarray) -> float:
    """Compute the Kullback-Leibler divergence between two histograms."""
    div = hist1 * np.log(hist1 / hist2)
    div = np.where(hist1 == 0, 0, div)
    return np.sum(div)


def compute_JS_divergence(hist1: np.ndarray, hist2: np.ndarray) -> float:
    """Compute the Jensen-Shannon divergence between two histograms."""
    mix = (hist1 + hist2) / 2
    return (compute_KL_divergence(hist1, mix) + compute_KL_divergence(hist2, mix)) / 2
            

def norm_of_difference_in_CA_positions(traj: md.Trajectory, diff_index: int) -> np.ndarray:
    """Compute the norm of the difference in alpha-carbon positions between timesteps separated by `diff_index`."""
    # Select the alpha-carbon (CA) atoms.
    ca_indices = traj.topology.select('name CA')
    num_residues = len(ca_indices)
    num_frames = traj.n_frames

    alpha_carbon_positions = traj.xyz[:, ca_indices, :]
    assert alpha_carbon_positions.shape == (num_frames, num_residues, 3)

    # Convert to Angstroms.
    alpha_carbon_positions = alpha_carbon_positions * 10

    # Take the difference in CA coordinates between timesteps separated by `diff_index`.
    alpha_carbon_positions_diff = alpha_carbon_positions[diff_index:] - alpha_carbon_positions[:-diff_index]
    assert alpha_carbon_positions_diff.shape == (num_frames - diff_index, num_residues, 3)

    # Compute the norm of the difference vector.
    alpha_carbon_positions_norm = np.linalg.norm(alpha_carbon_positions_diff, axis=-1)
    assert alpha_carbon_positions_norm.shape == (num_frames - diff_index, num_residues)

    alpha_carbon_positions_sq_norm = np.square(alpha_carbon_positions_norm)
    assert alpha_carbon_positions_sq_norm.shape == (num_frames - diff_index, num_residues)

    alpha_carbon_positions_rmsf = np.sqrt(np.mean(alpha_carbon_positions_sq_norm, axis=-2))
    return alpha_carbon_positions_rmsf


def compute_metrics(traj: md.Trajectory, 
                    contact_map_distance_cutoff: float = 8.0,
                    superpose: bool = True,
                    ref: md.Trajectory = None) -> Dict[str, Any]:
    metrics = {}    


    """Compute the RMSD, RMSF, radius of gyration, phi and psi angles, and contact map of the trajectory."""
    # Align the trajectory to the reference structure
    if ref is not None:
        if superpose:
            print("Aligning to the reference structure")
            traj = traj.superpose(ref)
        else:
            print("Not aligning to the reference structure")

        rmsd = compute_rmsd(traj, ref)
        rmsf = compute_rmsf(traj, ref)
        gyration_radius = compute_radius_of_gyration(traj)

        metrics = {
            'rmsd': rmsd,
            'rmsf': rmsf,
            'gyration_radius': gyration_radius,
        }
    


### Measure Ramachandran

In [None]:
def compute_phi_and_psi_angles(traj: md.Trajectory) -> float:
    # filter_backbone 
    # traj = traj.atom_slice(traj.top.select('backbone'))
    _, phi_angles = md.compute_phi(traj)
    _, psi_angles = md.compute_psi(traj)

    phi_angles = np.rad2deg(phi_angles).flatten() # dropping the first one?
    psi_angles = np.rad2deg(psi_angles).flatten()
    return phi_angles, psi_angles

def compute_phi_psi_histogram(phi_angles: np.ndarray, psi_angles: np.ndarray) -> np.ndarray:
    return np.histogram2d(
        phi_angles, psi_angles, bins=np.linspace(-180, 180, 60), density=True
    )[0]



def compute_KL_divergence(hist1: np.ndarray, hist2: np.ndarray) -> float:
    """Compute the Kullback-Leibler divergence between two histograms."""
    hist2 = np.where(hist2 == 0, 1, hist2)
    div = hist1 * np.nan_to_num(np.log(hist1 / hist2))
    div = np.where(hist1 == 0, 0, div)
    return np.sum(div)

hists =[]


for label, traj_ in zip(('Reference', 'Model'), (real_md_traj, md_traj)):
    phi_angles, psi_angles = compute_phi_and_psi_angles(traj_)

    plt.hist2d(
        phi_angles, psi_angles, bins=60, density=True,
        range=[[-180, 180], [-180, 180]],
    ) 
    plt.gca().set_aspect('equal', adjustable='box')
    plt.xlabel('$\phi$', fontsize=20)
    plt.ylabel('$\psi$', fontsize=20)

    plt.xlim(-180, 180)
    plt.xticks(np.arange(-180, 181, 60))

    plt.ylim(-180, 180)
    plt.yticks(np.arange(-180, 181, 60))
    
    plt.title(label)

    plt.show()

    hists.append(
        compute_phi_psi_histogram(phi_angles, psi_angles)
    )


js = compute_JS_divergence(*hists)
print(f"JS divergence: {js:.5f}")

kl = compute_KL_divergence(*hists)
print(f"KL divergence: {kl:.5f}")



### Pairwise Distance Distribution

In [None]:

contact_map_distance_cutoff = 8.0
for label, traj_ in zip(('Reference', 'Model'), (real_md_traj, md_traj)):
    c_map = compute_contact_maps(
        _traj, contacts='all', scheme='ca', distance_cutoff=contact_map_distance_cutoff
    )

    z = c_map[0]
    plt.imshow(
        z,
        cmap=plt.cm.viridis,
    )
    plt.gca().set_aspect('equal', adjustable='box')

    plt.xlabel('Residue index', fontsize=13)
    plt.ylabel('Residue index', fontsize=13)
    
    plt.title(label)

    plt.show()

In [None]:
def get_pwd(traj):
    distances = md.geometry.contact.squareform(*md.compute_contacts(traj, ignore_nonprotein=True))
    return distances.mean(0)[np.triu_indices_from(distances[0], k=3)]

compute_JS_divergence(get_pwd(real_traj), get_pwd(md_traj))

In [None]:
def compute_velocity(traj):
    ca_velocities = []
    sc_velocities = []
    for i in range(1, len(traj)):
        v = traj[i].atom_coord - traj[i-1].atom_coord
        ca_velocities.append(np.linalg.norm(v[..., 1, :], axis=-1))
        sc_velocities.append(np.linalg.norm(v[..., 2:, :], axis=-1))
    ca_velocities = np.stack(ca_velocities).mean(-1)
    sc_velocities = np.stack(sc_velocities).mean(-1)
    return ca_velocities, sc_velocities

ca_velocities, sc_velocities = compute_velocity(traj)
# import matplotlib.pyplot as plt
plt.plot(ca_velocities)

### TICA Analysis

In [None]:
import mdtraj
from pathlib import Path
# get the xtc path and get its directory
reference = Path(dataset.files[0])
reference_dir = reference.parent
reference = mdtraj.load(reference_dir / 'filtered.pdb')
ca_indices = reference.top.select('name CA')

In [None]:
from einops import rearrange

real_traj_ca_coord = dataset.coords[:, ca_indices]
pairwise_distances = np.linalg.norm(rearrange(real_traj_ca_coord, 'b n d -> b n () d') - rearrange(real_traj_ca_coord, 'b m d -> b 1 m d'), axis=-1)
pairwise_distances = pairwise_distances.reshape(-1, pairwise_distances.shape[-1] ** 2)

# real_traj_ca_coord.shape

In [None]:
ca_coords = [step.atom_coord[..., 1, :] for step in traj]
pairwise_distances_model = np.linalg.norm(rearrange(ca_coords, 'b n d -> b n () d') - rearrange(ca_coords, 'b m d -> b 1 m d'), axis=-1)
pairwise_distances_model = pairwise_distances_model.reshape(-1, pairwise_distances_model.shape[-1] ** 2)

In [None]:

import matplotlib.pyplot as plt
import numpy as np

from pyemma.coordinates import tica
real_traj = dataset.splits['train']

# print(len(real_traj))


tic1, tic2 = tica(pairwise_distances, lag=100, dim=2).get_output()[0].T



In [None]:
plt.figure(figsize=(5, 5))
import matplotlib as mpl
plt.hist2d(tic1, tic2, bins=60, cmap='viridis', density=True, norm=mpl.colors.LogNorm())
# plt.gca().set_aspect('equal', adjustable='box')

## Random

In [None]:
# from moleculib.graphics.py3Dmol import plot_py3dmol_grid
# from moleculib.protein.datum import ProteinDatum
# from model.base.protein import protein_to_tensor_cloud
# from model.generative.stochastic_interpolants import NormalDistribution
# from model.base.utils import TensorCloud

# # trp_cage1 = ProteinDatum.fetch_pdb_id('1l2y', model=3)
# trp_cage2 = ProteinDatum.fetch_pdb_id('1l2y', model=6)
# plot_py3dmol_grid([[trp_cage2]], (500, 500),colors=colors).show()

# # trp_cage = trp_cage[-6:-3]
# # prng = hk.PRNGSequence(42)
# # error = NormalDistribution('14x1e', coords_scale=3, irreps_scale=0.7).sample(next(prng), (20, ))

# # colors = list(Color("violet").range_to(Color("red"), 20))
# # colors = [c.hex for c in colors]

# # plot_py3dmol_grid([[trp_cage]], (500, 500), colors=colors).show()
# plot_py3dmol_grid([[protein_to_tensor_cloud(trp_cage2)]], (500, 500), radius=0.1, mid=0.8, colors=colors).show()
# # from colour import Color


# # latent = protein_to_tensor_cloud(trp_cage)
# # plot_py3dmol_grid([[error]], (500, 500), radius=0.1, mid=0.8, colors=colors).show()

# # plot_py3dmol_grid([[1.0 * error + latent.replace(irreps_array=latent.irreps_array.filter('1e'))]], (500, 500), radius=0.1, mid=0.8, colors=colors).show()

