# Preprocessing

## Featurizing

In [None]:
import mdtraj as md
import numpy as np
import datetime
import tensorflow as tf
import nglview as nv

%cd /home/jovyan/ASMSA/mydev

import os, sys

repo_dir = os.getcwd()   
if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)

from utils import split_dataset, plot_latent_space, process_trajectory
from asmsa_callbacks import callbacks
from vae import asmsa_beta_vae

nn_model = 'vae'
latent_dim = 2

In [None]:
tr = "trpcage_ds_nH.xtc"
conf = "trpcage_npt400_nH.pdb"

traj = md.load_xtc(tr, top=conf)
backbone_atoms = traj.topology.select('backbone')
traj.superpose(traj, 0, atom_indices=backbone_atoms)

In [None]:
view = nv.show_mdtraj(traj)

view.add_representation('line', selection='protein')
view

In [None]:
ca_indices, n_ca, bb_indices, n_bb, features_normalized, scaler, coords = process_trajectory(tr, conf)
features_normalized.shape

## NN preprocessing

In [None]:
# Uso:
ds_train, ds_val, ds_test, ds_all = split_dataset(features_normalized, train_size=70, val_size=15, batch_size=64, seed=42)

# VAE

In [None]:
'''
Batch Norm, nel caso, va prima della layer activation)
'''

In [None]:
log_dir = "logs/autoencoder/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

cb = callbacks(log_dir, latent_dim, monitor="val_loss", model=nn_model)

beta_vae, encoder, decoder = asmsa_beta_vae(
    n_features=features_normalized.shape[1], 
    latent_dim=latent_dim,
    beta=0.0001,
    )

In [None]:
beta_vae.fit(ds_train, epochs=500, validation_data=ds_val, callbacks=cb)

tensorboard --logdir logs/autoencoder --host localhost --port 6006

# Decode and visualize

In [None]:
encoder.save('encoder_vae.keras')
decoder.save('decoder_vae.keras')


In [None]:
for batch_x, _ in ds_test.take(1):
    sample_x = batch_x[30]  
    break
    
sample_x_batch = tf.expand_dims(sample_x, axis=0)
z_mean, z_log_var, test = encoder.predict(sample_x_batch)

z_mean

In [None]:
target = np.array([0,-2]).reshape(1, latent_dim)
emb, sample = plot_latent_space(latent_dim, encoder, ds_all, conf, tr, test, bb_indices, model=nn_model, exact=False)

In [None]:
rms_ref = md.load_pdb(conf)
rms_ref_bb  = rms_ref.atom_slice(bb_indices)
rms_ref_ca  = rms_ref.atom_slice(ca_indices)
rms_tr = md.load_xtc(tr, top=rms_ref)
rmsd = md.rmsd(rms_tr, rms_ref)

In [None]:
p_indices = traj.topology.select("protein")
n_p = len(p_indices)

In [None]:
s = decoder.predict(sample)
s_orig = scaler.inverse_transform(s)

coords_flat = s_orig[0, :coords.shape[1]]                    
coords_p = coords_flat.reshape((n_p, 3))
mask_bb = np.isin(p_indices, bb_indices)
coords_bb = coords_p[mask_bb] 
coords_ca = coords_bb[1::4] 

new_traj = md.Trajectory(
    xyz=np.array([coords_bb]),     
    topology=rms_ref_bb.topology     
)


new_traj.save_pdb("vae_reconstructed.pdb")

In [None]:
view = nv.show_file('vae_reconstructed.pdb')
view.clear_representations()
view.add_line() 
#view.add_cartoon()
view.center()
view

In [None]:
mse = np.mean((s[0] - sample_x.numpy())**2)
print("MSE ricostruzione:", mse)

In [None]:
from asmsanalysis import analyze_reconstruction, plot_section_errors
import matplotlib.pyplot as plt
# Utilizzo con i tuoi dati
# Assumendo che sample_x e s[0] siano definiti
orig = sample_x.numpy() 
recon = s[0]

# Analisi completa
fig, metrics = analyze_reconstruction(orig, recon, title_prefix="Beta-VAE ")

# Solo errori per sezione
fig_sections, section_stats = plot_section_errors(orig, recon, n_sections=25, 
                                                   title="Analisi Errori per Sezione")

plt.show()

# Stampa metriche
print("\n=== METRICHE RICOSTRUZIONE ===")
for metric, value in metrics.items():
    (f"{metric.upper()}: {value:.6f}")

In [None]:
import nglview as nv

# Crea un widget vuoto
view = nv.NGLWidget()

# --- primo modello -------------------------------------------------
comp1 = view.add_component("vae_reconstructed.pdb")   # oppure percorso al tuo 1° PDB
comp1.clear_representations()                         # opzionale
comp1.add_representation("line", color="skyblue")  # o 'line', 'surface', ecc.

# --- secondo modello ----------------------------------------------
comp2 = view.add_component("ae_reconstructed.pdb")       # secondo PDB da sovrapporre
comp2.clear_representations()
comp2.add_representation("line", color="orange")   # scegli un colore diverso

# Centra e mostra
view.center()
view
