In [5]:
"""Tools for handing the strided irreps layout."""

import math

import torch

from e3nn.o3 import Irreps


class StridedLayout:
    """Utility class to represent a strided layout of a tensor whose irreps all have the same multiplicity."""

    irreps: Irreps
    base_irreps: Irreps
    pad_to_multiple: int
    dim: int
    base_dim: int
    mul: int

    def __init__(self, irreps: Irreps, pad_to_multiple: int = 1):
        irreps = Irreps(irreps)
        if not self.can_be_strided(irreps):
            raise ValueError(f"Irreps `{irreps}` cannot be strided.")
        self.irreps = irreps
        self.base_irreps = Irreps([(1, ir) for _, ir in irreps])
        self.mul = self.irreps[0].mul if len(irreps) > 0 else 0
        assert self.irreps.dim == self.base_irreps.dim * self.mul
        self.pad_to_multiple = pad_to_multiple
        assert self.pad_to_multiple in (1, 2, 4, 8)

        self.base_dim = int(
            math.ceil(self.base_irreps.dim / self.pad_to_multiple)
            * self.pad_to_multiple
        )
        pad_by = self.base_dim - self.base_irreps.dim
        self.dim = self.base_dim * self.mul

        # indexes to convert
        self.indexes_to_strided = torch.zeros(self.dim, dtype=torch.long)
        self.indexes_to_catted = torch.zeros(self.irreps.dim, dtype=torch.long)
        i: int = 0
        for mul_i in range(self.mul):
            for irrep_i, (_, irrep) in enumerate(self.base_irreps):
                strided_indexes = torch.arange(start=i, end=i + irrep.dim)
                catted_indexes = (
                    torch.arange(irrep.dim)
                    + self.irreps[:irrep_i].dim
                    + irrep.dim * mul_i
                )
                self.indexes_to_strided[strided_indexes] = catted_indexes
                self.indexes_to_catted[catted_indexes] = strided_indexes
                i += irrep.dim
            # pad out this line of the [mul, k] shape
            i += pad_by

        # They should be inverses:
        assert torch.all(
            self.indexes_to_strided[self.indexes_to_catted]
            == torch.arange(self.irreps.dim)
        )

    @staticmethod
    def can_be_strided(irreps: Irreps) -> bool:
        """Check whether ``irreps`` is compatible with strided layout."""
        irreps = Irreps(irreps)
        if len(irreps) == 0:
            return True
        return all(irreps[0].mul == mul for mul, ir in irreps)

    def to_strided(self, x: torch.Tensor) -> torch.Tensor:
        """Convert a tensor from default to strided layout."""
        return x[..., self.indexes_to_strided]

    def to_catted(self, x: torch.Tensor) -> torch.Tensor:
        """Convert a tensor from strided to default layout."""
        return x[..., self.indexes_to_catted]

In [6]:
math.ceil(5.1)

6

In [7]:
out_layout = StridedLayout("32x0e + 32x1o + 32x2e")

In [15]:
print(out_layout.indexes_to_catted.shape)
print(out_layout.indexes_to_strided.shape)
out_layout.indexes_to_strided

torch.Size([288])
torch.Size([288])


tensor([  0,  32,  33,  34, 128, 129, 130, 131, 132,   1,  35,  36,  37, 133,
        134, 135, 136, 137,   2,  38,  39,  40, 138, 139, 140, 141, 142,   3,
         41,  42,  43, 143, 144, 145, 146, 147,   4,  44,  45,  46, 148, 149,
        150, 151, 152,   5,  47,  48,  49, 153, 154, 155, 156, 157,   6,  50,
         51,  52, 158, 159, 160, 161, 162,   7,  53,  54,  55, 163, 164, 165,
        166, 167,   8,  56,  57,  58, 168, 169, 170, 171, 172,   9,  59,  60,
         61, 173, 174, 175, 176, 177,  10,  62,  63,  64, 178, 179, 180, 181,
        182,  11,  65,  66,  67, 183, 184, 185, 186, 187,  12,  68,  69,  70,
        188, 189, 190, 191, 192,  13,  71,  72,  73, 193, 194, 195, 196, 197,
         14,  74,  75,  76, 198, 199, 200, 201, 202,  15,  77,  78,  79, 203,
        204, 205, 206, 207,  16,  80,  81,  82, 208, 209, 210, 211, 212,  17,
         83,  84,  85, 213, 214, 215, 216, 217,  18,  86,  87,  88, 218, 219,
        220, 221, 222,  19,  89,  90,  91, 223, 224, 225, 226, 2

In [17]:
out_layout.indexes_to_catted

tensor([  0,   9,  18,  27,  36,  45,  54,  63,  72,  81,  90,  99, 108, 117,
        126, 135, 144, 153, 162, 171, 180, 189, 198, 207, 216, 225, 234, 243,
        252, 261, 270, 279,   1,   2,   3,  10,  11,  12,  19,  20,  21,  28,
         29,  30,  37,  38,  39,  46,  47,  48,  55,  56,  57,  64,  65,  66,
         73,  74,  75,  82,  83,  84,  91,  92,  93, 100, 101, 102, 109, 110,
        111, 118, 119, 120, 127, 128, 129, 136, 137, 138, 145, 146, 147, 154,
        155, 156, 163, 164, 165, 172, 173, 174, 181, 182, 183, 190, 191, 192,
        199, 200, 201, 208, 209, 210, 217, 218, 219, 226, 227, 228, 235, 236,
        237, 244, 245, 246, 253, 254, 255, 262, 263, 264, 271, 272, 273, 280,
        281, 282,   4,   5,   6,   7,   8,  13,  14,  15,  16,  17,  22,  23,
         24,  25,  26,  31,  32,  33,  34,  35,  40,  41,  42,  43,  44,  49,
         50,  51,  52,  53,  58,  59,  60,  61,  62,  67,  68,  69,  70,  71,
         76,  77,  78,  79,  80,  85,  86,  87,  88,  89,  94,  

In [None]:
# sparse matrix multiplication
from typing import Tuple
from packaging import version

import torch

from e3nn.util.jit import compile_mode

_USE_PYG_SPARSE: bool = False

_TORCH_IS_GE_1_10: bool = version.parse(torch.__version__) >= version.parse("1.10.0")

if not _USE_PYG_SPARSE:

    class _ExplicitGradSpmm(torch.autograd.Function):
        @staticmethod
        def forward(ctx, sparse, a):
            ctx.save_for_backward(sparse)
            return torch.mm(sparse, a)

        @staticmethod
        def backward(ctx, grad_output):
            (sparse,) = ctx.saved_tensors
            return None, torch.mm(sparse.t(), grad_output)

    # TODO: support csr with similar method; wait for 1.10 probably
    @torch.jit.script
    def _remake_sparse_coo(i, v, shape: Tuple[int, int]):
        out = torch.sparse_coo_tensor(
            indices=i, values=v, size=shape, device=v.device, dtype=v.dtype
        )
        # mark it as coalesced, cause it is from when we build it in
        # ExplicitGradSpmm's __init__
        out._coalesced_(True)  # undocumented, AFAIK
        assert out.is_coalesced()
        return out

    @compile_mode("trace")
    class ExplicitGradSpmmCOO(torch.nn.Module):
        shape: Tuple[int, int]

        def __init__(self, mat: torch.Tensor):
            super().__init__()
            assert mat.is_sparse
            assert mat.ndim == 2
            mat = mat.coalesce()
            # To workaround https://github.com/pytorch/pytorch/issues/63987,
            # save indices and values explicitly
            self.register_buffer("_i", mat.indices())
            self.register_buffer("_v", mat.values())
            self.shape = tuple(mat.shape)

        def forward(self, x):
            # TODO: support csr
            sp = _remake_sparse_coo(self._i, self._v, self.shape)
            if self.training:
                # Use a custom autograd function for 2nd derivatives
                # torch.mm doesn't do double derivatives with sparse w3j
                tmp = _ExplicitGradSpmm.apply(sp, x)
            else:
                # For inference, assume only first derivatives necessary
                tmp = torch.mm(sp, x)
            return tmp

        def _make_tracing_inputs(self, n: int):
            return [
                {
                    "forward": (
                        torch.randn(
                            self.shape[-1],
                            3,
                            device=self._v.device,
                            dtype=self._v.dtype,
                        ),
                    )
                }
                for _ in range(n)
            ]

    if _TORCH_IS_GE_1_10:

        @torch.jit.script
        def _remake_sparse_csr(crow, col, v, shape: Tuple[int, int]) -> torch.Tensor:
            return torch.sparse_csr_tensor(
                crow_indices=crow,
                col_indices=col,
                values=v,
                size=shape,
                layout=torch.sparse_csr,
                device=v.device,
                dtype=v.dtype,
            )

        @compile_mode("trace")
        class ExplicitGradSpmmCSR(torch.nn.Module):
            shape: Tuple[int, int]

            def __init__(self, mat: torch.Tensor):
                super().__init__()
                assert mat.is_sparse_csr
                assert mat.ndim == 2
                # To workaround https://github.com/pytorch/pytorch/issues/63987,
                # save indices and values explicitly
                self.register_buffer("_crow", mat.crow_indices())
                self.register_buffer("_col", mat.col_indices())
                self.register_buffer("_v", mat.values())
                self.shape = tuple(mat.shape)

            def forward(self, x):
                # TODO: support csr
                sp = _remake_sparse_csr(self._crow, self._col, self._v, self.shape)
                if self.training:
                    # Use a custom autograd function for 2nd derivatives
                    # torch.mm doesn't do double derivatives with sparse w3j
                    tmp = _ExplicitGradSpmm.apply(sp, x)
                else:
                    # For inference, assume only first derivatives necessary
                    tmp = torch.mm(sp, x)
                return tmp

            def _make_tracing_inputs(self, n: int):
                return [
                    {
                        "forward": (
                            torch.randn(
                                self.shape[-1],
                                3,
                                device=self._v.device,
                                dtype=self._v.dtype,
                            ),
                        )
                    }
                    for _ in range(n)
                ]

    def ExplicitGradSpmm(mat):
        if mat.is_sparse:
            return ExplicitGradSpmmCOO(mat)
        elif _TORCH_IS_GE_1_10 and mat.is_sparse_csr:
            return ExplicitGradSpmmCSR(mat)
        else:
            raise TypeError

else:  # _USE_PYG_SPARSE

    from torch_sparse import SparseTensor
    from torch_sparse.matmul import spmm_add

    class ExplicitGradSpmm(torch.nn.Module):
        def __init__(self, mat):
            super().__init__()
            self.mat = SparseTensor.from_dense(mat.to_dense())

        def forward(self, x):
            return spmm_add(self.mat, x)
