In [None]:
import halotools
import numpy as np
from astropy.utils.misc import NumpyRNGContext
from math import gamma


In [None]:

from halotools.mock_observables.two_point_clustering.clustering_helpers import process_optional_input_sample2

from halotools.mock_observables.mock_observables_helpers import (enforce_sample_has_correct_shape,
    get_separation_bins_array, get_period, get_num_threads)
from halotools.mock_observables.pair_counters.mesh_helpers import _enforce_maximum_search_length

from halotools.mock_observables.pair_counters import npairs_3d, marked_npairs_3d

from halotools.custom_exceptions import HalotoolsError

from halotools.mock_observables.two_point_clustering.tpcf_estimators import _TP_estimator
from halotools.mock_observables.two_point_clustering.marked_tpcf import _marked_tpcf_process_args,marked_pair_counts


In [None]:
from halotools.mock_observables.two_point_clustering import tpcf

In [None]:
def nball_volume(R, k=3):
        """
        Calculate the volume of a n-shpere.
        This is used for the analytical randoms.
        """
        return (np.pi**(k/2.0)/gamma(k/2.0+1.0))*R**k #gamma should be math gamma?
    
def _random_counts(N1,N2,NR, rbins, period, num_threads,
        _sample1_is_sample2,do_RR=True, do_DR=True, approx_cell1_size=None,
        approx_cell2_size=None, approx_cellran_size=None,PBCs=None):
    """
        modified version of https://halotools.readthedocs.io/en/latest/_modules/halotools/mock_observables/two_point_clustering/tpcf.html#tpcf
        
        Doing only the analytical version.
    """
#     if randoms is None:
#         NR = len(sample1)

    # do volume calculations
    v = nball_volume(rbins)  # volume of spheres
    dv = np.diff(v)  # volume of shells
    global_volume = period.prod()  # volume of simulation

    # calculate randoms for sample1
#         N1 = np.shape(sample1)[0]  # number of points in sample1
    rho1 = N1/global_volume  # number density of points
    D1R = (NR)*(dv*rho1)  # random counts are N**2*dv*rho

    # calculate randoms for sample2
#         N2 = np.shape(sample2)[0]  # number of points in sample2
    rho2 = N2/global_volume  # number density of points
    D2R = (NR)*(dv*rho2)  # random counts are N**2*dv*rho

    # calculate the random-random pairs.
    rhor = (NR**2)/global_volume
    RR = (dv*rhor)

    return D1R, D2R, RR

In [None]:
def weighted_tpcf(sample1, rbins, sample2=None,
        weights1=None, weights2=None, period=None, do_auto=True, do_cross=True,
        num_threads=1, weight_func_id=1,
        normalize_weights=True, seed=None,estimator='Natural'):
    """
    """
    if normalize_weights:
        if weights1 is not None:
            weights1=weights1/weights1.mean()
        else:
            weights1=np.ones(len(sample1))
        if weights2 is not None:
            weights2=weights2/weights2.mean()
        elif sample2 is not None:
            weights2=np.ones(len(sample2))
    
    randomize_marks=False
    normalize_by='random_marks'
    iterations=1
    function_args = (sample1, rbins, sample2, weights1,weights2,
        period, do_auto, do_cross, num_threads,
        weight_func_id, normalize_by, iterations, randomize_marks, seed)
    
    sample1, rbins, sample2, weights1,weights2, period, do_auto, do_cross, num_threads,\
        weight_func_id, normalize_by, _sample1_is_sample2, PBCs,\
        randomize_marks = _marked_tpcf_process_args(*function_args)      
    
    N1 = weights1.sum() #len(sample1)
    N2=N1*1.
    if sample2 is not None:
        N2 = weights2.sum() #len(sample2)
    NR=N1*1.

    # calculate weighted pairs
    W1W1, W1W2, W2W2 = marked_pair_counts(sample1, sample2, rbins, period,
        num_threads, do_auto, do_cross, weights1, weights2, weight_func_id, _sample1_is_sample2)
    randoms=None #use analytical calculations
    
    D1R, D2R, RR= _random_counts(N1=N1,N2=N2,NR=NR, rbins=rbins, period=period, num_threads= num_threads,
        _sample1_is_sample2=_sample1_is_sample2,do_RR=True, do_DR=True,)
    
    print('weighted number of galaxies: ',N1,N2,NR)
    
    print('pair_counts: ',W1W1.sum()/1.e6,RR.sum()/1.e6,D1R.sum()/1.e6)
    # run results through the estimator and return relavent/user specified results.
    if _sample1_is_sample2:
        xi_11 = _TP_estimator(W1W1, D1R, RR, N1, N1, NR, NR, estimator)
        return xi_11
    else:
        if (do_auto is True) & (do_cross is True):
            xi_11 = _TP_estimator(W1W1, D1R, RR, N1, N1, NR, NR, estimator)
            xi_12 = _TP_estimator(W1W2, D1R, RR, N1, N2, NR, NR, estimator)
            xi_22 = _TP_estimator(W2W2, D2R, RR, N2, N2, NR, NR, estimator)
            return xi_11, xi_12, xi_22
        elif (do_cross is True):
            xi_12 = _TP_estimator(W1W2, D1R, RR, N1, N2, NR, NR, estimator)
            return xi_12
        elif (do_auto is True):
            xi_11 = _TP_estimator(W1W1, D1R, RR, N1, N1, NR, NR, estimator)
            xi_22 = _TP_estimator(W2W2, D2R, RR, N2, N2, NR, NR, estimator)
            return xi_11, xi_22

In [None]:
N_gal1=10000
N_gal2=20000

period=40

sample1=np.random.uniform(0,period,N_gal1*3).reshape(N_gal1,3)
sample2=np.random.uniform(0,period,N_gal2*3).reshape(N_gal2,3)

weights1=np.random.uniform(1,20,N_gal1)
weights2=np.random.uniform(1,20,N_gal2)

rbins=np.logspace(-1,1,20)

# The default un-weighted calculation for testing
xi0_11, xi0_12, xi0_22=tpcf(sample1, rbins, sample2=sample2,
        period=period, do_auto=True, do_cross=True,
        num_threads=1, seed=None,estimator='Natural') 

In [None]:
#weighted calculation
xi_11, xi_12, xi_22=weighted_tpcf(sample1=sample1, rbins=rbins, sample2=sample2,
        weights1=weights1, weights2=weights2, period=period, do_auto=True, do_cross=True,
        num_threads=1, weight_func_id=1,
         normalize_weights=True, seed=None,estimator='Natural')

In [None]:
rb=0.5*(rbins[1:]+rbins[:-1])
plot(rb,xi_11)
plot(rb,xi_12)
plot(rb,xi_22)
hlines(0,1,10,color='k')
# loglog()

In [None]:
rb=0.5*(rbins[1:]+rbins[:-1])
plot(rb,xi0_11)
plot(rb,xi0_12)
plot(rb,xi0_22)
hlines(0,1,10,color='k')
# loglog()