In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import AllChem
from scipy.spatial.transform import Rotation as R
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
import py3Dmol
from joblib import Parallel, delayed
from numba import njit, prange


vdw_radii = {
    "H": 1.20, "He": 1.40, "Li": 1.82, "Be": 1.53, "B": 1.92, "C": 1.70,
    "N": 1.55, "O": 1.52, "F": 1.47, "Ne": 1.54, "Na": 2.27, "Mg": 1.73,
    "Al": 1.84, "Si": 2.10, "P": 1.80, "S": 1.80, "Cl": 1.75, "Ar": 1.88,
    "K": 2.75, "Ca": 2.31, "Sc": 2.11, "Ti": 2.00, "V": 2.00, "Cr": 2.00,
    "Mn": 2.00, "Fe": 2.00, "Co": 2.00, "Ni": 1.63, "Cu": 1.40, "Zn": 1.39,
    "Ga": 1.87, "Ge": 2.11, "As": 1.85, "Se": 1.90, "Br": 1.85, "Kr": 2.02,
    "Rb": 3.03, "Sr": 2.49, "Y": 2.23, "Zr": 2.16, "Nb": 2.08, "Mo": 2.10,
    "Tc": 2.00, "Ru": 2.00, "Rh": 2.00, "Pd": 1.63, "Ag": 1.72, "Cd": 1.58,
    "In": 1.93, "Sn": 2.17, "Sb": 2.06, "Te": 2.06, "I": 1.98, "Xe": 2.16,
    "Cs": 3.43, "Ba": 2.68, "La": 2.39, "Ce": 2.35, "Pr": 2.40, "Nd": 2.39,
    "Pm": 2.38, "Sm": 2.36, "Eu": 2.35, "Gd": 2.32, "Tb": 2.31, "Dy": 2.29,
    "Ho": 2.28, "Er": 2.27, "Tm": 2.25, "Yb": 2.24, "Lu": 2.23, "Hf": 2.17,
    "Ta": 2.17, "W": 2.10, "Re": 2.05, "Os": 2.00, "Ir": 2.00, "Pt": 1.75,
    "Au": 1.66, "Hg": 1.55, "Tl": 1.96, "Pb": 2.02, "Bi": 2.07, "Po": 1.97,
    "At": 2.02, "Rn": 2.20, "Fr": 3.48, "Ra": 2.83, "Ac": 2.60, "Th": 2.40,
    "Pa": 2.44, "U": 2.40
}


def get_atom_coords(mol):
    AllChem.EmbedMolecule(mol)
    AllChem.MMFFOptimizeMolecule(mol)
    conformer = mol.GetConformer()
    coords = np.array([conformer.GetAtomPosition(i) for i in range(mol.GetNumAtoms())])
    return coords


def get_vdw_radii(mol, atom_indices):
    radii = []
    for i in atom_indices:
        atom = mol.GetAtomWithIdx(i)
        atom_type = atom.GetSymbol()
        radii.append(vdw_radii.get(atom_type, 1.70))  
    return np.array(radii)


def dfs_fragment(mol, start_atom, exclude_atoms):
    visited = set()
    stack = [start_atom]
    while stack:
        atom_idx = stack.pop()
        if atom_idx not in visited:
            visited.add(atom_idx)
            atom = mol.GetAtomWithIdx(atom_idx)
            for neighbor in atom.GetNeighbors():
                neighbor_idx = neighbor.GetIdx()
                if neighbor_idx not in visited and neighbor_idx not in exclude_atoms:
                    stack.append(neighbor_idx)
    return visited


def get_rotating_fragments_by_bond(mol, atom_A_idx, atom_B_idx):
    fragment1_indices = dfs_fragment(mol, atom_A_idx, {atom_B_idx})
    fragment2_indices = dfs_fragment(mol, atom_B_idx, {atom_A_idx})
    return list(fragment1_indices), list(fragment2_indices)


def rotate_fragment(fragment_coords, axis, angle):
    rotation = R.from_rotvec(angle * np.array(axis))
    rotated_coords = rotation.apply(fragment_coords)
    return rotated_coords


def find_open_sidechains(mol, atom_A_idx, atom_B_idx):
    atom_A = mol.GetAtomWithIdx(atom_A_idx)
    sidechains = []
    for neighbor in atom_A.GetNeighbors():
        atom_C_idx = neighbor.GetIdx()
        if atom_C_idx == atom_B_idx:
            continue
        atom_C = mol.GetAtomWithIdx(atom_C_idx)
        for atom_D in atom_C.GetNeighbors():
            atom_D_idx = atom_D.GetIdx()
            if atom_D_idx != atom_A_idx and not atom_D.IsInRing():
                sidechain = dfs_fragment(mol, atom_D_idx, {atom_C_idx})
                sidechains.append([atom_C_idx, atom_D_idx] + list(sidechain))
    return sidechains


def rotate_sidechain(sidechain_indices, axis_start_idx, axis_end_idx, coords, angle):
    axis_vector = coords[axis_end_idx] - coords[axis_start_idx]
    axis_vector = axis_vector / np.linalg.norm(axis_vector)
    sidechain_coords = coords[sidechain_indices]
    axis_start_coords = coords[axis_start_idx]
    translated_coords = sidechain_coords - axis_start_coords
    rotation = R.from_rotvec(angle * axis_vector)
    rotated_coords = rotation.apply(translated_coords)
    rotated_coords += axis_start_coords
    coords[sidechain_indices] = rotated_coords
    return coords


@njit(parallel=True)
def monte_carlo_overlap_numba(fragment1_coords, fragment2_coords, fragment1_radii, fragment2_radii, min_coords, max_coords, num_samples=10000):
    overlap_count = 0
    num_fragment1_atoms = fragment1_coords.shape[0]
    num_fragment2_atoms = fragment2_coords.shape[0]
    for i in prange(num_samples):
        point = np.empty(3)
        for j in range(3):
            point[j] = np.random.uniform(min_coords[j], max_coords[j])
        inside_frag1 = False
        for k in range(num_fragment1_atoms):
            dx = fragment1_coords[k, 0] - point[0]
            dy = fragment1_coords[k, 1] - point[1]
            dz = fragment1_coords[k, 2] - point[2]
            dist_sq = dx * dx + dy * dy + dz * dz
            if dist_sq <= fragment1_radii[k] * fragment1_radii[k]:
                inside_frag1 = True
                break
        if not inside_frag1:
            continue
        inside_frag2 = False
        for k in range(num_fragment2_atoms):
            dx = fragment2_coords[k, 0] - point[0]
            dy = fragment2_coords[k, 1] - point[1]
            dz = fragment2_coords[k, 2] - point[2]
            dist_sq = dx * dx + dy * dy + dz * dz
            if dist_sq <= fragment2_radii[k] * fragment2_radii[k]:
                inside_frag2 = True
                break
        if inside_frag2:
            overlap_count += 1
    box_volume = (max_coords[0] - min_coords[0]) * (max_coords[1] - min_coords[1]) * (max_coords[2] - min_coords[2])
    overlap_volume = (overlap_count / num_samples) * box_volume
    return overlap_volume


def monte_carlo_overlap(all_coords, fragment1_indices, fragment2_indices, fragment1_radii, fragment2_radii, num_samples=10000):
    fragment1_coords = all_coords[fragment1_indices]
    fragment2_coords = all_coords[fragment2_indices]
    all_radii = np.concatenate((fragment1_radii, fragment2_radii))
    all_coords_concat = np.concatenate((fragment1_coords, fragment2_coords))
    min_coords = np.min(all_coords_concat - all_radii[:, np.newaxis], axis=0)
    max_coords = np.max(all_coords_concat + all_radii[:, np.newaxis], axis=0)
    overlap_volume = monte_carlo_overlap_numba(fragment1_coords, fragment2_coords, fragment1_radii, fragment2_radii, min_coords, max_coords, num_samples)
    return overlap_volume


def optimize_overlap_sidechains(mol, updated_coords, atom_A_idx, atom_B_idx, fragment1_indices, fragment2_indices, fragment1_radii, fragment2_radii, num_samples, num_iterations, sidechains_A, sidechains_B):
    best_overlap = float('inf')
    best_coords = None
    for _ in range(num_iterations):
        temp_coords = updated_coords.copy()
        for sidechain_A in sidechains_A:
            axis_start_idx_A = sidechain_A[0]
            axis_end_idx_A = sidechain_A[1]
            angle_A = np.random.uniform(0, 2 * np.pi)
            temp_coords = rotate_sidechain(np.array(sidechain_A[2:]), axis_start_idx_A, axis_end_idx_A, temp_coords, angle_A)
        for sidechain_B in sidechains_B:
            axis_start_idx_B = sidechain_B[0]
            axis_end_idx_B = sidechain_B[1]
            angle_B = np.random.uniform(0, 2 * np.pi)
            temp_coords = rotate_sidechain(np.array(sidechain_B[2:]), axis_start_idx_B, axis_end_idx_B, temp_coords, angle_B)
        overlap = monte_carlo_overlap(temp_coords, fragment1_indices, fragment2_indices, fragment1_radii, fragment2_radii, num_samples)
        if overlap < best_overlap:
            best_overlap = overlap
            best_coords = temp_coords.copy()
    return best_overlap, best_coords


def update_molecule_conformer(mol, new_coords):
    conformer = mol.GetConformer()
    for i in range(mol.GetNumAtoms()):
        x, y, z = new_coords[i]
        conformer.SetAtomPosition(i, (x, y, z))


def calculate_overlap_volume(smiles, atom_A_idx, atom_B_idx, num_samples, num_iterations, num_angles):
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    atom_coords = get_atom_coords(mol)
    fragment1_indices, fragment2_indices = get_rotating_fragments_by_bond(mol, atom_A_idx, atom_B_idx)
    fragment1_radii = get_vdw_radii(mol, fragment1_indices)
    fragment2_radii = get_vdw_radii(mol, fragment2_indices)
    axis_vector = atom_coords[atom_B_idx] - atom_coords[atom_A_idx]
    axis_vector = axis_vector / np.linalg.norm(axis_vector)
    sidechains_A = find_open_sidechains(mol, atom_A_idx, atom_B_idx)
    sidechains_B = find_open_sidechains(mol, atom_B_idx, atom_A_idx)
    angles = np.linspace(0, 2.1 * np.pi, num_angles)
    overlap_volumes = []
    for angle in angles:
        updated_coords = atom_coords.copy()
        rotated_fragment2_coords = rotate_fragment(updated_coords[fragment2_indices] - atom_coords[atom_A_idx], axis_vector, angle) + atom_coords[atom_A_idx]
        updated_coords[fragment2_indices] = rotated_fragment2_coords
        update_molecule_conformer(mol, updated_coords)
        overlap_vol, best_coords = optimize_overlap_sidechains(mol, updated_coords, atom_A_idx, atom_B_idx, fragment1_indices, fragment2_indices, fragment1_radii, fragment2_radii, num_samples, num_iterations, sidechains_A, sidechains_B)
        overlap_volumes.append(overlap_vol)
        print(f"Rotation angle: {np.degrees(angle):.2f} degrees, Overlap volume: {overlap_vol:.4f} Å^3")
        if int(angle * 180 / np.pi) % 10 < 2:
            show_molecule(mol)         
    overlap_volumes = np.array(overlap_volumes)
    return angles, overlap_volumes


def moving_average_circular(data, window_size):
    padded_data = np.concatenate((data[-(window_size//2):], data, data[:window_size//2]))
    averaged_data = np.convolve(padded_data, np.ones(window_size) / window_size, mode='same')
    return averaged_data[window_size//2:len(data) + window_size//2]


def find_extrema(angles, overlap_volumes, sigma=2, window_size=3):
    smoothed_volumes_ma = moving_average_circular(overlap_volumes, window_size)
    smoothed_volumes_gaussian = gaussian_filter1d(smoothed_volumes_ma, sigma=sigma)
    peaks, _ = find_peaks(smoothed_volumes_gaussian)
    valleys, _ = find_peaks(-smoothed_volumes_gaussian)
    peak_values = smoothed_volumes_gaussian[peaks]
    valley_values = smoothed_volumes_gaussian[valleys]
    peak_angles = angles[peaks]
    valley_angles = angles[valleys]
    return peak_angles, peak_values, valley_angles, valley_values, smoothed_volumes_gaussian


def save_overlap_volume_results(angles, overlap_volumes, sigma, window_size, root_path="./", csv_filename='overlap_volumes.csv', png_filename='overlap_volume_vs_angle.png'):
    peak_angles, peak_values, valley_angles, valley_values, smoothed_volumes = find_extrema(angles, overlap_volumes, sigma=sigma, window_size=window_size)
    peak_angles_360 = peak_angles * 180 / np.pi
    valley_angles_360 = valley_angles * 180 / np.pi
    angles_360 = angles * 180 / np.pi
    df = pd.DataFrame({
        'Angle (degrees)': angles_360,
        'Original Overlap Volume (Å^3)': overlap_volumes,
        'Smoothed Overlap Volume (Å^3)': smoothed_volumes
    })
    csv_file = os.path.join(root_path, csv_filename)
    df.to_csv(csv_file, index=False)
    plt.figure(figsize=(8, 6))
    plt.plot(angles_360, overlap_volumes, marker='o', label='Original Overlap Volumes')
    plt.plot(angles_360, smoothed_volumes, marker='.', label='Smoothed Overlap Volumes')
    plt.scatter(peak_angles_360, peak_values, color='red', label='Peaks', zorder=5)
    for i, (angle, value) in enumerate(zip(peak_angles_360, peak_values)):
        plt.text(angle + 5, value, f'{value:.2f}\n{angle:.1f}°', color='red', ha='center', va='bottom')
    plt.scatter(valley_angles_360, valley_values, color='blue', label='Valleys', zorder=5)
    for i, (angle, value) in enumerate(zip(valley_angles_360, valley_values)):
        plt.text(angle + 5, value, f'{value:.2f}\n{angle:.1f}°', color='blue', ha='center', va='top')
    plt.title('Overlap Volume vs. Rotation Angle with Peaks and Valleys')
    plt.xlabel('Rotation Angle (degrees)')
    plt.ylabel('Overlap Volume (Å^3)')
    plt.grid(True)
    plt.legend()
    png_file = os.path.join(root_path, png_filename)
    plt.savefig(png_file)
    plt.show()
    return peak_angles_360, peak_values, valley_angles_360, valley_values


def show_molecule(mol, confId=-1):
    mol_block = Chem.MolToMolBlock(mol, confId=confId)
    view = py3Dmol.view(width=400, height=400)
    view.addModel(mol_block, 'mol')
    view.setStyle({'stick': {}})
    view.addSurface(py3Dmol.VDW, {'opacity': 0.8, 'colorscheme': 'default'})
    view.zoomTo()
    view.show()


if __name__ == "__main__":
    smiles = 'OC1=C(/C([C@@H](O)CC=C)=C\C2=CC=CC=C2)C3=CC=CC=C3C=C1'
    atom_A_idx = 2
    atom_B_idx = 3
    num_samples = 200000  
    num_iterations = 2000  
    num_angles = 180  
    window_size = 6
    sigma = 3
    angles, overlap_volumes = calculate_overlap_volume(smiles, atom_A_idx, atom_B_idx, num_samples, num_iterations, num_angles=num_angles)
    peak_angles_360, peak_values, valley_angles_360, valley_values = save_overlap_volume_results(angles, overlap_volumes, sigma, window_size, root_path="./", csv_filename='overlap_volumes.csv', png_filename='overlap_volume_vs_angle_entry201.png')
    print("Peak Angles (degrees):", peak_angles_360)
    print("Peak Values (Å^3):", peak_values)
    print("Valley Angles (degrees):", valley_angles_360)
    print("Valley Values (Å^3):", valley_values)
