Skip to content

Commit 976fd5b

Browse files
All hail mypy
1 parent eb50ca6 commit 976fd5b

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

pytensor/link/numba/dispatch/linalg/dot/banded.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,14 @@ def _gbmv(
5656
A_banded = A_to_banded(A, kl=kl, ku=ku)
5757

5858
incx = x.strides[0] // x.itemsize
59-
incy = y.strides[0] // y.itemsize if y is not None else 1
60-
6159
offx = 0 if incx >= 0 else -x.size + 1
62-
offy = 0 if incy >= 0 else -y.size + 1
60+
61+
if y is not None:
62+
incy = y.strides[0] // y.itemsize
63+
offy = 0 if incy >= 0 else -y.size + 1
64+
else:
65+
incy = 1
66+
offy = 0
6367

6468
return fn(
6569
m=m,

pytensor/link/numba/dispatch/linalg/dot/general.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
from typing import cast as type_cast
23

34
import numpy as np
45
from numba.core.extending import overload
@@ -33,12 +34,16 @@ def _matrix_vector_product(
3334
(fn,) = linalg.get_blas_funcs(("gemv",), (A, x))
3435

3536
incx = x.strides[0] // x.itemsize
36-
incy = y.strides[0] // y.itemsize if y is not None else 1
37-
3837
offx = 0 if incx >= 0 else -x.size + 1
39-
offy = 0 if incy >= 0 else -y.size + 1
4038

41-
return fn(
39+
if y is not None:
40+
incy = y.strides[0] // y.itemsize if y is not None else 1
41+
offy = 0 if incy >= 0 else -y.size + 1
42+
else:
43+
incy = 1
44+
offy = 0
45+
46+
res = fn(
4247
alpha=alpha,
4348
a=A,
4449
x=x,
@@ -52,6 +57,8 @@ def _matrix_vector_product(
5257
trans=trans,
5358
)
5459

60+
return type_cast(np.ndarray, res)
61+
5562

5663
@overload(_matrix_vector_product)
5764
def matrix_vector_product_impl(
@@ -63,7 +70,15 @@ def matrix_vector_product_impl(
6370
overwrite_y: bool = False,
6471
trans: int = 1,
6572
) -> Callable[
66-
[float, np.ndarray, np.ndarray, float, np.ndarray, int, int, int, int, int],
73+
[
74+
np.ndarray,
75+
np.ndarray,
76+
np.ndarray,
77+
np.ndarray | None,
78+
np.ndarray | None,
79+
bool,
80+
int,
81+
],
6782
np.ndarray,
6883
]:
6984
ensure_lapack()

pytensor/tensor/blas.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,12 +1941,23 @@ def banded_gemv(
19411941
out: Tensor
19421942
The matrix multiplication result
19431943
"""
1944+
A = as_tensor_variable(A)
1945+
x = as_tensor_variable(x)
1946+
19441947
if alpha is None:
19451948
alpha = pt.ones((), dtype=A.type.dtype)
1949+
else:
1950+
alpha = as_tensor_variable(alpha)
1951+
19461952
if beta is None:
19471953
beta = pt.zeros((), dtype=A.type.dtype)
1954+
else:
1955+
beta = as_tensor_variable(beta)
1956+
19481957
if y is None:
19491958
y = pt.empty(A.shape[:-1], dtype=A.type.dtype)
1959+
else:
1960+
y = as_tensor_variable(y)
19501961

19511962
return Blockwise(BandedGEMV(lower_diags, upper_diags, overwrite_y=False))(
19521963
A, x, y, alpha, beta

0 commit comments

Comments
 (0)