In [3]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  # add this
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"


In [4]:

import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)


import chex
import pickle
from typing import NamedTuple, Any, List
import plotly.graph_objects as go
import jax
import jax.numpy as jnp
import haiku as hk
import optax
from einops import repeat, rearrange
import functools
from fastprogress import master_bar
from collections import defaultdict
import plotly as plt
import math
import e3nn_jax as e3nn
from tqdm import tqdm
from colour import Color
from model.base.utils import inner_stack, inner_split

from functools import reduce



## Tyrosine Autoencoder

In [None]:
from moleculib.protein.dataset import MonomerDataset
from moleculib.protein.transform import DescribeChemistry, MaybeMirror

protein_ds = MonomerDataset(
    base_path='/mas/projects/molecularmachines/db/PDB',
    attrs="all",
    max_resolution=1.5,
    min_sequence_length=16,
    max_sequence_length=64,
    frac=1.0,
    transform=[
        MaybeMirror(hand="right"),
        DescribeChemistry(),
    ]
)

In [None]:
from moleculib.protein.datum import ProteinDatum 
from moleculib.protein.alphabet import all_residues

target_token = all_residues.index('TYR')
residue_ds = []
print('Collecting Tyrosines')
for protein in tqdm(protein_ds):
    for index, (coord, coord_mask, token, mask, bond, bond_mask) in enumerate(zip(
        jnp.array(protein.atom_coord), 
        jnp.array(protein.atom_mask), 
        jnp.array(protein.residue_token), 
        protein.residue_mask, 
        protein.bonds_list, 
        protein.bonds_mask
    )):
        if token.item() != target_token: continue

        if not mask.all(): continue
        if not coord_mask[1].all(): continue

        bonds_mask = protein.bonds_mask[index:index+1] 
        bonds_mask[..., -1] = False

        angles_mask = protein.angles_mask[index:index+1]
        angles_mask[..., -1] = False
        angles_mask[..., -2] = False

        cut = ProteinDatum(
            protein.idcode,
            protein.resolution,
            sequence=protein.sequence[index:index+1],
            residue_token=protein.residue_token[index:index+1],
            residue_index=protein.residue_index[index:index+1],
            residue_mask=protein.residue_mask[index:index+1],
            chain_token=protein.chain_token[index:index+1],
            atom_token=protein.atom_token[index:index+1],
            atom_coord=protein.atom_coord[index:index+1],
            atom_mask=protein.atom_mask[index:index+1],
            flips_list=protein.flips_list[index:index+1],
            flips_mask=protein.flips_mask[index:index+1],
            bonds_list=protein.bonds_list[index:index+1] - index * 14,
            bonds_mask=bonds_mask,
            angles_list=protein.angles_list[index:index+1] - index * 14,
            angles_mask=angles_mask,
        )
        
        residue_ds.append(cut)

In [None]:
len(residue_ds)

In [None]:
residue_ds[-1].bonds_list

In [None]:
import pickle 
from copy import deepcopy
with open('tyrosine.pkl', 'wb') as f:
    pickle.dump(residue_ds, f)

In [5]:
import pickle 
from copy import deepcopy
with open('tyrosine.pkl', 'rb') as f:
    residue_ds = pickle.load(f)

In [None]:
from moleculib.graphics.py3Dmol import plot_py3dmol_grid
plot_py3dmol_grid([[residue_ds[i] for i in range(0, 2)]])

In [None]:
make_loader = lambda ds, n: [ds[i * n:(i + 1) * n] for i in range((len(ds) + n - 1) // n )]
batch_size = 64
dataloader = make_loader(residue_ds * 10, batch_size) 

In [None]:
from model.base.self_interaction import SelfInteraction
from model.base.utils import InternalState
from moleculib.protein.datum import ProteinDatum
import chex 

@chex.dataclass
class ModelOutput:
    v1_hat: e3nn.IrrepsArray = None
    hidden: List[InternalState] = None
    datum: ProteinDatum = None
    atom_perm_loss: jnp.ndarray = None

In [None]:
from copy import deepcopy

class OrderedAutoencoder(hk.Module):
    
    def __init__(self, layers, basis=e3nn.Irreps('0e + 1e + 1e')):
        super().__init__()
        self.basis = basis        
        self.layers = [dim * self.basis for dim in layers]

    def __call__(self, x: ProteinDatum):
        ca_pos = x.atom_coord[..., 1:2, :]
        aa_pos = jnp.where(
            x.atom_mask[..., None],
            x.atom_coord - ca_pos,
            0
        )

        
        v0 = e3nn.IrrepsArray('32x0e', hk.Embed(1, 32)(x.residue_token))
        v1 = e3nn.IrrepsArray('14x1e', jnp.array(rearrange(aa_pos, '... a c -> ... (a c)')))
        v = e3nn.concatenate((v0, v1), axis=-1)

        state = InternalState(
            irreps_array=v,
            mask_irreps_array=jnp.array(x.atom_mask[..., 1]),
            coord=jnp.array([[0, 0, 0]]),
            mask_coord=jnp.array([1]).astype(jnp.bool_),
        )

        hidden = SelfInteraction(
            self.layers[:len(self.layers) // 2],
            full_square=False,
            chunk_factor=0,
            norm_last=False,
        )(state)

        state = SelfInteraction(
            self.layers[len(self.layers) // 2:],
            full_square=True,
            chunk_factor=0,
            norm_last=False,
        )(hidden)

        v1_hat = e3nn.haiku.Linear('14x1e')(state.irreps_array).array
        
        datum = deepcopy(x)
        datum.atom_coord = jnp.array(rearrange(v1_hat, '... (a c) -> ... a c', a=14, c=3)) + ca_pos
 
        return ModelOutput(datum=datum, hidden=hidden)

    

layers = [8, 8, 8]
model = hk.transform(lambda *args: OrderedAutoencoder(layers)(*args))

metrics = defaultdict(list)

rng_seq = hk.PRNGSequence(42)
params = model.init(next(rng_seq), residue_ds[0])

optimizer = optax.adam(1e-2, 0.9, 0.999)
opt_state = optimizer.init(params)


In [None]:
from functools import partial
from typing import Any, List, Tuple
import jax
import jax.numpy as jnp
import haiku as hk
import e3nn_jax as e3nn
from model.base.self_interaction import SelfInteraction
from einops import rearrange, repeat
from moleculib.protein.alphabet import (
    all_residues_atom_mask,
    all_residues_atom_tokens,
    flippable_arr,
    flippable_mask,
)

from model.base.utils import safe_normalize
import jax.lax.linalg as lax_linalg

@jax.custom_jvp
def eig(a):
    w, vl, vr = lax_linalg.eig(a)
    return w, vr

from jax.numpy.linalg import solve

import jax.lax as lax

@eig.defjvp
def eig_jvp_rule(primals, tangents):
    a, = primals
    da, = tangents

    w, v = eig(a)

    eye = jnp.eye(a.shape[-1], dtype=a.dtype)
    # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
    Fmat = (jnp.reciprocal(eye + w[..., jnp.newaxis, :] - w[..., jnp.newaxis])
            - eye)
    dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)
    vinv_da_v = dot(solve(v, da), v)
    du = dot(v, jnp.multiply(Fmat, vinv_da_v))
    corrections = (jnp.conj(v) * du).sum(-2, keepdims=True)
    dv = du - v * corrections
    dw = jnp.diagonal(vinv_da_v, axis1=-2, axis2=-1)
    return (w, v), (dw, dv)




from model.base.protein import ProteinDatumEncoder, ProteinDatumDecoder


def sqrt_2tensor(y):
    """
    sqrt_2tensor(e3nn.sh(2, x, False)) == x or -x
    e3nn.sh(2, sqrt_2tensor(y), False) == y
    """
    assert y.shape == (5,)
    A = e3nn.generators(2) @ y
    A = jnp.conj(A) @ A.T
    val, vec = safe_eigh(A)
    x = vec.T[0]  # first is the smallest eigenvalue
    safe_sqrt = lambda x: jnp.sqrt(jnp.maximum(x, 1e-7))
    x = x * safe_sqrt(safe_sqrt(jnp.mean(y**2)))
    return x, val[0]


class ProteinDatumDecoder(hk.Module):
    def __init__(
        self,
        irreps: e3nn.Irreps,
        interact: bool = True,
        depth: int = 0,
        rescale: float = 2.0,
    ):
        super().__init__()
        self.depth = depth
        if not str(irreps.filter('2e')):
            dim = irreps[0].mul
            irreps = irreps + e3nn.Irreps(f'{dim}x2e')
        self.irreps = irreps
        self.interact = interact
        self.rescale = rescale

    def __call__(
        self, state: InternalState, sequence_token: jnp.ndarray = None
    ) -> ProteinDatum:
        seq_len = state.irreps_array.shape[0]
        assert state.irreps_array.shape == (seq_len, state.irreps_array.irreps.dim)
        assert state.mask.shape == (seq_len,)
        assert state.coord.shape == (seq_len, 3)

        all_vecs_decoded = SelfInteraction(
            [self.irreps * int(self.rescale)] * self.depth
            + [e3nn.Irreps(f"{25}x0e + {23 * 14}x1e + {23}x2e").regroup()],
            chunk_factor=2,
            norm_last=False,
        )(state)

        invariants = state.irreps_array.filter(keep="0e")
        logits = e3nn.haiku.MultiLayerPerceptron(
            [invariants.irreps.dim, invariants.irreps.num_irreps, 25], act=jax.nn.silu,
            output_activation=False
        )(invariants).array
        
        # logits = all_vecs_decoded.irreps_array.filter("0e").array
        res_logits, sos_logits, eos_logits = (
            logits[..., :23],
            logits[..., -1],
            logits[..., -2],
        )

        if sequence_token is None:
            sequence_token = jnp.argmax(res_logits, axis=-1)
            sequence_token = jnp.where(
                jnp.arange(len(sequence_token)) > eos_logits.argmax(-1),
                0, sequence_token
            )


        seq_len = sequence_token.shape[0]
        assert sequence_token.shape == (seq_len,)
        state = jax.tree_util.tree_map(lambda x: x[:seq_len], state)

        all_vecs_decoded = all_vecs_decoded.irreps_array.filter("1e + 2e").mul_to_axis(
            23
        )
        vecs_decoded = jax.vmap(lambda arr, idx: arr[idx])(
            all_vecs_decoded, sequence_token
        )
        vecs3 = rearrange(
            vecs_decoded.filter("1e").array, "... (a e) -> ... a e", a=14, e=3
        )
        vecs5 = vecs_decoded.filter("2e").array
        vecs5 = repeat(vecs5, "... e -> ... a e", a=2)

        logit_extract = repeat(sequence_token, "r -> r l", l=23) == repeat(
            jnp.arange(0, 23), "l -> () l"
        )

        atom_token = (logit_extract[..., None] * all_residues_atom_tokens[None]).sum(-2)
        atom_mask = (logit_extract[..., None] * all_residues_atom_mask[None]).sum(-2)
        assert atom_token.shape == (seq_len, 14), atom_token.shape
        assert atom_mask.shape == (seq_len, 14), atom_mask.shape

        flips = (logit_extract[..., None, None] * flippable_arr[None]).sum(-3)
        flips_mask = (
            (logit_extract[..., None, None] * flippable_mask[None]).sum(-3).squeeze(-1)
        )

        flips = jnp.where(flips_mask[..., None], flips, 0)
        assert flips.shape == (seq_len, 2, 2), flips.shape

        vecs3 = vecs3 * atom_mask[..., None]
        vecs5 = vecs5 * flips_mask[..., None]
        assert vecs3.shape == (seq_len, 14, 3), vecs3.shape
        assert vecs5.shape == (seq_len, 2, 5), vecs5.shape

        flips_extract = (
            repeat(flips, "... -> ... a", a=14)
            == repeat(jnp.arange(0, 14), "a -> () a")
        ) * flips_mask[..., None, None]

        flippable = (vecs3[..., None, None, :, :] * flips_extract[..., None]).sum(-2)
        center = flippable[..., 0, :]

        diff, atom_perm_loss = jax.vmap(jax.vmap(sqrt_2tensor))(vecs5)

        # atom_perm_loss = 0.0        
        # rand_vecs = e3nn.normal('16x1e', hk.next_rng_key(), vecs_decoded.shape[:-1])
        # R = e3nn.concatenate(
        #     [
        #        invariants,
        #        rand_vecs,
        #        vecs_decoded,
        #     ], 
        #     axis=-1
        # ).regroup()

        # diff = SelfInteraction(
        #     [R.irreps] * 2 + ['2x0e + 2x1e'],
        #     full_square=True,
        #     norm_last=False,
        # )(InternalState(
        #     irreps_array=R,
        #     mask_irreps_array=jnp.ones((seq_len)).astype(jnp.bool_),
        #     coord=jnp.array([[0, 0, 0]]),
        #     mask_coord=jnp.array([1]).astype(jnp.bool_),
        # )).irreps_array.filter('1e').array

        # diff = rearrange(diff, '... (a e) -> ... a e', a=2, e=3)
        diff = safe_normalize(diff)

        assert diff.shape == (seq_len, 2, 3), diff.shape
        
        diff_scale = jnp.square(e3nn.haiku.Linear('1x0e')(state.irreps_array.filter('0e')).array)
        diff_scale = jnp.concatenate([diff_scale, diff_scale], axis=-1)
        diff_scale = diff_scale + 0.8

        diff = diff * diff_scale[..., None]

        sym_vecs3 = jnp.stack([center + diff, center - diff], axis=-2)
        assert sym_vecs3.shape == (seq_len, 2, 2, 3), sym_vecs3.shape

        atom_perm_loss = (atom_perm_loss * flips_mask).sum() / (flips_mask.sum() + 1e-6)

        sym_vecs3_aggregate = (sym_vecs3[..., None, :] * flips_extract[..., None]).sum(
            (-3, -4)
        )
        subs_mask = sym_vecs3_aggregate.sum(-1) > 0

        vecs3 = jnp.where(subs_mask[..., None], sym_vecs3_aggregate, vecs3)
        vecs3 = vecs3.at[..., 1, :].set(0.0)

        ca_coord = state.coord        

        atom_coord = ca_coord[..., None, :] + vecs3 * atom_mask[..., None]
        assert atom_coord.shape == (seq_len, 14, 3), atom_coord.shape


        datum = ProteinDatum(
            idcode=None,
            resolution=None,  # ophiuchus doesn't label resolution
            sequence=None,  # str and jax dont like each other
            residue_token=sequence_token,
            residue_index=jnp.arange(res_logits.shape[0]),
            residue_mask=sequence_token != 0,
            chain_token=None,  # TODO(Allan)
            atom_token=atom_token,  # TODO(Allan)
            atom_coord=atom_coord,
            atom_mask=atom_mask,
        )

        return (
            datum,
            (res_logits, sos_logits, eos_logits),
            atom_perm_loss,
        )



class OphiuchusAutoencoder(hk.Module):
    
    def __init__(self, layers, basis=e3nn.Irreps('0e + 1e + 2e')):
        super().__init__()
        self.basis = basis        
        self.layers = [int(dim) * self.basis for dim in layers]
        for layer in self.layers:
            print(layer)

    def __call__(self, x: ProteinDatum):
        state = ProteinDatumEncoder(irreps=self.layers[0], interact=False)(x)
        # hidden = state
        state = SelfInteraction(
            self.layers[:len(self.layers) // 2],
            full_square=False,
            chunk_factor=0,
            norm_last=False,
        )(state)

        hidden = state

        state = SelfInteraction(
            self.layers[len(self.layers) // 2:],
            full_square=False,
            chunk_factor=0,
            norm_last=True,
        )(state)

        (datum, _, perm_loss) = ProteinDatumDecoder(depth=2, irreps=self.layers[-1])(state, x.residue_token)
        
        return ModelOutput(datum=datum, hidden=hidden, atom_perm_loss=perm_loss)

layers = [8, 8, 8]
model = hk.transform(lambda *args: OphiuchusAutoencoder(layers)(*args))

metrics = defaultdict(list)

rng_seq = hk.PRNGSequence(42)
params = model.init(next(rng_seq), residue_ds[0])

optimizer = optax.adam(1e-2, 0.9, 0.999)
opt_state = optimizer.init(params)


In [None]:
from typing import Dict 
from model.losses import LossFunction
from model.base.utils import safe_norm

from model.losses import LossPipe, AtomPermLoss, BondLoss, AngleLoss, ClashLoss, InternalVectorLoss

loss_pipe = LossPipe(
    [InternalVectorLoss(100, norm_only=False)]# AtomPermLoss(1e-4), BondLoss(2.0, 0), AngleLoss(2.0, 0)],
)

def loss(params, rng, data, step):
    _apply_model = lambda k, x: model.apply(params, k, x)
    rng_keys = jax.random.split(rng, len(data))
    batch = inner_stack(data)
    predictions = jax.vmap(_apply_model)(rng_keys, batch) 
    _apply_loss = lambda x, y: loss_pipe(None, x, y, 0)
    predictions, loss_val, metrics = jax.vmap(_apply_loss)(predictions, batch)
    metrics = {k: v.mean() for k, v in metrics.items()}    
    return loss_val.mean(), [predictions, metrics]

@jax.jit
def update(rng, params, opt_state, data, step) -> List:
    key, rng = jax.random.split(rng, 2)
    grads, [predictions, metrics] = jax.grad(
        loss,
        has_aux=True, 
        argnums=0,
    )(params, key, data, step) 
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return predictions, params, opt_state, metrics

metrics = defaultdict(list)
num_epochs = 500
for epoch in range(num_epochs):
    bar = tqdm(dataloader)
    for step, data in enumerate(bar):
        if len(data) != batch_size:
            continue
        predictions, params, opt_state, step_metrics = update(
            next(rng_seq), params, opt_state, data, step + epoch * len(dataloader)
        )

        for k, v in step_metrics.items():
            metrics[k].append(v) 
        
        bar.set_postfix({k: f'{float(v):.2e}' for k, v in step_metrics.items()})

In [None]:
downsample = 1
for metric, array in metrics.items():
    array = array[::downsample]
    fig = go.Figure(data=[
        go.Scatter(x=jnp.arange(len(array)), y=array, line=dict(color="rgb(100, 100, 100)")),
    ])
    fig.update_layout(
        title=metric,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)'
    )
    fig.show()

In [None]:
@jax.jit
def autoencode(params, key, x):
    return model.apply(params, key, x)

In [None]:
for datum in residue_ds[:5]:
    prediction = autoencode(params, next(rng_seq), datum).datum.atom_coord
    predicted_datum = deepcopy(datum)
    predicted_datum.atom_coord = prediction
    plot_py3dmol_grid([[datum, predicted_datum]], sphere=True).show()

In [19]:
sample_residue = residue_ds[3]

In [20]:
import numpy as np

def rotation_matrix(axis, theta):
    axis = np.asarray(axis)
    axis = axis / math.sqrt(np.dot(axis, axis))
    a = math.cos(theta / 2.0)
    b, c, d = -axis * math.sin(theta / 2.0)
    aa, bb, cc, dd = a * a, b * b, c * c, d * d
    bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
    return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
                     [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
                     [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])


In [21]:
coord = sample_residue.atom_coord
rotation_axis = coord[:, 5] - coord[:, 4]
rotation_axis = rotation_axis / jnp.linalg.norm(rotation_axis, axis=-1, keepdims=True)
rotating_vectors = coord[:, 6:]

rotated  = []
for theta in np.linspace(0, 2 * np.pi, 12):
    rotation_matrix_ = rotation_matrix(rotation_axis[0], theta)

    rotated_vectors = rotating_vectors[0] - coord[0, 4:5]
    rotated_vectors = np.matmul(rotated_vectors, rotation_matrix_)
    rotated_vectors = rotated_vectors + coord[0, 4:5]
    rotated_vectors = np.expand_dims(rotated_vectors, axis=0)
    
    after_datum = deepcopy(sample_residue)
    after_datum.atom_coord = jnp.concatenate([
        coord[:, :6], rotated_vectors,
    ], axis=1)
    rotated.append(after_datum)


In [25]:
# from moleculib.graphics.py3Dmol import plot_py3dmol_grid
plot_py3dmol_grid([rotated[:len(rotated)//2], rotated[len(rotated)//2:]], sphere=True, window_size=(200, 200)).show()


In [20]:
hiddens = []
for datum in rotated:
    prediction = autoencode(params, next(rng_seq), datum)
    hidden = prediction.hidden
    hiddens.append(hidden.irreps_array)


In [21]:
x, y = [hidden.filter('1e').array[0][0] for hidden in hiddens], [hidden.filter('1e').array[0][1] for hidden in hiddens]

In [27]:
import matplotlib.pyplot as plt

theta = np.linspace(0, 2 * np.pi, 100)
# with large marker size
plt.scatter(x, y, c=theta, cmap='hsv', markersize=100)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=16)
plt.show()

NameError: name 'x' is not defined

In [None]:
_sample = hk.transform(lambda *args: premodel().sample())
@jax.jit
def sample(params, key):
    return _sample.apply(params, key)

_ = sample(params, next(rng_seq))

In [None]:
for _ in range(3):
    view = plot_3d([vec(sample(params, next(rng_seq))[0].irreps_array[0])])
    view.spin()
    view.show()

In [None]:

def plot_traj(traj, window_size=(300, 300), duration=10000, color='gray'):
    interval = duration / len(traj)
    v = py3Dmol.view(width=window_size[0], height=window_size[1])
    for frame_idx, (frame, color) in enumerate(zip(traj, Color('red').range_to(Color('green'), len(traj)))):
        for (x, y, z) in frame:
            # v.addSphere({"center": {"x": float(x), "y": float(y), "z": float(z)}, 'radius': 0.2, 'color': str(color), 'frame': frame_idx + 1})
            v.addArrow({"start": {"x": 0.0, "y":0.0, "z":0.0}, "end": {"x": float(x), "y": float(y), "z": float(z)}, 'radius': 0.05, 'mid':0.8, 'color': str(color), 'frame': frame_idx + 1})
    v.setBackgroundColor('rgb(0,0,0)', 0)
    v.animate({'loop': 'forward', 'interval': interval})
    return v


In [None]:
_, traj = sample(params, next(rng_seq))
vec_traj = vec(traj.irreps_array[:, 0, :])
v = plot_traj(vec_traj[::10])
v.show()