In [1]:
import sys, os
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from typing import List, Tuple, Callable, Any, Dict

from methods.PolyCG.polycg.transforms.transform_marginals import send_to_back_permutation
from methods.midstep_composites import midstep_composition_transformation, midstep_groundstate, midstep_groundstate_se3
from methods.midstep_composites import midstep_composition_transformation_correction
from methods.read_nuc_data import read_nucleosome_triads, GenStiffness
from methods.free_energy import calculate_midstep_triads, midstep_excess_vals
from methods.PolyCG.polycg.SO3 import so3

from methods.nearest_correlation import nearcorr

# from methods.PolyCG.polycg.SO3 import so3
# from methods.PolyCG.polycg.transforms.transform_SO3 import euler2rotmat_so3
# from methods.PolyCG.polycg.cgnaplus import cgnaplus_bps_params
# from methods.PolyCG.polycg.transforms.transform_algebra2group import algebra2group_lintrans, group2algebra_lintrans

np.set_printoptions(linewidth=250,precision=5,suppress=True)

-----
-----
# Methods
-----
-----

### Matrix Methods

In [2]:
def plot_matrix(mat,dims=6,cmap='gray_r'):
    mat = np.copy(mat)
    fig = plt.figure(figsize=(17./2.54*1.3,17.2/2.54))
    ax1 = fig.add_subplot(111)
    
    img = ax1.imshow(mat, interpolation='none', cmap=cmap)
    # plot blocks
    for l in range(1,len(mat)//dims+1):
        ax1.plot([l*dims-0.6,l*dims-0.6],[0-0.5,len(mat)-0.5],color='black',lw=0.5)
        ax1.plot([0-0.5,len(mat)-0.5],[l*dims-0.4,l*dims-0.4],color='black',lw=0.5)

        ax1.plot([l*dims-0.6,l*dims-0.6],[(l-1)*dims-0.5,l*dims-0.5],color='red',lw=0.5)
        ax1.plot([(l-1)*dims-0.6,(l-1)*dims-0.6],[(l-1)*dims-0.5,l*dims-0.5],color='red',lw=0.5)
        ax1.plot([(l-1)*dims-0.5,l*dims-0.5],[l*dims-0.4,l*dims-0.4],color='red',lw=0.5)
        ax1.plot([(l-1)*dims-0.5,l*dims-0.5],[(l-1)*dims-0.4,(l-1)*dims-0.4],color='red',lw=0.5)
        
    ax1.set_xlim((-0.5,len(mat)-.5))
    ax1.set_ylim((-0.5,len(mat)-.5))
    # ax1.set_xlim((-0.5,12-.5))
    # ax1.set_ylim((-0.5,12-.5))
    fig.colorbar(img, ax=ax1)
    plt.show()
    
def regularize_matrix(A, epsilon=1e-6):
    # (Optional) ensure A is symmetric to avoid floating-point asymmetries
    A_sym = 0.5 * (A + A.T)
    # Eigen-decomposition for symmetric (real) matrices
    eigenvals, Q = np.linalg.eigh(A_sym)
    # Lift eigenvalues below epsilon to epsilon
    eigenvals_clipped = np.clip(eigenvals, epsilon, None)
    # Reconstruct the matrix
    A_reg = Q @ np.diag(eigenvals_clipped) @ Q.T
    # A_reg is guaranteed to be real and symmetric
    A_reg = A_reg.astype(np.float64)
    return A_reg

def block_sum(M,dims):
    if len(M)%dims != 0:
        raise ValueError('Matrix dimension not a multiple of dims.')
    N = len(M)//dims
    bM = np.zeros((N,N))
    for i in range(N):
        for j in range(N):
            bM[i,j] = np.max(M[i*dims:(i+1)*dims,j*dims:(j+1)*dims])
    return bM

def is_positive_definite(matrix):
    if not np.allclose(matrix, matrix.T):
        print('Not allclose')
        return False
    eigenvalues = np.linalg.eigvals(matrix)
    return np.all(eigenvalues > 0)

def is_regular(A):
    if A.shape[0] != A.shape[1]:
        return False
    return np.linalg.matrix_rank(A) == A.shape[0]


### Selection Methods

In [3]:
def get_midstep_locations(left_open: int, right_open: int, base_locations = None):
    if base_locations is None:
        MIDSTEP_LOCATIONS = [
            2, 6, 14, 17, 24, 29, 
            34, 38, 45, 49, 55, 59, 
            65, 69, 76, 80, 86, 90, 
            96, 100, 107, 111, 116, 121, 
            128, 131, 139, 143
        ]
    else:
        MIDSTEP_LOCATIONS = base_locations
    if left_open + right_open > len(MIDSTEP_LOCATIONS):
        return []
    return MIDSTEP_LOCATIONS[left_open:len(MIDSTEP_LOCATIONS)-right_open]

def submatrix(M, dim: int = 6, left_open: int = 0, right_open: int = 0, marginalize=False):
    n = len(M)//dim
    if marginalize:
        return np.linalg.inv(np.linalg.inv(M)[left_open*dim:(n-right_open)*dim,left_open*dim:(n-right_open)*dim])
    return M[left_open*dim:(n-right_open)*dim,left_open*dim:(n-right_open)*dim]

### Binding Model Methods

In [4]:
def Hinverse(Psi):
    psih = so3.hat_map(Psi)
    psihsq = psih @ psih
    Hinv = np.eye(3)
    Hinv += 0.5* psih
    Hinv += 1./12 * psihsq
    Hinv -= 1./720 * psihsq @ psihsq
    Hinv += 1./30240 * psihsq @ psihsq @ psihsq
    return Hinv
    
def coordinate_transformation(muk0s,sks):
    B = np.zeros((len(sks)*6,len(muk0s)*6))
    Pbar = np.zeros(len(sks)*6)
    for k in range(len(sks)):
        sig0 = np.linalg.inv(muk0s[k]) @ muk0s[k+1]
        Sig  = sig0[:3,:3]
        sig  = sig0[:3,3]
        Sk   = sks[k,:3,:3]
        sk   = sks[k,:3,3]
        
        Psi  = so3.rotmat2euler(Sk.T @ Sig)
        Hi   = Hinverse(Psi)
        Bkm = np.zeros((6,6))
        Bkp = np.zeros((6,6))
        Bkm[:3,:3] = -Hi @ Sig.T
        Bkm[3:,:3] = Sk.T @ so3.hat_map(sig)
        Bkm[3:,3:] = -Sk.T
        Bkp[:3,:3] = Hi
        Bkp[3:,3:] = Sk.T @ Sig
        
        B[6*k:6*(k+1),6*k:6*(k+1)]      = Bkm
        B[6*k:6*(k+1),6*(k+1):6*(k+2)]  = Bkp
        
        Pbar[k*6:k*6+3]   = Psi
        Pbar[k*6+3:k*6+6] = Sk.T @ (sig-sk)
    return B, Pbar



### Free Energy

In [7]:

def binding_model_free_energy(
    free_gs: np.ndarray,
    free_M: np.ndarray,    
    nuc_mu0: np.ndarray,
    nuc_K: np.ndarray,
    left_open: int,
    right_open: int,
    NUCLEOSOME_TRIADS: np.ndarray,
    use_correction: bool = True,
) -> np.ndarray:

    midstep_constraint_locations = get_midstep_locations(left_open, right_open)
    midstep_constraint_locations = sorted(list(set(midstep_constraint_locations)))

    if len(midstep_constraint_locations) <= 1:
        n = len(free_M)
        F_pi = -0.5*n * np.log(2*np.pi)
        # matrix term
        logdet_sign, logdet = np.linalg.slogdet(free_M)
        F_mat = 0.5*logdet
        F = F_mat + F_pi  
        Fdict = {
            'F': F,
            'F_entropy' : F,
            'F_enthalpy': 0,
            'F_jacob'   : 0,
            'F_free'    : F
        }
        return Fdict
    
    # DELETE!
    # Find midstep triads in fixed framework for comparison
    FIXED_midstep_triads = calculate_midstep_triads(
        midstep_constraint_locations,
        NUCLEOSOME_TRIADS
    )
    
    sks = midstep_groundstate_se3(free_gs,midstep_constraint_locations)
    
    # find composite transformation
    transform, replaced_ids = midstep_composition_transformation(
        free_gs,
        midstep_constraint_locations
    )
    
    # transform stiffness matrix
    inv_transform = np.linalg.inv(transform)
    M_transformed = inv_transform.T @ free_M @ inv_transform
    
    # rearrange stiffness matrix
    full_replaced_ids = list()
    for i in range(len(replaced_ids)):
        full_replaced_ids += [6*replaced_ids[i]+j for j in range(6)]
     
    P = send_to_back_permutation(len(free_M),full_replaced_ids)
    M_rearranged = P @ M_transformed @ P.T
    
    # select M and R submatrices
    N  = len(M_rearranged)
    NC = len(full_replaced_ids)
    NF = N-NC
    
    M_R = M_rearranged[:NF,:NF]
    M_M = M_rearranged[NF:,NF:]
    M_RM = M_rearranged[:NF,NF:]
    
    # Calculate M block marginal
    M_Mp = M_M - M_RM.T @ np.linalg.inv(M_R) @ M_RM
    M_Mp = 0.5*(M_Mp+M_Mp.T)
    

    ##############################################
    # Binding Model
    ##############################################
    
    # Calculate Incidence Matrix
    B, Pbar = coordinate_transformation(nuc_mu0,sks)  
    
    
    
    reg = 10
    Kcomb = nuc_K + B.T @ M_Mp @ B
    # Kcomb = regularize_matrix(Kcomb,epsilon=reg)
    
    # print(is_positive_definite(nuc_K))
    # print(is_positive_definite(Kcomb))
    # eig = np.linalg.eigvals(Kcomb)
    # print(np.sort(eig))
    
    # calculate ground state
    alpha = -np.linalg.inv(Kcomb) @ B.T @ M_Mp @ Pbar
    Y_C = Pbar + B @ alpha
    
    C = 0.5* Pbar.T @ ( M_Mp - M_Mp @ B @ np.linalg.inv(Kcomb) @ B.T @ M_Mp ) @ Pbar
    print(f'C = {C}')
    C = 0.5* Pbar.T @ ( M_Mp - M_Mp @ B @ np.linalg.inv(regularize_matrix(Kcomb,epsilon=reg)) @ B.T @ M_Mp ) @ Pbar
    print(f'C_reg = {C}')
    
    gamma = -np.linalg.inv(M_R) @ M_RM @ Y_C
    
    iters = 1
    for iter in range(iters):

        gs_transf_perm = np.concatenate((gamma,Y_C))
        gs_transf = P.T @ gs_transf_perm
        gs = inv_transform @ gs_transf

        gs = gs.reshape((len(gs)//6,6))
        # find composite transformation
        transform, replaced_ids = midstep_composition_transformation_correction(
            free_gs,
            midstep_constraint_locations,
            -gs
        )
        
        # transform stiffness matrix
        inv_transform = np.linalg.inv(transform)
        M_transformed = inv_transform.T @ free_M @ inv_transform
        
        # rearrange stiffness matrix
        full_replaced_ids = list()
        for i in range(len(replaced_ids)):
            full_replaced_ids += [6*replaced_ids[i]+j for j in range(6)]
        
        P = send_to_back_permutation(len(free_M),full_replaced_ids)
        M_rearranged = P @ M_transformed @ P.T
        
        # select M and R submatrices
        N  = len(M_rearranged)
        NC = len(full_replaced_ids)
        NF = N-NC
        
        M_R = M_rearranged[:NF,:NF]
        M_M = M_rearranged[NF:,NF:]
        M_RM = M_rearranged[:NF,NF:]
        
        # Calculate M block marginal
        M_Mp = M_M - M_RM.T @ np.linalg.inv(M_R) @ M_RM
        M_Mp = 0.5*(M_Mp+M_Mp.T)
        
        ##############################################
        # Binding Model
        ##############################################
        
        # Calculate Incidence Matrix
        B, Pbar = coordinate_transformation(nuc_mu0,sks)  
        
        Kcomb = nuc_K + B.T @ M_Mp @ B
        # Kcomb = regularize_matrix(Kcomb,epsilon=reg)
        
        # print(is_positive_definite(nuc_K))
        # print(is_positive_definite(Kcomb))
        # eig = np.linalg.eigvals(Kcomb)
        # print(np.sort(eig))
        
        # calculate ground state
        alpha = -np.linalg.inv(Kcomb) @ B.T @ M_Mp @ Pbar
        Y_C = Pbar + B @ alpha
        gamma = -np.linalg.inv(M_R) @ M_RM @ Y_C
    
        C = 0.5* Pbar.T @ ( M_Mp - M_Mp @ B @ np.linalg.inv(Kcomb) @ B.T @ M_Mp ) @ Pbar
        print(f'C = {C}')
        C = 0.5* Pbar.T @ ( M_Mp - M_Mp @ B @ np.linalg.inv(regularize_matrix(Kcomb,epsilon=reg)) @ B.T @ M_Mp ) @ Pbar
        print(f'C_reg = {C}')
        
    # C = 0.5* Pbar.T @ ( M_Mp - M_Mp @ B @ np.linalg.inv(Kcomb) @ B.T @ M_Mp ) @ Pbar
    # print(f'C = {C}')
    

    
    
     
    
    # logdet_sign, logdet = np.linalg.slogdet(Kcomb)
    # print(logdet)
    
    
    
    
    # find contraint excess values
    # excess_vals = midstep_excess_vals(
    #     free_gs,
    #     midstep_constraint_locations,
    #     FIXED_midstep_triads
    # )  
    # C = excess_vals.flatten()
    
    # print(Y_C.shape)
    # print(C.shape)

    # for i in range(len(C)//6):
    #     print('####################')
    #     print(Y_C[6*i:6*(i+1)])
    #     print(C[6*i:6*(i+1)])
        
    
    
    
    # print(Sks.shape)
    
    # print(Sks[:2])
    
    
    sys.exit()
    
    
        
    
    # transform stiffness matrix
    inv_transform = np.linalg.inv(transform)
    stiffmat_transformed = inv_transform.T @ stiffmat @ inv_transform
    
    # rearrange stiffness matrix
    full_replaced_ids = list()
    for i in range(len(replaced_ids)):
        full_replaced_ids += [6*replaced_ids[i]+j for j in range(6)]
     
    P = send_to_back_permutation(len(stiffmat),full_replaced_ids)
    stiffmat_rearranged = P @ stiffmat_transformed @ P.T

    # select fluctuating, constraint and coupling part of matrix
    N  = len(stiffmat)
    NC = len(full_replaced_ids)
    NF = N-NC
    
    MF = stiffmat_rearranged[:NF,:NF]
    MC = stiffmat_rearranged[NF:,NF:]
    MM = stiffmat_rearranged[NF:,:NF]
    
    MFi = np.linalg.inv(MF)
    b = MM.T @ C
    
    ########################################
    ########################################
    if use_correction:
        alpha = -MFi @ b
        
        gs_transf_perm = np.concatenate((alpha,C))
        gs_transf = P.T @ gs_transf_perm
        gs = inv_transform @ gs_transf
    
        gs = gs.reshape((len(gs)//6,6))
        # find composite transformation
        transform, replaced_ids = midstep_composition_transformation_correction(
            groundstate,
            midstep_constraint_locations,
            -gs
        )
        
        # transform stiffness matrix
        inv_transform = np.linalg.inv(transform)
        stiffmat_transformed = inv_transform.T @ stiffmat @ inv_transform
        
        stiffmat_rearranged = P @ stiffmat_transformed @ P.T

        # select fluctuating, constraint and coupling part of matrix
        N  = len(stiffmat)
        NC = len(full_replaced_ids)
        NF = N-NC
        
        MF = stiffmat_rearranged[:NF,:NF]
        MC = stiffmat_rearranged[NF:,NF:]
        MM = stiffmat_rearranged[NF:,:NF]
        
        MFi = np.linalg.inv(MF)
        b = MM.T @ C
    
    ########################################
    ########################################
    
    # constant energies
    F_const_C =  0.5 * C.T @ MC @ C
    F_const_b = -0.5 * b.T @ MFi @ b
    
    F_enthalpy = F_const_C + F_const_b
    
    
    K_partial = select_partial(K,left_open=left_open,right_open=right_open)
    Mtot_k = np.copy(stiffmat_rearranged)
    Mtot_k[NF:,NF:] += K_partial
    
    # print(stiffmat_rearranged[NF:NF+6,NF:NF+6])
    # print(Mtot_k[NF:NF+6,NF:NF+6])
    
    
    # entropy term
    n = len(Mtot_k)
    logdet_sign, logdet = np.linalg.slogdet(Mtot_k)
    F_pi = -0.5*n * np.log(2*np.pi)
    # matrix term
    F_mat = 0.5*logdet
    F_entropy = F_pi + F_mat
    F_jacob = np.log(np.linalg.det(transform))
    
    # free energy of unconstrained DNA
    ff_logdet_sign, ff_logdet = np.linalg.slogdet(stiffmat)
    ff_pi = -0.5*len(stiffmat) * np.log(2*np.pi)
    F_free = 0.5*ff_logdet + ff_pi
     
    # prepare output
    Fdict = {
        'F': F_entropy + F_jacob + F_enthalpy,
        'F_entropy' : F_entropy + F_jacob,
        'F_enthalpy': F_enthalpy,
        'F_jacob'   : F_jacob,
        'F_free'    : F_free
    }
    return Fdict
        

genstiff = GenStiffness(method='hybrid')
triadfn = 'methods/State/Nucleosome.state'
nuctriads = read_nucleosome_triads(triadfn)

BM_Kcomb_fn = 'MDParams/Calibrated/reg_Kcomb.npy'
BM_K_fn     = 'MDParams/Calibrated/reg_K.npy'
BM_mu0_fn   = 'MDParams/Calibrated/reg_mu0.npy'

nuc_Kcomb = np.load(BM_Kcomb_fn)
nuc_K     = np.load(BM_K_fn)
nuc_mu0   = np.load(BM_mu0_fn)


left_open  = 0
right_open = 0     
  
seq = "ATCAATATCCACCTGCAGATACTACCAAAAGTGTATTTGGAAACTGCTCCATCAAAAGGCATGTTCAGCTGGAATCCAGCTGAACATGCCTTTTGATGGAGCAGTTTCCAAATACACTTTTGGTAGTATCTGCAGGTGGATATTGAT"  

free_M,free_gs = genstiff.gen_params(seq,use_group=True)
        
# nucleosome_free_energy(groundstate,stiffmat,left_open,right_open,nuctriads,K)


binding_model_free_energy(
    free_gs,
    free_M,    
    nuc_mu0,
    nuc_K,
    left_open,
    right_open,
    nuctriads,
    use_correction = True
)


# print(is_positive_definite(nuc_K))
# nuc_K = 0.5*(nuc_K+nuc_K.T)

# Kpos = nearcorr(nuc_K, tol=[], flag=0, max_iterations=1000, n_pos_eig=0,
#              weights=None, verbose=False,
#              except_on_too_many_iterations=False)


# print(is_positive_definite(Kpos))

C = 146.81882933342965
C_reg = 146.81882933342962
C = 132.6730292205418
C_reg = 132.67302922054188


SystemExit: 