Skip to content

Commit

Permalink
Fix guess() routines in thermal models module and ensure they are bei…
Browse files Browse the repository at this point in the history
…ng tested
  • Loading branch information
gb119 committed Dec 14, 2019
1 parent df86a64 commit af1b415
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 15 deletions.
7 changes: 6 additions & 1 deletion Stoner/analysis/fitting/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,11 @@ def __lmfit_one(self, model, data, params, prefix, columns, scale_covar, **kargs
header = kargs.pop("header", "")
residuals = kargs.pop("residuals", False)
output = kargs.pop("output", "row")
nan_policy = kargs.pop("nan_policy", "raise")
kargs[model.independent_vars[0]] = data[0]
fit = model.fit(data[1], params, scale_covar=scale_covar, weights=1.0 / data[2], **kargs)
fit = model.fit(
data[1], params, scale_covar=scale_covar, weights=1.0 / data[2], nan_policy=nan_policy, **kargs
)
if fit.success:
row = self._record_curve_fit_result(
model,
Expand Down Expand Up @@ -1374,6 +1377,7 @@ def lmfit(self, model, xcol=None, ycol=None, p0=None, sigma=None, **kargs):
data, scale_covar, _ = self._assemnle_data_to_fit(xcol, ycol, sigma, bounds, scale_covar)
model, prefix = _prep_lmfit_model(model, kargs)
p0, single_fit = _prep_lmfit_p0(model, data[1], data[0], p0, kargs)
nan_policy = kargs.pop("nan_policy", getattr(model, "nan_policy", "omit"))

if single_fit:
ret_val = self.__lmfit_one(
Expand All @@ -1388,6 +1392,7 @@ def lmfit(self, model, xcol=None, ycol=None, p0=None, sigma=None, **kargs):
replace=replace,
output=output,
residuals=residuals,
nan_policy=nan_policy,
)
else: # chi^2 mode
pn = p0
Expand Down
24 changes: 19 additions & 5 deletions Stoner/analysis/fitting/models/thermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import scipy.constants as consts
from scipy.optimize import curve_fit

try:
from lmfit import Model
Expand Down Expand Up @@ -120,7 +121,11 @@ def vftEquation(x, A, DE, x_0):
:outname: vft
"""
_kb = consts.physical_constants["Boltzmann constant"][0] / consts.physical_constants["elementary charge"][0]
return A * np.exp(-DE / (_kb * (x - x_0)))
X = np.where(np.isclose(x, x_0), 1e-8, x - x_0)
y = A * np.exp(-DE / (_kb * X))
if np.any(np.isnan(y)):
breakpoint()
return y


class Arrhenius(Model):
Expand Down Expand Up @@ -234,7 +239,7 @@ def guess(self, data, x=None, **kwargs):

d1, d2 = 1.0, 0.0
if x is not None:
d1, d2 = np.polyfit(-1.0 / x, np.log(data), 1)
d1, d2 = np.polyfit(-1.0 / x, np.log(data / x), 1)
pars = self.make_params(A=np.exp(d2), DE=_kb * d1, n=1.0)
return update_param_vals(pars, self.prefix, **kwargs)

Expand Down Expand Up @@ -265,17 +270,26 @@ class VFTEquation(Model):

display_names = ["A", r"\Delta E", "x_0"]

nan_policy = "omit"

def __init__(self, *args, **kwargs):
"""Configure Initial fitting function."""
super(VFTEquation, self).__init__(vftEquation, *args, **kwargs)

def guess(self, data, x=None, **kwargs):
"""Guess paramneters from a set of data."""
_kb = consts.physical_constants["Boltzmann constant"][0] / consts.physical_constants["elementary charge"][0]

d1, d2, x0 = 1.0, 0.0, 1.0
yy = np.log(data)
if x is not None:
# Getting a good x_0 is critical, so we first of all use poly fit to look
x0 = x[np.argmin(np.abs(data))]
d1, d2 = np.polyfit(-1.0 / (x - x0), np.log(data), 1)
pars = self.make_params(A=np.exp(d2), dE=_kb * d1, x_0=x0)

def _find_x0(x, d1, d2, x0):
return -d1 / (x - x0) + d2

popt, pcov = curve_fit(_find_x0, x, yy, p0=[x0, 20, 10])
d1, d2, x0 = popt
pars = self.make_params(A=np.exp(d2), DE=_kb * d1, x_0=x0)
print(pars)
return update_param_vals(pars, self.prefix, **kwargs)
4 changes: 2 additions & 2 deletions Stoner/plot/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def __call__(self, value, pos=None):
elif value != 0.0:
power = _np_.floor(_np_.log10(_np_.abs(value)))
pre = _np_.ceil(power / 3.0) * 3
power = power % 3
if pre == 0:
if -1 <= power <= 3 or pre == 0:
ret = "${}\\,\\mathrm{{{}}}$".format(_round(value, 4), self.unit)
else:
power = power % 3
v = _round(value / (10 ** pre), 4)
if _np_.abs(v) < 0.1:
v *= 1000
Expand Down
8 changes: 5 additions & 3 deletions doc/samples/Fitting/modArrhenius.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
)
d.setas = "xyy"
d.plot(fmt=["r.", "b-"])
d.annotate_fit(SF.modArrhenius, x=0.2, y=0.5)
d.annotate_fit(SF.modArrhenius, x=0.2, y=0.5, mode="eng")

# lmfit using lmfit guesses
fit = SF.ModArrhenius()
p0 = [1e6, 0.5, 1.5]
d.lmfit(fit, p0=p0, result=True, header="lmfit")
d.lmfit(fit, result=True, header="lmfit")
d.setas = "x..y"
d.plot()
d.annotate_fit(SF.ModArrhenius, x=0.2, y=0.25, prefix="ModArrhenius")
d.annotate_fit(
SF.ModArrhenius, x=0.2, y=0.25, prefix="ModArrhenius", mode="eng"
)

d.title = "Modified Arrhenius Test Fit"
d.ylabel = "Rate"
Expand Down
8 changes: 5 additions & 3 deletions doc/samples/Fitting/nDimArrhenius.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
d.curve_fit(SF.nDimArrhenius, p0=[1e6, 0.5, 2], result=True, header="curve_fit")
d.setas = "xyy"
d.plot(fmt=["r.", "b-"])
d.annotate_fit(SF.nDimArrhenius, x=0.25, y=0.3)
d.annotate_fit(SF.nDimArrhenius, x=0.25, y=0.3, mode="eng")

# lmfit using lmfit guesses
fit = SF.NDimArrhenius()
p0 = fit.guess(R, x=T)
d.lmfit(fit, p0=p0, result=True, header="lmfit")
d.lmfit(fit, result=True, header="lmfit")
d.setas = "x..y"
d.plot(fmt="g-")
d.annotate_fit(SF.NDimArrhenius, x=0.25, y=0.05, prefix="NDimArrhenius")
d.annotate_fit(
SF.NDimArrhenius, x=0.25, y=0.05, prefix="NDimArrhenius", mode="eng"
)

d.title = "n-D Arrhenius Test Fit"
d.ylabel = "Rate"
Expand Down
2 changes: 1 addition & 1 deletion doc/samples/Fitting/vftEquation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# lmfit uses some guesses
p0 = params
d.lmfit(VFTEquation, p0=p0, result=True, header="lmfit")
d.lmfit(VFTEquation, result=True, header="lmfit")

# Plot these results too
d.setas = "x..yy"
Expand Down

0 comments on commit af1b415

Please sign in to comment.