In [1]:
import sympy as sp

# FLOP optimal chunk size

In this notebook we calculate the FLOP optimal chunk size for mLSTMsig. 


In [2]:
from flops_mlstm import simpl_comp_flop_cwp_sig_total

from flops_mlstm import d_qk, d_hv, L, F_causal

In [3]:
simpl_comp_flop_cwp_sig_total

1.0*(L**2*(2.0*F_causal*T*d_hv + 2.0*F_causal*T*d_qk + 6.0*F_causal*T + 1.0*T) + L*(4.0*T*d_hv*d_qk + 1.0*T*d_hv + 2.0*T*d_qk + 11.0*T) + 2.0*T*d_hv*d_qk + 5.0*T)/L

In [4]:
p_qk = sp.symbols("p_qk")

In [5]:
# We begin with the total number of flops for mLSTMsig
# 1) we substitute the qk head dimension d_qk with p_qk * d_hv
flops_total_sig_cwp_subs = simpl_comp_flop_cwp_sig_total.subs(d_qk, p_qk * d_hv)
flops_total_sig_cwp_subs

1.0*(L**2*(2.0*F_causal*T*d_hv*p_qk + 2.0*F_causal*T*d_hv + 6.0*F_causal*T + 1.0*T) + L*(4.0*T*d_hv**2*p_qk + 2.0*T*d_hv*p_qk + 1.0*T*d_hv + 11.0*T) + 2.0*T*d_hv**2*p_qk + 5.0*T)/L

In [6]:
# 2) we differentiate the total number of flops with respect to L, to find the minima
diff_flops_total_sig_cwp_subs = sp.simplify(sp.diff(flops_total_sig_cwp_subs, L))
diff_flops_total_sig_cwp_subs

2.0*F_causal*T*d_hv*p_qk + 2.0*F_causal*T*d_hv + 6.0*F_causal*T + 1.0*T - 2.0*T*d_hv**2*p_qk/L**2 - 5.0*T/L**2

In [7]:
# 3) we set the derivative to zero and solve for L and take the positive solution
L_optimal = sp.solve(sp.Eq(diff_flops_total_sig_cwp_subs, 0), L)[1]
L_optimal

sqrt((2.0*d_hv**2*p_qk + 5.0)/(2.0*F_causal*d_hv*p_qk + 2.0*F_causal*d_hv + 6.0*F_causal + 1.0))

In [8]:
# double check with hand calculation: ok!
L_optimal.subs({p_qk: 0.5, d_hv: 512, F_causal: 1.0})

13.0344028558834