In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os,sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import mdtraj
import time

In [None]:
print("Define Beta to use for averaging in FES computation. Beta = ",round(BETA,3),"identified")
def generateFES(input_states,bins=(300,300),title="Free Energy Surface",at_beta=BETA,binskip=25):
    print("FES generated at beta=",at_beta)
    hg,x_entries,y_entries=np.histogram2d(input_states[:,0],input_states[:,1],bins=bins,range=((-np.pi,np.pi),(-np.pi,np.pi)))
    fes=(-1/at_beta)*np.log(hg.transpose()[::-1,:])
    plt.figure(figsize=(8,6))
    plt.title(title,fontsize=24)
    plt.imshow(fes)
    plt.xticks(np.arange(bins[0]+1)[::binskip],np.round(x_entries[::binskip],2),rotation=90,fontsize=16)
    plt.yticks(np.arange(bins[1]+1)[::binskip],np.round(y_entries[::binskip],2),fontsize=16)
    cbar=plt.colorbar()
    cbar.set_label("Free Energy (kJ/mol)",fontsize=18)
    plt.xlabel("$\\Phi$ (rad)",fontsize=21)
    plt.ylabel("$\\Psi$ (rad)",fontsize=21)

def getFES1D(values,bins=None):
    if bins is None: bins=np.linspace(np.min(values)-0.1,np.max(values)+0.1,64)
    gen_counts,gen_bins=np.histogram(values,bins=bins)
    gen_counts=gen_counts.astype(float)
    gen_counts+=1
    gen_counts/=np.sum(gen_counts)
    gen_bins=(gen_bins[1:]+gen_bins[:-1])/2
    gen_FES=-np.log(gen_counts)
    gen_FES-=np.min(gen_FES)
    return gen_bins,gen_FES

def getFES2D(data,bins=None):
    if bins is None: bins=(64,64)
    hg_grid,xlab,ylab=np.histogram2d(data[:,0],data[:,1],bins=bins)
    hg_grid/=np.sum(hg_grid)
    FES_grid=-np.log(hg_grid)
    FES_grid-=np.nanmin(FES_grid)
    return FES_grid,xlab,ylab

# Getting indices for histogramming
def arghistogram(data,bins=None):
    if bins is None: bins=np.linspace(np.min(data),np.max(data),PES_GRID_RES)
    return np.digitize(data,bins=bins)-1,bins

def arghistogram2D(data,bins=None):
    if bins is None: bins=(np.linspace(np.min(data[:,0]),np.max(data[:,0]),PES_GRID_RES),np.linspace(np.min(data[:,1]),np.max(data[:,1]),PES_GRID_RES))
    ax1=np.digitize(data[:,0],bins=bins[0])-1
    ax2=np.digitize(data[:,1],bins=bins[1])-1
    return np.stack((ax1,ax2),axis=-1).astype(int),bins

# Backmapping
def backmap(db,lst,batch_size=32,log_time=2500):
    print("DB Size:",db.shape)
    print("Input Size:",lst.shape)
    print("Processing...",flush=True)
    ret=[]
    for i in range(0,len(lst),batch_size):
        lext=lst[i:i+batch_size]
        dists=env.cross_distance2(lext,db) #torch.sum((db[np.newaxis,:,:]-lext[:,np.newaxis,:])**2,dim=-1)
        closest=torch.argmin(dists,dim=1).cpu().numpy()
        ret.append(closest)
        if i%log_time==0: print("\t",i,"of",len(lst))
    print("Done!",flush=True)
    return np.concatenate(ret)