Skip to content

Commit

Permalink
continue refactoring x
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Mar 5, 2023
1 parent 99a2a37 commit 171ec24
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions py/minimint/mist_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ def prepare(eep_prefix,
bolom.prepare(bolom_prefix, outp_prefix, filters)


def _binary_search(bads, logage, neep, FF):
def _binary_search(bads, logage, neep, getAge):
"""
Peform a binary search on a grid to find pts
such as FF(pt)<logage<FF(pt+1)
such as getAge(pt)<logage<getAge(pt+1)
Returns:
lefts:
Expand Down Expand Up @@ -276,13 +276,13 @@ def _binary_search(bads, logage, neep, FF):
# A we either in the situation of boundaries having
# 2 finite values or
# B one finite on the left and the other one nan
leftY, rightY = [FF(_, curgood) for _ in [leftX, rightX]]
leftY, rightY = [getAge(_, curgood) for _ in [leftX, rightX]]

while True:
targY = logage[curgood]

propX = (leftX + rightX) // 2
propY = FF(propX, curgood)
propY = getAge(propX, curgood)

# It is written in this way to also include nans
x1 = propY <= targY # option 1
Expand All @@ -291,7 +291,7 @@ def _binary_search(bads, logage, neep, FF):
leftX[x1] = propX[x1]
rightX[x2] = propX[x2]
rightX[x3] = propX[x3]
leftY, rightY = [FF(_, curgood) for _ in [leftX, rightX]]
leftY, rightY = [getAge(_, curgood) for _ in [leftX, rightX]]
# we stop for either right-left==1 or for bads
curbad = (targY < leftY) | (targY >= rightY) # we'll exclude them
curbad2 = (rightX == leftX + 1) & np.isnan(rightY) # this is option B
Expand Down Expand Up @@ -404,6 +404,8 @@ def __call__(self, mass, logage, feh):
l1mass_good, l2mass_good, cureep))
xret[curkey] = curr[0] + eep_frac_good * (curr[1] - curr[0])
# perfoming the linear interpolation with age
# the formula is (1-eep_frac) * V_left + eep_frac V_right
# so V_left + eep_frac * (V_right - V_left)

ret = {}
for k in ['logg', 'logteff', 'logl', 'phase']:
Expand Down Expand Up @@ -443,24 +445,25 @@ def getLogAgeFromEEP(self, mass, eep, feh, returnJac=False):
self.umass, l1feh, l2feh,
l1mass, l2mass)

xind = ~bad
goodsel = ~bad

def FF(cureep):
return _interpolator(self.logage_grid, C11[xind], C12[xind],
C21[xind], C22[xind], l1feh[xind],
l2feh[xind], l1mass[xind], l2mass[xind],
cureep)

retage = mass * 0
Fe1 = FF(eep1)
Fe2 = FF(eep2)
retage[xind] = Fe1 * (1 - eep_frac) + (eep_frac) * Fe2
def getAge(cureep):
return _interpolator(self.logage_grid, C11[goodsel], C12[goodsel],
C21[goodsel], C22[goodsel], l1feh[goodsel],
l2feh[goodsel], l1mass[goodsel],
l2mass[goodsel], cureep)

ret_logage = np.zeros_like(mass)
logage1 = getAge(eep1)
logage2 = getAge(eep2)
# these are boundaries in the age grid
ret_logage[goodsel] = logage1 * (1 - eep_frac) + (eep_frac) * logage2
if returnJac:
jac = mass * 0
jac[xind] = Fe2 - Fe1
ret = (retage, jac)
jac[goodsel] = logage2 - logage1
ret = (ret_logage, jac)
else:
ret = retage
ret = ret_logage
return ret

def getMaxMassMS(self, logage, feh):
Expand All @@ -480,6 +483,7 @@ def getMaxMassMS(self, logage, feh):
self.phase_grid[R['l2feh'], R['l1mass'], eep],
self.phase_grid[R['l1feh'], R['l2mass'], eep],
self.phase_grid[R['l2feh'], R['l2mass'], eep])
# this is a max phase among interpolation box vertices
if phase > 0.5 or bad:
i2 = ix
else:
Expand Down Expand Up @@ -559,22 +563,25 @@ def _get_eep_coeffs(self, mass, logage, feh):
self.umass, l1feh, l2feh,
l1mass, l2mass)

def FF(cureep, subset):
def getAge(cureep, subset):
return _interpolator(self.logage_grid, C11[subset], C12[subset],
C21[subset], C22[subset], l1feh[subset],
l2feh[subset], l1mass[subset], l2mass[subset],
cureep)

lefts, rights, bads = _binary_search(bads, logage, self.neep, FF)
lefts, rights, bads = _binary_search(bads, logage, self.neep, getAge)
LV = np.zeros(len(mass))
RV = LV + 1
LV[~bads] = FF(lefts[~bads], ~bads)
RV[~bads] = FF(rights[~bads], ~bads)
LV[~bads] = getAge(lefts[~bads], ~bads)
RV[~bads] = getAge(rights[~bads], ~bads)
eep_frac = (logage - LV) / (RV - LV)
# eep_frac is the coefficient for interpolation in EEP axis
# 0<=eep_frac<1
# eep1 is the position in the EEP axis (essentially floor(EEP))
# 0<=eep1<neep
# eep_frac is zero if the pt is close to left edge, and one if close to
# right edge, so interpolation then needs to be done as
# (1-eep_frac) * V_left + eep_frac * V_right
return dict(C11=C11,
C12=C12,
C21=C21,
Expand Down

0 comments on commit 171ec24

Please sign in to comment.