-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #293 from pfafflabatuiuc/analyzer
Adding Analyzer functionality
- Loading branch information
Showing
13 changed files
with
1,366 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from typing import Tuple, Any, Optional, Union, Dict, List | ||
from collections import OrderedDict | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
import lmfit | ||
|
||
|
||
class Parameter: | ||
|
||
def __init__(self, name: str, value: Any = None, **kw: Any): | ||
self.name = name | ||
self.value = value | ||
self._attrs = {} | ||
for k, v in kw.items(): | ||
self._attrs[k] = v | ||
|
||
def __getattr__(self, key: str) -> Any: | ||
return self._attrs[key] | ||
|
||
|
||
class Parameters(OrderedDict): | ||
"""A collection of parameters""" | ||
|
||
def add(self, name: str, **kw: Any) -> None: | ||
"""Add/overwrite a parameter in the collection.""" | ||
self[name] = Parameter(name, **kw) | ||
|
||
|
||
class AnalysisResult(object): | ||
|
||
def __init__(self, parameters: Dict[str, Union[Dict[str, Any], Any]]): | ||
self.params = Parameters() | ||
for k, v in parameters.items(): | ||
if isinstance(v, dict): | ||
self.params.add(k, **v) | ||
else: | ||
self.params.add(k, value=v) | ||
|
||
def eval(self, *args: Any, **kwargs: Any) -> np.ndarray: | ||
"""Analysis types that produce data (like filters or fits) should implement this. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class Analysis(object): | ||
"""Basic analysis object. | ||
Parameters | ||
---------- | ||
coordinates | ||
may be a single 1d numpy array (for a single coordinate) or a tuple | ||
of 1d arrays (for multiple coordinates). | ||
data | ||
a 1d array of data | ||
""" | ||
|
||
def __init__(self, coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], | ||
data: np.ndarray): | ||
"""Constructor of `Analysis`. """ | ||
self.coordinates = coordinates | ||
self.data = data | ||
|
||
def analyze(self, coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], data: np.ndarray, *args: Any, | ||
**kwargs: Any) -> AnalysisResult: | ||
"""Needs to be implemented by each inheriting class.""" | ||
raise NotImplementedError | ||
|
||
def run(self, *args: Any, **kwargs: Any) -> AnalysisResult: | ||
return self.analyze(self.coordinates, self.data, **kwargs) | ||
|
||
|
||
# def analyze(analysis_class: Analysis, coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], | ||
# data: np.ndarray, **kwarg: Any) -> AnalysisResult: | ||
# analysis = analysis_class(coordinates, data) | ||
# return analysis.run(**kwarg) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from . import generic_functions, experiment_functions | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import Tuple, Any, Optional, Union, Dict, List | ||
|
||
import numpy as np | ||
import lmfit | ||
from plottr.analyzer.fitters.fitter_base import Fit, FitResult | ||
|
||
|
||
class T1_Decay(Fit): | ||
@staticmethod | ||
def model(coordinates: np.ndarray, amp: float, tau: float) -> np.ndarray: # type: ignore[override] | ||
""" amp * exp(-1.0 * x / tau)""" | ||
return amp * np.exp(-1.0 * coordinates / tau) | ||
@staticmethod | ||
def guess(coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], | ||
data: np.ndarray) -> Dict[str, Any]: | ||
return dict(amp=1, tau=2) | ||
|
||
|
||
class T2_Ramsey(Fit): | ||
@staticmethod | ||
def model(coordinates: np.ndarray, amp: float, tau: float, freq: float, phase: float) -> np.ndarray: # type: ignore[override] | ||
""" amp * exp(-1.0 * x / tau) * sin(2 * PI * freq * x + phase) """ | ||
return amp * np.exp(-1.0 * coordinates / tau) * \ | ||
np.sin(2 * np.pi * freq * coordinates + phase) | ||
|
||
@staticmethod | ||
def guess(coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], | ||
data: np.ndarray) -> Dict[str, Any]: | ||
return dict(amp=1, tau=2, freq=3, phase=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from typing import Tuple, Any, Union, Dict | ||
|
||
import numpy as np | ||
import lmfit | ||
|
||
from ..base import Analysis, AnalysisResult | ||
|
||
|
||
class FitResult(AnalysisResult): | ||
|
||
def __init__(self, lmfit_result: lmfit.model.ModelResult): | ||
self.lmfit_result = lmfit_result | ||
self.params = lmfit_result.params | ||
|
||
def eval(self, *args: Any, **kwargs: Any) -> np.ndarray: | ||
return self.lmfit_result.eval(*args, **kwargs) | ||
|
||
|
||
class Fit(Analysis): | ||
|
||
@staticmethod | ||
def model(*arg: Any, **kwarg: Any) -> np.ndarray: | ||
raise NotImplementedError | ||
|
||
def analyze(self, coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], data: np.ndarray, | ||
dry: bool = False, params: Dict[str, Any] = {}, *args: Any, **fit_kwargs: Any) -> FitResult: | ||
model = lmfit.model.Model(self.model) | ||
|
||
_params = lmfit.Parameters() | ||
for pn, pv in self.guess(coordinates, data).items(): | ||
_params.add(pn, value=pv) | ||
for pn, pv in params.items(): | ||
if isinstance(pv, lmfit.Parameter): | ||
_params[pn] = pv | ||
else: | ||
_params[pn].set(value=pv) | ||
|
||
if dry: | ||
for pn, pv in _params.items(): | ||
pv.set(vary=False) | ||
lmfit_result = model.fit(data, params=_params, | ||
coordinates=coordinates, **fit_kwargs) | ||
|
||
return FitResult(lmfit_result) | ||
|
||
@staticmethod | ||
def guess(coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], | ||
data: np.ndarray) -> Dict[str, Any]: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Tuple, Any, Optional, Union, Dict, List | ||
|
||
import numpy as np | ||
import lmfit | ||
|
||
from plottr.analyzer.fitters.fitter_base import Fit, FitResult | ||
|
||
|
||
class Cosine(Fit): | ||
@staticmethod | ||
def model(coordinates: np.ndarray, # type: ignore[override] | ||
A: float, f: float, phi: float, of: float) -> np.ndarray: | ||
"""$A \cos(2 \pi f x + \phi) + of$""" | ||
return A * np.cos(2 * np.pi * coordinates * f + phi) + of | ||
|
||
@staticmethod | ||
def guess(coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], | ||
data: np.ndarray) -> Dict[str, float]: | ||
of = np.mean(data) | ||
A = (np.max(data) - np.min(data)) / 2. | ||
|
||
# Making sure that coordinates is ndarray. | ||
# Changing the type in the signature will create a different mypy error. | ||
assert isinstance(coordinates, np.ndarray) | ||
fft_val = np.fft.rfft(data)[1:] | ||
fft_frq = np.fft.rfftfreq(data.size, | ||
np.mean(coordinates[1:] - coordinates[:-1]))[1:] | ||
idx = np.argmax(np.abs(fft_val)) | ||
f = fft_frq[idx] | ||
phi = np.angle(fft_val[idx]) | ||
|
||
return dict(A=A, of=of, f=f, phi=phi) | ||
|
||
|
||
class Exponential(Fit): | ||
@staticmethod | ||
def model(coordinates: np.ndarray, a: float, b: float) -> np.ndarray: # type: ignore[override] | ||
""" a * b ** x""" | ||
return a * b ** coordinates | ||
|
||
@staticmethod | ||
def guess(coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], | ||
data: np.ndarray) -> Dict[str, float]: | ||
return dict(a=1, b=2) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Any, Union, Tuple | ||
import numpy as np | ||
|
||
from ..base import Analysis, AnalysisResult | ||
|
||
|
||
class FindMax(Analysis): | ||
"""A simple example class to illustrate the concept.""" | ||
|
||
def analyze(self, coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], | ||
data: np.ndarray, *args: Any, **kwargs: Any) -> AnalysisResult: | ||
i = np.argmax(data) | ||
|
||
return AnalysisResult( | ||
dict( | ||
max_val=data[i], | ||
max_pos=coordinates[i] | ||
) | ||
) |
Oops, something went wrong.