In [1]:
from wigner_functions import *
import h5py
from scipy.sparse import *
import zarr

In [2]:
import time

In [3]:
lmax=5000 #1e4
wlmax=5e2
lmax=np.int(lmax)
wlmax=np.int(wlmax)

In [4]:
ncpu=8
l_step=10 #not used with dask
w_l=np.arange(wlmax)
l=np.arange(lmax)

In [5]:
import dask
import dask.array as da
from dask import delayed

from distributed import LocalCluster
from dask.distributed import Client  # we already had this above
#http://distributed.readthedocs.io/en/latest/_modules/distributed/worker.html
LC=LocalCluster(n_workers=1,processes=False,memory_limit='50gb',threads_per_worker=ncpu,memory_spill_fraction=.99,
               memory_monitor_interval='2000ms')
client=Client(LC)

In [6]:
def wigner_3j_asym(j_1,j_2,j_3,m_1,m_2,m_3): #assume j1,j2>>j3.... wiki
    sj=(j_1+j_2+1)
    th=np.arccos((m_1-m_2)/sj)
    wd=wigner_d(m_3,j_2-j_1,np.atleast_1d(th),j_3)[0,0]
    print(wd)
    return ((-1)**(j_2+m_2))*wd/np.sqrt(sj)

In [7]:
def wigner_3j_asym_H(j_1,j_2,j_3,m_1,m_2,m_3): #j1+j2+j3>>1, m1=m2=m3=0.. Hivon
    wj=2*(j1*j2)**2
    wj+=2*(j2*j3)**2
    wj+=2*(j1*j3)**2
    wj-=j1**4+j2**4+j3**4
    wj=wj**(-0.5)
    wj*=2./np.pi
    return np.sqrt(wj)*(-1)**((j_1+j_2+j_3)/2)

def wigner_3j_asym_H2(j_1,j_2,j_3,m_1,m_2,m_3): #j1+j2+j3>>1, m1=m2=m3=0.. Hivon
    J=j_1+j_2+j_3
    logwj=log_factorial(J/2)
    logwj-=log_factorial(J/2-j_1)
    logwj-=log_factorial(J/2-j_2)
    logwj-=log_factorial(J/2-j_3)
    logwj-=0.5*log_factorial(J+1)
    logwj+=0.5*log_factorial(J-2*j_1)
    logwj+=0.5*log_factorial(J-2*j_2)
    logwj+=0.5*log_factorial(J-2*j_3)
    return (-1)**(J/2)*np.exp(logwj)

In [8]:
def wig3j_map(m1,m2,m3,j1,j2,j3,asym_fact=np.inf):
#     n1=len(l)
#     n2=len(l)
#     n3=len(j3)

#     c=np.array(list(Comb(l,l,j3)))
# #     print(c.shape)
#     j_max=np.amax(l.max()+l.max()+l.max()+1)
#     _calc_factlist(j_max)
#     d_mat=client.gather(client.map(partial(wigner_3j_3, m_1, m_2, m_3),c))
#     d_mat=np.array(d_mat).reshape(n1,n2,n3)
    
    return Wigner3j_parallel( m1, m2, m3,np.atleast_1d(j1), np.atleast_1d(j2) ,np.atleast_1d(j3),ncpu=1,
                             asym_fact=asym_fact)[:,:,0]


# dst=client.map(wig3j_map,w_l)

In [9]:
def Wigner3j_parallel2( m_1, m_2, m_3,j_1, j_2, j_3,ncpu=None,asym_fact=np.inf):
    if ncpu is None:
        ncpu=cpu_count()-2
    p=Pool(ncpu)

    j_max=np.amax(j_1.max()+j_2.max()+j_3.max()+1)
#     _calc_factlist(j_max)

    n1=len(j_1)
    n2=len(j_2)
    n3=len(j_3)

    c=np.array(np.meshgrid(j_1,j_2,j_3,indexing='ij')).T.reshape(-1,3) #only needed to put cuts below. Otherwise Comb is better

    x=c[:,0]+c[:,1]-c[:,2]>=0
    x*=c[:,0]-c[:,1]+c[:,2]>=0
    x*=-c[:,0]+c[:,1]+c[:,2]>=0
    
    x*=abs(m_1) <= c[:,0]
    x*=abs(m_2) <= c[:,1]
    x*=abs(m_3) <= c[:,2]
    
    if m_1==0 and m_2==0 and m_3==0:
        x*=(c[:,0]+c[:,1]+c[:,2])%2==0

    c=c[x]
    
    
    d_mat=p.map(partial(wigner_3j_3,asym_fact, m_1, m_2, m_3),c,chunksize=100)
    p.close()
    
    dd=np.zeros((n1,n2,n3))
    indx1=np.searchsorted(j_1,c[:,0])
    indx2=np.searchsorted(j_2,c[:,1])
    indx3=np.searchsorted(j_3,c[:,2])
    dd[indx1,indx2,indx3]=d_mat
    
    return dd

In [10]:
asym_fact=100

In [11]:
# if lmax>500:
#     fname='temp/dask_wig3j_big_{i}.zarr'
# else:
#     fname='temp/dask_wig3j_test_{i}.zarr'
fname='temp/dask_wig3j_l{lmax}_w{wlmax}_{i}.zarr'
if asym_fact is not np.inf:
    fname='temp/dask_wig3j_l{lmax}_w{wlmax}_{i}_asym'+str(asym_fact)+'.zarr'

In [12]:
fname,lmax,wlmax

('temp/dask_wig3j_l{lmax}_w{wlmax}_{i}_asym100.zarr', 5000, 500)

In [13]:
step=np.int(min(5e3,lmax/10))
lb=np.sort(np.unique(np.append([0,100,lmax],np.arange(100,lmax,step))))
lb,lb.shape

(array([   0,  100,  600, 1100, 1600, 2100, 2600, 3100, 3600, 4100, 4600,
        5000]), (12,))

In [None]:
m1=-2
m2=2
m3=0

arrs=[da.hstack([da.vstack([da.from_delayed(delayed(wig3j_map)(m1,m2,m3,l[lb[i]:lb[i+1]],l[lb[k]:lb[k+1]],
                                                               np.atleast_1d(j3),asym_fact),
                    shape=(lb[i+1]-lb[i],lb[k+1]-lb[k]),dtype='float32') 
                    for i in np.arange(len(lb)-1)]) 
                     for k in np.arange(len(lb)-1)])
                      for j3 in w_l]
arrs2=da.stack(arrs)
arrs2=arrs2.rechunk(chunks=(1,1000,1000))
%time arrs2.to_zarr(fname.format(i=2,lmax=lmax,wlmax=wlmax),overwrite=True)

tornado.application - ERROR - Exception in callback functools.partial(<function wrap.<locals>.null_wrapper at 0x7fe667444840>, <Future finished exception=OSError("Timed out trying to connect to 'inproc://136.152.250.183/12571/1' after 10 s: connect() didn't finish in time")>)
Traceback (most recent call last):
  File "/usr/lib/python3.7/site-packages/distributed/comm/core.py", line 186, in connect
    quiet_exceptions=EnvironmentError)
  File "/usr/lib/python3.7/site-packages/tornado/gen.py", line 1133, in run
    value = future.result()
tornado.util.TimeoutError: Timeout

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.7/site-packages/tornado/ioloop.py", line 758, in _run_callback
    ret = callback()
  File "/usr/lib/python3.7/site-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/lib/python3.7/site-packages/tornado/ioloop.py", line 779, in _discard

tornado.application - ERROR - Exception in callback functools.partial(<function wrap.<locals>.null_wrapper at 0x7fe6672132f0>, <Future finished exception=OSError("Timed out trying to connect to 'inproc://136.152.250.183/12571/1' after 10 s: connect() didn't finish in time")>)
Traceback (most recent call last):
  File "/usr/lib/python3.7/site-packages/distributed/comm/core.py", line 186, in connect
    quiet_exceptions=EnvironmentError)
  File "/usr/lib/python3.7/site-packages/tornado/gen.py", line 1133, in run
    value = future.result()
tornado.util.TimeoutError: Timeout

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.7/site-packages/tornado/ioloop.py", line 758, in _run_callback
    ret = callback()
  File "/usr/lib/python3.7/site-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/lib/python3.7/site-packages/tornado/ioloop.py", line 779, in _discard

  wj=(-1)**(J/2)*np.exp(logwj)


[436] 1099 599 done wig time,size:  9.9269118309021 1998 [635 199 436]




In [None]:
arrs2

In [None]:
z1=zarr.open('./temp/dask_wig3j_l100_w30_0_asym50.zarr/')
z2=np.array(zarr.open('./temp/dask_wig3j_l100_w30_0.zarr/'))

In [None]:
z1.reshape(1,3000,100)

In [None]:
crash

In [None]:
zaa_in=zarr.open('temp/dask_test0.zarr')

In [None]:
np.einsum('i,ijk->jk',w_l,zaa_in)

In [None]:
zaa_in.oindex[np.int32(w_l[:5]), [1, 3],[1,5]]

In [None]:
zaa_in

In [None]:
aa2 = da.from_array(zaa_in, chunks=zaa_in.chunks)

In [None]:
aa2

In [None]:
m_1=0
m_2=0
m_3=0
fname='temp/wigner_test_big.h5'

with h5py.File(fname,'w') as f:
    dst = f.create_dataset("0", shape=(lmax,lmax,wlmax),
                           dtype=np.float32)
    lm=0
    while lm<lmax:
        l_t=np.arange(lm,lm+l_step)
        lm2=0
        while lm2<lmax:
            if lm2>lm+l_step+wlmax or lm>lm2+l_step+wlmax:
                print(lm2,lm)
                lm2+=l_step
                continue
            l_t2=np.arange(lm2,lm2+l_step)
            dst[lm:lm+l_step,lm2:lm2+l_step,:]=Wigner3j_parallel( m_1, m_2, m_3,l_t, l_t2 , w_l,ncpu=ncpu)
            
            lm2+=l_step
            
        lm+=l_step

In [None]:
f.close()

In [None]:
m_1=2
m_2=-2
m_3=0

with h5py.File(fname,'a') as f:
    dst = f.create_dataset("2", shape=(lmax,lmax,wlmax),
                           dtype=np.float32)
    lm=0
    while lm<lmax:
        l_t=np.arange(lm,lm+l_step)
        lm2=0
        while lm2<lmax:
            if lm2>lm+l_step+wlmax or lm>lm2+l_step+wlmax:
                print(lm2,lm)
                lm2+=l_step
                continue
            l_t2=np.arange(lm2,lm2+l_step)
            dst[lm:lm+l_step,lm2:lm2+l_step,:]=Wigner3j_parallel( m_1, m_2, m_3,l_t, l_t2 , w_l,ncpu=ncpu)
            lm2+=l_step
        lm+=l_step

In [None]:
f.close()

In [None]:
fname='temp/wigner_test.h5'

In [None]:
f = h5py.File(fname, 'r')


In [None]:
d=f["0"]

In [None]:
d.shape

In [None]:
w=np.ones_like(w_l)

In [None]:
M=np.zeros((500,500))
M2=np.zeros((500,500))
M[:,:]=np.dot(d,w)

In [None]:
x=d[w_l[3:5],:,:]#

In [None]:
x[:,np.int32(w_l[:5]),:]

In [None]:
lm=0
while lm<lmax:
    M2[lm:lm+l_step,:]=np.dot(d[lm:lm+l_step,:,:],w)
    lm+=l_step

In [None]:
np.all(M2==M)

In [None]:
f.close()

In [None]:
coo_matrix((100,100))