## **Calculation of Binomial Moments and Reconstruction of Probability Mass Function**  

**Stochastic Kinetics of mRNA Molecules in a General Transcription Model**  

*Yuntao Lu and Yunxin Zhang*  

School of Mathematical Sciences, Fudan University, Shanghai 200433, China  

Email: `yuntaolu22@m.fudan.edu.cn` and `xyz@fudan.edu.cn`  

This script accompanies our manuscript titled **"Stochastic Kinetics of mRNA Molecules in a General Transcription Model"**.  

It implements two core functions:  

1. **`moment(D0, D1, M, threshold=1e-32)`**  
   Computes the first $M + 1$ binomial moments $B_0, B_1, \ldots, B_M$ (with $B_0 = 1$) of the mRNA copy number distribution, using the parameters $D_0$ and $D_1$ defined in Model (1) of our manuscript.  

2. **`distribution(binomial_moments, N, threshold=1e-8)`**  
   Reconstructs the probability mass function $[P_0, P_1, \ldots, P_{N-1}]$ from binomial moments via Equation (7) in our manuscript, where $N$ specifies the desired vector length.

In [1]:
import scipy.linalg as linalg
import numpy as np
import time
import math

In [2]:
import Parameters_for_Figures

In [3]:
# D0=Parameters_for_Figures.D0_3a
# D1=Parameters_for_Figures.D1_3a

In [4]:
# D0=Parameters_for_Figures.D0_3b
# D1=Parameters_for_Figures.D1_3b

In [5]:
# D0=Parameters_for_Figures.D0_3c
# D1=Parameters_for_Figures.D1_3c

In [6]:
D0=Parameters_for_Figures.D0_3d
D1=Parameters_for_Figures.D1_3d

In [8]:
M=62
threshold=1e-64

In [9]:
# ==============================================
# INPUT VALIDATION AND SAFETY CHECKS
# ==============================================

# Type checking for input parameters
if not isinstance(D0, np.ndarray) or not isinstance(D1, np.ndarray):
    raise TypeError("Input matrices D0 and D1 must be NumPy arrays")
if not isinstance(M, int):
    raise TypeError("M must be an integer")

# Matrix dimension validation
if D0.ndim != 2 or D1.ndim != 2:
    raise ValueError("Input matrices must be 2-dimensional arrays")

n0, m0 = D0.shape
n1, m1 = D1.shape

# Check if matrices are square
if n0 != m0:
    raise ValueError("D0 must be a square matrix")
if n1 != m1:
    raise ValueError("D1 must be a square matrix")

# Check matrix dimension consistency
if n0 != n1:
    raise ValueError("D0 and D1 must have the same dimension")
n = n0  # Dimension of the system

# Validate moment count
if M < 0:
    raise ValueError("M must be a non-negative integer")

In [10]:
# Hydrolysis rate (normalized to 1)
d = 1.0  

# Construct Q-matrix D
D = D0 + D1

In [11]:
# ==============================================
# Q-MATRIX VALIDATION
# ==============================================
tol = 1e-8  # Numerical tolerance for checks

# Validate Q-matrix properties:
# 1. Row sums should be zero (within tolerance)
row_sums = np.sum(D, axis=1)
if not np.allclose(row_sums, 0, atol=tol):
    max_error = np.max(np.abs(row_sums))
    raise ValueError(f"Invalid generator matrix: Row sums should be zero "
                     f"(max error: {max_error:.2e})")

# 2. Off-diagonal elements should be non-negative
D_offdiag = D.copy()
np.fill_diagonal(D_offdiag, 0)
if np.any(D_offdiag < -tol):
    min_value = np.min(D_offdiag)
    raise ValueError(f"Invalid generator matrix: Negative off-diagonals "
                     f"(min value: {min_value:.2e})")

# 3. Diagonal elements should be non-positive
diag_elements = np.diag(D)
if np.any(diag_elements > tol):
    max_value = np.max(diag_elements)
    raise ValueError(f"Invalid generator matrix: Positive diagonals "
                     f"(max value: {max_value:.2e})")

# Check numerical stability
inf_norm_D1 = linalg.norm(D1, np.inf)
if inf_norm_D1 > 9:
    print(f"Warning: High matrix norm ||D1||_inf = {inf_norm_D1:.2f} - "
           "potential numerical instability")



In [12]:
# ==============================================
# COMPUTATION of Invariant Distribution of D
# ==============================================
try:
    # Construct linear system to calculate stationary distribution of the underlying Markov chain characterized by D:
    # D^T π = 0 with constraint sum(π) = 1
    DT = D.T.copy()
    
    # Replace first row with constraint equation
    DT[0, :] = np.ones(n)
    
    # Right-hand side vector: [1, 0, 0, ...]
    b = np.zeros(n)
    b[0] = 1.0
    
    # Check matrix condition number
    cond_num = np.linalg.cond(DT)
    if cond_num > 1e12:
        print(f"Warning: Ill-conditioned matrix (cond={cond_num:.2e})")
    
    # Solve for stationary distribution pi (pi is a 2D NumPy array of shape (1, n))
    pi = np.linalg.solve(DT, b).reshape(1, n)

except np.linalg.LinAlgError as e:
    print(f"Linear system solver failed: {str(e)}.")
    print(f'Make sure that the Q-matrix D is irreducible.')
    raise

In [13]:
# Vector of ones (e is a 2D NumPy array of shape (n, 1))
e = np.ones((n, 1))

In [14]:
# ==============================================
# BINOMIAL MOMENTS COMPUTATION
# ==============================================
binomial_moments = [1.0]  # strat with B_0 = 1

# # Early return for trivial case
# if M == 0:
#     return binomial_moments

# Compute first moment B1
B1 = (pi @ D1 @ e) / d
binomial_moments.append(float(B1[0, 0]))


# Compute higher-order moments (B2 to BM)
if M > 1:
    B_vec = pi @ D1  # Initialize moment vector

    index_at_threshold = None
    
    for i in range(2, M + 1):
        # Construct matrix for current iteration
        A = (i - 1) * d * np.eye(n) - D
        
        # Compute the inverse of (i-1)*d*I - D using scipy.linalg.inv()    
        invA = linalg.inv(A)

        # Compute next vector: B_vec = B_vec * inv(A) * D1
        B_vec = B_vec @ invA @ D1
        moment_i = (B_vec @ e) / (i * d)
        moment_val = float(moment_i[0, 0])
        
        # Store computed moment
        binomial_moments.append(moment_val)
        
        # Early termination if moments become negligible
        if moment_val < threshold:
            if index_at_threshold is None:
                index_at_threshold = i
                print(f"Moments below threshold at i={i}")
                break

In [15]:
# ==============================================
# FINAL VALIDATION AND OUTPUT
# ==============================================
# Verify we computed the correct number of moments
if len(binomial_moments) != M + 1:
    print(f"Expected {M+1} moments, got {len(binomial_moments)}. ")
    print(f'\nEarly stop due to threshold={threshold}')

# # Print performance statistics
# total_time = time.time() - start_time
# print(f"Computation completed in {total_time:.4f} seconds")
# print(f"\nComputed {len(binomial_moments)} binomial moments")

In [16]:
# print(binomial_moments)

In [17]:
def stable_comb(n, k, b):
    """
    Stable computation of math.comb(n, k) * b

    This function avoids floating-point overflow/underflow by using logarithmic
    transformations and gamma functions. Suitable for cases where b is extremely small.

    Args:
        n: Total items (integer >= 0)
        k: Items to choose (integer, 0 <= k <= n)
        b: Scaling factor (float >= 0)

    Returns:
        math.comb(n, k) * b as float, or 0.0 for edge cases

    Raises:
        ValueError: For invalid input parameters
    """
    # Validate input parameters
    if b < 0:
        raise ValueError(f"Scaling factor b must be non-negative. Received: {b}")
    if b == 0:
        return 0.0
    if not isinstance(n, int) or not isinstance(k, int):
        raise TypeError(f"Both n and N must be integers. Received types: n={type(n)}, N={type(k)}")
    if n < 0:
        raise ValueError(f"n must be non-negative. Received: {n}")
    if k < 0 or k > n:
        # Combination is zero when N is out of range
        raise ValueError(f"k must satisfy 0 <= k <= n. Received: {k}")

    try:
        # Compute combination using logarithmic gamma functions
        # log(comb(n, N)) = lgamma(n+1) - lgamma(N+1) - lgamma(n-N+1)
        log_comb = math.lgamma(n + 1) - math.lgamma(k + 1) - math.lgamma(n - k + 1)

        # Add logarithm of scaling factor
        log_total = log_comb + math.log(b)

        # Handle potential underflow
        if log_total < -700:  # exp(-700) ≈ 10^-304
            return 0.0

        return math.exp(log_total)

    except ValueError as e:
        # Handle domain errors in math functions
        print(f"Math domain error: {e}")
        raise
    except OverflowError as e:
        # Handle overflow in intermediate calculations
        print(f"Overflow in logarithmic calculation: {e}")
        return 0.0

In [18]:
N=31
threshold=1e-5
# Validate inputs

M = len(binomial_moments) - 1  # Last available moment index (Note that len(binomial_moments) will be M+1)

if N>int(M/2):
    print(f'Input N is greater than M/2, which may cause instability. We replace N with floor(M/2)={math.floor(M/2)}')
N=min(N, math.floor(M/2))

In [19]:
prob = [0.0] * N  # Initialize probabilities for n=0 to (N-1)

In [20]:
# ==============================================
# Start timing the computation
# start_time = time.time()
# ==============================================

index_at_threshold=None

# Compute probabilities for n=0 to N-1
for n in range(N):
    total = 0.0
    sign = 1  # Starts positive for j=n
    cumulative_error = 0.0
    
    for j in range(n, M + 1):
        # Compute term: (-1)^(j-n) * C(j, n) * B_j
        term = sign * stable_comb(j,n,binomial_moments[j])
        total += term
        
        # # Track floating-point error accumulation
        # cumulative_error += abs(term) * 1e-15
            
        sign = -sign  # Flip sign for next term

    
    # Apply floating-point error correction
    prob[n]= total
    # prob[n] = max(0.0, total) if abs(total) < cumulative_error else total

    # Early termination if moments become negligible
    if total < threshold:
        if index_at_threshold is None:
            index_at_threshold = n
            print(f"Probability Density Function Value below threshold at n={n}")
            break

Probability Density Function Value below threshold at n=23


In [21]:
# ==============================================
# FINAL VALIDATION AND OUTPUT
# ==============================================
# Verify we computed the correct number of moments
if len(prob) != N:
    print(f"Expected {N} probability density function value, got {len(prob)}.")
    print(f"\n Two Posible Reasons:")
    print(f'\n Early stop happened due to threshold={threshold}')
    print(f"\n Input N is greater than M/2. Increase M is recommended")
    
# Print summary
# end_time = time.time()
# print(f"Computed {N} probabilities in {end_time-start_time:.6f} seconds")

s=sum(prob)
print(f'Sum of calculated probability is: {s}')
if abs(s-1)>1e-1 and abs(s-1)<1:
    print(f"Probability density function value after {N} is not negligible.")
    print(f"\n Increase N is recommended.")
elif abs(s-1)>1:
    print(f"There is severe floating point error, please increase M.")  

Sum of calculated probability is: 0.9999982110301964


In [22]:
print(prob)

[0.0004380982321958292, 0.0034011022789675015, 0.013188173669002639, 0.034057158947054175, 0.06589436022281203, 0.10189167260460202, 0.13116414925172945, 0.14458366998006553, 0.139319979747191, 0.11921883521899714, 0.09173111416441423, 0.0641064382139709, 0.04103098903487286, 0.02422044765188877, 0.013264703259713153, 0.006774658529114452, 0.0032411114997792237, 0.001458225134142408, 0.0006191444505676432, 0.0002488547416168939, 9.495069667535604e-05, 3.447804906221649e-05, 1.1941864649572523e-05, 3.95358711138508e-06, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


In [23]:
# np.save('fig3_MAIN_a.npy', prob)
# np.save('fig3_MAIN_b.npy', prob)
# np.save('fig3_MAIN_c.npy', prob)
# np.save('fig3_MAIN_d.npy', prob)

## Visualization of Steady-state Probability Distribution of mRNA Copy Number

In [22]:
import matplotlib.pyplot as plt

In [23]:
plt.figure(figsize=(8, 8),dpi=100)
plt.rcParams['axes.linewidth'] = 3
# plt.plot(list(range(len(prob))), (np.array(prob)).flatten(),
# color='red', linestyle='--', marker='o', markerfacecolor='blue',
# markeredgecolor='green',linewidth=2,label='Theoretical Result')
plt.stem(list(range(len(prob))), (np.array(prob)).flatten(),linefmt='r--',
         markerfmt='bo',basefmt="",label='Theoretical Result')
plt.setp(plt.gca().get_lines(), markersize=15)
# plt.xlim(0, 40)
# plt.ylim(0, 0.1)
plt.xticks(
    # ticks=list(range(0,45,5)),
           fontsize=40,
           color='blue',
           # fontname='Times New Roman',
           fontname='DejaVu Sans',
           ha='right')
plt.yticks(
    # ticks=np.arange(0.02, 0.12, 0.02),
           fontsize=40,
           color='blue',
           rotation=90,
           # fontname='Times New Roman',
           fontname='DejaVu Sans',
           ha='right')
# plt.title(f'No.{r}Distribution with threshold={threshold},
# end at index {2*index_at_threshold}, sum of prob={s}, d={d}, timing={timing:.2f}s
# \n D0={D0},D1={D1}',fontname='Times New Roman',fontsize=20,fontweight='bold')
plt.title(f'Theoretical Result Verified by Stochastic Simulation',
          # fontname='Times New Roman',
           fontname='DejaVu Sans',
          fontsize=30,fontweight='bold',color='green')
plt.xlabel('Number of mRNA Molecules',fontsize=35,fontweight='bold',
           labelpad=20,
           # fontname='Times New Roman',
           fontname='DejaVu Sans',
           color='green')
plt.ylabel('Probability',fontsize=35,fontweight='bold',
           labelpad=40,
           # fontname='Times New Roman',
           fontname='DejaVu Sans',
           color='green',rotation=75)
plt.tick_params(direction='out',width=3,length=10)
# plt.axhline(0, color='blue', linewidth=2)
# plt.axhline(0, color='blue', linewidth=2)
# from SSA_5_plot import RNAnumber
# plt.hist(RNAnumber,bins=np.arange(-0.5, 41.5, 1), density=True, color='#228B22', edgecolor='green', alpha=0.3,
#              linewidth=2,align='mid',label='Stochastic Simulation')
# plt.legend(loc='upper right', frameon=True, shadow=True, markerscale=2,
#            fancybox=True,prop={'family': 'Times New Roman','size': 40})
# print("Saving figure...")
# plt.savefig('Example2.pdf', format='pdf',edgecolor='black', dpi=800)
# plt.show()
plt.close()