In [None]:
import mdtraj as md
import numpy as np
import datetime
import tensorflow as tf
import nglview as nv
import os, sys
import matplotlib.pyplot as plt

from tensorflow.keras.models import load_model
%cd /home/jovyan/ASMSA/mydev


repo_dir = os.getcwd()   
if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)

from src.utils import plot_latent_space
from src.asmsa_callbacks import callbacks
from src.ae import asmsa_ae
from src.asmsa_features import process_trajectory
from src.asmsa_split import asmsa_datasets
from src.asmsa_analysis import analyze_reconstruction_with_sincos_blocks, plot_section_errors
from src.asmsa_loss import asmsa_ae_loss

nn_model = 'ae'
latent_dim = 2

In [None]:
tr = "trpcage_ds_nH.xtc"
conf = "trpcage_npt400_nH.pdb"

traj = md.load_xtc(tr, top=conf)
backbone_atoms = traj.topology.select('backbone')
traj.superpose(traj, 0, atom_indices=backbone_atoms)


In [None]:
view = nv.show_mdtraj(traj)

view.add_representation('line', selection='protein')
view

In [None]:
pdb_path = conf

backbone = []
with open(pdb_path) as f:
    atom_counter = 0
    for line in f:
        if not line.startswith("ATOM"):
            continue
        name = line[12:16].strip()
        if name == "N" or name == "C" or name=="CA":
            backbone.append(atom_counter)
        atom_counter += 1

ON = []
with open(pdb_path) as f:
    atom_counter = 0
    for line in f:
        if not line.startswith("ATOM"):
            continue
        name = line[12:16].strip()
        if name == "N" or name == "O":
            ON.append(atom_counter)
        atom_counter += 1

polar = []
with open(pdb_path) as f:
    atom_counter = 0
    for line in f:
        if not line.startswith("ATOM"):
            continue
        name = line[12:16].strip()
        pol = {
            "N", #backbone amide nitrogen
            "O", #backbone carbonyl oxygen
            "OG", #Serine
            "OG1", #Threonine
            "OH2", #Tyrosine
            "SG", #Cysteine
            "OD1", #Aspartate
            "OD2", #Aspartate
            "OE1", #Glutamate
            "OE2", #Glutamate 
            "ND2", #Asparagine
            "OD1", #Asparagine
            "NE2", #Glutamine
            "OE1", #Glutamine
            "ND1", #Histidine
            "NE2", #Histidine
            "NZ", #Lysine
            "NE", #Arginine
            "NH1", #Arginine
            "NH2", #Arginine
        }
        if name in pol:
            polar.append(atom_counter)
        atom_counter += 1

            
alpha = []
with open(pdb_path) as f:
    atom_counter = 0
    for line in f:
        if not line.startswith("ATOM"):
            continue
        name = line[12:16].strip()
        if name == "CA":
            alpha.append(atom_counter)
        atom_counter += 1

alphabeta = []
with open(pdb_path) as f:
    atom_counter = 0
    for line in f:
        if not line.startswith("ATOM"):
            continue
        name = line[12:16].strip()
        if name == "CA" or name == "CB":
            alphabeta.append(atom_counter)
        atom_counter += 1

print(f'Backbone({len(backbone)}): {backbone}')
print(f'ON({len(ON)}): {ON}')
print(f'Polar Atoms ({len(polar)}): {polar}')
print(f'Alpha C ({len(alpha)}): {alpha}')
print(f'Alpha and Beta ({len(alphabeta)}): {alphabeta}')

bonds = np.array([[backbone[i], backbone[i+1]] for i in range(len(backbone) - 1)])
angles = np.array([[backbone[i], backbone[i+1], backbone[i+2]] for i in range(len(backbone) - 2)])
dih = np.array([backbone[i:i+4] for i in range(len(backbone) - 3)])

In [None]:
feat = process_trajectory(
    traj=tr,
    conf=conf,
    atom_selection=alpha,      # 'protein' | 'backbone' | 'CA' | lista indici
    distance_mode="dense",   # 'sparse' o 'dense'
    density=2,                # only if sparse; 1..N
    include_angles=True
)

In [None]:
print(f"feature_keys: {feat.keys()}, feat_distances_shape: {(feat['dists'].shape)}")

## NN preprocessing

In [None]:
# Uso:
ds_train, ds_val, ds_test, ds_all, info = asmsa_datasets(feat['features_normalized'], train_size=70, val_size=15, batch_size=64, seed=42)

# AE

In [None]:
autoencoder, encoder, decoder = asmsa_ae(
    n_features=feat['features_normalized'].shape[1],
    latent_dim=latent_dim)

autoencoder.summary()


In [None]:
nD = feat['n_distance_features']
nA = feat['n_angle_features']

dist_slice = slice(0, nD)
ang_slice  = slice(nD, nD + nA)


## Train

In [None]:
log_dir = "logs/autoencoder/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
cb = callbacks(log_dir, latent_dim, monitor="val_loss", model=nn_model)  

learning_rate = 1e-4
optimizer = tf.keras.optimizers.AdamW(
    learning_rate=learning_rate,
    weight_decay=1e-5, 
    beta_1=0.9,
    beta_2=0.999
)

# Assuming asmsa_ae function exists
ae, encoder, decoder = asmsa_ae(
    n_features=feat['features_normalized'].shape[1],
    latent_dim=latent_dim
)
# Compilo l'autoencoder con la loss personalizzata
ae.compile(optimizer=optimizer,
           loss=asmsa_ae_loss(nD,nA),
           )


tensorboard --logdir logs/autoencoder --host localhost --port 6006

In [None]:
ae.fit(ds_train,epochs=500,validation_data=ds_val,callbacks=cb)

In [None]:
src = f"ae_{latent_dim}d.keras"
dest = "/home/tedeschg/prj/ASMSA/mydev/models/"

!mv {src} {dest}

# Decode and visualize

In [None]:
path = f"/home/tedeschg/prj/ASMSA/mydev/models/ae_{latent_dim}d.keras"

autoencoder = load_model(
    path,
    custom_objects={"recon_loss": asmsa_ae_loss(nD, nA)}
)

encoder = autoencoder.get_layer("encoder")
decoder = autoencoder.get_layer("decoder")


In [None]:
for batch_x, _ in ds_test.take(1):
    sample_x = batch_x[30]  
    break
    
sample_x_batch = tf.expand_dims(sample_x, axis=0)
test = encoder.predict(sample_x_batch)

test

In [None]:
target = np.array([0, 0]).reshape(1, latent_dim)
bb_indices = traj.topology.select('backbone')
ca_indices = traj.topology.select('name CA')
emb, sample = plot_latent_space(latent_dim, encoder, ds_all, conf, tr, test, bb_indices, model=nn_model, exact=True)

In [None]:
rms_ref = md.load_pdb(conf)
rms_ref_bb  = rms_ref.atom_slice(bb_indices)
rms_ref_ca  = rms_ref.atom_slice(ca_indices)
rms_tr = md.load_xtc(tr, top=rms_ref)
rmsd = md.rmsd(rms_tr, rms_ref)

In [None]:
'''
s = decoder.predict(sample)

coords_size = feat['coords'].shape[1]  # dimensione delle coordinate
angles_size = s.shape[1] - coords_size  # dimensione degli angoli

# Separa coordinate e angoli
s_coords = s[:, :coords_size]
s_angles = s[:, coords_size:]

# Inverti le trasformazioni separatamente
coords_orig = feat['scaler_coords'].inverse_transform(s_coords)
angles_orig = feat['scaler_angles'].inverse_transform(s_angles)

# Prendi solo le coordinate per la ricostruzione
coords_flat = coords_orig[0, :]                    
coords_p = coords_flat.reshape((n_p, 3))
mask_bb = np.isin(p_indices, bb_indices)
coords_bb = coords_p[mask_bb] 
coords_ca = coords_bb[1::4] 

new_traj = md.Trajectory(
    xyz=np.array([coords_bb]),     
    topology=rms_ref_bb.topology     
)

new_traj.save_pdb("./models/ae_reconstructed.pdb")
'''

In [None]:
'''
view = nv.show_file('./models/ae_reconstructed.pdb')
view.clear_representations()
view.add_line() 
#view.add_cartoon()
view.center()
view
'''

In [None]:
print(f'distances feat: {nD}, angles feat: {nA}')

In [None]:
s = decoder.predict(sample)
orig = sample_x.numpy() 
recon = s[0]

In [None]:
n_phi = 0 if feat['raw_angles']['phi'] is None else feat['raw_angles']['phi'].shape[1]
n_psi = 0 if feat['raw_angles']['psi'] is None else feat['raw_angles']['psi'].shape[1]

fig, metrics = analyze_reconstruction_with_sincos_blocks(
    orig, recon,
    nD=nD, n_phi=n_phi, n_psi=n_psi,
    deltaD=0.1, deltaA=0.5,
    title_prefix="Test "
)

plt.show()
print(metrics)


In [None]:
fig, sec_stats = plot_section_errors(
    orig, recon,
    nD=nD, n_phi=n_phi, n_psi=n_psi,
    n_sections=20,     # divido in 5 blocchi
    kind="distance",  # oppure "angle"
    title="Errors x section"
)

plt.show()
print(sec_stats)

In [None]:
fig, sec_stats = plot_section_errors(
    orig, recon,
    nD=nD, n_phi=n_phi, n_psi=n_psi,
    n_sections=19,     # divido in 5 blocchi
    kind="angle",  
    title="Errors x section"
)

plt.show()
print(sec_stats)