# Preprocessing

## Featurizing

In [None]:
import mdtraj as md
import numpy as np
import datetime
import tensorflow as tf
import nglview as nv
import os, sys


from sklearn.preprocessing import StandardScaler, MinMaxScaler
from tensorflow.keras.models import load_model

%cd /home/jovyan/ASMSA/mydev


repo_dir = os.getcwd()   
if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)

from src.utils import split_dataset, plot_latent_space, process_trajectory
from src.asmsa_callbacks import callbacks
from src.ae import asmsa_ae

nn_model = 'ae'
latent_dim = 2

In [None]:
tr = "trpcage_ds_nH.xtc"
conf = "trpcage_npt400_nH.pdb"

traj = md.load_xtc(tr, top=conf)
backbone_atoms = traj.topology.select('backbone')
traj.superpose(traj, 0, atom_indices=backbone_atoms)


In [None]:
'''
import mdtraj as md
import numpy as np
from sklearn.utils import resample

# 1. Carica topologia e traiettoria
traj = md.load_xtc('trpcage_ds_nH.xtc', top='trpcage_npt400_nH.pdb')

# 2. Definisci la struttura di riferimento folded (frame 0)
ref = md.load_pdb('trpcage_npt400_nH.pdb')
atom_sel = traj.topology.select('protein')

# 3. Calcola l’RMSD di ogni frame rispetto al reference (in nm)
rmsd = md.rmsd(traj, ref, frame=0, atom_indices=atom_sel)

# 4. Imposta le soglie (in nm)
t1 = 0.5   # fino a 0.5 nm → folded
t2 = 1.0   # da 0.5 a 1.0 nm → semi-unfolded; >1.0 nm → unfolded

# 5. Crea le etichette (0=folded, 1=semi-unfolded, 2=unfolded)
labels = np.empty_like(rmsd, dtype=int)
labels[rmsd <= t1]                  = 0
labels[(rmsd >  t1) & (rmsd <= t2)] = 1
labels[rmsd >  t2]                  = 2

# 6. Raggruppa indici per ciascuna classe
idx0 = np.where(labels == 0)[0]  # folded
idx1 = np.where(labels == 1)[0]  # semi-unfolded
idx2 = np.where(labels == 2)[0]  # unfolded

print(f"Counts before balancing: folded={len(idx0)}, semi-unfolded={len(idx1)}, unfolded={len(idx2)}")

# 7. Bilancia con undersampling alla classe meno numerosa
n_target = min(len(idx0), len(idx1), len(idx2))
idx0_bal = resample(idx0, replace=False, n_samples=n_target, random_state=42)
idx1_bal = resample(idx1, replace=False, n_samples=n_target, random_state=42)
idx2_bal = resample(idx2, replace=False, n_samples=n_target, random_state=42)

# 8. Combina, ordina e crea la traiettoria bilanciata
idx_balanced = np.sort(np.concatenate([idx0_bal, idx1_bal, idx2_bal]))
balanced = traj.slice(idx_balanced)

# 9. Salva il nuovo XTC bilanciato
balanced.save_xtc('traj_balanced_3class_0.5_1.0nm.xtc')

print(f"Balanced dataset: {n_target} frames per classe, totale {len(idx_balanced)} frames.")

tr = "traj_balanced_3class_0.5_1.0nm.xtc"
conf = "trpcage_npt400_nH.pdb"

traj = md.load_xtc(tr, top=conf)
backbone_atoms = traj.topology.select('backbone')
traj.superpose(traj, 0, atom_indices=backbone_atoms)
'''

In [None]:
view = nv.show_mdtraj(traj)

view.add_representation('line', selection='protein')
view

In [None]:
ca_indices, n_ca, bb_indices, n_bb, features_normalized, scaler, coords = process_trajectory(tr, conf)
features_normalized.shape

## NN preprocessing

In [None]:
# Uso:
ds_train, ds_val, ds_test, ds_all = split_dataset(features_normalized, train_size=70, val_size=15, batch_size=64, seed=42)

# AE

In [None]:
'''
Batch Norm, nel caso, va prima della layer activation)
'''

In [None]:
autoencoder, encoder, decoder = asmsa_ae(
    n_features=features_normalized.shape[1],
    latent_dim=latent_dim)

autoencoder.summary()


In [None]:
log_dir = "logs/autoencoder/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
cb = callbacks(log_dir, latent_dim, monitor="val_loss", model=nn_model)  # Uncomment when callbacks is defined

learning_rate = 1e-4
optimizer = tf.keras.optimizers.AdamW(
    learning_rate=learning_rate,
    weight_decay=1e-5, 
    beta_1=0.9,
    beta_2=0.999
)

# Assuming asmsa_ae function exists
ae, encoder, decoder = asmsa_ae(
    n_features=features_normalized.shape[1],
    latent_dim=latent_dim
)

mse_fn = tf.keras.losses.MeanSquaredError()
mae_fn = tf.keras.losses.MeanAbsoluteError()

# Definisco la loss ricostruzione pesata
def recon_loss(y_true, y_pred):
    mse = mse_fn(y_true, y_pred)
    mae = mae_fn(y_true, y_pred)
    return 0.8 * mse + 0.2 * mae

# Compilo l'autoencoder con la loss personalizzata
ae.compile(optimizer=optimizer,
           loss=recon_loss)


tensorboard --logdir logs/autoencoder --host localhost --port 6006

In [None]:
ae.fit(ds_train,epochs=200,validation_data=ds_val,callbacks=cb, shuffle=True)

In [None]:
src = f"ae_{latent_dim}d.keras"
dest = "/home/tedeschg/prj/ASMSA/mydev/models/"

!mv {src} {dest}

# Decode and visualize

In [None]:
path = f"/home/tedeschg/prj/ASMSA/mydev/models/ae_{latent_dim}d.keras"

autoencoder = load_model(
    path,
    custom_objects={"recon_loss": recon_loss}
)

encoder = autoencoder.get_layer("encoder")
decoder = autoencoder.get_layer("decoder")


In [None]:
for batch_x, _ in ds_test.take(1):
    sample_x = batch_x[30]  
    break
    
sample_x_batch = tf.expand_dims(sample_x, axis=0)
test = encoder.predict(sample_x_batch)

test

In [None]:
target = np.array([7.5, -10]).reshape(1, latent_dim)
emb, sample = plot_latent_space(latent_dim, encoder, ds_all, conf, tr, test, bb_indices, model=nn_model, exact=True)

In [None]:
rms_ref = md.load_pdb(conf)
rms_ref_bb  = rms_ref.atom_slice(bb_indices)
rms_ref_ca  = rms_ref.atom_slice(ca_indices)
rms_tr = md.load_xtc(tr, top=rms_ref)
rmsd = md.rmsd(rms_tr, rms_ref)

In [None]:
p_indices = traj.topology.select("protein")
n_p = len(p_indices)

In [None]:
s = decoder.predict(sample)
s_orig = scaler.inverse_transform(s)

coords_flat = s_orig[0, :coords.shape[1]]                    
coords_p = coords_flat.reshape((n_p, 3))
mask_bb = np.isin(p_indices, bb_indices)
coords_bb = coords_p[mask_bb] 
coords_ca = coords_bb[1::4] 

new_traj = md.Trajectory(
    xyz=np.array([coords_bb]),     
    topology=rms_ref_bb.topology     
)


new_traj.save_pdb("./models/ae_reconstructed.pdb")


In [None]:
view = nv.show_file('./models/ae_reconstructed.pdb')
view.clear_representations()
view.add_line() 
#view.add_cartoon()
view.center()
view

In [None]:
mse = np.mean((s[0] - sample_x.numpy())**2)
print("MSE ricostruzione:", mse)

In [None]:
from src.asmsanalysis import analyze_reconstruction, plot_section_errors
import matplotlib.pyplot as plt
# Utilizzo con i tuoi dati
# Assumendo che sample_x e s[0] siano definiti
orig = sample_x.numpy() 
recon = s[0]

# Analisi completa
fig, metrics = analyze_reconstruction(orig, recon, title_prefix="Autoencoder ")

# Solo errori per sezione
fig_sections, section_stats = plot_section_errors(orig, recon, n_sections=25, 
                                                   title="Analisi Errori per Sezione")

plt.show()

# Stampa metriche
print("\n=== METRICHE RICOSTRUZIONE ===")
for metric, value in metrics.items():
    (f"{metric.upper()}: {value:.6f}")