Skip to content

Commit

Permalink
Add simple np_fns for use when jax is not installed.
Browse files Browse the repository at this point in the history
When jax is installed, any numpy array will use the jax_fns xnp backend.
When jax is not installed, now all numpy arrays will use the np_fns xnp
backend.
  • Loading branch information
gpleiss committed Aug 21, 2023
1 parent 91ce584 commit 3345985
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 0 deletions.
241 changes: 241 additions & 0 deletions cola/np_fns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import logging
import sys

Check warning on line 2 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L1-L2

Added lines #L1 - L2 were not covered by tests

import numpy as np
from scipy.linalg import block_diag as _block_diag, lu as _lu, solve_triangular
from scipy.signal import convolve2d

Check warning on line 6 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L4-L6

Added lines #L4 - L6 were not covered by tests


class NumpyNotImplementedError(NotImplementedError):
def __init__(self):
fn_name = sys._getframe(1).f_code.co_name
super().__init__(f"{fn_name} is not implemented for the numpy backend.")

Check warning on line 12 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L9-L12

Added lines #L9 - L12 were not covered by tests


abs = np.abs
all = np.all
allclose = np.allclose
any = np.any
arange = np.arange
argsort = np.argsort
block_diag = _block_diag
cholesky = np.linalg.cholesky
clip = np.clip
complex64 = np.complex64
concat = np.concatenate
concatenate = np.concatenate
conj = np.conj
copy = np.copy
cos = np.cos
eig = np.linalg.eig
eigh = np.linalg.eigh
exp = np.exp
float32 = np.float32
float64 = np.float64
int32 = np.int32
int64 = np.int64
inv = np.linalg.inv
isreal = np.isreal
kron = np.kron
log = np.log
long = np.int64
lu = _lu
max = np.max
maximum = np.maximum
mean = np.mean
min = np.min
moveaxis = np.moveaxis
nan_to_num = np.nan_to_num
ndarray = np.ndarray
norm = np.linalg.norm
normal = np.random.normal
ones_like = np.ones_like
prod = np.prod
qr = np.linalg.qr
reshape = np.reshape
roll = np.roll
sign = np.sign
sin = np.sin
slogdet = np.linalg.slogdet
solve = np.linalg.solve
solvetri = solve_triangular
sort = np.sort
sqrt = np.sqrt
stack = np.stack
sum = np.sum
svd = np.linalg.svd
where = np.where

Check warning on line 67 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L15-L67

Added lines #L15 - L67 were not covered by tests


def PRNGKey(key):
raise NumpyNotImplementedError()

Check warning on line 71 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L70-L71

Added lines #L70 - L71 were not covered by tests


def Parameter(array):
return array

Check warning on line 75 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L74-L75

Added lines #L74 - L75 were not covered by tests


def array(arr, dtype=None, device=None):
return np.array(arr, dtype=dtype)

Check warning on line 79 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L78-L79

Added lines #L78 - L79 were not covered by tests


def canonical(loc, shape, dtype, device=None):
vec = np.zeros(shape=shape, dtype=dtype)
vec = vec.at[loc].set(1.0)
return vec

Check warning on line 85 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L82-L85

Added lines #L82 - L85 were not covered by tests


def cast(array, dtype):
return array.astype(dtype)

Check warning on line 89 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L88-L89

Added lines #L88 - L89 were not covered by tests


def convolve(in1, in2, mode="same"):
in12 = np.pad(

Check warning on line 93 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L92-L93

Added lines #L92 - L93 were not covered by tests
in1,
(
(in2.shape[0] // 2, (in2.shape[0] + 1) // 2 - 1),
(in2.shape[1] // 2, (in2.shape[1] + 1) // 2 - 1),
),
"symmetric",
)
out = convolve2d(in12, in2, mode="valid")
return out # ,boundary='symm')

Check warning on line 102 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L101-L102

Added lines #L101 - L102 were not covered by tests


def device(device_name):
raise NumpyNotImplementedError()

Check warning on line 106 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L105-L106

Added lines #L105 - L106 were not covered by tests


def diag(v, diagonal=0):
return np.diag(v, k=diagonal)

Check warning on line 110 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L109-L110

Added lines #L109 - L110 were not covered by tests


def dynamic_slice(operand, start_indices, slice_sizes):
raise NumpyNotImplementedError()

Check warning on line 114 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L113-L114

Added lines #L113 - L114 were not covered by tests


def expand(array, axis):
return np.expand_dims(array, dimensions=(axis,))

Check warning on line 118 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L117-L118

Added lines #L117 - L118 were not covered by tests


def eye(n, m=None, dtype=None, device=None):
del device
return np.eye(N=n, M=m, dtype=dtype)

Check warning on line 123 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L121-L123

Added lines #L121 - L123 were not covered by tests


def fixed_normal_samples(shape, dtype=None):
key = PRNGKey(4)
z = normal(key, shape, dtype=dtype)
return z

Check warning on line 129 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L126-L129

Added lines #L126 - L129 were not covered by tests


def for_loop(lower, upper, body_fun, init_val):
raise NumpyNotImplementedError()

Check warning on line 133 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L132-L133

Added lines #L132 - L133 were not covered by tests


def get_default_device():
return None

Check warning on line 137 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L136-L137

Added lines #L136 - L137 were not covered by tests


def get_device(array):
return None

Check warning on line 141 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L140-L141

Added lines #L140 - L141 were not covered by tests


def grad(fun):
raise NumpyNotImplementedError()

Check warning on line 145 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L144-L145

Added lines #L144 - L145 were not covered by tests


def is_array(array):
return isinstance(array, np.ndarray)

Check warning on line 149 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L148-L149

Added lines #L148 - L149 were not covered by tests


def jit(fn, static_argnums=None):
raise NumpyNotImplementedError()

Check warning on line 153 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L152-L153

Added lines #L152 - L153 were not covered by tests


def jvp(fun, primals, tangents, has_aux=False):
raise NumpyNotImplementedError()

Check warning on line 157 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L156-L157

Added lines #L156 - L157 were not covered by tests


def jvp_derivs(fun, primals, tangents, create_graph=True):
raise NumpyNotImplementedError()

Check warning on line 161 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L160-L161

Added lines #L160 - L161 were not covered by tests


def linear_transpose(fun, primals, duals):
raise NumpyNotImplementedError()

Check warning on line 165 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L164-L165

Added lines #L164 - L165 were not covered by tests


def lu_solve(a, b):
return solve(a, b)

Check warning on line 169 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L168-L169

Added lines #L168 - L169 were not covered by tests


def move_to(arr, device, dtype):
if dtype is not None:
arr = arr.astype(dtype)
if device is not None:
raise RuntimeError(

Check warning on line 176 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L172-L176

Added lines #L172 - L176 were not covered by tests
"move_to does not take in a device argument for the numpy backend."
)
return arr

Check warning on line 179 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L179

Added line #L179 was not covered by tests


def next_key(key):
raise NumpyNotImplementedError()

Check warning on line 183 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L182-L183

Added lines #L182 - L183 were not covered by tests


def ones(shape, dtype):
return np.ones(shape=shape, dtype=dtype)

Check warning on line 187 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L186-L187

Added lines #L186 - L187 were not covered by tests


def pbar_while(errorfn, tol, desc='', every=1, hide=False):
raise NumpyNotImplementedError()

Check warning on line 191 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L190-L191

Added lines #L190 - L191 were not covered by tests


def permute(array, axes):
return np.transpose(array, axes=axes)

Check warning on line 195 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L194-L195

Added lines #L194 - L195 were not covered by tests


def randn(*shape, dtype=None, key=None):
if key is None:
print("Non keyed randn used. To be deprecated soon.")
logging.warning("Non keyed randn used. To be deprecated soon.")
out = np.random.randn(*shape)
if dtype is not None:
out = out.astype(dtype)
return out

Check warning on line 205 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L198-L205

Added lines #L198 - L205 were not covered by tests
else:
z = normal(key, shape, dtype=dtype)
return z

Check warning on line 208 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L207-L208

Added lines #L207 - L208 were not covered by tests


def update_array(array, update, *slices):
return array.at[slices].set(update)

Check warning on line 212 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L211-L212

Added lines #L211 - L212 were not covered by tests


def vjp(fun, *primals, has_aux=False):
raise NumpyNotImplementedError()

Check warning on line 216 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L215-L216

Added lines #L215 - L216 were not covered by tests


def vjp_derivs(fun, primals, duals, create_graph=True):
raise NumpyNotImplementedError()

Check warning on line 220 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L219-L220

Added lines #L219 - L220 were not covered by tests


def vmap(fun, in_axes=0, out_axes=0):
raise NumpyNotImplementedError()

Check warning on line 224 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L223-L224

Added lines #L223 - L224 were not covered by tests


def while_loop(cond_fun, body_fun, init_val):
raise NumpyNotImplementedError()

Check warning on line 228 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L227-L228

Added lines #L227 - L228 were not covered by tests


def while_loop_no_jit(cond_fun, body_fun, init_val):
raise NumpyNotImplementedError()

Check warning on line 232 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L231-L232

Added lines #L231 - L232 were not covered by tests


def while_loop_winfo(errorfn, tol, every=1, desc="", pbar=False, **kwargs):
raise NumpyNotImplementedError()

Check warning on line 236 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L235-L236

Added lines #L235 - L236 were not covered by tests


def zeros(shape, dtype, device=None):
del device
return np.zeros(shape=shape, dtype=dtype)

Check warning on line 241 in cola/np_fns.py

View check run for this annotation

Codecov / codecov/patch

cola/np_fns.py#L239-L241

Added lines #L239 - L241 were not covered by tests
4 changes: 4 additions & 0 deletions cola/ops/operator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from cola.utils import export
from numbers import Number
import numpy as np

Array = Dtype = Any
export(Array)
Expand All @@ -27,6 +28,9 @@ def get_library_fns(dtype: Dtype):
]:
import cola.torch_fns as fns
return fns
elif dtype in [np.float32, np.float64, np.complex64, np.complex128, np.int32, np.int64]:
import cola.np_fns as fns
return fns

Check warning on line 33 in cola/ops/operator_base.py

View check run for this annotation

Codecov / codecov/patch

cola/ops/operator_base.py#L31-L33

Added lines #L31 - L33 were not covered by tests
except ImportError:
pass
raise ImportError("No supported array library found")
Expand Down

0 comments on commit 3345985

Please sign in to comment.