In [27]:
import sys
from pathlib import Path


# Add the 'src' directory to the Python path
src_path = Path('../src').resolve()
sys.path.append(str(src_path))

from aspire.volume import Volume
import mrcfile
import numpy as np 
import numpy.linalg as LA 
from viewing_direction import *
from utils import *
from aspire.basis.basis_utils import lgwt
from volume import *
from moments import * 
import matplotlib.pyplot as plt


# import jax
# import jax.numpy as jnp
# from jax import grad, jit 
# from jax.numpy.linalg import norm

In [28]:
%%time

# get ground truth a and b 

c = 10
centers = np.random.normal(0,1,size=(c,3))
centers /= LA.norm(centers, axis=1, keepdims=True)
w_vmf = np.random.uniform(0,1,c)
w_vmf = w_vmf/np.sum(w_vmf)


ngrid = 50 
_ths = np.pi*np.arange(ngrid)/ngrid
_phs = 2*np.pi*np.arange(ngrid)/ngrid

ths, phs = np.meshgrid(_ths,_phs,indexing='ij')
ths, phs = ths.flatten(), phs.flatten()

grid = Grid_3d(type='spherical', ths=ths, phs=phs)


kappa = 5
f_vmf = vMF_density(centers,w_vmf,kappa,grid)
f_vmf = f_vmf*np.sin(ths)
f_vmf = f_vmf.reshape((ngrid,ngrid))



def my_fun(th,ph):
    grid = Grid_3d(type='spherical', ths=np.array([th]),phs=np.array([ph]))
    return 4*np.pi*vMF_density(centers,w_vmf,kappa,grid)[0]

ell_max_half_view = 4
sph_coef, indices = sph_harm_transform(my_fun, ell_max_half_view)
rot_coef = sph_t_rot_coef(sph_coef, ell_max_half_view)
rot_coef[0] = 1
sph_r_t_c , sph_c_t_r =  get_sph_r_t_c_mat(ell_max_half_view)
b = np.real(sph_c_t_r @ rot_coef)
rot_coef = sph_r_t_c @ b
b = b[1:]




CPU times: user 25.1 ms, sys: 7.41 ms, total: 32.5 ms
Wall time: 18.5 ms


In [29]:
%%time

# get the spherical FB coefficient of the volume
with mrcfile.open('../data/emd_34948.map') as mrc:
    data = mrc.data

data = data/LA.norm(data.flatten())
Vol = Volume(data)
ds_res = 64 
Vol = Vol.downsample(ds_res)
vol = Vol.asnumpy()
vol = vol[0]


ell_max_vol = 5
# spherical bessel transform 
vol_coef, k_max, r0, indices_vol = sphFB_transform(vol, ell_max_vol)
sphFB_r_t_c, sphFB_c_t_r = get_sphFB_r_t_c_mat(ell_max_vol, k_max, indices_vol)
a = np.real(sphFB_c_t_r @ vol_coef)
vol_coef = sphFB_r_t_c @ a 

CPU times: user 12.9 s, sys: 2.46 s, total: 15.4 s
Wall time: 8.36 s


In [30]:
%%time

# form the moments 
r2_max = 250 
r3_max = 100 
tol2 = 1e-10
tol3 = 1e-6 
grid = get_2d_unif_grid(ds_res,1/ds_res)
grid = Grid_3d(xs=grid.xs, ys=grid.ys, zs=np.zeros(grid.ys.shape))

opts = {}
opts['r2_max'] = r2_max
opts['r3_max'] = r3_max
opts['tol2'] = tol2 
opts['tol3'] = tol3 
opts['grid'] = grid

subMoMs = coef_t_subspace_moments(vol_coef, ell_max_vol, k_max, r0, indices_vol, rot_coef, ell_max_half_view, opts)
m1_emp = subMoMs['m1']
m2_emp = subMoMs['m2']
m3_emp = subMoMs['m3']
U2 = subMoMs['U2']
U3 = subMoMs['U3']


print(m1_emp.shape)
print(m2_emp.shape)
print(m3_emp.shape)

getting the first moment
getting the second moment
getting the third moment
(121, 1)
(121, 121)
(51, 51, 51)
CPU times: user 21min 38s, sys: 26min 33s, total: 48min 12s
Wall time: 2min 31s


In [31]:
%%time 

quadrature_rules = {} 
quadrature_rules['m2'] = load_so3_quadrature(2*ell_max_vol, 2*ell_max_half_view)
quadrature_rules['m3'] = load_so3_quadrature(3*ell_max_vol, 2*ell_max_half_view)

subspaces = {}
subspaces['m2'] = U2 
subspaces['m3'] = U3 

Phi_precomps, Psi_precomps = precomputation(ell_max_vol, k_max, r0, indices_vol, ell_max_half_view, subspaces, quadrature_rules, grid)

CPU times: user 1h 45min 33s, sys: 2h 21min 10s, total: 4h 6min 44s
Wall time: 10min 47s


In [32]:
%%time 

xtrue =  np.concatenate([a,b])
na = len(a)
nb = len(b)
view_constr, rhs, _ = get_linear_ineqn_constraint(ell_max_half_view)
A_constr = np.zeros([len(rhs), len(xtrue)])
A_constr[:,na:] = view_constr 

a0 = np.random.normal(0,1,a.shape)
b0 = np.zeros(b.shape)


x0 = xtrue+1e-4*np.concatenate([a0,b0])


CPU times: user 75.8 ms, sys: 39.3 ms, total: 115 ms
Wall time: 5.02 ms


In [49]:
from scipy.io import savemat
savemat('moms_emp.mat',{'m1_emp':m1_emp,'m2_emp':m2_emp,'m3_emp':m3_emp})
savemat('precomps.mat',{'Phi_precomps_m2':Phi_precomps['m2'],'Phi_precomps_m3':Phi_precomps['m3'],'Psi_precomps_m2':Psi_precomps['m2'],'Psi_precomps_m3':Psi_precomps['m3']})
savemat('w_so3.mat',{'w_so3_m2':quadrature_rules['m2'][1],'w_so3_m3':quadrature_rules['m3'][1]})
savemat('params.mat',{'a':a,'b':b,'xtrue':xtrue,'x0':x0})
savemat('constraint.mat',{'A_constr':A_constr,'rhs':rhs})
# savemat('b.mat','b')
# savemat('xtrue.mat','xtrue')
# savemat('x0.mat','x0')
# savemat('A_constr.mat',A_constr)
# savemat('rhs.mat',rhs)

In [39]:
# %%time 
# from scipy.optimize import LinearConstraint

# def moment_LS(x0, quadrature_rules, Phi_precomps, Psi_precomps, m1_emp, m2_emp, m3_emp, A_constr, b_constr, l1=None, l2=None, l3=None):
    
#     if l1 is None:
#         l1 = LA.norm(m1_emp.flatten())**2
#     if l2 is None:
#         l2 = LA.norm(m2_emp.flatten())**2
#     if l3 is None:
#         l3 = LA.norm(m3_emp.flatten())**2
    
#     # linear_constraint = {'type': 'ineq', 'fun': lambda x: b_constr - A_constr @ x}
#     linear_constraint = LinearConstraint(A_constr, -np.inf, b_constr)
#     objective = lambda x: find_cost_grad(x, quadrature_rules, Phi_precomps, Psi_precomps, m1_emp, m2_emp, m3_emp, l1, l2, l3)
#     # options = {'disp':True, 'iprint':3, 'maxiter':200}
#     options = {'disp':True, 'verbose':3, 'maxiter':200,'initial_tr_radius': 1.0}
#     result = minimize(objective, x0, method='trust-constr', jac=True, constraints=[linear_constraint], options=options)

#     return result 


# res = moment_LS(x0, quadrature_rules, Phi_precomps, Psi_precomps, m1_emp, m2_emp, m3_emp, A_constr, rhs, l1=None, l2=None, l3=0)

| niter |f evals|CG iter|  obj func   |tr radius |   opt    |  c viol  | penalty  |barrier param|CG stop|
|-------|-------|-------|-------------|----------|----------|----------|----------|-------------|-------|
|   1   |   1   |   0   | +1.0251e-02 | 1.00e+00 | 2.59e+01 | 0.00e+00 | 1.00e+00 |  1.00e-01   |   0   |
|   2   |   2   |   1   | +1.0251e-02 | 1.00e-01 | 2.59e+01 | 0.00e+00 | 1.00e+00 |  1.00e-01   |   2   |
|   3   |   3   |   4   | +7.5155e+01 | 7.00e-01 | 2.14e+03 | 0.00e+00 | 5.90e+03 |  1.00e-01   |   2   |
|   4   |   4   |   7   | +9.4188e+02 | 4.90e+00 | 7.41e+03 | 0.00e+00 | 1.77e+04 |  1.00e-01   |   2   |
|   5   |   5   |  31   | +2.3478e+00 | 2.88e+01 | 4.45e+02 | 0.00e+00 | 1.77e+04 |  1.00e-01   |   4   |
|   6   |   6   |  54   | +1.2345e+00 | 2.88e+01 | 3.00e+02 | 0.00e+00 | 1.77e+04 |  1.00e-01   |   4   |
|   7   |   7   |  77   | +1.5743e-01 | 2.88e+01 | 6.97e+01 | 0.00e+00 | 1.77e+04 |  1.00e-01   |   4   |
|   8   |   8   |  95   | +1.1686e-01 | 2.88e+

KeyboardInterrupt: 