diff --git a/requirements.txt b/requirements.txt index df29eac795..d8ebb01cfa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pip<24.0 immutabledict<2.2.6 +typing-extensions>=4.7,<5 diff --git a/xdsl/interpreters/shaped_array.py b/xdsl/interpreters/shaped_array.py index 9b082ee656..36a51d277b 100644 --- a/xdsl/interpreters/shaped_array.py +++ b/xdsl/interpreters/shaped_array.py @@ -1,11 +1,12 @@ from __future__ import annotations import operator +from dataclasses import dataclass from itertools import accumulate, product from math import prod from typing import Generic, Iterable, TypeAlias, TypeVar -from attr import dataclass +from typing_extensions import Self _T = TypeVar("_T") @@ -64,7 +65,7 @@ def indices(self) -> Iterable[tuple[int, ...]]: """ yield from product(*(range(dim) for dim in self.shape)) - def transposed(self, dim0: int, dim1: int) -> ShapedArray[_T]: + def transposed(self, dim0: int, dim1: int) -> Self: """ Returns a new ShapedArray, with the dimensions `dim0` and `dim1` transposed. """