Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
move the binary search in a separate function
  • Loading branch information
segasai committed Mar 3, 2023
1 parent 8f6514b commit 2bdf2b7
Showing 1 changed file with 37 additions and 33 deletions.
70 changes: 37 additions & 33 deletions py/minimint/mist_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,40 @@ def prepare(eep_prefix,
bolom.prepare(bolom_prefix, outp_prefix, filters)


def _binary_search(mass, logage, neep, FF):
# This will be our working subset
xind = np.arange(len(mass))
# This will be updated mask of bad points outside limits
bads = np.zeros(len(mass), dtype=bool)
# these will be left/right of the binary search
lefts = np.zeros(len(mass), dtype=int)
rights = np.zeros(len(mass), dtype=int) + neep - 1
curlefts = lefts
currights = rights

# binary search
while len(xind) > 0:
LV, RV = [FF(_, xind) for _ in [curlefts, currights]]
LA = logage[xind]
props = (curlefts + currights) // 2
MV = FF(props, xind)
curbad = (LA < LV) | (LA >= RV) # we'll exclude them
bads[xind[curbad]] = True
x1 = LA >= MV
x2 = LA < MV
curlefts[x1] = props[x1]
currights[x2] = props[x2]
currights[(~x1) & (~x2)] = props[(~x1) & (~x2)]
# we stop for either right-left==1 or for bads
exclude = (currights == curlefts + 1) | curbad
lefts[xind[exclude]] = curlefts[exclude]
rights[xind[exclude]] = currights[exclude]
xind = xind[~exclude]
curlefts = curlefts[~exclude]
currights = currights[~exclude]
return lefts, rights, bads


class TheoryInterpolator:

def __init__(self, prefix=None):
Expand Down Expand Up @@ -380,7 +414,7 @@ def getMaxMass(self, logage, feh):
Parameters:
-----------
logage: float
Log10 of age
Log10 of age
feh: float
Metallicity
Expand Down Expand Up @@ -457,44 +491,14 @@ def __get_eep_coeffs(self, mass, logage, feh):
C3 = x * (1 - y)
C4 = x * y

# This will be our working subset
xind = np.arange(len(mass))
# This will be updated mask of bad points outside limits
bads = np.zeros(len(mass), dtype=bool)
# these will be left/right of the binary search
lefts = np.zeros(len(mass), dtype=int)
rights = np.zeros(len(mass), dtype=int) + self.neep - 1
curlefts = lefts
currights = rights

def FF(curi):
def FF(curi, xind):
return (
C1[xind] * self.logage_grid[l1feh[xind], l1mass[xind], curi] +
C2[xind] * self.logage_grid[l1feh[xind], l2mass[xind], curi] +
C3[xind] * self.logage_grid[l2feh[xind], l1mass[xind], curi] +
C4[xind] * self.logage_grid[l2feh[xind], l2mass[xind], curi])

# binary search
while len(xind) > 0:
LV, RV = [FF(_) for _ in [curlefts, currights]]
LA = logage[xind]
props = (curlefts + currights) // 2
MV = FF(props)
curbad = (LA < LV) | (LA >= RV) # we'll exclude them
bads[xind[curbad]] = True
x1 = LA >= MV
x2 = LA < MV
curlefts[x1] = props[x1]
currights[x2] = props[x2]
currights[(~x1) & (~x2)] = props[(~x1) & (~x2)]
# we stop for either right-left==1 or for bads
exclude = (currights == curlefts + 1) | curbad
lefts[xind[exclude]] = curlefts[exclude]
rights[xind[exclude]] = currights[exclude]
xind = xind[~exclude]
curlefts = curlefts[~exclude]
currights = currights[~exclude]

lefts, rights, bads = _binary_search(mass, logage, self.neep, FF)
bads = bads | (rights >= self.neep)
lefts[bads] = 0
rights[bads] = 1
Expand Down

0 comments on commit 2bdf2b7

Please sign in to comment.