In [1]:
import sys
from pathlib import Path


# Add the 'src' directory to the Python path
src_path = Path('../src').resolve()
sys.path.append(str(src_path))

from aspire.volume import Volume
import mrcfile
import numpy as np 
import numpy.linalg as LA 
from viewing_direction import *
from utils import *
from aspire.basis.basis_utils import lgwt
from volume import *
from moments import * 
import matplotlib.pyplot as plt


import jax
import jax.numpy as jnp
from jax import grad, jit 
from jax.numpy.linalg import norm

In [2]:
%%time

# get ground truth a and b 

c = 10
centers = np.random.normal(0,1,size=(c,3))
centers /= LA.norm(centers, axis=1, keepdims=True)
w_vmf = np.random.uniform(0,1,c)
w_vmf = w_vmf/np.sum(w_vmf)


ngrid = 50 
_ths = np.pi*np.arange(ngrid)/ngrid
_phs = 2*np.pi*np.arange(ngrid)/ngrid

ths, phs = np.meshgrid(_ths,_phs,indexing='ij')
ths, phs = ths.flatten(), phs.flatten()

grid = Grid_3d(type='spherical', ths=ths, phs=phs)


kappa = 5
f_vmf = vMF_density(centers,w_vmf,kappa,grid)
f_vmf = f_vmf*np.sin(ths)
f_vmf = f_vmf.reshape((ngrid,ngrid))



def my_fun(th,ph):
    grid = Grid_3d(type='spherical', ths=np.array([th]),phs=np.array([ph]))
    return 4*np.pi*vMF_density(centers,w_vmf,kappa,grid)[0]

ell_max_half_view = 4
sph_coef, indices = sph_harm_transform(my_fun, ell_max_half_view)
rot_coef = sph_t_rot_coef(sph_coef, ell_max_half_view)
rot_coef[0] = 1
sph_r_t_c , sph_c_t_r =  get_sph_r_t_c_mat(ell_max_half_view)
b = np.real(sph_c_t_r @ rot_coef)
rot_coef = sph_r_t_c @ b
b = b[1:]




CPU times: user 1.71 s, sys: 558 ms, total: 2.26 s
Wall time: 114 ms


In [3]:
%%time

# get the spherical FB coefficient of the volume
with mrcfile.open('../data/emd_34948.map') as mrc:
    data = mrc.data


data = data/LA.norm(data.flatten())
Vol = Volume(data)
ds_res = 64 
Vol = Vol.downsample(ds_res)
vol = Vol.asnumpy()
vol = vol[0]
savemat('vol.mat',{'vol':vol})


ell_max_vol = 5
# spherical bessel transform 
vol_coef, k_max, r0, indices_vol = sphFB_transform(vol, ell_max_vol)
sphFB_r_t_c, sphFB_c_t_r = get_sphFB_r_t_c_mat(ell_max_vol, k_max, indices_vol)
a = np.real(sphFB_c_t_r @ vol_coef)
vol_coef = sphFB_r_t_c @ a 

CPU times: user 15.7 s, sys: 2.2 s, total: 17.9 s
Wall time: 7.43 s


In [4]:
%%time

# form the moments 
r2_max = 250 
r3_max = 100 
tol2 = 1e-10
tol3 = 1e-6 
grid = get_2d_unif_grid(ds_res,1/ds_res)
grid = Grid_3d(xs=grid.xs, ys=grid.ys, zs=np.zeros(grid.ys.shape))

opts = {}
opts['r2_max'] = r2_max
opts['r3_max'] = r3_max
opts['tol2'] = tol2 
opts['tol3'] = tol3 
opts['grid'] = grid

subMoMs = coef_t_subspace_moments(vol_coef, ell_max_vol, k_max, r0, indices_vol, rot_coef, ell_max_half_view, opts)
m1_emp = subMoMs['m1']
m2_emp = subMoMs['m2']
m3_emp = subMoMs['m3']
U2 = subMoMs['U2']
U3 = subMoMs['U3']


print(m1_emp.shape)
print(m2_emp.shape)
print(m3_emp.shape)

getting the first moment
getting the second moment
getting the third moment
(122, 1)
(122, 122)
(53, 53, 53)
CPU times: user 22min 47s, sys: 28min 30s, total: 51min 18s
Wall time: 2min 39s


In [5]:
%%time 

quadrature_rules = {} 
quadrature_rules['m2'] = load_so3_quadrature(2*ell_max_vol, 2*ell_max_half_view)
quadrature_rules['m3'] = load_so3_quadrature(3*ell_max_vol, 2*ell_max_half_view)

subspaces = {}
subspaces['m2'] = U2 
subspaces['m3'] = U3 

Phi_precomps, Psi_precomps = precomputation(ell_max_vol, k_max, r0, indices_vol, ell_max_half_view, subspaces, quadrature_rules, grid)

2025-02-11 05:08:09,012 INFO [jax._src.xla_bridge] Unable to initialize backend 'cuda': 
2025-02-11 05:08:09,014 INFO [jax._src.xla_bridge] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-02-11 05:08:09,027 INFO [jax._src.xla_bridge] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
CPU times: user 1h 55min 44s, sys: 2h 39min 20s, total: 4h 35min 4s
Wall time: 12min


In [6]:
%%time 

xtrue =  np.concatenate([a,b])
na = len(a)
nb = len(b)
view_constr, rhs, _ = get_linear_ineqn_constraint(ell_max_half_view)
A_constr = np.zeros([len(rhs), len(xtrue)])
A_constr[:,na:] = view_constr 

a0 = np.random.normal(0,1,a.shape)
b0 = np.zeros(b.shape)


x0 = xtrue+1e-4*np.concatenate([a0,b0])


CPU times: user 29.1 ms, sys: 39.4 ms, total: 68.4 ms
Wall time: 2.93 ms


In [7]:
from scipy.io import savemat
savemat('moms_emp.mat',{'m1_emp':m1_emp,'m2_emp':m2_emp,'m3_emp':m3_emp})
savemat('precomps.mat',{'Phi_precomps_m2':Phi_precomps['m2'],'Phi_precomps_m3':Phi_precomps['m3'],'Psi_precomps_m2':Psi_precomps['m2'],'Psi_precomps_m3':Psi_precomps['m3']})
savemat('w_so3.mat',{'w_so3_m2':quadrature_rules['m2'][1],'w_so3_m3':quadrature_rules['m3'][1]})
savemat('params.mat',{'a':a,'b':b,'xtrue':xtrue,'x0':x0})
savemat('constraint.mat',{'A_constr':A_constr,'rhs':rhs})
# savemat('b.mat','b')
# savemat('xtrue.mat','xtrue')
# savemat('x0.mat','x0')
# savemat('A_constr.mat',A_constr)
# savemat('rhs.mat',rhs)

In [None]:
l1 = LA.norm(m1_emp.flatten())**2
l2 = LA.norm(m2_emp.flatten())**2
l3 = LA.norm(m3_emp.flatten())**2
objective = lambda x: find_cost_grad(x, quadrature_rules, Phi_precomps, Psi_precomps, m1_emp, m2_emp, m3_emp, l1, l2, l3)
f,g = objective(xtrue)
print(f, LA.norm(g))

In [8]:
# %%time 
# from scipy.optimize import LinearConstraint

# def moment_LS(x0, quadrature_rules, Phi_precomps, Psi_precomps, m1_emp, m2_emp, m3_emp, A_constr, b_constr, l1=None, l2=None, l3=None):
    
#     if l1 is None:
#         l1 = LA.norm(m1_emp.flatten())**2
#     if l2 is None:
#         l2 = LA.norm(m2_emp.flatten())**2
#     if l3 is None:
#         l3 = LA.norm(m3_emp.flatten())**2
    
#     # linear_constraint = {'type': 'ineq', 'fun': lambda x: b_constr - A_constr @ x}
#     linear_constraint = LinearConstraint(A_constr, -np.inf, b_constr)
#     objective = lambda x: find_cost_grad(x, quadrature_rules, Phi_precomps, Psi_precomps, m1_emp, m2_emp, m3_emp, l1, l2, l3)
#     # options = {'disp':True, 'iprint':3, 'maxiter':200}
#     options = {'disp':True, 'verbose':3, 'maxiter':200,'initial_tr_radius': 1.0}
#     result = minimize(objective, x0, method='trust-constr', jac=True, constraints=[linear_constraint], options=options)

#     return result 


# res = moment_LS(x0, quadrature_rules, Phi_precomps, Psi_precomps, m1_emp, m2_emp, m3_emp, A_constr, rhs, l1=None, l2=None, l3=0)