### Making wrapper for QM9 dataset

In [1]:
# QM9 -> dgl

import os
import sys

import dgl
import numpy as np
import torch

from torch.utils.data import Dataset, DataLoader

from scipy.constants import physical_constants

hartree2eV = physical_constants['hartree-electron volt relationship'][0]
DTYPE = np.float32
DTYPE_INT = np.int32

class QM9Dataset(Dataset):
    """QM9 dataset."""
    num_bonds = 4
    atom_feature_size = 6 
    input_keys = ['mol_id', 'num_atoms', 'num_bonds', 'x', 'one_hot', 
                  'atomic_numbers', 'edge']
    unit_conversion = {'mu': 1.0,
                       'alpha': 1.0,
                       'homo': hartree2eV,
                       'lumo': hartree2eV,
                       'gap': hartree2eV, 
                       'r2': 1.0, 
                       'zpve': hartree2eV, 
                       'u0': hartree2eV, 
                       'u298': hartree2eV, 
                       'h298': hartree2eV,
                       'g298': hartree2eV,
                       'cv': 1.0} 

    def __init__(self, file_address: str, task: str, mode: str='train', 
            transform=None, fully_connected: bool=False): 
        """Create a dataset object

        Args:
            file_address: path to data
            task: target task ["homo", ...]
            mode: [train/val/test] mode
            transform: data augmentation functions
            fully_connected: return a fully connected graph
        """
        self.file_address = file_address
        self.task = task
        self.mode = mode
        self.transform = transform
        self.fully_connected = fully_connected

        # Encode and extra bond type for fully connected graphs
        self.num_bonds += fully_connected

        self.load_data()
        self.len = len(self.targets)
        print(f"Loaded {mode}-set, task: {task}, source: {self.file_address}, length: {len(self)}")

    
    def __len__(self):
        return self.len

    
    def load_data(self):
        # Load dict and select train/valid/test split
        data = torch.load(self.file_address)
        data = data[self.mode]
    
        # Filter out the inputs
        self.inputs = {key: data[key] for key in self.input_keys}

        # Filter out the targets and population stats
        self.targets = data[self.task]

        # TODO: use the training stats unlike the other papers
        self.mean = np.mean(self.targets)
        self.std = np.std(self.targets)


    def get_target(self, idx, normalize=True):
        target = self.targets[idx]
        if normalize:
            target = (target - self.mean) / self.std
        return target


    def norm2units(self, x, denormalize=True, center=True):
        # Convert from normalized to QM9 representation
        if denormalize:
            x = x * self.std
            # Add the mean: not necessary for error computations
            if not center:
                x += self.mean
        x = self.unit_conversion[self.task] * x
        return x


    def to_one_hot(self, data, num_classes):
        one_hot = np.zeros(list(data.shape) + [num_classes])
        one_hot[np.arange(len(data)),data] = 1
        return one_hot


    def _get_adjacency(self, n_atoms):
        # Adjust adjacency structure
        seq = np.arange(n_atoms)
        src = seq[:,None] * np.ones((1,n_atoms), dtype=np.int32)
        dst = src.T
        ## Remove diagonals and reshape
        src[seq, seq] = -1
        dst[seq, seq] = -1
        src, dst = src.reshape(-1), dst.reshape(-1)
        src, dst = src[src > -1], dst[dst > -1]
            
        return src, dst


    def get(self, key, idx):
        return self.inputs[key][idx]


    def connect_fully(self, edges, num_atoms):
        """Convert to a fully connected graph"""
        # Initialize all edges: no self-edges
        adjacency = {}
        for i in range(num_atoms):
            for j in range(num_atoms):
                if i != j:
                    # assigning new type of connection if originally not connected
                    adjacency[(i, j)] = self.num_bonds - 1 

        # Add bonded edges
        for idx in range(edges.shape[0]):
            adjacency[(edges[idx,0], edges[idx,1])] = edges[idx,2]
            adjacency[(edges[idx,1], edges[idx,0])] = edges[idx,2]

        # Convert to numpy arrays
        src = []
        dst = []
        w = []
        for edge, weight in adjacency.items():
            src.append(edge[0])
            dst.append(edge[1])
            w.append(weight)

        return np.array(src), np.array(dst), np.array(w)


    def connect_partially(self, edge):
        src = np.concatenate([edge[:,0], edge[:,1]])
        dst = np.concatenate([edge[:,1], edge[:,0]])
        w = np.concatenate([edge[:,2], edge[:,2]])
        return src, dst, w


    def __getitem__(self, idx):
        # Load node features
        num_atoms = self.get('num_atoms', idx) # number of atoms
        x = self.get('x', idx)[:num_atoms].astype(DTYPE) # coordinates of atoms
        one_hot = self.get('one_hot', idx)[:num_atoms].astype(DTYPE)
        atomic_numbers = self.get('atomic_numbers', idx)[:num_atoms].astype(DTYPE)

        # Load edge features
        num_bonds = self.get('num_bonds', idx)
        edge = self.get('edge', idx)[:num_bonds]
        edge = np.asarray(edge, dtype=DTYPE_INT)

        # Load target
        y = self.get_target(idx, normalize=True).astype(DTYPE)
        y = np.array([y])

        # Augmentation on the coordinates
        if self.transform:
            x = self.transform(x).astype(DTYPE)

        # Create nodes
        if self.fully_connected:
            src, dst, w = self.connect_fully(edge, num_atoms)
        else:
            src, dst, w = self.connect_partially(edge)
        w = self.to_one_hot(w, self.num_bonds).astype(DTYPE)

        # Create graph
        G = dgl.DGLGraph((src, dst))

        # Add node features to graph
        G.ndata['x'] = torch.tensor(x) #[num_atoms,3]
        G.ndata['f'] = torch.tensor(np.concatenate([one_hot, atomic_numbers], -1)[...,None]) #[num_atoms,6,1]

        # Add edge features to graph
        G.edata['d'] = torch.tensor(x[dst] - x[src]) #[num_atoms,3]
        G.edata['w'] = torch.tensor(w) #[num_atoms,4]

        return G, y



def collate(samples):
    graphs, y = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(y)

dataset = QM9Dataset('./QM9_data/QM9_data.pt', "homo", mode='train', fully_connected=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)

iter_dataloader = iter(dataloader) # so I can use next
for i in range(1):
    data = next(iter_dataloader)
    print("MINIBATCH")
    print(data[0]) 
    print(data[1].shape) # batch size -> connected graph of size batch

Loaded train-set, task: homo, source: ./QM9_data/QM9_data.pt, length: 100000
MINIBATCH
Graph(num_nodes=558, num_edges=9450,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32), 'f': Scheme(shape=(6, 1), dtype=torch.float32)}
      edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(5,), dtype=torch.float32)})
torch.Size([32, 1])


  return batched_graph, torch.tensor(y)


In [2]:
# Some tests of some parts
data = torch.load('./QM9_data/QM9_data.pt')

# max number of atoms times one hot encoding of type
print(data['train']['one_hot'][1][0]) 

# Edge src, dst, type
print(data['train']['edge'][1][0]) 
print(data['train']['edge'][1][1])

[False  True False False False]
[4 5 1]
[ 2 12  0]


In [3]:
def _get_adjacency(n_atoms):
    # Adjust adjacency structure
    seq = np.arange(n_atoms)
    src = seq[:,None] * np.ones((1,n_atoms), dtype=np.int32)
    dst = src.T
    ## Remove diagonals and reshape
    src[seq, seq] = -1
    dst[seq, seq] = -1
    src, dst = src.reshape(-1), dst.reshape(-1)
    src, dst = src[src > -1], dst[dst > -1]

    return src, dst

_get_adjacency(10) # from src to dst fully connected in this case, no self connections

# all the connections are represented in the form
# src[0]   src[1]   ...
# dst[0]   dst[1]   ...
# w[0]     w[1]

(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9,
        9, 9]),
 array([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 3, 4,
        5, 6, 7, 8, 9, 0, 1, 2, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 5, 6, 7, 8,
        9, 0, 1, 2, 3, 4, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 7, 8, 9, 0, 1, 2,
        3, 4, 5, 6, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 9, 0, 1, 2, 3, 4, 5, 6,
        7, 8]))

### DGL example

In [4]:
data

{'train': {'mol_id': array(['gdb_27331', 'gdb_71477', 'gdb_79811', ..., 'gdb_62673',
         'gdb_132497', 'gdb_10505'], dtype='<U10'),
  'A': array([3.32485, 4.067  , 2.09048, ..., 3.09118, 3.12223, 3.19768]),
  'B': array([1.36526, 1.20944, 1.93684, ..., 1.36482, 1.30785, 1.34433]),
  'C': array([0.96881, 1.09028, 1.66715, ..., 1.18845, 0.92701, 1.08382]),
  'mu': array([3.0485, 3.2222, 1.617 , ..., 2.5608, 2.4198, 3.326 ]),
  'alpha': array([74.35, 74.76, 76.71, ..., 74.9 , 67.48, 71.83]),
  'homo': array([-0.1975, -0.2405, -0.2391, ..., -0.2432, -0.2487, -0.2371]),
  'lumo': array([-0.009 , -0.0223,  0.0552, ...,  0.0687, -0.0362, -0.0185]),
  'gap': array([0.1885, 0.2182, 0.2943, ..., 0.3119, 0.2125, 0.2185]),
  'r2': array([1134.1372, 1125.2513,  972.2918, ..., 1104.5796, 1137.5235,
         1137.4255]),
  'zpve': array([0.11425 , 0.134161, 0.172351, ..., 0.156702, 0.100119, 0.163796]),
  'u0': array([-453.976855, -421.777683, -403.126103, ..., -422.970801,
         -486.005005,

In [5]:
G

NameError: name 'G' is not defined

In [None]:
import dgl
import dgl.function as fn


dataset = data



# initialize dgl graph
G = dataset[0]

ntype = 'x'
# ndata is a dict
# retrieve data from all nodes labels ntype (position example)
print(G.ndata[ntype]) 

d = 5
G.ndata[f'out{d}'] = torch.tensor(np.zeros((len(G.ndata['x']), 5)))
# retrieve output features of type d from node data
print(G.ndata[f'out{d}']) 


etype = 'd'
# retrive data from all edges labeled etype
print(G.edata[etype].shape)

di = 'd'
do = 'w'
G.edata[f'({di},{do})'] = torch.tensor(np.random.rand(G.edata[etype].shape[0], 3, 4))
# retrive edge kernels that transform from type di to type di
print(G.edata[f'({di},{do})'])

e = 'd' # edge feature (distance)
v = 'x' # node features (cartesian vector)
m = 'm' # output message on the edge
# calling built in dgl fubction e_dot_v that computes a message on
# edge by performing element-wise dot between features of e and v
# and stores it as edge message labeled 'm'
f = fn.e_dot_v(e, v, m)

# applies the function f to update the features of the edges with function
G.apply_edges(f)

print(G.edata['m'].shape)

In [None]:
?fn.e_dot_v

In [6]:
dataset[0][0]

Graph(num_nodes=15, num_edges=210,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32), 'f': Scheme(shape=(6, 1), dtype=torch.float32)}
      edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(5,), dtype=torch.float32)})

### Fibers

In [7]:
#from utils.utils_profiling import * # load before other local modules
try:
    profile
except NameError:
    def profile(func):
        return func

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

from typing import Dict, List, Tuple


class Fiber(object):
    """A Handy Data Structure for Fibers"""
    def __init__(self, num_degrees: int=None, num_channels: int=None,
                 structure: List[Tuple[int,int]]=None, dictionary=None):
        """
        define fiber structure; use one num_degrees & num_channels OR structure
        OR dictionary

        :param num_degrees: degrees will be [0, ..., num_degrees-1]
        :param num_channels: number of channels, same for each degree
        :param structure: e.g. [(32, 0),(16, 1),(16,2)]
        :param dictionary: e.g. {0:32, 1:16, 2:16}
        
        Structure in the form: List[(Tuple[int, int])]. In particular Features[(num_channels, feature_degree)]
        """
        
        if structure:
            self.structure = structure
        elif dictionary:
            self.structure = [(dictionary[o], o) for o in sorted(dictionary.keys())]
        else:
            self.structure = [(num_channels, i) for i in range(num_degrees)]

            
        # assigning to dict format and computing cummulative variables
        self.multiplicities, self.degrees = zip(*self.structure)
        self.max_degree = max(self.degrees)
        self.min_degree = min(self.degrees)
        self.structure_dict = {k: v for v, k in self.structure}
        self.dict = self.structure_dict
        self.n_features = np.sum([i[0] * (2*i[1]+1) for i in self.structure])

        
        # Mapping to vec() case. f = [...] with starting ind saved in feature ind dict
        # feature_ind = {degree: starting ind}
        self.feature_indices = {}
        idx = 0
        for (num_channels, d) in self.structure:
            length = num_channels * (2*d + 1)
            self.feature_indices[d] = (idx, idx + length)
            idx += length

    def copy_me(self, multiplicity: int=None):
        s = copy.deepcopy(self.structure)
        if multiplicity is not None:
            # overwrite multiplicities
            s = [(multiplicity, o) for m, o in s]
        return Fiber(structure=s)

    @staticmethod
    def combine(f1, f2):
        new_dict = copy.deepcopy(f1.structure_dict)
        for k, m in f2.structure_dict.items():
            if k in new_dict.keys():
                new_dict[k] += m
            else:
                new_dict[k] = m
        structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
        return Fiber(structure=structure)

    @staticmethod
    def combine_max(f1, f2):
        new_dict = copy.deepcopy(f1.structure_dict)
        for k, m in f2.structure_dict.items():
            if k in new_dict.keys():
                new_dict[k] = max(m, new_dict[k])
            else:
                new_dict[k] = m
        structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
        return Fiber(structure=structure)

    @staticmethod
    def combine_selectively(f1, f2):
        # only use orders which occur in fiber f1

        new_dict = copy.deepcopy(f1.structure_dict)
        for k in f1.degrees:
            if k in f2.degrees:
                new_dict[k] += f2.structure_dict[k]
        structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
        return Fiber(structure=structure)

    @staticmethod
    def combine_fibers(val1, struc1, val2, struc2):
        """
        combine two fibers

        :param val1/2: fiber tensors in dictionary form
        :param struc1/2: structure of fiber
        :return: fiber tensor in dictionary form
        """
        struc_out = Fiber.combine(struc1, struc2)
        val_out = {}
        for k in struc_out.degrees:
            if k in struc1.degrees:
                if k in struc2.degrees:
                    val_out[k] = torch.cat([val1[k], val2[k]], -2)
                else:
                    val_out[k] = val1[k]
            else:
                val_out[k] = val2[k]
                
            # number of channels is the second dimenstion from the end I guess
            # might look like [tensor_axis = degree, channel axis, tensor-component axis]
            assert val_out[k].shape[-2] == struc_out.structure_dict[k]
        return val_out

    def __repr__(self):
        return f"{self.structure}"



def get_fiber_dict(F, struc, mask=None, return_struc=False):
    if mask is None: mask = struc
    index = 0
    fiber_dict = {}
    first_dims = F.shape[:-1]
    masked_dict = {}
    for o, m in struc.structure_dict.items():
        length = m * (2*o + 1)
        if o in mask.degrees:
            masked_dict[o] = m
            fiber_dict[o] = F[...,index:index + length].view(list(first_dims) + [m, 2*o + 1])
        index += length
    assert F.shape[-1] == index
    if return_struc:
        return fiber_dict, Fiber(dictionary=masked_dict)
    return fiber_dict


def get_fiber_tensor(F, struc):
    some_entry = tuple(F.values())[0]
    first_dims = some_entry.shape[:-2]
    res = some_entry.new_empty([*first_dims, struc.n_features])
    index = 0
    for o, m in struc.structure_dict.items():
        length = m * (2*o + 1)
        res[..., index: index + length] = F[o].view(*first_dims, length)
        index += length
    assert index == res.shape[-1]
    return res


def fiber2tensor(F, structure, squeeze=False):
    if squeeze:
        fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1) for i in structure.degrees]
        fibers = torch.cat(fibers, -1)
    else:
        fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1, 1) for i in structure.degrees]
        fibers = torch.cat(fibers, -2)
    return fibers


# Reduce fibers into single tensor cell h (I guess)
def fiber2head(F, h, structure, squeeze=False):
    if squeeze:
        fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1) for i in structure.degrees]
        fibers = torch.cat(fibers, -1)
    else:
        fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1, 1) for i in structure.degrees]
        fibers = torch.cat(fibers, -2)
    return fibers

### Basis transformation matrixes and irreducable representation

In [8]:
'''
Cache in files
'''
from functools import wraps, lru_cache
import pickle
import gzip
import os
import sys
import fcntl


class FileSystemMutex:
    '''
    Mutual exclusion of different **processes** using the file system
    '''

    def __init__(self, filename):
        self.handle = None
        self.filename = filename

    def acquire(self):
        '''
        Locks the mutex
        if it is already locked, it waits (blocking function)
        '''
        self.handle = open(self.filename, 'w')
        fcntl.lockf(self.handle, fcntl.LOCK_EX)
        self.handle.write("{}\n".format(os.getpid()))
        self.handle.flush()

    def release(self):
        '''
        Unlock the mutex
        '''
        if self.handle is None:
            raise RuntimeError()
        fcntl.lockf(self.handle, fcntl.LOCK_UN)
        self.handle.close()
        self.handle = None

    def __enter__(self):
        self.acquire()

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()


def cached_dirpklgz(dirname, maxsize=128):
    '''
    Cache a function with a directory

    :param dirname: the directory path
    :param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
    '''

    def decorator(func):
        '''
        The actual decorator
        '''

        @lru_cache(maxsize=maxsize)
        @wraps(func)
        def wrapper(*args, **kwargs):
            '''
            The wrapper of the function
            '''
            try:
                os.makedirs(dirname)
            except FileExistsError:
                pass

            indexfile = os.path.join(dirname, "index.pkl")
            mutexfile = os.path.join(dirname, "mutex")

            with FileSystemMutex(mutexfile):
                try:
                    with open(indexfile, "rb") as file:
                        index = pickle.load(file)
                except FileNotFoundError:
                    index = {}

                key = (args, frozenset(kwargs), func.__defaults__)

                try:
                    filename = index[key]
                except KeyError:
                    index[key] = filename = "{}.pkl.gz".format(len(index))
                    with open(indexfile, "wb") as file:
                        pickle.dump(index, file)

            filepath = os.path.join(dirname, filename)

            try:
                with FileSystemMutex(mutexfile):
                    with gzip.open(filepath, "rb") as file:
                        result = pickle.load(file)
            except FileNotFoundError:
                print("compute {}... ".format(filename), end="")
                sys.stdout.flush()
                result = func(*args, **kwargs)
                print("save {}... ".format(filename), end="")
                sys.stdout.flush()
                with FileSystemMutex(mutexfile):
                    with gzip.open(filepath, "wb") as file:
                        pickle.dump(result, file)
                print("done")
            return result

        return wrapper

    return decorator

In [9]:
import torch
import math
import numpy as np

class torch_default_dtype:

    def __init__(self, dtype):
        self.saved_dtype = None
        self.dtype = dtype

    def __enter__(self):
        self.saved_dtype = torch.get_default_dtype()
        torch.set_default_dtype(self.dtype)

    def __exit__(self, exc_type, exc_value, traceback):
        torch.set_default_dtype(self.saved_dtype)
        


def irr_repr(order, alpha, beta, gamma, dtype=None):
    """
    irreducible representation of SO3
    - compatible with compose and spherical_harmonics
    """
    # from from_lielearn_SO3.wigner_d import wigner_D_matrix
    from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
    # if order == 1:
    #     # change of basis to have vector_field[x, y, z] = [vx, vy, vz]
    #     A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
    #     return A @ wigner_D_matrix(1, alpha, beta, gamma) @ A.T

    # TODO (non-essential): try to do everything in torch
    # return torch.tensor(wigner_D_matrix(torch.tensor(order), alpha, beta, gamma), dtype=torch.get_default_dtype() if dtype is None else dtype)
    return torch.tensor(wigner_D_matrix(order, np.array(alpha), np.array(beta), np.array(gamma)), dtype=torch.get_default_dtype() if dtype is None else dtype)




In [10]:
### example for irrep
# Weigner_D matrix
irr_repr(1, 4., 3., 2.)

tensor([[-0.4093, -0.1068, -0.9061],
        [ 0.1283, -0.9900,  0.0587],
        [-0.9033, -0.0922,  0.4189]])

In [11]:
import torch
import math
import numpy as np
from torch import kron

################################################################################
# Solving the constraint coming from the stabilizer of 0 and e
################################################################################

# Get's eigenvectors for eigvalue equel close by eps to zero
def get_matrix_kernel(A, eps=1e-10):
    '''
    Compute an orthonormal basis of the kernel (x_1, x_2, ...)
    A x_i = 0
    scalar_product(x_i, x_j) = delta_ij

    :param A: matrix
    :return: matrix where each row is a basis vector of the kernel of A
    '''
    _u, s, v = torch.svd(A)

    # A = u @ torch.diag(s) @ v.t()
    kernel = v.t()[s < eps]
    return kernel


# Stacks the matrix to big matrix and does the same as previous function
def get_matrices_kernel(As, eps=1e-10):
    '''
    Computes the commun kernel of all the As matrices
    '''
    return get_matrix_kernel(torch.cat(As, dim=0), eps)


@cached_dirpklgz("cache/trans_Q")
def _basis_transformation_Q_J(J, order_in, order_out, version=3):  # pylint: disable=W0613
    """
    :param J: order of the spherical harmonics
    :param order_in: order of the input representation
    :param order_out: order of the output representation
    :return: one part of the Q^-1 matrix of the article
    """
    with torch_default_dtype(torch.float64):
        def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))

        def _sylvester_submatrix(J, a, b, c):
            ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
            R_tensor = _R_tensor(a, b, c)  # [m_out * m_in, m_out * m_in]
            R_irrep_J = irr_repr(J, a, b, c)  # [m, m]
            return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
                kron(torch.eye(R_tensor.size(0)).t(), R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]
        
        # some random angles to enshure equivariance
        random_angles = [
            [4.41301023, 5.56684102, 4.59384642],
            [4.93325116, 6.12697327, 4.14574096],
            [0.53878964, 4.09050444, 5.36539036],
            [2.16017393, 3.48835314, 5.55174441],
            [2.52385107, 0.2908958, 3.90040975]
        ]
        null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles])
        assert null_space.size(0) == 1, null_space.size()  # unique subspace solution
        Q_J = null_space[0]  # [(m_out * m_in) * m]
        Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1)  # [m_out * m_in, m]
        assert all(torch.allclose(_R_tensor(a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in torch.rand(4, 3))

    assert Q_J.dtype == torch.float64
    return Q_J  # [m_out * m_in, m]

In [12]:
 _basis_transformation_Q_J(1, 1, 1, version=3)

tensor([[ 9.0557e-18, -1.1249e-16, -7.4619e-17],
        [ 1.3221e-17,  1.2627e-17, -4.0825e-01],
        [-9.2577e-17,  4.0825e-01,  1.0174e-16],
        [ 2.5382e-17, -1.7291e-16,  4.0825e-01],
        [ 8.9704e-17,  4.9655e-18,  3.9990e-17],
        [-4.0825e-01, -2.6838e-17, -3.0802e-17],
        [-3.0997e-17, -4.0825e-01, -7.5482e-18],
        [ 4.0825e-01,  3.4930e-17,  3.7093e-17],
        [-4.0149e-18, -3.0270e-17, -1.6411e-17]], dtype=torch.float64)

In [13]:
from torch import kron

# some code deconstruction of Q^lk_j^{-1}
# helper function that calculates the Kronecker product between
# type-k and type-l Wigner-D matrices for rotations a, b, c
def _R_tensor(a, b, c):
    # kron calculates the kroneker product between two matrices
    # irr_repr returns the irrep from (type, alpha, beta, gamma)
    # Remember the order A x B = kron(B, A)
    return kron(irr_repr(l, a, b, c), irr_repr(k, a, b, c))

l = 1
k = 2
a = 4.; b = 3.; c = 2.;

print("l tensor shape", irr_repr(l, a, b, c).shape)
print("k tensor shape", irr_repr(k, a, b, c).shape)
print("Kron shape", _R_tensor(a, b, c).shape)

l tensor shape torch.Size([3, 3])
k tensor shape torch.Size([5, 5])
Kron shape torch.Size([15, 15])


In [14]:
# Computes submatrix to solve sylvester equation
# AX - XB = 0
# same as (I x A - B^T x I) vec(X) = 0
def _sylvester_submatrix(J, a, b, c):
    # Calculates the Kroneker product between type-l and type-k
    # Wigner-D matrices for rotation angles a, b, c
    R_tensor = _R_tensor(a, b, c) # [(2l + 1)*(2k + 1), (2l + 1)*(2k + 1)]
    # Calculates type-J Wigner-D matrix for same rotation
    R_irrep_J = irr_repr(J, a, b, c) # [2J + 1, 2J + 1]
    # .reshape(9).reshape(3, 3) Annoying stuff due to some torch bug with data placement in memory
    return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - kron(torch.eye(R_tensor.size(0)).t(), R_irrep_J.t())

J = 3

_sylvester_submatrix(J, a, b, c).shape

torch.Size([105, 105])

In [15]:
# Check on random angles
with torch_default_dtype(torch.float64): # !!! Important otherwise zero
    # some random angles to enshure equivariance
    random_angles = [
        [4.41301023, 5.56684102, 4.59384642],
        [4.93325116, 6.12697327, 4.14574096],
        [0.53878964, 4.09050444, 5.36539036],
        [2.16017393, 3.48835314, 5.55174441],
        [2.52385107, 0.2908958, 3.90040975]
    ]
    # Calculate the vector that is solution of the homogeneous equation
    # for all sets of angles
    null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c)
                                      for a, b, c in random_angles])
    # confirm that the solution is unique
    assert null_space.size(0) == 1, null_space.size()

In [16]:
# Final Q^lk_J compute and reshape to (2 * l + 1) * (2 * k + 1) * (2 * J + 1)
with torch_default_dtype(torch.float64): # !!! Important otherwise zero
    Q_J = null_space[0] # only one vector
    Q_J = Q_J.view((2 * l + 1) * (2 * k + 1), 2 * J + 1) # unvectorize
    assert all(torch.allclose(_R_tensor(a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) 
               for a, b, c in torch.rand(4, 3)) # sanity check that is a solution
    
Q_J.shape

torch.Size([15, 7])

### Spherical harmonics

In [45]:
import time

import torch
import numpy as np
from scipy.special import lpmv as lpmv_scipy


def semifactorial(x):
    """Compute the semifactorial function x!!.

    x!! = x * (x-2) * (x-4) *...

    Args:
        x: positive int
    Returns:
        float for x!!
    """
    y = 1.
    for n in range(x, 1, -2):
        y *= n
    return y


def pochhammer(x, k):
    """Compute the pochhammer symbol (x)_k.

    (x)_k = x * (x+1) * (x+2) *...* (x+k-1)

    Args:
        x: positive int
    Returns:
        float for (x)_k
    """
    xf = float(x)
    for n in range(x+1, x+k):
        xf *= n
    return xf

def lpmv(l, m, x):
    """Associated Legendre function including Condon-Shortley phase.

    Args:
        m: int order 
        l: int degree
        x: float argument tensor
    Returns:
        tensor of x-shape
    """
    m_abs = abs(m)
    
    # P^m_J = 0 forall m > J
    if m_abs > l:
        return torch.zeros_like(x)

    # Compute P_m^m
    # P_m^m = (-1)^J (1 - x^2)^(J/2) (2J - 1)!!
    yold = ((-1)**m_abs * semifactorial(2*m_abs-1)) * torch.pow(1-x*x, m_abs/2)
    
    # Compute P_{m+1}^m
    # P_m+1^m = x (2 m + 1) P^m_m
    if m_abs != l:
        y = x * (2*m_abs+1) * yold
    else:
        y = yold

    # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
    # P_l^m (x) = [(2 l - 1 ) / ( l - m )] P^m_{l - 1} (x) - [(l + m - 1) / (l - m)] P^m_{l - 2} (x)
    for i in range(m_abs+2, l+1):
        tmp = y
        # Inplace speedup
        y = ((2*i-1) / (i-m_abs)) * x * y
        y -= ((i+m_abs-1)/(i-m_abs)) * yold
        yold = tmp

    # P^-m_l (x) = (-1)^m (l - m)!/(l + m)! P^m_l (x)
    if m < 0:
        y *= ((-1)**m / pochhammer(l+m+1, -2*m))

    return y

In [46]:
leg = {}
J = 10
m = 5
x = torch.Tensor([0.5, 0.1, 0.2])
ans_1 = lpmv(J, m, x)

In [47]:
class SphericalHarmonics(object):
    def __init__(self):
        self.leg = {}

    def clear(self):
        self.leg = {}

    def negative_lpmv(self, l, m, y):
        """Compute negative order coefficients"""
        if m < 0:
            y *= ((-1)**m / pochhammer(l+m+1, -2*m))
        return y

    def lpmv(self, l, m, x):
        """Associated Legendre function including Condon-Shortley phase.

        Args:
            m: int order 
            l: int degree
            x: float argument tensor
        Returns:
            tensor of x-shape
        """
        # Check memoized versions
        m_abs = abs(m)
        if (l,m) in self.leg:
            return self.leg[(l,m)]
        elif m_abs > l:
            return None
        elif l == 0:
            self.leg[(l,m)] = torch.ones_like(x)
            return self.leg[(l,m)]
        
        # Check if on boundary else recurse solution down to boundary
        if m_abs == l:
            # Compute P_m^m
            y = (-1)**m_abs * semifactorial(2*m_abs-1)
            y *= torch.pow(1-x*x, m_abs/2)
            self.leg[(l,m)] = self.negative_lpmv(l, m, y)
            return self.leg[(l,m)]
        else:
            # Recursively precompute lower degree harmonics
            self.lpmv(l-1, m, x)

        # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
        # Inplace speedup
        y = ((2*l-1) / (l-m_abs)) * x * self.lpmv(l-1, m_abs, x)
        if l - m_abs > 1:
            y -= ((l+m_abs-1)/(l-m_abs)) * self.leg[(l-2, m_abs)]
        #self.leg[(l, m_abs)] = y
        
        if m < 0:
            y = self.negative_lpmv(l, m, y)
        self.leg[(l,m)] = y

        return self.leg[(l,m)]

    def get_element(self, l, m, theta, phi):
        """Tesseral spherical harmonic with Condon-Shortley phase.

        The Tesseral spherical harmonics are also known as the real spherical
        harmonics.

        Args:
            l: int for degree
            m: int for order, where -l <= m < l
            theta: collatitude or polar angle
            phi: longitude or azimuth
        Returns:
            tensor of shape theta
        """
        assert abs(m) <= l, "absolute value of order m must be <= degree l"

        N = np.sqrt((2*l+1) / (4*np.pi))
        leg = self.lpmv(l, abs(m), torch.cos(theta))
        if m == 0:
            return N*leg
        elif m > 0:
            Y = torch.cos(m*phi) * leg
        else:
            Y = torch.sin(abs(m)*phi) * leg
        N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m)))
        Y *= N
        return Y

    def get(self, l, theta, phi, refresh=True):
        """Tesseral harmonic with Condon-Shortley phase.

        The Tesseral spherical harmonics are also known as the real spherical
        harmonics.

        Args:
            l: int for degree
            theta: collatitude or polar angle
            phi: longitude or azimuth
        Returns:
            tensor of shape [*theta.shape, 2*l+1]
        """
        results = []
        if refresh:
            self.clear()
        for m in range(-l, l+1):
            results.append(self.get_element(l, m, theta, phi))
        return torch.stack(results, -1)

### My own implementation from spherical harmonics from the blog post

In [48]:
# First we will implement ALP (Associated Legandre Polynomials)

# function analagous to pochammer function in the SE(3) Transformer
# (J - m)!/(J + m)!
def falling_factorial(J, m):
    # computes (J + m)*(J+m - 1)*...(J-m+1)
    f = 1.
    for n in range(J + m, J - m, -1):
        f *= n
    return f


def semifactorial(x):
    """Compute the semifactorial function x!!.

    x!! = x * (x-2) * (x-4) *...

    Args:
        x: positive int
    Returns:
        float for x!!
    """
    y = 1.
    for n in range(x, 1, -2):
        y *= n
    return y

# y: Legendre polynimil for the absolute value of m
# P_J^{-m} (x) = (-1)^m (J - m)!/(J + m)! P_J^m(x)
def negative_lpmv(J, m, y):
    # check if m is negative
    if m < 0:
        # multiply y with the coefficient containing the falling 
        # factorial
        y *= ((-1)**m / falling_factorial(J, m))
    return y


In [49]:
falling_factorial(5, 5)

3628800.0

In [59]:
# recursive implementation of APL
def lpmv(J, m, x):
    # get the absolute value of m
    m_abs = abs(m)
    # check if the polynomial has already been computed
    if (J, m) in leg:
        return leg[(J, m)]
    # check if m is out of range -J to J
    elif m_abs > J:
        return None
    # if J = 0, the associated Legendre polynomial is equal to 1
    elif J == 0:
        # return tensor of 1s with the same shape as x
        leg[(l, m)] = torch.ones_like(x)
        return leg[(l, m)]
    
    # if |m| = J, compute the polynomial using the equation from step 1
    if m_abs == J:
        # P^J_J (x) = (-1)^J (1 - x^2)^(J/2) (2J - 1)!!
        # calculate coefficient term
        y = (-1)**J * semifactorial(2*J - 1)
        # multiply by the term dependent on x
        y *= torch.pow(1 - x*x, m_abs/2)
        # negative_lpmv returns y if m is positive and y multiplied by 
        # the negative coefficient defined in step 4 if m is negative
        leg[(l, m)] = negative_lpmv(l, m, y)
        return leg[(l, m)]
    else:
        # retursive call to compute lower degree polynomials up to
        # boundary m = J
        lpmv(J - 1, m, x)
        
    # if m is not on the boundary, first compute the first term of the relation
    # defined in step 3
    # P^J-1_J (x) = x (2J + 1) P^m_(J - 1) (x)
    # if m_abs = J - 1, then this calculates the relation defined in step 2
    y = ((2*J - 1)/(J - m_abs)) * x * lpmv(J - 1, m_abs, x)
    
    # P_J^m(x) = (2J - 1)/(J - m) x P^m_{J - 1}(x) - (J + m - 1)/(J - m) P^m_{J - 2} (x)
    # check if m_abs != J - 1, then add the second term defined in step 3
    if l - m_abs > 1:
        y -= ((l + m_abs -1)/(l - m_abs)) * leg[(l - 2, m_abs)]
        
    # if m is negative, return the polynomial for m_abs scaled by the 
    # negative coefficient
    if m < 0:
        y = negative_lpmv(l, m, y)
        
    leg[(l, m)] = y
    
    return leg[(l, m)]

In [60]:
leg = {}
J = 10
m = 5
x = torch.Tensor([0.5, 0.1, 0.2])
ans_2 = lpmv(J, m, x)

In [61]:
ans_1

tensor([ 30086.1719, -21961.9492, -26591.5078])

In [62]:
ans_2

tensor([-8.3058e+04, -5.3207e+01, -1.5766e+03])

In [63]:
# spherical harmonics

def get_element(J, m, theta, phi):
    assert abs(m) <= J, "m must be in the range -J to J"
    
    # calculates the first fraction in the square root
    N = np.sqrt((2*J + 1) / (4 * np.pi))
    # stores the ALP term in leg
    leg = lpmv(J, abs(m), torch.cos(theta))
    
    # multiply by the phi dependent term depending on the value of m
    if m == 0:
        # when m = 0 the other fraction in the square root cancels 
        # and the phi dependent term is 1
        return N * leg
    elif m > 0:
        Y = torch.cos(m*phi) * leg
    else:
        Y = torch.sin(abs(m) * phi) * leg
        
    # multiply by a square root of the inverse falling factorial
    # which is the same as in the ALP
    N *= np.sqrt(2. / falling_factorial(J, abs(m)))
    # multiplies the coefficient with angle-dependent term
    Y *= N
    return Y

In [67]:
# compute Y_J^m, m in {-J, J}
def get(J, theta, phi, refresh = True):
    # initialize tensor
    results = []
    
    # loop over all possible values of m from J to J and add the 
    # computed spherical harmonic to results
    for m in range(-J, J + 1):
        results.append(get_element(J, m, theta, phi))
        
    return torch.stack(results, -1)

In [72]:
get(3, torch.Tensor([0.5, 0.2, 0.1]), torch.Tensor([0.5, 0.2, 0.1]))

tensor([[-6.4857e-02,  2.4532e-01,        -inf,         inf,        -inf,
          1.5752e-01, -4.5993e-03],
        [-2.6125e-03,  2.1772e-02,        -inf,         inf,        -inf,
          5.1495e-02, -3.8186e-03],
        [-1.7350e-04,  2.8475e-03,        -inf,         inf,        -inf,
          1.4047e-02, -5.6087e-04]])

In [None]:
# r_ij: the relative displacement between nodes in spherical
# coordinates [radius, alpha, beta]
# beta = pi - theta (beta is 0 at south pole and pi at north pole;
# supplementary to theta
# alpha = phi (ranges from 0 to 2 pi)
# r_ij: shape (batch_size, nodes, neighbors, 3 (r_ij))
def precompute_sh(r_ij, max_J):
    # initialize dictionary where 