# Integrated MD Analysis Workflow

This notebook demonstrates a comprehensive analysis pipeline that integrates:
- Structural analysis (RMSD, RMSF, contacts)
- Correlation and PCA analysis
- Visualization of results

The workflow shows how to perform these analyses cohesively without jumping between different tools.

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

# Import MD-Toolkit modules
from mdtoolkit.core import TrajectoryHandler
from mdtoolkit.structure import RMSDAnalysis, ContactAnalysis, NativeContacts, HydrogenBonds
from mdtoolkit.dynamics import CorrelationAnalysis, PCAAnalysis
from mdtoolkit.visualization import (
    plot_rmsd, plot_rmsf, plot_correlation_matrix, 
    plot_pca, plot_free_energy_landscape, plot_contacts
)

# Set up paths
topology_file = "example_data/protein.pdb"
trajectory_file = "example_data/trajectory.xtc"
output_dir = Path("analysis_results")
output_dir.mkdir(exist_ok=True)

print("MD-Toolkit loaded successfully!")

## 1. Load Trajectory

First, we load the trajectory using the unified TrajectoryHandler.

In [None]:
# Load trajectory
traj = TrajectoryHandler(
    topology=topology_file,
    trajectory=trajectory_file,
    in_memory=False  # Set to True for smaller trajectories
)

print(f"Loaded trajectory with {traj.n_frames} frames and {traj.n_atoms} atoms")
print(f"Simulation time: {traj.time[0]:.1f} to {traj.time[-1]:.1f} ps")

# Align trajectory to first frame
traj.align_trajectory(selection="protein and name CA")
print("Trajectory aligned")

## 2. Structural Analysis

### 2.1 RMSD Analysis

In [None]:
# Define selections for multi-domain RMSD
selections = {
    "backbone": "protein and backbone",
    "ca": "protein and name CA",
    "domain1": "protein and resid 1:100 and name CA",
    "domain2": "protein and resid 101:200 and name CA"
}

# Run RMSD analysis
rmsd_analysis = RMSDAnalysis(
    trajectory=traj,
    align_selection="protein and name CA",
    analysis_selections=selections
)

rmsd_results = rmsd_analysis.run()

# Plot RMSD
rmsd_data = {name: results['rmsd'] for name, results in rmsd_results.items()}
fig, ax = plot_rmsd(
    time=rmsd_results['backbone']['time'] / 1000,  # Convert to ns
    rmsd=rmsd_data,
    title="RMSD Analysis by Domain",
    save_path=output_dir / "rmsd_analysis.png"
)
plt.show()

# Calculate convergence
convergence = rmsd_analysis.calculate_convergence(selection_name="ca", window_size=100)
print(f"Final RMSD convergence index: {convergence['convergence_index'][-1]:.3f}")

### 2.2 Contact Analysis

In [None]:
# Native contacts analysis
native_contacts = NativeContacts(
    trajectory=traj,
    selection="protein and name CA",
    radius=8.0
)

nc_results = native_contacts.run()

# Hydrogen bonds analysis
hbonds = HydrogenBonds(
    trajectory=traj,
    distance_cutoff=3.5,
    angle_cutoff=150.0
)

hb_results = hbonds.run()

# Plot contact analysis
fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Native contacts
axes[0].plot(nc_results['time'] / 1000, nc_results['q'], linewidth=2)
axes[0].set_ylabel('Q (Fraction of Native Contacts)')
axes[0].set_title('Native Contact Analysis')
axes[0].grid(True, alpha=0.3)

# Hydrogen bonds
axes[1].plot(hb_results['time'] / 1000, hb_results['n_hbonds'], linewidth=2, color='green')
axes[1].set_xlabel('Time (ns)')
axes[1].set_ylabel('Number of H-bonds')
axes[1].set_title('Hydrogen Bond Analysis')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / "contact_analysis.png", dpi=300)
plt.show()

# Get persistent hydrogen bonds
persistent_hbonds = hbonds.get_persistent_hbonds(persistence_cutoff=0.5)
print(f"Found {len(persistent_hbonds)} persistent hydrogen bonds (>50% of frames)")

## 3. Correlation and PCA Analysis

### 3.1 Correlation Analysis

In [None]:
# Perform correlation analysis
corr_analysis = CorrelationAnalysis(
    trajectory=traj,
    selection="protein and name CA",
    align=True
)

# Extract positions and calculate correlation
positions = corr_analysis.extract_positions()
corr_matrix = corr_analysis.calculate_correlation_matrix(method="pearson")

# Calculate residue-level correlation
res_corr = corr_analysis.calculate_residue_correlation()

# Plot correlation matrices
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Atom-level correlation
im1 = axes[0].imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
axes[0].set_title('Atom-level Correlation Matrix')
axes[0].set_xlabel('Atom Index')
axes[0].set_ylabel('Atom Index')
plt.colorbar(im1, ax=axes[0])

# Residue-level correlation
im2 = axes[1].imshow(res_corr, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
axes[1].set_title('Residue-level Correlation Matrix')
axes[1].set_xlabel('Residue Index')
axes[1].set_ylabel('Residue Index')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.savefig(output_dir / "correlation_analysis.png", dpi=300)
plt.show()

print(f"Correlation analysis complete: {corr_matrix.shape[0]} atoms analyzed")

### 3.2 PCA Analysis with Validation

In [None]:
# Perform PCA analysis
pca_analysis = PCAAnalysis(
    trajectory=traj,
    selection="protein and name CA",
    align=True
)

# Run MDAnalysis PCA
mda_pca = pca_analysis.run_mda_pca()
print(f"MDAnalysis PCA: {mda_pca['n_components']} components")

# Run sklearn PCA for validation
sklearn_pca = pca_analysis.run_sklearn_pca(n_components=10)
print(f"sklearn PCA: {sklearn_pca['n_components']} components")

# Validate PCA results
validation = pca_analysis.validate_pca()
print(f"\nPCA Validation:")
print(f"  Variance match: {validation['variance_match']}")
print(f"  Max variance difference: {validation['max_variance_diff']:.4f}")
print(f"  Projection correlations (first 3 PCs): {validation['projection_correlations']}")

# Check for cosine content (random diffusion)
cosine_content = pca_analysis.calculate_cosine_content(n_components=3)
print(f"\nCosine content (first 3 PCs): {cosine_content}")
if any(cc > 0.7 for cc in cosine_content):
    print("Warning: High cosine content detected - possible random diffusion")

### 3.3 PCA Visualization

In [None]:
# Plot variance explained
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Scree plot
n_show = 20
axes[0].bar(range(1, n_show+1), mda_pca['variance'][:n_show], color='steelblue')
axes[0].set_xlabel('Principal Component')
axes[0].set_ylabel('Variance Explained')
axes[0].set_title('Scree Plot')
axes[0].grid(True, alpha=0.3)

# Cumulative variance
axes[1].plot(range(1, len(mda_pca['cumulated_variance'])+1), 
            mda_pca['cumulated_variance'], 
            marker='o', linewidth=2)
axes[1].axhline(y=0.9, color='r', linestyle='--', label='90% variance')
axes[1].set_xlabel('Number of Components')
axes[1].set_ylabel('Cumulative Variance Explained')
axes[1].set_title('Cumulative Variance Explained')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / "pca_variance.png", dpi=300)
plt.show()

# Plot PCA projections
projections = mda_pca['transformed']

# 2D projection plot
fig, ax = plot_pca(
    projections=projections,
    pc_x=0,
    pc_y=1,
    color_by=np.arange(len(projections)),
    title="PCA Projection (PC1 vs PC2)",
    save_path=output_dir / "pca_projection.png"
)
plt.show()

# Free energy landscape
fig, ax = plot_free_energy_landscape(
    x=projections[:, 0],
    y=projections[:, 1],
    bins=50,
    temperature=300.0,
    title="Free Energy Landscape",
    xlabel="PC1",
    ylabel="PC2",
    save_path=output_dir / "free_energy_landscape.png"
)
plt.show()

## 4. Integrated Analysis Summary

Combine all analyses to understand the system dynamics.

In [None]:
# Create summary DataFrame
summary_data = {
    'Time (ns)': rmsd_results['ca']['time'] / 1000,
    'RMSD (Å)': rmsd_results['ca']['rmsd'],
    'Native Contacts (Q)': nc_results['q'],
    'H-bonds': hb_results['n_hbonds'],
    'PC1': projections[:, 0],
    'PC2': projections[:, 1]
}

df_summary = pd.DataFrame(summary_data)

# Calculate correlations between metrics
metric_corr = df_summary.corr()
print("\nCorrelations between analysis metrics:")
print(metric_corr.round(3))

# Plot integrated analysis
fig, axes = plt.subplots(3, 2, figsize=(14, 10))

# RMSD vs Native Contacts
axes[0, 0].scatter(df_summary['RMSD (Å)'], df_summary['Native Contacts (Q)'], 
                  c=df_summary['Time (ns)'], cmap='viridis', alpha=0.6)
axes[0, 0].set_xlabel('RMSD (Å)')
axes[0, 0].set_ylabel('Native Contacts (Q)')
axes[0, 0].set_title(f"Correlation: {metric_corr.loc['RMSD (Å)', 'Native Contacts (Q)']:.3f}")

# RMSD vs PC1
axes[0, 1].scatter(df_summary['RMSD (Å)'], df_summary['PC1'], 
                  c=df_summary['Time (ns)'], cmap='viridis', alpha=0.6)
axes[0, 1].set_xlabel('RMSD (Å)')
axes[0, 1].set_ylabel('PC1')
axes[0, 1].set_title(f"Correlation: {metric_corr.loc['RMSD (Å)', 'PC1']:.3f}")

# Time series of all metrics
axes[1, 0].plot(df_summary['Time (ns)'], df_summary['RMSD (Å)'], label='RMSD')
axes[1, 0].set_ylabel('RMSD (Å)')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(df_summary['Time (ns)'], df_summary['Native Contacts (Q)'], 
               label='Q', color='orange')
axes[1, 1].set_ylabel('Native Contacts (Q)')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

axes[2, 0].plot(df_summary['Time (ns)'], df_summary['PC1'], label='PC1', color='green')
axes[2, 0].set_xlabel('Time (ns)')
axes[2, 0].set_ylabel('PC1')
axes[2, 0].legend()
axes[2, 0].grid(True, alpha=0.3)

axes[2, 1].plot(df_summary['Time (ns)'], df_summary['H-bonds'], label='H-bonds', color='red')
axes[2, 1].set_xlabel('Time (ns)')
axes[2, 1].set_ylabel('Number of H-bonds')
axes[2, 1].legend()
axes[2, 1].grid(True, alpha=0.3)

plt.suptitle('Integrated Analysis Summary', fontsize=16)
plt.tight_layout()
plt.savefig(output_dir / "integrated_analysis.png", dpi=300)
plt.show()

# Save summary data
df_summary.to_csv(output_dir / "analysis_summary.csv", index=False)
print(f"\nAnalysis complete! Results saved to {output_dir}")

## 5. Export Results

Save all analysis results for further use.

In [None]:
# Save RMSD results
rmsd_analysis.save_results(output_dir / "rmsd", format="csv")

# Save correlation results
corr_analysis.save_results(output_dir / "correlation", save_positions=False)

# Save PCA results
pca_analysis.save_results(output_dir / "pca", save_projections=True)

print("All results exported successfully!")
print(f"\nOutput directory contents:")
for file in output_dir.glob("*"):
    print(f"  - {file.name}")