In [None]:
import plumed
from matplotlib import pyplot as plt
import numpy as np
import MDAnalysis as md
from MDAnalysis.analysis import distances
import pandas as pd
import itertools
import random
import matplotlib
from deeptime.decomposition import TICA
from deeptime.covariance import KoopmanWeightingEstimator
from deeptime.clustering import MiniBatchKMeans
from deeptime.markov import TransitionCountEstimator
from deeptime.markov.msm import MaximumLikelihoodMSM
from deeptime.plots import plot_implied_timescales
from deeptime.util.validation import implied_timescales
import networkx as nx
from copy import deepcopy
from numpy.random import multinomial
import subprocess
import os
import math
from scipy.stats import pearsonr

import warnings
warnings.filterwarnings('ignore')
###############################USER DEFINE REGION##################################

### Environment: change this according to your own environment, make sure you can run gmx_mpi/gmx/plumed inside this notebook
os.environ['PATH'] = '/usr/local/Climber:/usr/local/pymol:/usr/local/gromacs/bin:/usr/local/plumed/bin:/usr/local/openmpi/bin:/usr/local/cuda-12.2/bin:/usr/local/clash:/home/mingyuan/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/snap/bin'
os.environ['LD_LIBRARY_PATH'] = '/usr/local/gromacs/lib:/usr/local/plumed/lib:/usr/local/libtorch/lib:/usr/local/openmpi/lib:/usr/local/cuda-12.2/lib64:'

### Hyperparameters
# Run setup
sim_name = 'ala2'
colvar = 'CV/COLVAR'
topol = 'traj_and_dat/input.pdb'
seed_ref = 'traj_and_dat/seed_ref.pdb'
n_sim = 16
n_steps = 20000
# Progress control
n_rounds = 20
sim_idx = 0
# hardware-related
ntomp = 1
gpu_id = '0,1'
n_jobs = 32
# TICA parameters (for adaptive sampling)
tica_lagtime = 20
dim = None
var_cutoff = 0.95
koopman = True
# Markov State Model parameters
msm_lagtime = 20
# PCCA parameters
n_metastable_sets = 30
# CV machine learning & convergence check
convergence_check = True
num_cvs = 2
patience = 2
convergence_criteria = [0.99,0.95]
# OPES parameters
barrier = 20
opes_steps = 500000

def initialize(sim_name,n_sim):
    with open('initialize.sh','w+') as f:
        f.writelines('gmx grompp -f ../mdp/md.mdp -c ../4-npt/npt.gro -p ../1-topol/topol.top -o {sim_name}.tpr\n'.format(sim_name=sim_name))
        f.writelines('for i in `seq 0 {n_sim_0}`\n'.format(n_sim_0=n_sim-1))
        f.writelines('do\n')
        f.writelines('    mkdir $i\n')
        f.writelines('    cp {sim_name}.tpr $i\n'.format(sim_name=sim_name))
        f.writelines('done\n')
        f.writelines('mkdir CV traj_and_dat figures opes\n')
        f.writelines('mv {sim_name}.tpr opes/\n'.format(sim_name=sim_name))
    subprocess.run(['chmod u+x initialize.sh'], shell=True)
    subprocess.run(['./initialize.sh'], shell=True)
    return None

def gmx_mpirun(sim_name,sim_idx,n_sim,n_steps,ntomp,gpu_id):
    multidir = ''
    idx_start = sim_idx * n_sim
    idx_end = (sim_idx + 1) * n_sim - 1
    for i in range(idx_start,idx_end+1):
        multidir = multidir + str(i) + ' '
    multidir = multidir[:-1]
    with open('mpirun.sh','w+') as f:
        f.writelines('mpirun -np {n_sim} -cpus-per-rank {ntomp} gmx_mpi mdrun -v -deffnm {sim_name} -multidir {multidir} -pme gpu -nb gpu -bonded gpu -nsteps {n_steps} -ntomp {ntomp} -gpu_id {gpu_id} -update gpu'.format(n_sim=n_sim,ntomp=ntomp,sim_name=sim_name,multidir=multidir,n_steps=n_steps,gpu_id=gpu_id))
    subprocess.run(['chmod u+x mpirun.sh'],shell=True)
    subprocess.run(['./mpirun.sh'],shell=True)
    return None
    
# Feature functions
def select_dihedrals(universe,dihedral_type,start_res,end_res):
    dihedrals={}
    if 'phi' in dihedral_type:
        dihedrals['phi']=[]
        for i in range(start_res,end_res+1):
            if i != universe.residues.resids[0]:         # last residue does not have phi (if not capped)
                dihedrals['phi'].append(i)

    if 'psi' in dihedral_type:
        dihedrals['psi']=[]
        for i in range(start_res,end_res+1):
            if i != universe.residues.resids[-1]:          # first residue does not have psi (if not capped)
                dihedrals['psi'].append(i)

    if 'omega' in dihedral_type:
        dihedrals['omega']=[]
        for i in range(start_res,end_res+1):
            if i != universe.residues.resids[0]:          # first residue does not have omega (if not capped)
                dihedrals['omega'].append(i)

    if 'chi1' in dihedral_type:
        dihedrals['chi1']=[]
        for i in range(start_res,end_res+1):
            count = i - universe.residues.resids[0]
            if universe.residues.resnames[count] not in ['GLY','ALA']:
                dihedrals['chi1'].append(i)

    if 'chi2' in dihedral_type:
        dihedrals['chi2']=[]
        for i in range(start_res,end_res+1):
            count = i - universe.residues.resids[0]
            if universe.residues.resnames[count] not in ['GLY','ALA','CYS','SER','THR','VAL']:
                dihedrals['chi2'].append(i)

    if 'chi3' in dihedral_type:
        dihedrals['chi3']=[]
        for i in range(start_res,end_res+1):
            count = i - universe.residues.resids[0]
            if universe.residues.resnames[count] in ['ARG','GLN','GLU','LYS','MET']:
                dihedrals['chi3'].append(i)

    if 'chi4' in dihedral_type:
        dihedrals['chi4']=[]
        for i in range(start_res,end_res+1):
            count = i - universe.residues.resids[0]
            if universe.residues.resnames[count] in ['ARG','LYS']:
                dihedrals['chi4'].append(i)
    return dihedrals
    
def write_features(colvar,sim_idx,n_sim,sim_name):
    idx_start = sim_idx * n_sim
    idx_end = (sim_idx+1) * n_sim - 1
    with open('features.sh','w+') as f:
        f.writelines('start={idx_start}\n'.format(idx_start=idx_start))
        f.writelines('end={idx_end}\n'.format(idx_end=idx_end))
        f.writelines('for i in `seq $start $end`\n')
        f.writelines('do\n')
        f.writelines('    cp $i/{sim_name}.xtc traj_and_dat/\n'.format(sim_name=sim_name))
        f.writelines('    mv traj_and_dat/{sim_name}.xtc traj_and_dat/$i.xtc\n'.format(sim_name=sim_name))
        f.writelines('    plumed driver --plumed traj_and_dat/features.dat --mf_xtc traj_and_dat/$i.xtc\n')
        f.writelines('    plumed driver --plumed traj_and_dat/ref.dat --mf_xtc traj_and_dat/$i.xtc\n')
        f.writelines('    mv COLVAR CV/COLVAR_$i\n')
        f.writelines('    mv COLVAR_ref CV/COLVAR_ref_$i\n')
        f.writelines('done\n')
    subprocess.run(['chmod u+x features.sh'], shell=True)
    subprocess.run(['./features.sh'], shell=True)
    return None

### Analysis functions
def read_features(colvar,sim_idx,n_sim):
    # traj is the time-series COLVAR in pandas.DataFrame format
    traj_idx = []
    for i in range(sim_idx*n_sim):
        traj_idx.append(i)
    
    no_traj = len(traj_idx)
    traj = [0]*no_traj
    
    for i in traj_idx:
        traj[i] = plumed.read_as_pandas(colvar+'_{i}'.format(i=i))
        traj[i] = traj[i].drop(columns=['time'])
        columns = list(traj[i].columns.values)
        # Remove all dihedral angles, only keep sin/cos dihedrals 
        for column in columns:
            if column[:3] == 'phi' or column[:3] == 'psi' or column[:3] == 'chi' or column[:5] == 'omega':
                traj[i] = traj[i].drop(columns=[column])
            
    # data is the time-series COLVAR in numpy.ndarrays format
    data = [0]*len(traj)
    for i in range(len(traj)):
        numpy_data = traj[i].to_numpy(dtype='float32')
        data[i] = numpy_data
            
    return traj,data

def data_supplement(sim_idx,data,lagtime):
    round_seed_idx = []

    for i in range(1,sim_idx):
        round_seed_idx_i = np.loadtxt('round{i}_seed.txt'.format(i=i),dtype=int)
        round_seed_idx.append(round_seed_idx_i)

    round_seed_idx = np.concatenate(round_seed_idx)

    data_supp = []

    for i,round_seed_idx_i in enumerate(round_seed_idx):
    
        sim_i = round_seed_idx_i[0]
        frame = round_seed_idx_i[1]
    
        if frame == 0:
            continue
        elif lagtime > frame:
            start_frame = 0
        else:
            start_frame = frame - lagtime + 1
    
        end_frame = lagtime
    
        data_supp_pre = data[sim_i][start_frame:frame,:]
        data_supp_post = data[i][:end_frame,:]
        data_supp_i = np.concatenate([data_supp_pre,data_supp_post])
    
        data_supp.append(data_supp_i)
    
    return data_supp
    
def run_TICA(data,data_supp,lagtime,dim=None,var_cutoff=None,koopman=True):
    data_syn = data + data_supp
    tica = TICA(lagtime=lagtime,dim=dim,var_cutoff=var_cutoff)
    if koopman == True:
        koopman_estimator = KoopmanWeightingEstimator(lagtime=lagtime)
        reweighting_model = koopman_estimator.fit(data_syn).fetch_model()
        tica = tica.fit(data_syn, weights=reweighting_model).fetch_model()
    else:
        tica = tica.fit(data_syn).fetch_model()
    # tica is the data-fitted model, which contains eigenvalues and eigenvectors
    # tica_output is the tranformed time-series data in TICA space in shape(traj_idx,no_frames,dim)
    # tica_output_concat is tica_output in shape(traj_idx*no_frames,dim)
    tica_output = tica.transform(data)
    tica_output_concat = np.concatenate(tica_output)

    tica_output_supp = []
    for data_supp_i in data_supp:
        tica_output_supp_i = tica.transform(data_supp_i)
        tica_output_supp.append(tica_output_supp_i)
        
    return tica,tica_output,tica_output_concat,tica_output_supp

def calculate_nmicro(data_concat):
    # Heuristic approach to determine cluster number from htmd 
    # https://github.com/Acellera/htmd/blob/master/htmd/adaptive/adaptivebandit.py
    n_microstates = int(max(100, np.round(0.6 * np.log10(data_concat.shape[0] / 1000) * 1000 + 50)))
    return n_microstates
    
def run_kmeans(tica_output,tica_output_supp,tica_output_concat,n_microstates,n_jobs):
    minibatch_kmeans = MiniBatchKMeans(n_clusters=n_microstates,batch_size=10000,max_iter=100,init_strategy='kmeans++',n_jobs=n_jobs)
    microstates = minibatch_kmeans.fit(tica_output_concat).fetch_model()
    cluster_centers = microstates.cluster_centers
    # assignments_concat is the microstate assignment in shape (traj_idx*no_frames,)
    # assignments is the microstate assignment in shape (traj_idx,no_frames)
    assignments_concat = microstates.transform(tica_output_concat)
    assignments = assignments_concat.reshape(-1,tica_output.shape[1])

    assignments_supp = []
    for tica_output_supp_i in tica_output_supp:
        assignments_supp_i = microstates.transform(tica_output_supp_i)
        assignments_supp.append(assignments_supp_i)
    
    return assignments,assignments_concat,cluster_centers,assignments_supp
    
def build_MSM(msm_lagtime,assignments,assignments_supp):
    assignments_syn = list(assignments) + assignments_supp
    counts = TransitionCountEstimator(lagtime=msm_lagtime, count_mode='sliding').fit_fetch(assignments_syn)
    msm = MaximumLikelihoodMSM().fit_fetch(counts)
    return counts,msm
    
def run_PCCA(msm,n_metastable_sets):
    pcca = msm.pcca(n_metastable_sets=n_metastable_sets)
    return pcca

### Adaptive seeding functions
def fix_disconnected(counts,n_microstates,msm,pcca):
    ### Locate the connected and disconnected sets:
    sets = counts.connected_sets(connectivity_threshold=0,directed=True,sort_by_population=True)
    disconnected_sets = sets[1:]
    n_macro_disconnected = len(disconnected_sets)
    disconnected_dict = {}
    for i in range(n_macro_disconnected):
        macro_label = n_metastable_sets + i
        for j in disconnected_sets[i]:
            disconnected_dict[j] = macro_label

    # We need to modify:
    # 1. pcca.assignments: assign each disconnected microstate to a new pseudo-macrostate
    # 2. msm.stationary_distribution: insert 0 at the location corresponding to the disconnected microstates
    pcca_assignments = np.zeros(n_microstates,dtype=int)
    stationary_distribution = np.zeros(n_microstates,)

    connected_count = 0
    for i in range(n_microstates):
        if i in disconnected_dict.keys():
            pcca_assignments[i] = disconnected_dict[i]
            stationary_distribution[i] = 0
        else:
            pcca_assignments[i] = pcca.assignments[connected_count]
            stationary_distribution[i] = msm.stationary_distribution[connected_count]
            connected_count += 1
        
    return n_macro_disconnected,pcca_assignments,stationary_distribution
    
    
def count_macro(n_sim,n_macro_disconnected,pcca_assignments,assignments,assignments_concat,sim_idx):
    # Obtain macrostate assignment to original time-series data
    macro_assignments = dict(enumerate(pcca_assignments))
    macro_timeseries = np.vectorize(macro_assignments.get)(assignments_concat)
    
    # Macrostate seeding
    unique_macro, counts_macro = np.unique(macro_timeseries, return_counts=True)
    prob_macro = (1 / counts_macro) / np.sum(1 / counts_macro)
    macrostate_seed = multinomial(n_sim,prob_macro)

    # Microstate seeding
    # First count the occurences of all microstates
    unique_micro, counts_micro = np.unique(assignments_concat, return_counts=True)
    seed_idx = []
    counts_micro_i_log = {}
    for macro_i, n_sample in enumerate(macrostate_seed):
        # locate the index of microstates not assigned to current selected macrostates
        not_macro_idx = np.where(pcca_assignments != np.unique(pcca_assignments)[macro_i])
        # let all entries corresponding to not_macro_idx = 0, therefore ignore them during selection
        counts_micro_i = deepcopy(counts_micro)
        counts_micro_i[not_macro_idx] = 0
        # let 1/0 = 0
        inverse_counts = np.where(counts_micro_i==0, 0, 1/counts_micro_i)
        prob_micro_i = inverse_counts / np.sum(inverse_counts)
        microstate_seed = multinomial(n_sample,prob_micro_i)
        # Record selection statistics for visualization
        if n_sample != 0:
            macro_idx_log = unique_macro[macro_i] 
            counts_micro_i_log[macro_idx_log] = [counts_micro_i,microstate_seed]
        for micro_i, n_sample in enumerate(microstate_seed):
            seed_idx = seed_idx + n_sample * [micro_i]

    conf_seed = []

    for i,seed in enumerate(seed_idx):
        conf_idx = np.array(np.where(assignments==seed)).T
        conf_seed_frame = conf_idx[np.random.randint(conf_idx.shape[0], size=1), :][0]
        conf_seed.append(conf_seed_frame)

    # Visualization
    n_macro_nodes = n_metastable_sets + n_macro_disconnected
    labels = {}
    color_map = []
    node_size = []

    G = nx.Graph()

    # label and color for origin
    labels[0] = assignments_concat.shape[0]
    color_map.append('red')

    # label,color,node size
    for i,macro_idx in enumerate(unique_macro):
        seed_counts_i = macrostate_seed[i]
        labels[macro_idx+1] = str(counts_macro[i])+',{seed_counts_i}'.format(seed_counts_i=seed_counts_i)
        if seed_counts_i != 0:
            color_map.append('orange')
        else:
            color_map.append('blue')
        G.add_edge(0,macro_idx+1)

    for macro_i,micro_counts_seed in counts_micro_i_log.items():
        index = np.where(micro_counts_seed[0] != 0)
        for idx in index[0]:
            if idx <= n_macro_nodes:
                G.add_edge(macro_i+1,idx+10000)
                labels[idx+10000] = str(micro_counts_seed[0][idx])+','+str(micro_counts_seed[1][idx])
            else:
                G.add_edge(macro_i+1,idx)
                labels[idx] = str(micro_counts_seed[0][idx])+','+str(micro_counts_seed[1][idx])
            if micro_counts_seed[1][idx] != 0:
                color_map.append('green')
            else:
                color_map.append('blue')

    for i,node in enumerate(G):
        node_size.append(1000)

    # get positions
    pos = nx.spring_layout(G)

    # shift position a little bit
    shift = [0.1, 0]
    shifted_pos ={node: node_pos + shift for node, node_pos in pos.items()}

    # adjust size
    fig,ax = plt.subplots(figsize=(14,14))
    #ax.set_xlim([1*x for x in axis.get_xlim()])
    #ax.set_ylim([1*y for y in axis.get_ylim()])

    # draw graph
    nx.draw(G, pos, with_labels=True,font_color='white',node_color=color_map,node_size=node_size)

    # draw labels
    nx.draw_networkx_labels(G, shifted_pos, labels=labels, horizontalalignment="left")

    # turn off frame
    ax.axis("off")
    # Save figure
    plt.savefig('figures/Visualize_Count_Macro_{sim_idx}.png'.format(sim_idx=sim_idx),dpi=600)
        
    return conf_seed
    
    
def write_gmxfile(sim_idx,n_sim,seed_ref,conf_seed):
    ### .gro seed files generation
    u_list = []
    for i in range(sim_idx*n_sim):
        u_traj = md.Universe(seed_ref,'traj_and_dat/{i}.xtc'.format(i=i))
        u_list.append(u_traj)

    for i,seed in enumerate(conf_seed):
        traj_no = seed[0]
        frame = seed[1]
        u_list[traj_no].atoms.write('{i}.gro'.format(i=i+sim_idx*n_sim),frames=u_list[traj_no].trajectory[frame:frame+1])

    np.savetxt('round{sim_idx}_seed.txt'.format(sim_idx=sim_idx),conf_seed,fmt='%s')
        
    idx_start = sim_idx * n_sim
    idx_end = (sim_idx + 1) * n_sim - 1
        
    with open('grompp.sh','w+') as f:
        f.writelines('start={idx_start}\n'.format(idx_start=idx_start))
        f.writelines('end={idx_end}\n'.format(idx_end=idx_end))
        f.writelines('for i in `seq $start $end`\n')
        f.writelines('do\n')
        f.writelines('    mkdir $i\n')
        f.writelines('    gmx grompp -f ../mdp/md.mdp -p ../1-topol/topol.top -c $i.gro -o {sim_name}.tpr\n'.format(sim_name=sim_name))
        f.writelines('    mv {sim_name}.tpr $i\n'.format(sim_name=sim_name))
        f.writelines('done\n')
    subprocess.run(['chmod u+x grompp.sh'], shell=True)
    subprocess.run(['./grompp.sh'], shell=True)
        
    return None

### Machine learning CV related functions
def cv_ref_projection(): # TODO: print figures which project tica eigenvectors on 2d ref surface 
    return

def calculate_cv_sigma(tica_output_concat,num_cvs):
    sigma = np.zeros(num_cvs,)
    for i in range(num_cvs):
        sigma[i] = np.std(tica_output_concat.T[i])
    return sigma
    
def cv_convergence(data,tica_lagtime,num_cvs,tica_cv_models,sim_idx):
    # We use all TICA models learnt to transform current dataset
    tica_outputs = []

    for tica_i in tica_cv_models:
        tica_output_concat_i = np.concatenate(tica_i.transform(data))
        tica_outputs.append(tica_output_concat_i)

    # Stack all transformation into a 3d array with shape (model_no,frame_no,num_cvs)
    tica_outputs = np.stack(tica_outputs)

    # Check convergence with Pearson correlation
    correlations = np.zeros((tica_outputs.shape[0]-1,num_cvs))
    
    for i in range(tica_outputs.shape[0]-1):
        for j in range(num_cvs):
            correlation = pearsonr(tica_outputs[i,:,j].T,tica_outputs[-1,:,j].T)[0]
            correlations[i][j] = np.abs(correlation)

    return correlations

def tica_plumed(barrier,feature_dat,traj,tica,num_cvs,sigma):
    traj_concat = pd.concat(traj,axis=0)
    with open('opes/plumed.dat','w+') as f:
        f.writelines('MOLINFO STRUCTURE=input.pdb\n')
        with open(feature_dat,'r') as g:
            lines = g.readlines()
            for line in lines[1:-1]:   # remove PRINT argument
                f.writelines(line)
                
        arg_string = ''
        for feature in traj_concat.columns:
            arg_string = arg_string + feature + ','
        arg_string = arg_string[:-1]

        parameters_string = np.array2string(tica.mean_0,separator=',',threshold=np.inf,max_line_width=np.inf,floatmode='fixed')[1:-1]
        
        for i in range(num_cvs):
            coeff_string = ''
            for value in tica.singular_vectors_left.T[i]:
                string = str(value)+','
                coeff_string = coeff_string + string
            coeff_string = coeff_string[:-1]
            f.writelines('tica{i}: COMBINE ARG={arg_string} COEFFICIENTS={coeff_string} PARAMETERS={parameters_string} PERIODIC=NO\n'.format(i=i,arg_string=arg_string,coeff_string=coeff_string,parameters_string=parameters_string))
        for feature in traj_concat.columns:
            arg_string = arg_string + feature + ','
        arg_string = arg_string[:-1]

        ####
        f.writelines('phi: TORSION ATOMS=@phi-2\n')
        f.writelines('psi: TORSION ATOMS=@psi-2\n')
        ####
        
        f.writelines('opes: OPES_METAD ...\n')
        opes_arg = ''
        for i in range(num_cvs):
            opes_arg = opes_arg + 'tica{i}'.format(i=i) + ','
        opes_arg = opes_arg[:-1]
        f.writelines('    ARG={opes_arg}\n'.format(opes_arg=opes_arg))
        f.writelines('    PACE=500 BARRIER={barrier}\n'.format(barrier=barrier))
        sigma_string = ''
        for i in range(num_cvs):
            sigma_string = sigma_string + str(sigma[i]) + ','
        sigma_string = sigma_string[:-1]
        f.writelines('    SIGMA={sigma_string}\n'.format(sigma_string=sigma_string))
        f.writelines('    STATE_RFILE=../STATE\n')
        f.writelines('    STATE_WFILE=../STATE\n')
        f.writelines('    NLIST\n')
        f.writelines('    WALKERS_MPI\n')
        f.writelines('...\n')
        f.writelines('PRINT ARG=phi,psi,{opes_arg},opes.* STRIDE=500 FILE=COLVAR'.format(opes_arg=opes_arg))
        
    subprocess.run(['cp traj_and_dat/input.pdb opes/'], shell=True)

    return None

# OPES related
def opes_seed(tica_output,tica_output_concat,tica_output_supp,n_sim,n_jobs,sim_idx,seed_ref):
    # 2D clustering
    assignments,assignments_concat,cluster_centers,assignments_supp = run_kmeans(tica_output,tica_output_supp,tica_output_concat,n_sim,n_jobs)
    
    # Visualization
    fig,ax = plt.subplots()
    cmap = matplotlib.colors.ListedColormap(['black','indigo','darkslateblue','steelblue','teal','darkcyan','lightseagreen','mediumseagreen','slategrey','yellowgreen','greenyellow','gold','yellow','darkviolet','violet','pink'])
    norm = matplotlib.colors.BoundaryNorm(np.arange(-0.5,16), cmap.N) 
    sc = ax.scatter(tica_output_concat[:,0],tica_output_concat[:,1],c=assignments_concat,cmap=cmap,norm=norm,s=8)
    plt.plot(cluster_centers[:,0],cluster_centers[:,1],'o',ms=5,color='red')
    plt.colorbar(sc,label='cluster',ticks=np.linspace(0,15,16))
    plt.xlabel(r'KTICA tIC1 $\hat \phi_1$')
    plt.ylabel(r'KTICA tIC2 $\hat \phi_2$')
    plt.savefig('figures/OPES_seed.png',dpi=600,bbox_inches='tight')
    
    # Seeding
    
    opes_seed_idx = []
    
    for i in range(n_sim):
        dist_to_center_i = np.linalg.norm(tica_output - cluster_centers[i],axis=2)
        opes_seed_idx_i = np.array([np.where(dist_to_center_i == dist_to_center_i.min())[0][0],np.where(dist_to_center_i == dist_to_center_i.min())[1][0]])
        opes_seed_idx.append(opes_seed_idx_i)

    u_list = []
    for i in range(sim_idx*n_sim):
        u_traj = md.Universe(seed_ref,'traj_and_dat/{i}.xtc'.format(i=i))
        u_list.append(u_traj)

    for i,seed in enumerate(opes_seed_idx):
        traj_no = seed[0]
        frame = seed[1]
        u_list[traj_no].atoms.write('opes/{i}.gro'.format(i=i),frames=u_list[traj_no].trajectory[frame:frame+1])

    np.savetxt('opes/opes_seed.txt',opes_seed_idx,fmt='%s')
    
    return None

def opes_mpirun(opes_steps,sim_idx,n_steps,n_sim,ntomp,sim_name,gpu_id,num_cvs,barrier):
    with open('opes/opes_grompp.sh','w+') as f:
        f.writelines('for i in `seq 0 {n_sim}`\n'.format(n_sim=n_sim-1))
        f.writelines('do\n')
        f.writelines('    mkdir opes/$i\n')
        f.writelines('    gmx grompp -f ../mdp/md.mdp -p ../1-topol/topol.top -c opes/$i.gro -o opes/{sim_name}.tpr\n'.format(sim_name=sim_name))
        f.writelines('    mv opes/{sim_name}.tpr opes/$i.gro opes/$i\n'.format(sim_name=sim_name))
        f.writelines('    cp traj_and_dat/input.pdb opes/$i/\n')
        f.writelines('    cp opes/plumed.dat opes/$i/\n')
        f.writelines('done\n')
    subprocess.run(['chmod u+x opes/opes_grompp.sh'], shell=True)
    subprocess.run(['./opes/opes_grompp.sh'], shell=True)
    
    if opes_steps == None:
        opes_steps = sim_idx * n_steps
    
    multidir = ''
    
    for i in range(0,n_sim):
        multidir = multidir + str(i) + ' '
    multidir = multidir[:-1]
    
    with open('opes/mpirun_opes.sh','w+') as f:
        f.writelines('cd opes/\n')
        f.writelines('mpirun -np {n_sim} -cpus-per-rank {ntomp} gmx_mpi mdrun -v -deffnm {sim_name} -multidir {multidir} -pme gpu -nb gpu -bonded gpu -nsteps {opes_steps} -ntomp {ntomp} -gpu_id {gpu_id} -plumed plumed.dat -cpi\n'.format(n_sim=n_sim,ntomp=ntomp,sim_name=sim_name,multidir=multidir,opes_steps=opes_steps,gpu_id=gpu_id))

    subprocess.run(['chmod u+x ./opes/mpirun_opes.sh'],shell=True)
    subprocess.run(['./opes/mpirun_opes.sh'],shell=True)

    return None

### Main program

# Initialization
tica_cv_models = []
data_supp = []

# Retrieve all the past tica models if convergence check is required
if sim_idx != 0 and convergence_check == True:
    
    traj,data = read_features(colvar,sim_idx,n_sim)   # To do: This is repetitive
    if sim_idx != 1:
        data_supp = data_supplement(sim_idx,data,tica_lagtime)
    
    for sim_i in range(1,sim_idx+1):
        data_sim_i = data[:(sim_i*n_sim)]
        if sim_i == 1:
            data_supp_i = []
        else:
            data_supp_i = data_supplement(sim_i,data_sim_i,tica_lagtime)
        data_syn_i = data_sim_i + data_supp_i
        tica_cv_model_i = run_TICA(data_sim_i,data_syn_i,tica_lagtime,num_cvs,None,koopman)[0]
        tica_cv_models.append(tica_cv_model_i)

# First round parallel simulation
if sim_idx == 0:
    
    # Initialize files and directories
    initialize(sim_name,n_sim)
    gmx_mpirun(sim_name,sim_idx,n_sim,n_steps,ntomp,gpu_id)

    ### Write topol and seed_ref file
    subprocess.run(['echo 1 1 | gmx trjconv -f 0/{sim_name}.xtc -s 0/{sim_name}.tpr -o {topol} -pbc mol -ur compact -center -dump 0'.format(sim_name=sim_name,topol=topol)],shell=True)
    subprocess.run(['echo 0 0 | gmx trjconv -f 0/{sim_name}.xtc -s 0/{sim_name}.tpr -o {seed_ref} -pbc mol -ur compact -center -dump 0'.format(sim_name=sim_name,seed_ref=seed_ref)],shell=True)

    #########################################USER DEFINE REGION######################################################

    u = md.Universe('traj_and_dat/input.pdb')
    heavy_atom = u.select_atoms('name CA or name CB or name O or name N or name CH3 or name C')
    ids = list(heavy_atom.ids)
    atom_pairs = list(itertools.combinations(ids, 2))

    # Features
    with open('traj_and_dat/features.dat','w+') as f:
        f.writelines('MOLINFO STRUCTURE=traj_and_dat/input.pdb\n')
        count = 0
        for pair in atom_pairs:
            atom1 = pair[0]
            atom2 = pair[1]
            count = count + 1
            f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        f.writelines('PRINT ARG=* STRIDE=1 FILE=COLVAR')

    # Physical intuitive CV as reference
    with open('traj_and_dat/ref.dat','w+') as f:
        f.writelines('MOLINFO STRUCTURE=traj_and_dat/input.pdb\n')
        f.writelines('phi: TORSION ATOMS=@phi-2\n')
        f.writelines('psi: TORSION ATOMS=@psi-2\n')
        f.writelines('PRINT ARG=* STRIDE=1 FILE=COLVAR_ref\n')
    
    #################################################################################################################

    # Write features
    write_features(colvar,sim_idx,n_sim,sim_name)

    sim_idx = sim_idx + 1

    ### CV convergence check
    if convergence_check == True:
        traj,data = read_features(colvar,sim_idx,n_sim)        # TODO: This is a repetitive action
        tica_0 = run_TICA(data,data_supp,tica_lagtime,num_cvs,None,koopman)[0]
        tica_cv_models.append(tica_0)

# Adaptive Sampling starts from round 1
for sim_idx in range(sim_idx,n_rounds):
    ### Adaptive sampling
    # Read features
    traj,data = read_features(colvar,sim_idx,n_sim)
    if sim_idx > 1:
        data_supp = data_supplement(sim_idx,data,tica_lagtime)
    # Perform TICA
    tica,tica_output,tica_output_concat,tica_output_supp = run_TICA(data,data_supp,tica_lagtime,dim,var_cutoff,koopman)
    # K-means clustering
    n_microstates = calculate_nmicro(tica_output_concat)
    assignments,assignments_concat,cluster_centers,assignments_supp = run_kmeans(tica_output,tica_output_supp,tica_output_concat,n_microstates,n_jobs)
    counts,msm = build_MSM(msm_lagtime,assignments,assignments_supp)
    # Lumping microstates into macrostates by PCCA. NB: THIS STEP IS VERY SLOW. BE PATIENT.
    n_metastable_sets_tmp = min(n_metastable_sets, msm.transition_matrix.shape[0])
    pcca = run_PCCA(msm,n_metastable_sets_tmp)
    # Seed generation
    n_macro_disconnected,pcca_assignments,stationary_distribution = fix_disconnected(counts,n_microstates,msm,pcca)
    conf_seed = count_macro(n_sim,n_macro_disconnected,pcca_assignments,assignments,assignments_concat,sim_idx)
    # Write files
    write_gmxfile(sim_idx,n_sim,seed_ref,conf_seed)
    # Run adaptive MD
    gmx_mpirun(sim_name,sim_idx,n_sim,n_steps,ntomp,gpu_id)
    write_features(colvar,sim_idx,n_sim,sim_name)
    
    ### CV convergence check & machine learning CV
    if convergence_check == True:
        sim_idx = sim_idx + 1
        traj,data = read_features(colvar,sim_idx,n_sim)        # TODO: This is a repetitive action
        data_supp = data_supplement(sim_idx,data,tica_lagtime)
        # Machine learning CV
        tica_i,tica_output_i,tica_output_concat_i,tica_output_supp_i = run_TICA(data,data_supp,tica_lagtime,num_cvs,None,koopman)
        tica_cv_models.append(tica_i)
        
        correlations = cv_convergence(data,tica_lagtime,num_cvs,tica_cv_models,sim_idx)
        np.savetxt('opes/correlations_round{sim_idx}.txt'.format(sim_idx=sim_idx-1),correlations)
        
        if (((correlations[-patience:][:,0].all() > convergence_criteria[0]) == True) and (correlations[-patience:][:,1].all() > convergence_criteria[1]) == True and sim_idx > patience) or sim_idx == n_rounds:
            # Compute initial sigma for OPES simulation
            sigma = calculate_cv_sigma(tica_output_concat_i,num_cvs)
            # Write plumed files
            tica_plumed(barrier,'traj_and_dat/features.dat',traj,tica_i,num_cvs,sigma)
            break

# Perform multiple-walkers OPES simulation
opes_seed(tica_output_i,tica_output_concat_i,tica_output_supp_i,n_sim,n_jobs,sim_idx,seed_ref)
opes_mpirun(opes_steps,sim_idx,n_steps,n_sim,ntomp,sim_name,gpu_id,num_cvs,barrier)