In [1]:
import os
from typing import List, Dict, Union, Optional, Tuple
from beartype import beartype
from tqdm import tqdm
from copy import deepcopy

import torch
from e3nn import o3
# from e3nn.o3._wigner import _Jd
# _Jd: Tuple[torch.Tensor] = tuple(J.detach().clone().to(dtype=torch.float32) for J in _Jd)
from e3nn.math._linalg import direct_sum
from e3nn.util.jit import compile_mode
from diffusion_edf.transforms import matrix_to_euler_angles, quaternion_to_matrix, standardize_quaternion, random_quaternions
from diffusion_edf import w3j


In [2]:
from diffusion_edf.wigner import quat_to_angle_fast, transform_feature_slice_nonscalar

class SliceAndTransform(torch.nn.Module):
    mul: torch.jit.Final[int]
    l: torch.jit.Final[int]
    dim: torch.jit.Final[int]
    start: torch.jit.Final[int]
    len: torch.jit.Final[int]
    
    def __init__(self, mul:int, l: int, start: int, end: int, allow_zero_len: bool = False):
        super().__init__()
        self.mul = mul
        self.l = l
        self.dim = 2*self.l+1
        self.register_buffer("J", w3j._Jd[l].detach().clone())
        
        if allow_zero_len:
            assert end >= start, f"end ({end}) < start ({start})"
        else:
            assert end > start, f"end ({end}) =< start ({start})"
        self.start = start
        self.len = end - start
        
    def forward(self, feature: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor:
        sliced = torch.narrow(feature, dim=-1, start=self.start, length=self.len)
        assert sliced.shape[-1] == self.len, f"{sliced.shape[-1]} != {self.len}"
        if self.l == 0:
            return sliced.expand(len(alpha), len(sliced), self.len)
        else:
            return transform_feature_slice_nonscalar(feature=sliced, alpha=alpha, beta=beta, gamma=gamma, l=self.l, J=self.J)

class TransformFeatureQuaternion(torch.nn.Module):
    dim: torch.jit.Final[int]
    lmax: torch.jit.Final[int]
    # __constants__ = ['Js']
    
    def __init__(self, irreps: Union[str, o3.Irreps]):
        super().__init__()

        irreps = o3.Irreps(irreps)
        for n, (l,p) in irreps:
            if p != 1:
                raise NotImplementedError(f"E3 equivariance is not implemented! (input_irreps: {irreps})")
        self.dim = irreps.dim
        self.lmax = irreps.lmax
        
        self.transforms = torch.nn.ModuleList()
        for (mul, l), (start, end) in zip(
            tuple((mul, ir.l) for mul, ir in irreps), 
            tuple((slice_.start, slice_.stop) for slice_ in irreps.slices())
        ):
            self.transforms.append(
                SliceAndTransform(mul=mul, l=l, start=start, end=end, allow_zero_len=False)
            )
        
        
    def forward(self, feature: torch.Tensor, q: torch.Tensor) -> torch.Tensor : # (N_Q, N_D) x (N_T, 4) -> (N_T, N_Q, N_D)
        assert q.ndim == 2 and q.shape[-1] == 4, f"{q.shape}" # (nT, 4)
        assert feature.ndim == 2 and feature.shape[-1] == self.dim, f"{feature.shape}" # (nQ, D)

        # --------------------------------------------------- #
        # Return Identity if spin-0 only
        # --------------------------------------------------- #
        if self.lmax == 0:
            return feature.expand(len(q), -1, -1)
        
        # --------------------------------------------------- #
        # Quaternion to Euler angles
        # --------------------------------------------------- #
        q = standardize_quaternion(q / torch.norm(q, dim=-1, keepdim=True))
        angle = quat_to_angle_fast(q)
        alpha, beta, gamma = angle[0], angle[1], angle[2]
        
        # --------------------------------------------------- #
        # Quaternion to Euler angles
        # --------------------------------------------------- #
        feature_transformed = []
        for transform in self.transforms:
            feature_transformed.append(
                transform(feature=feature, alpha=alpha, beta=beta, gamma=gamma)
            )
        
        return torch.cat(feature_transformed, dim=-1)

In [3]:
from diffusion_edf.wigner import transform_feature_quat_
class OldTransformFeatureQuaternion(torch.nn.Module):
    def __init__(self, irreps: o3.Irreps):
        super().__init__()
        self.ls = tuple([ir.l for mul, ir in irreps])
        self.slices = tuple([(slice_.start, slice_.stop) for slice_ in irreps.slices()])
        self.Js = tuple(w3j._Jd[l] for l in self.ls)
        self.dim: int = o3.Irreps(irreps).dim

        for n, (l,p) in o3.Irreps(irreps):
            if p != 1:
                raise NotImplementedError(f"E3 equivariance is not implemented! (input_irreps: {o3.Irreps(irreps)})")
        self.lmax: int = o3.Irreps(irreps).lmax
        
    @torch.jit.ignore()
    def to(self, *args, **kwargs):
        self.Js = tuple(w3j._Jd[l].to(*args, **kwargs) for l in self.ls)
        for module in self.children():
            if isinstance(module, torch.nn.Module):
                module.to(*args, **kwargs)
        return super().to(*args, **kwargs)

    def forward(self, feature: torch.Tensor, q: torch.Tensor) -> torch.Tensor : # (N_Q, N_D) x (N_T, 4) -> (N_T, N_Q, N_D)
        assert q.ndim == 2 and q.shape[-1] == 4, f"{q.shape}" # (nT, 4)
        assert feature.ndim == 2 and feature.shape[-1] == self.dim, f"{feature.shape}" # (nQ, D)

        if self.lmax == 0:
            return feature.expand(len(q), -1, -1)
        
        feature_slices = []
        for slice_ in self.slices:
            feature_slices.append(feature[..., slice_[0]:slice_[1]])

        return transform_feature_quat_(ls=self.ls, feature_slices=feature_slices, Js=self.Js, q=q)

In [4]:
device = "cuda"
irreps = o3.Irreps("10x0e+12x1e+3x2e")

In [5]:
trans = TransformFeatureQuaternion(irreps=irreps)
trans = torch.jit.script(trans)
trans = trans.to(device)

In [6]:
feature = irreps.randn(100,-1, device=device)
q=random_quaternions(100, device=device)
for _ in tqdm(range(50)):
    trans(feature, q) 

100%|██████████| 50/50 [00:01<00:00, 47.65it/s]


In [7]:
old_trans = OldTransformFeatureQuaternion(irreps=irreps)
old_trans = old_trans.to(device)
old_trans = torch.jit.script(old_trans)

In [8]:
feature = irreps.randn(100,-1, device=device)
q=random_quaternions(100, device=device)
for _ in tqdm(range(50)):
    old_trans(feature, q)

  return forward_call(*input, **kwargs)
 does not have profile information (Triggered internally at ../torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
  return forward_call(*input, **kwargs)
100%|██████████| 50/50 [00:00<00:00, 305.62it/s]


In [9]:
irreps = o3.Irreps("10x0e+12x1e+3x2e")
feature = irreps.randn(100,-1, device=device)
q=random_quaternions(100, device=device)

In [10]:
for _ in tqdm(range(10000)):
    trans(feature, q) 

100%|██████████| 10000/10000 [00:03<00:00, 3016.57it/s]


In [11]:
for _ in tqdm(range(10000)):
    old_trans(feature, q) 

100%|██████████| 10000/10000 [00:03<00:00, 2945.21it/s]


New module is almost identical in speed but does not generate warning