Skip to content

Commit

Permalink
Merge pull request #14 from ericpre/fix_hyperspy_compatibility
Browse files Browse the repository at this point in the history
Fix hyperspy compatibility
  • Loading branch information
CCampJr committed Oct 25, 2021
2 parents a0f8829 + e7d83a0 commit b2d2876
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 94 deletions.
34 changes: 17 additions & 17 deletions pymcr/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def __init__(self, axis=-1, copy=False):

def transform(self, A):
""" Apply cumsum nonnegative constraint"""

if self.copy:
return A*(_np.cumsum(A, self.axis) > 0)
else:
A *= (_np.cumsum(A, self.axis) > 0)
return A


class ConstraintZeroEndPoints(Constraint):
"""
Expand Down Expand Up @@ -168,7 +168,7 @@ class ConstraintZeroCumSumEndPoints(Constraint):
def __init__(self, nodes=None, axis=-1, copy=False):
""" A must be non-negative"""
super().__init__(copy)

self.nodes = nodes
if [0, 1, -1].count(axis) != 1:
raise TypeError('Axis must be 0, 1, or -1')
Expand All @@ -177,7 +177,7 @@ def __init__(self, nodes=None, axis=-1, copy=False):

def transform(self, A):
""" Apply cumsum nonnegative constraint"""

meaner = A.mean(self.axis)

if self.nodes:
Expand Down Expand Up @@ -227,7 +227,7 @@ def transform(self, A):
else:
A -= meaner[:, None]
return A


class ConstraintNorm(Constraint):
"""
Expand Down Expand Up @@ -282,7 +282,7 @@ def __init__(self, axis=-1, fix=None, copy=False):

def transform(self, A):
""" Apply normalization constraint """

if self.copy:
if self.axis == 0:
if not self.fix: # No fixed axes
Expand Down Expand Up @@ -311,7 +311,7 @@ def transform(self, A):

return A * scaler
else: # Overwrite original data
if A.dtype != _np.float:
if A.dtype != float:
raise TypeError('A.dtype must be float for',
'in-place math (copy=False)')

Expand Down Expand Up @@ -412,7 +412,7 @@ def transform(self, A):
return A
else:
return A


class _CutExclude(Constraint):
"""
Expand Down Expand Up @@ -447,7 +447,7 @@ def __init__(self, value=0, axis_sumnz=None, exclude=None,
def _make_excl_mat(self, A_shape):
X, Y = _np.meshgrid(_np.arange(A_shape[1]), _np.arange(A_shape[0]))
if self.exclude is None:
self._excl_mat = _np.zeros(X.shape, dtype=_np.bool)
self._excl_mat = _np.zeros(X.shape, dtype=bool)
else:
if self.exclude_axis == 0:
self._excl_mat = _np.in1d(Y.ravel(), self.exclude).reshape(Y.shape)
Expand Down Expand Up @@ -594,15 +594,15 @@ def __init__(self, value=0, copy=False):

def transform(self, A):
""" Apply compress-above value constraint"""

if self.copy:
return A*(A <= self.value) + self.value*(A > self.value)
else:
temp = self.value*(A > self.value)
A *= (A <= self.value)
A += temp
return A

class ConstraintPlanarize(Constraint):
"""
Set a particular target to a plane
Expand Down Expand Up @@ -669,16 +669,16 @@ def __init__(self, target, shape, use_vals_above=None, use_vals_below=None,
def _setup_xy(self, scaler):

self.scaler = scaler
self._x = scaler*_np.arange(self.shape[1], dtype=_np.float)
self._y = scaler*_np.arange(self.shape[0], dtype=_np.float)
self._x = scaler*_np.arange(self.shape[1], dtype=float)
self._y = scaler*_np.arange(self.shape[0], dtype=float)

self._X, self._Y = _np.meshgrid(self._x, self._y)
self._X = self._X.ravel()
self._Y = self._Y.ravel()

def transform(self, A):
""" Set targets, t, to fit planes """

if (self.scaler is None) | (self.recalc):
self._setup_xy(1e3 * _np.abs(A.max() - A.min()))

Expand Down Expand Up @@ -726,11 +726,11 @@ def transform(self, A):
return A2
else:
return A


if __name__ == '__main__': # pragma: no cover
A = _np.array([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]]).astype(_np.float)
A_transform = _np.array([[1, 2, 3, 4], [4, 0, 0, 0], [0, 0, 0, 0]]).astype(_np.float)
A = _np.array([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]]).astype(float)
A_transform = _np.array([[1, 2, 3, 4], [4, 0, 0, 0], [0, 0, 0, 0]]).astype(float)

constr = ConstraintCutAbove(copy=True, value=4)
out = constr.transform(A)
Expand Down
10 changes: 5 additions & 5 deletions pymcr/mcr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
""" MCR Main Class for Computation"""
import sys as _sys
import copy as _copy

import numpy as _np
import logging as _logging
Expand Down Expand Up @@ -304,6 +303,7 @@ def fit(self, D, C=None, ST=None, st_fix=None, c_fix=None, c_first=True,
the docs.
"""
D = _np.asanyarray(D)

if verbose:
_logger.setLevel(_logging.DEBUG)
Expand All @@ -314,11 +314,11 @@ def fit(self, D, C=None, ST=None, st_fix=None, c_fix=None, c_first=True,
temp = self.fit_kwargs.get('C')
if (temp is not None) & (C is None):
C = temp

temp = self.fit_kwargs.get('ST')
if (temp is not None) & (ST is None):
ST = temp

temp = self.fit_kwargs.get('st_fix')
if (temp is not None) & (st_fix is None):
st_fix = temp
Expand All @@ -344,8 +344,8 @@ def fit(self, D, C=None, ST=None, st_fix=None, c_fix=None, c_first=True,
raise TypeError(
err_str1 + 'unless c_fix and st_fix are both provided')
else:
self.C_ = C
self.ST_ = ST
self.C_ = _np.asanyarray(C) if C is not None else C
self.ST_ = _np.asanyarray(ST) if ST is not None else ST

self.n_increase = 0
self.n_above_min = 0
Expand Down
Loading

0 comments on commit b2d2876

Please sign in to comment.