In [None]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from typing import List
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
from scipy.stats import spearmanr
from scipy.stats import chisquare

In [None]:
def dihedral_term(phi, n, phase, k=1):
    """
    Dihedral term in the dihedral potential.

    Parameters
    ----------
    phi : float
        Dihedral angle
    n : int
        Periodicity of the dihedral potential
    phase : float
        Phase of the dihedral potential
    k : float
        Force constant of the dihedral potential.
    
    Returns
    -------
    float
        Value of the dihedral term at phi
    """
    return k*(1+np.cos(n*phi - phase))

In [None]:
def V(phi: float, period: List[int], phases: List[float], force_const: List[float] = None ) -> float:
    """
    Proper dihedral potential.

    Parameters
    ----------
    phi : float
        Dihedral angle
    period : List[int]
        Period of the dihedral potential
    phase : List[float]
        Phase of the dihedral potential
    force_const : List[float]
        Force constants of the dihedral potential. If None, all force constants are set to 1.
    
    Returns
    -------
    float
        Value of the dihedral potential at phi
    """
    if force_const is None:
        force_const = [1 for _ in period]

    assert len(period) == len(phases), "n and phase must have the same length"
    assert len(period) == len(force_const), "n and force_const must have the same length"

    dihedral_pot = 0
    for n, phase, k in zip(period, phases, force_const):
        dihedral_pot += dihedral_term(phi, n, phase, k)

    return dihedral_pot

In [None]:
dihedral = "1234"
seed_no = 4
fragment = 5

In [None]:
import numpy as np
from itertools import product

phases = [0, np.pi]

# Initialize lists to store all combinations
all_combinations = []
force_constants = [1]

# Iterate over the number of periods and phases
for num_periods, num_phases in product(range(1, 4), repeat=2):  # Start from 1 and extend to ...
    if num_periods >= num_phases:  # Ensure num_periods >= num_phases
        # Generate all combinations of phases for the current number of phases
        phase_combinations = product(phases, repeat=num_phases)
        for phase_combo in phase_combinations:
            # Ensure the length of the first list is equal to the length of the tuple and the last list
            if len(phase_combo) == num_periods and len(force_constants) == num_periods:
                all_combinations.append((list(range(1, num_periods + 1)), phase_combo, force_constants[:num_periods]))

        # Update force_constants when the number of periods changes
        if num_periods > len(force_constants):
            force_constants.append(1)


In [None]:
forward = np.arange(-180,190,10).tolist()
forward.remove(-180) # -180 never gets scanned

mse_list = []
k_list = []

phi = []
for i in forward:
    phi.append(i*np.pi/180.0)

for combination in all_combinations:
    period =  combination[0]
    phases = combination[1]
    force_const = combination[2]

    # this is the QM data minus the array saved in Sire energy decomposition (QM - (MM_total - MM_torsion))
    Vref = np.load('./profiles_torsions/individual_conformer_scans/fragment%s/torsion%s/qm-mm_torsion%s_seed%s.npy' % (fragment, dihedral, dihedral, seed_no))
    
    # Fit the dihedral potential to the reference potential
    b = np.asarray(Vref).reshape(-1, 1)
    a = np.zeros((len(phi), len(period)))
    for i, p in enumerate(phi):
        for j, n in enumerate(period):
            a[i, j] = dihedral_term(p, n, phases[j], 1) # force constant is set to 1 because we are fitting it

    # Solve the linear system
    k = np.linalg.inv(a.T @ a) @ a.T @ b
    k_list.append(k)

    mse = mean_squared_error(Vref, a@k)
    mse_list.append(mse)
    
    # Plot the dihedral potential (all possible solutions for given number of periods)
    plt.plot(phi, Vref, 'o', label="Residual potential")
    plt.plot(phi, a @ k, label="Fitted potential")
    plt.title('Periodicity %s - Phases %s - MSE %s' % (combination[0][-1], list(combination[1]), round(mse,3)))
    plt.xlabel("Dihedral angle (rad)")
    plt.ylabel("Dihedral potential (kcal/mol)")
    plt.legend()
    plt.show()

In [None]:
min(mse_list)

In [None]:
min_index = mse_list.index(min(mse_list))

In [None]:
all_combinations[min_index][:-1]

In [None]:
k_list[min_index]

In [None]:
desired_index = mse_list.index(min(mse_list))
desired_index # or some other index as the min MSE doesn't always give the best solution

In [None]:
all_combinations[desired_index][:-1]
k_list[desired_index]

In [None]:
# prepare periods, phases, k for plotting

fitted_periods = all_combinations[desired_index][:-1][0]
fitted_phases = all_combinations[desired_index][:-1][1]
print(fitted_periods, fitted_phases)

In [None]:
fitted_k = []

for i in k_list[desired_index]:
    fitted_k.append(i.tolist()[0])

fitted_k

In [None]:
# if one of the k's is negative we can invert its phase to keep the same shape
# e.g. if the k for the 1st period is -0.111 and has phase 0, we can use a k of 0.111 with a phase of np.pi

In [None]:

forward = np.arange(-180,190,10).tolist()
forward.remove(-180)

phi = []
for i in forward:
    phi.append(i*np.pi/180.0)

period = fitted_periods
phases = fitted_phases
force_const = fitted_k

Vfit = [V(p, period, phases, force_const) for p in phi]

Vref = np.load('./profiles_torsions/individual_conformer_scans/fragment%s/torsion%s/qm-mm_torsion%s_seed%s.npy' % (fragment, dihedral, dihedral, seed_no))
# Plot the dihedral potential
plt.plot(phi, Vref, 'o', label="Residual potential")
plt.plot(phi, Vfit - min(Vfit), label="Fitted potential")
plt.xlabel("Dihedral angle (rad)")
plt.ylabel("Dihedral potential (kcal/mol)")
plt.legend()
plt.savefig('./profiles_torsions/individual_conformer_scans/fragment%s/torsion%s/qm-mm_torsion%s_seed%s_fitting.png' % (fragment, dihedral, dihedral, seed_no))
plt.show()


In [None]:
# once the k's are found, the respective torsion in the frcmod will need to be updated (periods, k, phases)
# if the torsion being fitted has some multiplicity other than 1 (the number in the first column of the frcmod e.g. 4, 6, 9), 
# the k's from the fitting need to be multipled with that number in the frcmod file