Skip to content

Commit

Permalink
[NumpyVectorArrayView] add type promotion to inplace operations
Browse files Browse the repository at this point in the history
  • Loading branch information
sdrave committed May 8, 2017
1 parent 656979d commit bc3a2dd
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/pymor/vectorarrays/numpy.py
Expand Up @@ -499,6 +499,10 @@ def __iadd__(self, other):
assert self.base.check_ind_unique(self.ind)
if self.base._refcount[0] > 1:
self._deep_copy()
other_dtype = other.base._array.dtype if other.is_view else other._array.dtype
common_dtype = np.promote_types(self.base._array.dtype, other_dtype)
if self.base._array.dtype != common_dtype:
self.base._array = self.base._array.astype(common_dtype)
self.base.array[self.ind] += other.base._array[other.ind] if other.is_view else other._array[:other._len]
return self

Expand All @@ -515,6 +519,10 @@ def __isub__(self, other):
assert self.base.check_ind_unique(self.ind)
if self.base._refcount[0] > 1:
self._deep_copy()
other_dtype = other.base._array.dtype if other.is_view else other._array.dtype
common_dtype = np.promote_types(self.base._array.dtype, other_dtype)
if self.base._array.dtype != common_dtype:
self.base._array = self.base._array.astype(common_dtype)
self.base._array[self.ind] -= other.base._array[other.ind] if other.is_view else other._array[:other._len]
return self

Expand All @@ -529,6 +537,10 @@ def __imul__(self, other):
assert self.base.check_ind_unique(self.ind)
if self.base._refcount[0] > 1:
self._deep_copy()
other_dtype = other.dtype if isinstance(other, np.ndarray) else type(other)
common_dtype = np.promote_types(self.base._array.dtype, other_dtype)
if self.base._array.dtype != common_dtype:
self.base._array = self.base._array.astype(common_dtype)
self.base._array[self.ind] *= other
return self

Expand Down

0 comments on commit bc3a2dd

Please sign in to comment.