In [4]:
import biotite
import biotite.structure as struc
import biotite.structure.io as strucio
import biotite.structure.io.dcd as dcd
import numpy as np
import copy
import matplotlib.pyplot as plt
import pandas as pd
import os
import re

def get_frame(reference_pdb : str, dcd_file : str, solvent="TIP"):
    """
    reference_pdb : target pdb file with hydrgeon and solvent
    dcd_file : trajectory file
    """
    template = strucio.load_structure(reference_pdb)
    filter = (template.element != "H") & (template.res_name != solvent)
    trajectory = dcd.DCDFile.read(dcd_file)
    template_filter = template[filter]

    coord = trajectory.get_coord()
    coord = coord[:, filter, :]

    # get the coord
    traj_list = []
    for i in range(len(coord)):
        template_filter.coord = coord[i]
        traj_list.append(copy.deepcopy(template_filter))

    stack_traj = struc.stack(traj_list)

    # superimpose
    superimposed,_ = struc.superimpose(stack_traj[0], stack_traj)
    rmsd = struc.rmsd(stack_traj[0], superimposed)
    return rmsd

def plot_rmsd(rmsd,total_step,dcdfreq,step_size=2,unit="fs"):
    import matplotlib.pyplot as plt
    dcd_interval = int(total_step / dcdfreq)
    if unit == "fs":
        dcd_interval_time = step_size * dcdfreq # fs
    elif unit == "ps":
        dcd_interval_time = step_size * dcdfreq / 1000
    elif unit == "ns":
        dcd_interval_time = step_size * dcdfreq / 1000 / 1000
    elif unit == "us":
        dcd_interval_time = step_size * dcdfreq / 1000 / 1000 / 1000
    else:
        raise ValueError(f"do u think {unit} is a reasonable unit? ")
    dcd_interval_real_time = np.arange(0,dcd_interval) * dcd_interval_time
    plt.plot(dcd_interval_real_time,rmsd[-dcd_interval:])
    plt.xlabel(f"Time ({unit})")
    plt.ylabel("RMSD ($\AA$)")
    plt.show()

def plot_log(time_data , kind = "energy (kcal/mol)", output_freq = 100, step_size = 2 , unit = "fs"):
    """
    time_data : the data from dat file from VMD
    kind : energy 
    - BOND, ANGLE, DIHED, IMPRP, ELECT, VDW, BOUNDARY, MISC, KINETIC, TOTAL, TEMP, TOTAL2, TOTAL3, TEMPAVG
    """
    if isinstance(time_data, str):
        time_data = pd.read_csv(time_data, sep='\t', header=None)
        time_data = np.array(time_data[1])
    time_data = np.array(time_data)
    output_number = np.arange(0,len(time_data)) 
    if unit == "fs":
        output_time_scale = output_freq * step_size
    elif unit == "ps":
        output_time_scale = output_freq * step_size / 1000
    elif unit == "ns":
        output_time_scale = output_freq * step_size / 1000 / 1000
    elif unit == "us":
        output_time_scale = output_freq * step_size / 1000 / 1000 / 1000
    else:
        raise ValueError(f"do u think {unit} is a reasonable unit?")
    output_time = output_number * output_time_scale
    plt.plot(output_time, time_data)
    plt.xlabel(f"Time ({unit})")
    plt.ylabel(f"{kind}")
    plt.show()

def tabular_namd_log(log_file,pressure=True):
    with open(log_file,"r") as f:
        prepend_log_str = f.read()
    match_minimize = re.findall(r'TCL: Minimizing for (\d+) steps', prepend_log_str)
    match_equilibration = re.findall(r'TCL: Running for (\d+) steps', prepend_log_str)
    meta_data = {}
    meta_data["minimize"] = int(match_minimize[0])
    meta_data = {f"equi_{i}":int(match_equilibration[i]) for i in range(len(match_equilibration))}
    meta_data['production'] = int(match_equilibration[-1])
    time_step = float(re.findall(r'Info: TIMESTEP\s+(\d+(\.\d+)?)', prepend_log_str)[0][0])
    energy_output = int(re.findall(r'Info: ENERGY OUTPUT STEPS\s+(\d+)', prepend_log_str)[0])
    dcd_freq = int(re.findall(r'Info: DCD FREQUENCY\s+(\d+)', prepend_log_str)[0])
    temperature = float(re.findall(r'Info: LANGEVIN TEMPERATURE\s+(\d+)', prepend_log_str)[0])
    meta_data["time_step"] = time_step
    meta_data["energy_output"] = energy_output
    meta_data["dcd_freq"] = dcd_freq
    meta_data["temperature"] = temperature
    if pressure:
        meta_data["pressure"] = float(re.findall(r'Info:        TARGET PRESSURE IS\s+(\d+\.\d+)', prepend_log_str)[0])
    temp_file_name = f"./.yyyy_sirius.temp"
    os.system(f"grep 'ENERGY:' {log_file} > {temp_file_name}")
    with open(f"{temp_file_name}","r") as f:
        log_str = f.read() 
    data = []
    for line in log_str.splitlines():
        temp_data = line.split()
        temp_data[0] = "ENERGY"
        temp_data[1] = int(temp_data[1])
        for i in range(2,len(temp_data)):
            temp_data[i] = float(temp_data[i])
        data.append(temp_data[1:])
    title =  ['TS','BOND','ANGLE','DIHED','IMPRP','ELECT','VDW','BOUNDARY','MISC','KINETIC','TOTAL','TEMP','POTENTIAL','TOTAL3','TEMPAVG','PRESSURE','GPRESSURE','VOLUME','PRESSAVG','GPRESSAVG']
    data = pd.DataFrame(data,columns=title)
    # 去除data中重复的行
    data_deduplicated = data.drop_duplicates(subset=['TS'],keep='first')
    os.remove(temp_file_name)
    data_production = data_deduplicated.iloc[int(-meta_data['production']/meta_data['energy_output'])-1:,:]
    return data_deduplicated,meta_data,data_production