### Making wrapper for QM9 dataset

In [2]:
# QM9 -> dgl

import os
import sys

import dgl
import numpy as np
import torch


from torch.utils.data import Dataset, DataLoader

from scipy.constants import physical_constants

hartree2eV = physical_constants['hartree-electron volt relationship'][0]
DTYPE = np.float32
DTYPE_INT = np.int32

class QM9Dataset(Dataset):
    """QM9 dataset."""
    num_bonds = 4
    atom_feature_size = 6 
    input_keys = ['mol_id', 'num_atoms', 'num_bonds', 'x', 'one_hot', 
                  'atomic_numbers', 'edge']
    unit_conversion = {'mu': 1.0,
                       'alpha': 1.0,
                       'homo': hartree2eV,
                       'lumo': hartree2eV,
                       'gap': hartree2eV, 
                       'r2': 1.0, 
                       'zpve': hartree2eV, 
                       'u0': hartree2eV, 
                       'u298': hartree2eV, 
                       'h298': hartree2eV,
                       'g298': hartree2eV,
                       'cv': 1.0} 

    def __init__(self, file_address: str, task: str, mode: str='train', 
            transform=None, fully_connected: bool=False): 
        """Create a dataset object

        Args:
            file_address: path to data
            task: target task ["homo", ...]
            mode: [train/val/test] mode
            transform: data augmentation functions
            fully_connected: return a fully connected graph
        """
        self.file_address = file_address
        self.task = task
        self.mode = mode
        self.transform = transform
        self.fully_connected = fully_connected

        # Encode and extra bond type for fully connected graphs
        self.num_bonds += fully_connected

        self.load_data()
        self.len = len(self.targets)
        print(f"Loaded {mode}-set, task: {task}, source: {self.file_address}, length: {len(self)}")

    
    def __len__(self):
        return self.len

    
    def load_data(self):
        # Load dict and select train/valid/test split
        data = torch.load(self.file_address)
        data = data[self.mode]
    
        # Filter out the inputs
        self.inputs = {key: data[key] for key in self.input_keys}

        # Filter out the targets and population stats
        self.targets = data[self.task]

        # TODO: use the training stats unlike the other papers
        self.mean = np.mean(self.targets)
        self.std = np.std(self.targets)


    def get_target(self, idx, normalize=True):
        target = self.targets[idx]
        if normalize:
            target = (target - self.mean) / self.std
        return target


    def norm2units(self, x, denormalize=True, center=True):
        # Convert from normalized to QM9 representation
        if denormalize:
            x = x * self.std
            # Add the mean: not necessary for error computations
            if not center:
                x += self.mean
        x = self.unit_conversion[self.task] * x
        return x


    def to_one_hot(self, data, num_classes):
        one_hot = np.zeros(list(data.shape) + [num_classes])
        one_hot[np.arange(len(data)),data] = 1
        return one_hot


    def _get_adjacency(self, n_atoms):
        # Adjust adjacency structure
        seq = np.arange(n_atoms)
        src = seq[:,None] * np.ones((1,n_atoms), dtype=np.int32)
        dst = src.T
        ## Remove diagonals and reshape
        src[seq, seq] = -1
        dst[seq, seq] = -1
        src, dst = src.reshape(-1), dst.reshape(-1)
        src, dst = src[src > -1], dst[dst > -1]
            
        return src, dst


    def get(self, key, idx):
        return self.inputs[key][idx]


    def connect_fully(self, edges, num_atoms):
        """Convert to a fully connected graph"""
        # Initialize all edges: no self-edges
        adjacency = {}
        for i in range(num_atoms):
            for j in range(num_atoms):
                if i != j:
                    # assigning new type of connection if originally not connected
                    adjacency[(i, j)] = self.num_bonds - 1 

        # Add bonded edges
        for idx in range(edges.shape[0]):
            adjacency[(edges[idx,0], edges[idx,1])] = edges[idx,2]
            adjacency[(edges[idx,1], edges[idx,0])] = edges[idx,2]

        # Convert to numpy arrays
        src = []
        dst = []
        w = []
        for edge, weight in adjacency.items():
            src.append(edge[0])
            dst.append(edge[1])
            w.append(weight)

        return np.array(src), np.array(dst), np.array(w)


    def connect_partially(self, edge):
        src = np.concatenate([edge[:,0], edge[:,1]])
        dst = np.concatenate([edge[:,1], edge[:,0]])
        w = np.concatenate([edge[:,2], edge[:,2]])
        return src, dst, w


    def __getitem__(self, idx):
        # Load node features
        num_atoms = self.get('num_atoms', idx) # number of atoms
        x = self.get('x', idx)[:num_atoms].astype(DTYPE) # coordinates of atoms
        one_hot = self.get('one_hot', idx)[:num_atoms].astype(DTYPE)
        atomic_numbers = self.get('atomic_numbers', idx)[:num_atoms].astype(DTYPE)

        # Load edge features
        num_bonds = self.get('num_bonds', idx)
        edge = self.get('edge', idx)[:num_bonds]
        edge = np.asarray(edge, dtype=DTYPE_INT)

        # Load target
        y = self.get_target(idx, normalize=True).astype(DTYPE)
        y = np.array([y])

        # Augmentation on the coordinates
        if self.transform:
            x = self.transform(x).astype(DTYPE)

        # Create nodes
        if self.fully_connected:
            src, dst, w = self.connect_fully(edge, num_atoms)
        else:
            src, dst, w = self.connect_partially(edge)
        w = self.to_one_hot(w, self.num_bonds).astype(DTYPE)

        # Create graph
        G = dgl.DGLGraph((src, dst))

        # Add node features to graph
        G.ndata['x'] = torch.tensor(x) #[num_atoms,3]
        G.ndata['f'] = torch.tensor(np.concatenate([one_hot, atomic_numbers], -1)[...,None]) #[num_atoms,6,1]

        # Add edge features to graph
        G.edata['d'] = torch.tensor(x[dst] - x[src]) #[num_atoms,3]
        G.edata['w'] = torch.tensor(w) #[num_atoms,4]

        return G, y



def collate(samples):
    graphs, y = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(y)

dataset = QM9Dataset('./QM9_data/QM9_data.pt', "homo", mode='train', fully_connected=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)

iter_dataloader = iter(dataloader) # so I can use next
for i in range(1):
    data = next(iter_dataloader)
    print("MINIBATCH")
    print(data[0]) 
    print(data[1].shape) # batch size -> connected graph of size batch

  data = torch.load(self.file_address)


Loaded train-set, task: homo, source: ./QM9_data/QM9_data.pt, length: 100000
MINIBATCH
Graph(num_nodes=575, num_edges=10254,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32), 'f': Scheme(shape=(6, 1), dtype=torch.float32)}
      edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(5,), dtype=torch.float32)})
torch.Size([32, 1])


In [3]:
import matplotlib.pyplot as plt

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cuda')

In [5]:
def _get_adjacency(n_atoms):
    # Adjust adjacency structure
    seq = np.arange(n_atoms)
    src = seq[:,None] * np.ones((1,n_atoms), dtype=np.int32)
    dst = src.T
    ## Remove diagonals and reshape
    src[seq, seq] = -1
    dst[seq, seq] = -1
    src, dst = src.reshape(-1), dst.reshape(-1)
    src, dst = src[src > -1], dst[dst > -1]

    return src, dst

_get_adjacency(10) # from src to dst fully connected in this case, no self connections

# all the connections are represented in the form
# src[0]   src[1]   ...
# dst[0]   dst[1]   ...
# w[0]     w[1]

(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9,
        9, 9]),
 array([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 3, 4,
        5, 6, 7, 8, 9, 0, 1, 2, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 5, 6, 7, 8,
        9, 0, 1, 2, 3, 4, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 7, 8, 9, 0, 1, 2,
        3, 4, 5, 6, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 9, 0, 1, 2, 3, 4, 5, 6,
        7, 8]))

### DGL example

In [6]:
import dgl
import dgl.function as fn


dataset = data



# initialize dgl graph
G = dataset[0]

ntype = 'x'
# ndata is a dict
# retrieve data from all nodes labels ntype (position example)
print(G.ndata[ntype]) 

d = 5
G.ndata[f'out{d}'] = torch.tensor(np.zeros((len(G.ndata['x']), 5)))
# retrieve output features of type d from node data
print(G.ndata[f'out{d}']) 


etype = 'd'
# retrive data from all edges labeled etype
print(G.edata[etype].shape)

di = 'd'
do = 'w'
G.edata[f'({di},{do})'] = torch.tensor(np.random.rand(G.edata[etype].shape[0], 3, 4))
# retrive edge kernels that transform from type di to type di
print(G.edata[f'({di},{do})'])

e = 'd' # edge feature (distance)
v = 'x' # node features (cartesian vector)
m = 'm' # output message on the edge
# calling built in dgl fubction e_dot_v that computes a message on
# edge by performing element-wise dot between features of e and v
# and stores it as edge message labeled 'm'
f = fn.e_dot_v(e, v, m)

# applies the function f to update the features of the edges with function
G.apply_edges(f)

print(G.edata['m'].shape)

tensor([[-0.1268,  1.4886, -0.3711],
        [-0.0493,  0.0360,  0.0684],
        [-1.3388, -0.5307, -0.0172],
        ...,
        [-1.7919, -2.3873,  1.5352],
        [ 1.7500, -0.5402,  1.3364],
        [ 0.4925,  0.5138,  2.0511]])
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], dtype=torch.float64)
torch.Size([10254, 3])
tensor([[[0.2344, 0.6642, 0.0072, 0.3524],
         [0.1158, 0.8405, 0.0173, 0.1057],
         [0.5976, 0.4465, 0.7831, 0.4323]],

        [[0.9599, 0.0646, 0.9775, 0.0375],
         [0.3635, 0.0576, 0.7881, 0.6512],
         [0.0792, 0.4827, 0.6080, 0.1550]],

        [[0.1088, 0.4255, 0.4270, 0.9816],
         [0.5244, 0.2549, 0.4082, 0.5101],
         [0.6430, 0.9398, 0.0098, 0.3039]],

        ...,

        [[0.1396, 0.4124, 0.5984, 0.5986],
         [0.6577, 0.0709, 0.0222, 0.6742],
         [0.7836, 0.3820, 0.1077,

### Fibers

In [7]:
#from utils.utils_profiling import * # load before other local modules
try:
    profile
except NameError:
    def profile(func):
        return func

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

from typing import Dict, List, Tuple


class Fiber(object):
    """A Handy Data Structure for Fibers"""
    def __init__(self, num_degrees: int=None, num_channels: int=None,
                 structure: List[Tuple[int,int]]=None, dictionary=None):
        """
        define fiber structure; use one num_degrees & num_channels OR structure
        OR dictionary

        :param num_degrees: degrees will be [0, ..., num_degrees-1]
        :param num_channels: number of channels, same for each degree
        :param structure: e.g. [(32, 0),(16, 1),(16,2)]
        :param dictionary: e.g. {0:32, 1:16, 2:16}
        
        Structure in the form: List[(Tuple[int, int])]. In particular Features[(num_channels, feature_degree)]
        """
        
        if structure:
            self.structure = structure
        elif dictionary:
            self.structure = [(dictionary[o], o) for o in sorted(dictionary.keys())]
        else:
            self.structure = [(num_channels, i) for i in range(num_degrees)]

            
        # assigning to dict format and computing cummulative variables
        self.multiplicities, self.degrees = zip(*self.structure)
        self.max_degree = max(self.degrees)
        self.min_degree = min(self.degrees)
        self.structure_dict = {k: v for v, k in self.structure}
        self.dict = self.structure_dict
        self.n_features = np.sum([i[0] * (2*i[1]+1) for i in self.structure])

        
        # Mapping to vec() case. f = [...] with starting ind saved in feature ind dict
        # feature_ind = {degree: starting ind}
        self.feature_indices = {}
        idx = 0
        for (num_channels, d) in self.structure:
            length = num_channels * (2*d + 1)
            self.feature_indices[d] = (idx, idx + length)
            idx += length

    def copy_me(self, multiplicity: int=None):
        s = copy.deepcopy(self.structure)
        if multiplicity is not None:
            # overwrite multiplicities
            s = [(multiplicity, o) for m, o in s]
        return Fiber(structure=s)

    @staticmethod
    def combine(f1, f2):
        new_dict = copy.deepcopy(f1.structure_dict)
        for k, m in f2.structure_dict.items():
            if k in new_dict.keys():
                new_dict[k] += m
            else:
                new_dict[k] = m
        structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
        return Fiber(structure=structure)

    @staticmethod
    def combine_max(f1, f2):
        new_dict = copy.deepcopy(f1.structure_dict)
        for k, m in f2.structure_dict.items():
            if k in new_dict.keys():
                new_dict[k] = max(m, new_dict[k])
            else:
                new_dict[k] = m
        structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
        return Fiber(structure=structure)

    @staticmethod
    def combine_selectively(f1, f2):
        # only use orders which occur in fiber f1

        new_dict = copy.deepcopy(f1.structure_dict)
        for k in f1.degrees:
            if k in f2.degrees:
                new_dict[k] += f2.structure_dict[k]
        structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
        return Fiber(structure=structure)

    @staticmethod
    def combine_fibers(val1, struc1, val2, struc2):
        """
        combine two fibers

        :param val1/2: fiber tensors in dictionary form
        :param struc1/2: structure of fiber
        :return: fiber tensor in dictionary form
        """
        struc_out = Fiber.combine(struc1, struc2)
        val_out = {}
        for k in struc_out.degrees:
            if k in struc1.degrees:
                if k in struc2.degrees:
                    val_out[k] = torch.cat([val1[k], val2[k]], -2)
                else:
                    val_out[k] = val1[k]
            else:
                val_out[k] = val2[k]
                
            # number of channels is the second dimenstion from the end I guess
            # might look like [tensor_axis = degree, channel axis, tensor-component axis]
            assert val_out[k].shape[-2] == struc_out.structure_dict[k]
        return val_out

    def __repr__(self):
        return f"{self.structure}"



def get_fiber_dict(F, struc, mask=None, return_struc=False):
    if mask is None: mask = struc
    index = 0
    fiber_dict = {}
    first_dims = F.shape[:-1]
    masked_dict = {}
    for o, m in struc.structure_dict.items():
        length = m * (2*o + 1)
        if o in mask.degrees:
            masked_dict[o] = m
            fiber_dict[o] = F[...,index:index + length].view(list(first_dims) + [m, 2*o + 1])
        index += length
    assert F.shape[-1] == index
    if return_struc:
        return fiber_dict, Fiber(dictionary=masked_dict)
    return fiber_dict


def get_fiber_tensor(F, struc):
    some_entry = tuple(F.values())[0]
    first_dims = some_entry.shape[:-2]
    res = some_entry.new_empty([*first_dims, struc.n_features])
    index = 0
    for o, m in struc.structure_dict.items():
        length = m * (2*o + 1)
        res[..., index: index + length] = F[o].view(*first_dims, length)
        index += length
    assert index == res.shape[-1]
    return res


def fiber2tensor(F, structure, squeeze=False):
    if squeeze:
        fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1) for i in structure.degrees]
        fibers = torch.cat(fibers, -1)
    else:
        fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1, 1) for i in structure.degrees]
        fibers = torch.cat(fibers, -2)
    return fibers


# Reduce fibers into single tensor cell h (I guess)
def fiber2head(F, h, structure, squeeze=False):
    if squeeze:
        fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1) for i in structure.degrees]
        fibers = torch.cat(fibers, -1)
    else:
        fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1, 1) for i in structure.degrees]
        fibers = torch.cat(fibers, -2)
    return fibers

### Basis transformation matrixes and irreducable representation

In [8]:
'''
Cache in files
'''
from functools import wraps, lru_cache
import pickle
import gzip
import os
import sys
import fcntl


class FileSystemMutex:
    '''
    Mutual exclusion of different **processes** using the file system
    '''

    def __init__(self, filename):
        self.handle = None
        self.filename = filename

    def acquire(self):
        '''
        Locks the mutex
        if it is already locked, it waits (blocking function)
        '''
        self.handle = open(self.filename, 'w')
        fcntl.lockf(self.handle, fcntl.LOCK_EX)
        self.handle.write("{}\n".format(os.getpid()))
        self.handle.flush()

    def release(self):
        '''
        Unlock the mutex
        '''
        if self.handle is None:
            raise RuntimeError()
        fcntl.lockf(self.handle, fcntl.LOCK_UN)
        self.handle.close()
        self.handle = None

    def __enter__(self):
        self.acquire()

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()


def cached_dirpklgz(dirname, maxsize=128):
    '''
    Cache a function with a directory

    :param dirname: the directory path
    :param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
    '''

    def decorator(func):
        '''
        The actual decorator
        '''

        @lru_cache(maxsize=maxsize)
        @wraps(func)
        def wrapper(*args, **kwargs):
            '''
            The wrapper of the function
            '''
            try:
                os.makedirs(dirname)
            except FileExistsError:
                pass

            indexfile = os.path.join(dirname, "index.pkl")
            mutexfile = os.path.join(dirname, "mutex")

            with FileSystemMutex(mutexfile):
                try:
                    with open(indexfile, "rb") as file:
                        index = pickle.load(file)
                except FileNotFoundError:
                    index = {}

                key = (args, frozenset(kwargs), func.__defaults__)

                try:
                    filename = index[key]
                except KeyError:
                    index[key] = filename = "{}.pkl.gz".format(len(index))
                    with open(indexfile, "wb") as file:
                        pickle.dump(index, file)

            filepath = os.path.join(dirname, filename)

            try:
                with FileSystemMutex(mutexfile):
                    with gzip.open(filepath, "rb") as file:
                        result = pickle.load(file)
            except FileNotFoundError:
                print("compute {}... ".format(filename), end="")
                sys.stdout.flush()
                result = func(*args, **kwargs)
                print("save {}... ".format(filename), end="")
                sys.stdout.flush()
                with FileSystemMutex(mutexfile):
                    with gzip.open(filepath, "wb") as file:
                        pickle.dump(result, file)
                print("done")
            return result

        return wrapper

    return decorator

In [9]:
import torch
import math
import numpy as np

class torch_default_dtype:

    def __init__(self, dtype):
        self.saved_dtype = None
        self.dtype = dtype

    def __enter__(self):
        self.saved_dtype = torch.get_default_dtype()
        torch.set_default_dtype(self.dtype)

    def __exit__(self, exc_type, exc_value, traceback):
        torch.set_default_dtype(self.saved_dtype)
        


def irr_repr(order, alpha, beta, gamma, dtype=None):
    """
    irreducible representation of SO3
    - compatible with compose and spherical_harmonics
    """
    # from from_lielearn_SO3.wigner_d import wigner_D_matrix
    from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
    # if order == 1:
    #     # change of basis to have vector_field[x, y, z] = [vx, vy, vz]
    #     A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
    #     return A @ wigner_D_matrix(1, alpha, beta, gamma) @ A.T

    # TODO (non-essential): try to do everything in torch
    # return torch.tensor(wigner_D_matrix(torch.tensor(order), alpha, beta, gamma), dtype=torch.get_default_dtype() if dtype is None else dtype)
    return torch.tensor(wigner_D_matrix(order, np.array(alpha), np.array(beta), np.array(gamma)), dtype=torch.get_default_dtype() if dtype is None else dtype)




In [10]:
### example for irrep
# Weigner_D matrix
irr_repr(1, 4., 3., 2.)

tensor([[-0.4093, -0.1068, -0.9061],
        [ 0.1283, -0.9900,  0.0587],
        [-0.9033, -0.0922,  0.4189]])

In [11]:
import torch
import math
import numpy as np

# @profile
def kron(a, b):
    """
    A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk

    Kronecker product of matrices a and b with leading batch dimensions.
    Batch dimensions are broadcast. The number of them mush
    :type a: torch.Tensor
    :type b: torch.Tensor
    :rtype: torch.Tensor
    """
    siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
    res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
    siz0 = res.shape[:-4]
    return res.reshape(siz0 + siz1)

################################################################################
# Solving the constraint coming from the stabilizer of 0 and e
################################################################################

# Get's eigenvectors for eigvalue equel close by eps to zero
def get_matrix_kernel(A, eps=1e-10):
    '''
    Compute an orthonormal basis of the kernel (x_1, x_2, ...)
    A x_i = 0
    scalar_product(x_i, x_j) = delta_ij

    :param A: matrix
    :return: matrix where each row is a basis vector of the kernel of A
    '''
    _u, s, v = torch.svd(A)

    # A = u @ torch.diag(s) @ v.t()
    kernel = v.t()[s < eps]
    return kernel


# Stacks the matrix to big matrix and does the same as previous function
def get_matrices_kernel(As, eps=1e-10):
    '''
    Computes the commun kernel of all the As matrices
    '''
    return get_matrix_kernel(torch.cat(As, dim=0), eps)


@cached_dirpklgz("cache/trans_Q")
def _basis_transformation_Q_J(J, order_in, order_out, version=3):  # pylint: disable=W0613
    """
    :param J: order of the spherical harmonics
    :param order_in: order of the input representation
    :param order_out: order of the output representation
    :return: one part of the Q^-1 matrix of the article
    """
    with torch_default_dtype(torch.float64):
        def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))

        def _sylvester_submatrix(J, a, b, c):
            ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
            R_tensor = _R_tensor(a, b, c)  # [m_out * m_in, m_out * m_in]
            R_irrep_J = irr_repr(J, a, b, c)  # [m, m]
            return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
                kron(torch.eye(R_tensor.size(0)), R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]
        
        # some random angles to enshure equivariance
        random_angles = [
            [4.41301023, 5.56684102, 4.59384642],
            [4.93325116, 6.12697327, 4.14574096],
            [0.53878964, 4.09050444, 5.36539036],
            [2.16017393, 3.48835314, 5.55174441],
            [2.52385107, 0.2908958, 3.90040975]
        ]
        null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles])
        assert null_space.size(0) == 1, null_space.size()  # unique subspace solution
        Q_J = null_space[0]  # [(m_out * m_in) * m]
        Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1)  # [m_out * m_in, m]
        assert all(torch.allclose(_R_tensor(a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in torch.rand(4, 3))

    assert Q_J.dtype == torch.float64
    return Q_J  # [m_out * m_in, m]

In [12]:
 _basis_transformation_Q_J(1, 1, 1, version=3).shape

  return torch.load(io.BytesIO(b))


torch.Size([9, 3])

In [13]:
# some code deconstruction of Q^lk_j^{-1}
# helper function that calculates the Kronecker product between
# type-k and type-l Wigner-D matrices for rotations a, b, c
def _R_tensor(a, b, c):
    # kron calculates the kroneker product between two matrices
    # irr_repr returns the irrep from (type, alpha, beta, gamma)
    # Remember the order A x B = kron(B, A)
    return kron(irr_repr(l, a, b, c), irr_repr(k, a, b, c))

l = 1
k = 2
a = 4.; b = 3.; c = 2.;

print("l tensor shape", irr_repr(l, a, b, c).shape)
print("k tensor shape", irr_repr(k, a, b, c).shape)
print("Kron shape", _R_tensor(a, b, c).shape)

l tensor shape torch.Size([3, 3])
k tensor shape torch.Size([5, 5])
Kron shape torch.Size([15, 15])


In [14]:
# Computes submatrix to solve sylvester equation
# AX - XB = 0
# same as (I x A - B^T x I) vec(X) = 0
def _sylvester_submatrix(J, a, b, c):
    # Calculates the Kroneker product between type-l and type-k
    # Wigner-D matrices for rotation angles a, b, c
    R_tensor = _R_tensor(a, b, c) # [(2l + 1)*(2k + 1), (2l + 1)*(2k + 1)]
    # Calculates type-J Wigner-D matrix for same rotation
    R_irrep_J = irr_repr(J, a, b, c) # [2J + 1, 2J + 1]
    # .reshape(9).reshape(3, 3) Annoying stuff due to some torch bug with data placement in memory
    return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - kron(torch.eye(R_tensor.size(0)), R_irrep_J.t())

J = 3

_sylvester_submatrix(J, a, b, c).shape

torch.Size([105, 105])

In [15]:
# Check on random angles
with torch_default_dtype(torch.float64): # !!! Important otherwise zero
    # some random angles to enshure equivariance
    random_angles = [
        [4.41301023, 5.56684102, 4.59384642],
        [4.93325116, 6.12697327, 4.14574096],
        [0.53878964, 4.09050444, 5.36539036],
        [2.16017393, 3.48835314, 5.55174441],
        [2.52385107, 0.2908958, 3.90040975]
    ]
    # Calculate the vector that is solution of the homogeneous equation
    # for all sets of angles
    null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c)
                                      for a, b, c in random_angles])
    # confirm that the solution is unique
    assert null_space.size(0) == 1, null_space.size()

In [16]:
# Final Q^lk_J compute and reshape to (2 * l + 1) * (2 * k + 1) * (2 * J + 1)
with torch_default_dtype(torch.float64): # !!! Important otherwise zero
    Q_J = null_space[0] # only one vector
    Q_J = Q_J.view((2 * l + 1) * (2 * k + 1), 2 * J + 1) # unvectorize
    assert all(torch.allclose(_R_tensor(a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) 
               for a, b, c in torch.rand(4, 3)) # sanity check that is a solution
    
Q_J.shape

torch.Size([15, 7])

### Spherical harmonics

In [17]:
import time

import torch
import numpy as np
from scipy.special import lpmv as lpmv_scipy


def semifactorial(x):
    """Compute the semifactorial function x!!.

    x!! = x * (x-2) * (x-4) *...

    Args:
        x: positive int
    Returns:
        float for x!!
    """
    y = 1.
    for n in range(x, 1, -2):
        y *= n
    return y


def pochhammer(x, k):
    """Compute the pochhammer symbol (x)_k.

    (x)_k = x * (x+1) * (x+2) *...* (x+k-1)

    Args:
        x: positive int
    Returns:
        float for (x)_k
    """
    xf = float(x)
    for n in range(x+1, x+k):
        xf *= n
    return xf

def lpmv(l, m, x):
    """Associated Legendre function including Condon-Shortley phase.

    Args:
        m: int order 
        l: int degree
        x: float argument tensor
    Returns:
        tensor of x-shape
    """
    m_abs = abs(m)
    
    # P^m_J = 0 forall m > J
    if m_abs > l:
        return torch.zeros_like(x)

    # Compute P_m^m
    # P_m^m = (-1)^J (1 - x^2)^(J/2) (2J - 1)!!
    yold = ((-1)**m_abs * semifactorial(2*m_abs-1)) * torch.pow(1-x*x, m_abs/2)
    
    # Compute P_{m+1}^m
    # P_m+1^m = x (2 m + 1) P^m_m
    if m_abs != l:
        y = x * (2*m_abs+1) * yold
    else:
        y = yold

    # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
    # P_l^m (x) = [(2 l - 1 ) / ( l - m )] P^m_{l - 1} (x) - [(l + m - 1) / (l - m)] P^m_{l - 2} (x)
    for i in range(m_abs+2, l+1):
        tmp = y
        # Inplace speedup
        y = ((2*i-1) / (i-m_abs)) * x * y
        y -= ((i+m_abs-1)/(i-m_abs)) * yold
        yold = tmp

    # P^-m_l (x) = (-1)^m (l - m)!/(l + m)! P^m_l (x)
    if m < 0:
        y *= ((-1)**m / pochhammer(l+m+1, -2*m))

    return y

In [18]:
leg = {}
J = 10
m = 5
x = torch.Tensor([0.5, 0.1, 0.2])
ans_1 = lpmv(J, m, x)

In [19]:
class SphericalHarmonics(object):
    def __init__(self):
        self.leg = {}

    def clear(self):
        self.leg = {}

    def negative_lpmv(self, l, m, y):
        """Compute negative order coefficients"""
        if m < 0:
            y *= ((-1)**m / pochhammer(l+m+1, -2*m))
        return y

    def lpmv(self, l, m, x):
        """Associated Legendre function including Condon-Shortley phase.

        Args:
            m: int order 
            l: int degree
            x: float argument tensor
        Returns:
            tensor of x-shape
        """
        # Check memoized versions
        m_abs = abs(m)
        if (l,m) in self.leg:
            return self.leg[(l,m)]
        elif m_abs > l:
            return None
        elif l == 0:
            self.leg[(l,m)] = torch.ones_like(x)
            return self.leg[(l,m)]
        
        # Check if on boundary else recurse solution down to boundary
        if m_abs == l:
            # Compute P_m^m
            y = (-1)**m_abs * semifactorial(2*m_abs-1)
            y *= torch.pow(1-x*x, m_abs/2)
            self.leg[(l,m)] = self.negative_lpmv(l, m, y)
            return self.leg[(l,m)]
        else:
            # Recursively precompute lower degree harmonics
            self.lpmv(l-1, m, x)

        # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
        # Inplace speedup
        y = ((2*l-1) / (l-m_abs)) * x * self.lpmv(l-1, m_abs, x)
        if l - m_abs > 1:
            y -= ((l+m_abs-1)/(l-m_abs)) * self.leg[(l-2, m_abs)]
        #self.leg[(l, m_abs)] = y
        
        if m < 0:
            y = self.negative_lpmv(l, m, y)
        self.leg[(l,m)] = y

        return self.leg[(l,m)]

    def get_element(self, l, m, theta, phi):
        """Tesseral spherical harmonic with Condon-Shortley phase.

        The Tesseral spherical harmonics are also known as the real spherical
        harmonics.

        Args:
            l: int for degree
            m: int for order, where -l <= m < l
            theta: collatitude or polar angle
            phi: longitude or azimuth
        Returns:
            tensor of shape theta
        """
        assert abs(m) <= l, "absolute value of order m must be <= degree l"

        N = np.sqrt((2*l+1) / (4*np.pi))
        leg = self.lpmv(l, abs(m), torch.cos(theta))
        if m == 0:
            return N*leg
        elif m > 0:
            Y = torch.cos(m*phi) * leg
        else:
            Y = torch.sin(abs(m)*phi) * leg
        N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m)))
        Y *= N
        return Y

    def get(self, l, theta, phi, refresh=True):
        """Tesseral harmonic with Condon-Shortley phase.

        The Tesseral spherical harmonics are also known as the real spherical
        harmonics.

        Args:
            l: int for degree
            theta: collatitude or polar angle
            phi: longitude or azimuth
        Returns:
            tensor of shape [*theta.shape, 2*l+1]
        """
        results = []
        if refresh:
            self.clear()
        for m in range(-l, l+1):
            results.append(self.get_element(l, m, theta, phi))
        return torch.stack(results, -1)

### My own implementation from spherical harmonics from the blog post

In [20]:
# First we will implement ALP (Associated Legandre Polynomials)

# function analagous to pochammer function in the SE(3) Transformer
# (J - m)!/(J + m)!
def falling_factorial(J, m):
    # computes (J + m)*(J+m - 1)*...(J-m+1)
    f = 1.
    for n in range(J + m, J - m, -1):
        f *= n
    return f


def semifactorial(x):
    """Compute the semifactorial function x!!.

    x!! = x * (x-2) * (x-4) *...

    Args:
        x: positive int
    Returns:
        float for x!!
    """
    y = 1.
    for n in range(x, 1, -2):
        y *= n
    return y

# y: Legendre polynimil for the absolute value of m
# P_J^{-m} (x) = (-1)^m (J - m)!/(J + m)! P_J^m(x)
def negative_lpmv(J, m, y):
    # check if m is negative
    if m < 0:
        # multiply y with the coefficient containing the falling 
        # factorial
        y *= ((-1)**m / falling_factorial(J, m))
    return y


In [21]:
falling_factorial(5, 5)

3628800.0

In [22]:
# recursive implementation of APL
def lpmv(J, m, x):
    # get the absolute value of m
    m_abs = abs(m)
    # check if the polynomial has already been computed
    if (J, m) in leg:
        return leg[(J, m)]
    # check if m is out of range -J to J
    elif m_abs > J:
        return None
    # if J = 0, the associated Legendre polynomial is equal to 1
    elif J == 0:
        # return tensor of 1s with the same shape as x
        leg[(l, m)] = torch.ones_like(x)
        return leg[(l, m)]
    
    # if |m| = J, compute the polynomial using the equation from step 1
    if m_abs == J:
        # P^J_J (x) = (-1)^J (1 - x^2)^(J/2) (2J - 1)!!
        # calculate coefficient term
        y = (-1)**J * semifactorial(2*J - 1)
        # multiply by the term dependent on x
        y *= torch.pow(1 - x*x, m_abs/2)
        # negative_lpmv returns y if m is positive and y multiplied by 
        # the negative coefficient defined in step 4 if m is negative
        leg[(l, m)] = negative_lpmv(l, m, y)
        return leg[(l, m)]
    else:
        # retursive call to compute lower degree polynomials up to
        # boundary m = J
        lpmv(J - 1, m, x)
        
    # if m is not on the boundary, first compute the first term of the relation
    # defined in step 3
    # P^J-1_J (x) = x (2J + 1) P^m_(J - 1) (x)
    # if m_abs = J - 1, then this calculates the relation defined in step 2
    y = ((2*J - 1)/(J - m_abs)) * x * lpmv(J - 1, m_abs, x)
    
    # P_J^m(x) = (2J - 1)/(J - m) x P^m_{J - 1}(x) - (J + m - 1)/(J - m) P^m_{J - 2} (x)
    # check if m_abs != J - 1, then add the second term defined in step 3
    if l - m_abs > 1:
        y -= ((l + m_abs -1)/(l - m_abs)) * leg[(l - 2, m_abs)]
        
    # if m is negative, return the polynomial for m_abs scaled by the 
    # negative coefficient
    if m < 0:
        y = negative_lpmv(l, m, y)
        
    leg[(l, m)] = y
    
    return leg[(l, m)]

In [23]:
leg = {}
J = 10
m = 5
x = torch.Tensor([0.5, 0.1, 0.2])
ans_2 = lpmv(J, m, x)

In [24]:
ans_1

tensor([ 30086.1719, -21961.9492, -26591.5078])

In [25]:
ans_2

tensor([-8.3058e+04, -5.3207e+01, -1.5766e+03])

In [26]:
# spherical harmonics

def get_element(J, m, theta, phi):
    assert abs(m) <= J, "m must be in the range -J to J"
    
    # calculates the first fraction in the square root
    N = np.sqrt((2*J + 1) / (4 * np.pi))
    # stores the ALP term in leg
    leg = lpmv(J, abs(m), torch.cos(theta))
    
    # multiply by the phi dependent term depending on the value of m
    if m == 0:
        # when m = 0 the other fraction in the square root cancels 
        # and the phi dependent term is 1
        return N * leg
    elif m > 0:
        Y = torch.cos(m*phi) * leg
    else:
        #print(phi.shape, leg.shape)
        Y = torch.sin(abs(m) * phi) * leg
        
    # multiply by a square root of the inverse falling factorial
    # which is the same as in the ALP
    N *= np.sqrt(2. / falling_factorial(J, abs(m)))
    # multiplies the coefficient with angle-dependent term
    Y *= N
    return Y

In [27]:
# compute Y_J^m, m in {-J, J}
def get(J, theta, phi, refresh = True):
    # initialize tensor
    results = []
    
    # loop over all possible values of m from J to J and add the 
    # computed spherical harmonic to results
    for m in range(-J, J + 1):
        results.append(get_element(J, m, theta, phi))
        
    return torch.stack(results, -1)

In [28]:
leg = {}

get(3, torch.Tensor([0.5, 0.2, 0.1]), torch.Tensor([0.5, 0.2, 0.1]))

tensor([[-6.4857e-02,  2.4532e-01, -1.0650e+00,  1.6601e+00, -2.9646e+01,
          1.5752e-01, -4.5993e-03],
        [-2.6125e-03,  2.1772e-02, -2.5473e-01,  2.5822e+00, -2.6617e+01,
          5.1495e-02, -3.8186e-03],
        [-1.7350e-04,  2.8475e-03, -6.7310e-02,  2.7433e+00, -1.4869e+01,
          1.4047e-02, -5.6087e-04]])

In [29]:
# r_ij: the relative displacement between nodes in spherical
# coordinates [radius, alpha, beta]
# beta = pi - theta (beta is 0 at south pole and pi at north pole;
# supplementary to theta
# alpha = phi (ranges from 0 to 2 pi)
# r_ij: shape (batch_size, nodes, neighbors, 3 (r_ij))
def precompute_sh(r_ij, max_J):
    # initialize dictionary where keys correspond to J and values
    # are tensors with shape (batch size, nodes, neighbors, 2J + 1)
    Y_Js = {}
    
    # calculate (2J + 1)-dimensional spherical harmonics tensors for degrees up to
    # max_J
    for J in range(max_J + 1):
        # r_ij[..., 2] extracts the values for beta for every edge in
        # the graph
        # r_ij[..., 1] extracts the values for alpha for every edge in the graph
        Y_Js[J] = get(J, theta = math.pi - r_ij[..., 2], phi = r_ij[..., 1], refresh = False)
        
        
    return Y_Js

In [30]:
leg = {}

r_ij = torch.randn((10, 1, 10, 3))/100

out = precompute_sh(r_ij, 3)

print(out.keys(), out[2].shape)

dict_keys([0, 1, 2, 3]) torch.Size([10, 1, 10, 5])


In [31]:
def get_basis_kernel(x_ij, max_degree = 2):

    # compute all spherical harmonics for every edge up to 2*maximum 
    # feature type

    Y = precompute_sh(x_ij, 2*max_degree)
    device = Y[0].device

    # initialize the dictionary where the key is the input and output degree 
    # pair and the values are all the basis kernels stored in an array of shape
    # (edges, 1, 2l+1, 1, 2k+1, 2min(l, k) + 1)
    basis = {}
    # loop through input and output degree pairs up to max_degree
    for di in range(max_degree + 1):
        for do in range(max_degree + 1):
            K_Js = [] # initialize set of basis kernels
            # loop through all values of J from |k - l| to k + l
            for J in range(abs(di - do), di + do + 1):
                # get change-of-basis matrices with shape 
                # ((2l + 1)*(2k + 1), 2J + 1) that transforms the (2J + 1)-dim spherical
                # tensor back to its original basis
                Q_J = _basis_transformation_Q_J(J, di, do)
                
                Q_J = Q_J.float().to(device)
                # Y[J] has shape (edges, 2J + 1)
                # Q_J has shape ((2l + 1)*(2k+1), 2J + 1)
                # matrix-vector multiplication to get K_J with shape
                # (edges, (2l + 1)*(2k + 1)) of the vectorized type-J basis kernels
                # W_lk = Q^lk_J @ Y_J = \sum_J Q^lk_J Y_J
                #print(Q_J.shape, Y[J].shape)
                K_J = torch.matmul(Q_J, Y[J].t())
                # Append to list of bases with shapes (2min(l, k) + 1, edges, (2l+1)*(2k+1))
                K_Js.append(K_J)
                
            # reshape for dot product with radial weights
            size = (-1, 1, 2*do + 1, 1, 2*di + 1, 2*min(di, do) + 1)
            # stack reshapes to (edges, (2l+1)*(2k + 1), 2min(l, k) + 1)
            # view reshapes to match size
            basis[f'{di}, {do}'] = torch.stack(K_Js, -1).view(*size)
            
    return basis

In [32]:
leg = {}
x_ij =  torch.randn((100, 3))/100

get_basis_kernel(x_ij)['1, 1'].shape

torch.Size([100, 1, 3, 1, 3, 3])

### Cartezian to spherical convertion

In [33]:
def get_spherical_from_cartesian_torch(cartesian, divide_radius_by=1.0):

    ###################################################################################################################
    # ON ANGLE CONVENTION
    #
    # sh has following convention for angles:
    # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
    # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
    #
    # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
    # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
    # alpha = phi
    #
    ###################################################################################################################

    # initialise return array
    # ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
    spherical = torch.zeros_like(cartesian)

    # indices for return array
    ind_radius = 0
    ind_alpha = 1
    ind_beta = 2

    cartesian_x = 2
    cartesian_y = 0
    cartesian_z = 1

    # get projected radius in xy plane
    # xy = xyz[:,0]**2 + xyz[:,1]**2
    r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2

    # get second angle
    # version 'elevation angle defined from Z-axis down'
    spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])
    # ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2])
    # version 'elevation angle defined from XY-plane up'
    #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy))
    # spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy))

    # get angle in x-y plane
    spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])

    # get overall radius
    # ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2)
    if divide_radius_by == 1.0:
        spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)
    else:
        spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)/divide_radius_by

    return spherical


# @profile
def get_spherical_from_cartesian(cartesian):

    ###################################################################################################################
    # ON ANGLE CONVENTION
    #
    # sh has following convention for angles:
    # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
    # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
    #
    # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
    # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
    # alpha = phi
    #
    ###################################################################################################################

    if torch.is_tensor(cartesian):
        cartesian = np.array(cartesian.cpu())

    # initialise return array
    # ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
    spherical = np.zeros(cartesian.shape)

    # indices for return array
    ind_radius = 0
    ind_alpha = 1
    ind_beta = 2

    cartesian_x = 2
    cartesian_y = 0
    cartesian_z = 1

    # get projected radius in xy plane
    # xy = xyz[:,0]**2 + xyz[:,1]**2
    r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2

    # get overall radius
    # ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2)
    spherical[..., ind_radius] = np.sqrt(r_xy + cartesian[...,cartesian_z]**2)

    # get second angle
    # version 'elevation angle defined from Z-axis down'
    spherical[..., ind_beta] = np.arctan2(np.sqrt(r_xy), cartesian[..., cartesian_z])
    # ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2])
    # version 'elevation angle defined from XY-plane up'
    #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy))
    # spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy))

    # get angle in x-y plane
    spherical[...,ind_alpha] = np.arctan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])

    return spherical

def test_coordinate_conversion():
    p = np.array([0, 0, -1])
    expected = np.array([1, 0, 0])
    assert get_spherical_from_cartesian(p) == expected
    return True

### Reference implementation

In [34]:
import time

import torch
import numpy as np
from scipy.special import lpmv as lpmv_scipy


def semifactorial(x):
    """Compute the semifactorial function x!!.

    x!! = x * (x-2) * (x-4) *...

    Args:
        x: positive int
    Returns:
        float for x!!
    """
    y = 1.
    for n in range(x, 1, -2):
        y *= n
    return y


def pochhammer(x, k):
    """Compute the pochhammer symbol (x)_k.

    (x)_k = x * (x+1) * (x+2) *...* (x+k-1)

    Args:
        x: positive int
    Returns:
        float for (x)_k
    """
    xf = float(x)
    for n in range(x+1, x+k):
        xf *= n
    return xf

def lpmv(l, m, x):
    """Associated Legendre function including Condon-Shortley phase.

    Args:
        m: int order 
        l: int degree
        x: float argument tensor
    Returns:
        tensor of x-shape
    """
    m_abs = abs(m)
    
    # P^m_J = 0 forall m > J
    if m_abs > l:
        return torch.zeros_like(x)

    # Compute P_m^m
    # P_m^m = (-1)^J (1 - x^2)^(J/2) (2J - 1)!!
    yold = ((-1)**m_abs * semifactorial(2*m_abs-1)) * torch.pow(1-x*x, m_abs/2)
    
    # Compute P_{m+1}^m
    # P_m+1^m = x (2 m + 1) P^m_m
    if m_abs != l:
        y = x * (2*m_abs+1) * yold
    else:
        y = yold

    # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
    # P_l^m (x) = [(2 l - 1 ) / ( l - m )] P^m_{l - 1} (x) - [(l + m - 1) / (l - m)] P^m_{l - 2} (x)
    for i in range(m_abs+2, l+1):
        tmp = y
        # Inplace speedup
        y = ((2*i-1) / (i-m_abs)) * x * y
        y -= ((i+m_abs-1)/(i-m_abs)) * yold
        yold = tmp

    # P^-m_l (x) = (-1)^m (l - m)!/(l + m)! P^m_l (x)
    if m < 0:
        y *= ((-1)**m / pochhammer(l+m+1, -2*m))

    return y

In [35]:
class SphericalHarmonics(object):
    def __init__(self):
        self.leg = {}

    def clear(self):
        self.leg = {}

    def negative_lpmv(self, l, m, y):
        """Compute negative order coefficients"""
        if m < 0:
            y *= ((-1)**m / pochhammer(l+m+1, -2*m))
        return y

    def lpmv(self, l, m, x):
        """Associated Legendre function including Condon-Shortley phase.

        Args:
            m: int order 
            l: int degree
            x: float argument tensor
        Returns:
            tensor of x-shape
        """
        # Check memoized versions
        m_abs = abs(m)
        if (l,m) in self.leg:
            return self.leg[(l,m)]
        elif m_abs > l:
            return None
        elif l == 0:
            self.leg[(l,m)] = torch.ones_like(x)
            return self.leg[(l,m)]
        
        # Check if on boundary else recurse solution down to boundary
        if m_abs == l:
            # Compute P_m^m
            y = (-1)**m_abs * semifactorial(2*m_abs-1)
            y *= torch.pow(1-x*x, m_abs/2)
            self.leg[(l,m)] = self.negative_lpmv(l, m, y)
            return self.leg[(l,m)]
        else:
            # Recursively precompute lower degree harmonics
            self.lpmv(l-1, m, x)

        # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
        # Inplace speedup
        y = ((2*l-1) / (l-m_abs)) * x * self.lpmv(l-1, m_abs, x)
        if l - m_abs > 1:
            y -= ((l+m_abs-1)/(l-m_abs)) * self.leg[(l-2, m_abs)]
        #self.leg[(l, m_abs)] = y
        
        if m < 0:
            y = self.negative_lpmv(l, m, y)
        self.leg[(l,m)] = y

        return self.leg[(l,m)]

    def get_element(self, l, m, theta, phi):
        """Tesseral spherical harmonic with Condon-Shortley phase.

        The Tesseral spherical harmonics are also known as the real spherical
        harmonics.

        Args:
            l: int for degree
            m: int for order, where -l <= m < l
            theta: collatitude or polar angle
            phi: longitude or azimuth
        Returns:
            tensor of shape theta
        """
        assert abs(m) <= l, "absolute value of order m must be <= degree l"

        N = np.sqrt((2*l+1) / (4*np.pi))
        leg = self.lpmv(l, abs(m), torch.cos(theta))
        if m == 0:
            return N*leg
        elif m > 0:
            Y = torch.cos(m*phi) * leg
        else:
            Y = torch.sin(abs(m)*phi) * leg
        N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m)))
        Y *= N
        return Y

    def get(self, l, theta, phi, refresh=True):
        """Tesseral harmonic with Condon-Shortley phase.

        The Tesseral spherical harmonics are also known as the real spherical
        harmonics.

        Args:
            l: int for degree
            theta: collatitude or polar angle
            phi: longitude or azimuth
        Returns:
            tensor of shape [*theta.shape, 2*l+1]
        """
        results = []
        if refresh:
            self.clear()
        for m in range(-l, l+1):
            results.append(self.get_element(l, m, theta, phi))
        return torch.stack(results, -1)

### Radial function

In [36]:
class BN(nn.Module):
    """SE(3)-equvariant batch/layer normalization"""
    def __init__(self, m):
        """SE(3)-equvariant batch/layer normalization

        Args:
            m: int for number of output channels
        """
        super().__init__()
        self.bn = nn.LayerNorm(m)

    def forward(self, x):
        return self.bn(x)

# num_freq = 2 min (l, k) + 1
# mid_dim = 32
class RadialFunc(nn.Module):
    """NN parameterized radial profile function"""
    def __init__(self, num_bases, mi, mo, edge_dim: int = 0):
        """ NN parametrized radial profile function.
        
        Args:
            num_freq: number of output frequencies
            in_dim: multiplicity of input (num input channels)
            out_dim: multiplicity of output (num output channels)
            edge_dim: number of dimensions for edge embedding
        """
        
        super().__init__()
        
        self.num_bases = num_bases
        self.mi = mi
        self.mid_dim = 32
        self.mo = mo
        self.edge_dim = edge_dim
        
        self.net = nn.Sequential(
            # FFN transfroms from number of edges to mid_dim
            nn.Linear(self.edge_dim + 1, self.mid_dim),
            # Normalization of the layer output to zeros mean and std equal to 1
            BN(self.mid_dim),
            nn.ReLU(),
            # Hidden layer that does not change dim
            nn.Linear(self.mid_dim, self.mid_dim),
            # Another Norm
            BN(self.mid_dim),
            # ReLU activation
            nn.ReLU(),
            # FFN transforms from dim_dim to (2min(l, k) + 1)*mi*mo
            nn.Linear(self.mid_dim, self.num_bases*mi*mo)
        )
        
        nn.init.kaiming_uniform_(self.net[0].weight)
        nn.init.kaiming_uniform_(self.net[3].weight)
        nn.init.kaiming_uniform_(self.net[6].weight)
        
    def __repr__(self):
        return f"RadialFunc(edge_dim={self.edge_dim}, in_dim={self.mi}, out_dim={self.mo})"
    
    
    def forward(self, x):
        # calculates a single vector of radial weights given the distance
        # between nodes with the FFN
        y = self.net(x)
        # reshapes to separate the radial weights by output channels,
        # input channel, and degree J to prepare for broadcasting and element-wise
        # multiplication with the array of basis kernels of shape (-1, 1, 2l+1, 1, 2k+1, num_bases)
        #print(y.shape)
        return y.view(-1, self.mo, 1, self.mi, 1, self.num_bases)

In [37]:
# Single test

dist_ij = torch.rand((10, 1))

model_Radial = RadialFunc(3, 10, 15)

model_Radial(dist_ij).shape

torch.Size([10, 15, 1, 10, 1, 3])

### Tensor Field Network

In [38]:
# initialize the radial network for type-k inputs and type-l outputs
# the falue of edge_dim has a default value of 1 which determines
# the dimension of the input to the radial function
num_bases = 3
mi = 1
mo = 3
edge_dim = 0
rp = RadialFunc(num_bases, mi, mo, edge_dim)

In [39]:
# calls forward method of RadialFunc class which feeds the realtive
# distance into radial network
R = rp(dist_ij)

In [40]:
x_ij =  torch.randn((100, 3))/100

r_ij = x_ij[..., 0, None]

In [41]:
leg = {}
x_ij =  torch.randn((100, 3))/100

r_ij = x_ij[..., 0, None]

do = 0
di = 1

basis = get_basis_kernel(x_ij)

# 2 min(l, k) + 1
num_bases = 2*min(do, di) + 1
mi = 1
mo = 3
edge_dim = 0
rp = RadialFunc(num_bases, mi, mo, edge_dim)
R = rp(r_ij)


# R: radial weights with shape (batch_size, mo, 1, mi, 1, 2 min(di, do) + 1)
# basis[f'{self.di}, {self.do}']: tensor of basis kernels 
# basis_shape:                 (batch_size, 1, 2*do + 1, 1, 2*di + 1, 2*min(di, do) + 1)
# for input deg di and output deg do
# kernel output:               (batch_size, mo, 2*do + 1, mi, 2*di + 1)
kernel = torch.sum(R*basis[f'{di}, {do}'], -1)

In [42]:
# reshape kernel to (mo * (2*do + 1), mi * (2*di + 1)) to prepare for
# matrix-vector multiplication with hte concatenated input channels of type
# di
kernel = kernel.view(kernel.shape[0], (2 * do + 1) * mo, -1)

In [43]:
class PairwiseConv(nn.Module):
    """SE(3)-equivariant convolution between two single-type features"""
    def __init__(self, degree_in: int, nc_in: int, degree_out: int,
                 nc_out: int, edge_dim: int=0):
        """SE(3)-equivariant convolution between a pair of feature types.

        This layer performs a convolution from nc_in features of type degree_in
        to nc_out features of type degree_out.

        Args:
            degree_in: degree of input fiber
            nc_in: number of channels on input
            degree_out: degree of out order
            nc_out: number of channels on output
            edge_dim: number of dimensions for edge embedding
        """
        super().__init__()
        # Log settings
        self.degree_in = degree_in
        self.degree_out = degree_out
        self.nc_in = nc_in
        self.nc_out = nc_out

        # Functions of the degree
        self.num_freq = 2*min(degree_in, degree_out) + 1
        self.d_out = 2*degree_out + 1
        self.edge_dim = edge_dim

        # Radial profile function
        self.rp = RadialFunc(self.num_freq, nc_in, nc_out, self.edge_dim)
        
    @profile
    def forward(self, feat, basis):
        # Get radial weights
        R = self.rp(feat)
        kernel = torch.sum(R * basis[f'{self.degree_in},{self.degree_out}'], -1)
        return kernel.view(kernel.shape[0], self.d_out*self.nc_out, -1)

### Fixing representation to apply fibers

In [44]:
@profile
def get_basis(G, max_degree, compute_gradients):
    """Precompute the SE(3)-equivariant weight basis, W_J^lk(x)

    This is called by get_basis_and_r().

    Args:
        G: DGL graph instance of type dgl.DGLGraph
        max_degree: non-negative int for degree of highest feature type
        compute_gradients: boolean, whether to compute gradients during basis construction
    Returns:
        dict of equivariant bases. Keys are in the form 'd_in,d_out'. Values are
        tensors of shape (batch_size, 1, 2*d_out+1, 1, 2*d_in+1, number_of_bases)
        where the 1's will later be broadcast to the number of output and input
        channels
    """
    if compute_gradients:
        context = nullcontext()
    else:
        context = torch.no_grad()

    with context:
        cloned_d = torch.clone(G.edata['d'])

        if G.edata['d'].requires_grad:
            cloned_d.requires_grad_()
            log_gradient_norm(cloned_d, 'Basis computation flow')

        # Relative positional encodings (vector)
        r_ij = get_spherical_from_cartesian_torch(cloned_d)
        # Spherical harmonic basis
        Y = precompute_sh(r_ij, 2*max_degree)
        device = Y[0].device

        basis = {}
        for d_in in range(max_degree+1):
            for d_out in range(max_degree+1):
                K_Js = []
                for J in range(abs(d_in-d_out), d_in+d_out+1):
                    # Get spherical harmonic projection matrices
                    Q_J = _basis_transformation_Q_J(J, d_in, d_out)
                    Q_J = Q_J.float().T.to(device)

                    # Create kernel from spherical harmonics
                    K_J = torch.matmul(Y[J], Q_J)
                    K_Js.append(K_J)

                # Reshape so can take linear combinations with a dot product
                size = (-1, 1, 2*d_out+1, 1, 2*d_in+1, 2*min(d_in, d_out)+1)
                basis[f'{d_in},{d_out}'] = torch.stack(K_Js, -1).view(*size)
        return basis


def get_r(G):
    """Compute internodal distances"""
    cloned_d = torch.clone(G.edata['d'])

    if G.edata['d'].requires_grad:
        cloned_d.requires_grad_()
        log_gradient_norm(cloned_d, 'Neural networks flow')

    return torch.sqrt(torch.sum(cloned_d**2, -1, keepdim=True))


def get_basis_and_r(G, max_degree, compute_gradients=False):
    """Return equivariant weight basis (basis) and internodal distances (r).

    Call this function *once* at the start of each forward pass of the model.
    It computes the equivariant weight basis, W_J^lk(x), and internodal
    distances, needed to compute varphi_J^lk(x), of eqn 8 of
    https://arxiv.org/pdf/2006.10503.pdf. The return values of this function
    can be shared as input across all SE(3)-Transformer layers in a model.

    Args:
        G: DGL graph instance of type dgl.DGLGraph()
        max_degree: non-negative int for degree of highest feature-type
        compute_gradients: controls whether to compute gradients during basis construction
    Returns:
        dict of equivariant bases, keys are in form '<d_in><d_out>'
        vector of relative distances, ordered according to edge ordering of G
    """
    basis = get_basis(G, max_degree, compute_gradients)
    r = get_r(G)
    return basis, r

In [45]:
torch.manual_seed(42)
np.random.seed(42)

leg = {}

dataset = QM9Dataset('./QM9_data/QM9_data.pt', "homo", mode='train', fully_connected=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)

iter_dataloader = iter(dataloader) # so I can use next
for i in range(1):
    data = next(iter_dataloader)
    print("MINIBATCH")
    print(data[0]) 
    print(data[1].shape) # batch size -> connected graph of size batch


atom_feature_size = dataset.atom_feature_size
num_degrees = 4
num_channels = 32
num_channels_out = num_channels*num_degrees
edge_dim = dataset.num_bonds

connection = 'skip'

# building fibers for input data
fibers = {'in': Fiber(1, atom_feature_size),
           'mid': Fiber(num_degrees, num_channels),
           'out': Fiber(1, num_channels_out)}


f_in = fibers['in']
f_out = fibers['out']
# initialize dgl graph
G = copy.deepcopy(dataset[0][0])

basis, r = get_basis_and_r(G, num_degrees - 1)

  data = torch.load(self.file_address)


Loaded train-set, task: homo, source: ./QM9_data/QM9_data.pt, length: 100000
MINIBATCH
Graph(num_nodes=591, num_edges=10696,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32), 'f': Scheme(shape=(6, 1), dtype=torch.float32)}
      edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(5,), dtype=torch.float32)})
torch.Size([32, 1])




In [46]:
G.ndata['f'].shape

torch.Size([15, 6, 1])

In [47]:
h = {'0': G.ndata['f']} # type 0 node feature

kernel_unary = {}
# loop over (multiplicity, degree) tuples in input fiber
for (mi, di) in f_in.structure:
    # loop over (multiplicity, degree) tuples in output fiber
    for (mo, do) in f_out.structure:
        # generate a (mi * mo) unique kernels corresponding to every input
        # and output channel pair
        # store in dictionary with key f'({di},{do})'
        kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
        
        
# center -> center self connections
# skip connection consideres output connection 
# as linear combination of input connections
kernel_self = {
    'skip': nn.ParameterDict(),
    'TFN': nn.ParameterDict(),
}

# in constructor
# skip connection consideres output connection 
# as linear combination of input connections
for mi, di in f_in.structure:
    # proceed if input type is also an output type
    if di in f_out.degrees:
        # extract num of output channels of the type
        mo = f_out.structure_dict[di]
        # initialize learnable mi x mo weight matrix with random integers and scale down
        # singleton dimension used to broadcast across nodes
        W = nn.Parameter(torch.randn(1, mo, mi)/np.sqrt(mi))
        kernel_self['skip'][f'{di}'] = W
        
        
# in constructor
# mixing connection consideres output self features
# as linear combination of output connections
for mo, do in f_in.structure:
    # initialize square learnable weight matrix of random integers
    # and scale down
    W = nn.Parameter(torch.randn(1, mo, mo) / np.sqrt(mo))
    kernel_self['TFN'][f"{do}"] = W
   


In [48]:
G.ndata

{'x': tensor([[ 3.4800e-02,  1.3604e+00,  1.5080e-01],
        [-9.5000e-03, -1.6100e-02,  5.3400e-02],
        [-1.1400e+00, -7.4810e-01,  2.6900e-02],
        [-1.1403e+00, -2.1319e+00, -1.2000e-02],
        [ 3.7300e-02, -2.7655e+00, -1.5400e-02],
        [-1.0070e-01, -4.0993e+00, -4.1100e-02],
        [ 1.2330e+00, -2.1386e+00,  6.0000e-04],
        [ 1.3055e+00, -6.9670e-01,  1.7000e-02],
        [ 2.3546e+00,  5.4000e-02, -7.2000e-03],
        [ 9.6010e-01,  1.7065e+00, -7.8600e-02],
        [-7.2620e-01,  1.8749e+00, -2.6370e-01],
        [-2.1488e+00, -3.6290e-01,  4.1400e-02],
        [-1.0447e+00, -4.3020e+00, -4.4600e-02],
        [ 2.1340e+00, -2.7379e+00, -1.4600e-02],
        [ 3.2064e+00, -5.0600e-01, -4.0100e-02]]), 'f': tensor([[[0.],
         [0.],
         [1.],
         [0.],
         [0.],
         [7.]],

        [[0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [6.]],

        [[0.],
         [1.],
         [0.],
         [0.],
        

In [49]:
with G.local_scope():
    # Add node features to local graph scope
    for k, v in h.items():
        G.ndata[k] = v

    # Add edge features
    if 'w' in G.edata.keys():
        w = G.edata['w']
        feat = torch.cat([w, r], -1)
    else:
        feat = torch.cat([r,], -1)

    for (mi, di) in f_in.structure:
        for (mo, do) in f_out.structure:
            etype = f"({di},{do})"
            G.edata[etype] = kernel_unary[etype](feat, basis) 

    # defining the function for reduction
    # function is defined for each output feature type d
    for do in f_out.degrees:

        # edge user-defined function in DGL that computes the message for a single output feature type
        # do
        def udf_u_mul_e(do):
            # calculate neighbor -> center message for type single output
            # feature type do
            def fnc(edges):
                msg = 0
                for mi, di in f_in.structure:
                    # extract all feature channels of type di from the neighborhood
                    # nodes and condense into single vector
                    # src has shape (edges, mi*(2*di + 1), 1)
                    src = edges.src[f'{di}'].view(-1, mi*(2*di + 1), 1)

                    # extract kernel for input type di and output type do
                    edge = edges.data[f"({di},{do})"]
                    # matrix multiplication to get (mo*(2*do + 1))-dimentional vector and add to total msg
                    msg += msg + torch.matmul(edge, src)

                # reshape message to separate output channels
                msg = msg.view(msg.shape[0], -1, 2*do + 1)


                # center -> center message
                if connection == 'skip':
                    # extract all input features of type do from all nodes
                    dst = edges.dst[f'{do}']
                    # extract self-intecation weights for type do channels
                    W = kernel_self[connection][f'{do}']
                    # calculate the array of self-interaction tensors with shape (nodes, mo, 2*do +1)
                    self_int = torch.matmul(W, dst)
                    msg = msg + self_int

                # in user-defined DGL edge -> node function
                # extract weight array of shape (1, mo, mo)
                if connection == 'TFN':
                    W = kernel_self['TFN'][f'{do}']
                    # matrix multiplication that generates output feature tensor for 
                    # for degree do
                    msg = torch.matmul(W, msg)

                return {'msg': msg.view(msg.shape[0], -1, 2*do + 1)}
            return fnc

        # call update all function that takes (message_func, reduce_func) as input
        G.update_all(udf_u_mul_e(do), fn.mean('msg', f'out{do}'))

    # return a dictionary of the output node features where every degree
    # is linked to an array with shape (edges, mo, 2*do + 1) by extracting
    # the node data stored from calling update_all

    f_mid = {f'{do}': G.ndata[f'out{do}'] for do in f_out.degrees}

In [50]:
G.ndata.keys(), f_mid.keys()

(dict_keys(['x', 'f']), dict_keys(['0']))

In [51]:
f_mid_my = f_mid

### Reference model

In [52]:
class GConvSE3(nn.Module):
    """A tensor field network layer as a DGL module.

    GConvSE3 stands for a Graph Convolution SE(3)-equivariant layer. It is the
    equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph
    conv layer in a GCN.

    At each node, the activations are split into different "feature types",
    indexed by the SE(3) representation type: non-negative integers 0, 1, 2, ..
    """
    def __init__(self, f_in, f_out, self_interaction: bool=False, edge_dim: int=0, flavor='skip'):
        """SE(3)-equivariant Graph Conv Layer

        Args:
            f_in: list of tuples [(multiplicities, type),...]
            f_out: list of tuples [(multiplicities, type),...]
            self_interaction: include self-interaction in convolution
            edge_dim: number of dimensions for edge embedding
            flavor: allows ['TFN', 'skip'], where 'skip' adds a skip connection
        """
        super().__init__()
        self.f_in = f_in
        self.f_out = f_out
        self.edge_dim = edge_dim
        self.self_interaction = self_interaction
        self.flavor = flavor

        # Neighbor -> center weights
        self.kernel_unary = nn.ModuleDict()
        for (mi, di) in self.f_in.structure:
            for (mo, do) in self.f_out.structure:
                self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)

        # Center -> center weights
        self.kernel_self = nn.ParameterDict()
        if self_interaction:
            assert self.flavor in ['TFN', 'skip']
            if self.flavor == 'TFN':
                for m_out, d_out in self.f_out.structure:
                    W = nn.Parameter(torch.randn(1, m_out, m_out) / np.sqrt(m_out))
                    self.kernel_self[f'{d_out}'] = W
            elif self.flavor == 'skip':
                for m_in, d_in in self.f_in.structure:
                    if d_in in self.f_out.degrees:
                        m_out = self.f_out.structure_dict[d_in]
                        W = nn.Parameter(torch.randn(1, m_out, m_in) / np.sqrt(m_in))
                        self.kernel_self[f'{d_in}'] = W



    def __repr__(self):
        return f'GConvSE3(structure={self.f_out}, self_interaction={self.self_interaction})'


    def udf_u_mul_e(self, d_out):
        """Compute the convolution for a single output feature type.

        This function is set up as a User Defined Function in DGL.

        Args:
            d_out: output feature type
        Returns:
            edge -> node function handle
        """
        def fnc(edges):
            # Neighbor -> center messages
            msg = 0
            for m_in, d_in in self.f_in.structure:
                src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
                edge = edges.data[f'({d_in},{d_out})']
                msg = msg + torch.matmul(edge, src)
            msg = msg.view(msg.shape[0], -1, 2*d_out+1)

            # Center -> center messages
            if self.self_interaction:
                if f'{d_out}' in self.kernel_self.keys():
                    if self.flavor == 'TFN':
                        W = self.kernel_self[f'{d_out}']
                        msg = torch.matmul(W, msg)
                    if self.flavor == 'skip':
                        dst = edges.dst[f'{d_out}']
                        W = self.kernel_self[f'{d_out}']
                        msg = msg + torch.matmul(W, dst)

            return {'msg': msg.view(msg.shape[0], -1, 2*d_out+1)}
        return fnc

    @profile
    def forward(self, h, G=None, r=None, basis=None, **kwargs):
        """Forward pass of the linear layer

        Args:
            G: minibatch of (homo)graphs
            h: dict of features
            r: inter-atomic distances
            basis: pre-computed Q * Y
        Returns:
            tensor with new features [B, n_points, n_features_out]
        """
        with G.local_scope():
            # Add node features to local graph scope
            for k, v in h.items():
                G.ndata[k] = v

            # Add edge features
            if 'w' in G.edata.keys():
                w = G.edata['w']
                feat = torch.cat([w, r], -1)
            else:
                feat = torch.cat([r, ], -1)

            for (mi, di) in self.f_in.structure:
                for (mo, do) in self.f_out.structure:
                    etype = f'({di},{do})'
                    G.edata[etype] = self.kernel_unary[etype](feat, basis)

            # Perform message-passing for each output feature type
            for d in self.f_out.degrees:
                G.update_all(self.udf_u_mul_e(d), fn.mean('msg', f'out{d}'))

            return {f'{d}': G.ndata[f'out{d}'] for d in self.f_out.degrees}

In [53]:
torch.manual_seed(42)
np.random.seed(42)


leg = {}

dataset = QM9Dataset('./QM9_data/QM9_data.pt', "homo", mode='train', fully_connected=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)

iter_dataloader = iter(dataloader) # so I can use next
for i in range(1):
    data = next(iter_dataloader)
    print("MINIBATCH")
    print(data[0]) 
    print(data[1].shape) # batch size -> connected graph of size batch


atom_feature_size = dataset.atom_feature_size
num_degrees = 4
num_channels = 32
num_channels_out = num_channels*num_degrees
edge_dim = dataset.num_bonds

connection = 'skip'

# building fibers for input data
fibers = {'in': Fiber(1, atom_feature_size),
           'mid': Fiber(num_degrees, num_channels),
           'out': Fiber(1, num_channels_out)}


f_in = fibers['in']
f_out = fibers['out']
# initialize dgl graph
G = copy.deepcopy(dataset[0][0])


h = {'0': G.ndata['f']} # type 0 node feature
basis, r = get_basis_and_r(G, num_degrees - 1)

layer = GConvSE3(f_in = f_in, 
                 f_out = f_out, 
                 edge_dim=edge_dim)

  data = torch.load(self.file_address)


Loaded train-set, task: homo, source: ./QM9_data/QM9_data.pt, length: 100000
MINIBATCH
Graph(num_nodes=591, num_edges=10696,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32), 'f': Scheme(shape=(6, 1), dtype=torch.float32)}
      edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(5,), dtype=torch.float32)})
torch.Size([32, 1])


In [54]:
f_mid_new = layer(h, G = G, r = r, basis = basis)

In [55]:
f_mid_new['0'].shape, f_mid_new['0'].sum()

(torch.Size([15, 128, 1]), tensor(216.5136, grad_fn=<SumBackward0>))

In [56]:
f_mid_my['0'].shape, f_mid_my['0'].sum()

(torch.Size([15, 128, 1]), tensor(380.8578, grad_fn=<SumBackward0>))

In [64]:
# GPU test

G_cu = G.to(device)

h_cu = {'0': G_cu.ndata['f']} # type 0 node feature

layer_cu = GConvSE3(f_in = f_in, 
                 f_out = f_out, 
                 edge_dim=edge_dim).to(device)

basis_cu = {key: basis[key].to(device) for key in basis}
r_cu = r.to(device)
f_mid_new_cu = layer_cu(h, G = G_cu, r = r_cu, basis = basis_cu)

In [65]:
f_mid_new_cu['0'].sum()

tensor(-219.1493, device='cuda:0', grad_fn=<SumBackward0>)

## SE3 transformer

### Query embedding

In [74]:
# First define the fiber f_mid_in with same multiplicities as value msgs with
# structure f_mid_out, 
# !!! but keep only degrees in input f_in

# building fibers for input data
fibers = {'in': Fiber(1, atom_feature_size),
           'mid': Fiber(num_degrees, num_channels),
           'out': Fiber(1, num_channels_out)}


f_in = fibers['in']
f_out = fibers['out']

f_mid_out = copy.deepcopy(f_out).structure_dict

f_mid_in = Fiber(dictionary={d: m for d, m in f_mid_out.items() if d in f_in.degrees})

f_mid_in

[(128, 0)]

In [76]:
# Initializing weights
transform = nn.ParameterDict()
# loop through all degrees in f_mid_in
for m_mid, d_mid in f_mid_in.structure:
    # extract number of input channels of degree d_mid to define
    # dimensions of weight matrix
    mi = f_in.structure_dict[d_mid]
    # initialize m_mid x mi weight matrix with random integers and scale down
    transform[str(d_mid)] = nn.Parameter(torch.randn(m_mid, mi) / np.sqrt(mi), requires_grad=True)

In [80]:
h = {'0': G.ndata['f']} # type 0 node feature

with G.local_scope():
    # Add node features to local graph scope
    for k, v in h.items():
        G.ndata[k] = v

    # Add edge features
    if 'w' in G.edata.keys():
        w = G.edata['w']
        feat = torch.cat([w, r], -1)
    else:
        feat = torch.cat([r, ], -1)

    # Perfoming linear self attention transformation
    output = {}
    # loop through output degrees in features dictionary and extract
    # features f
    for do, f in h.items():
        # check if there is a query matrix corresponding to the degree do
        if str(do) in transform.keys():
            # calculate the query for every channel of degree do
            # output has shape (mo, 2*do + 1)
            output[do] = torch.matmul(transform[str(do)], f)

In [82]:
# ref from the code
class G1x1SE3(nn.Module):
    """Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.

    This is equivalent to a self-interaction layer in TensorField Networks.
    """
    def __init__(self, f_in, f_out, learnable=True):
        """SE(3)-equivariant 1x1 convolution.

        Args:
            f_in: input Fiber() of feature multiplicities and types
            f_out: output Fiber() of feature multiplicities and types
        """
        super().__init__()
        self.f_in = f_in
        self.f_out = f_out

        # Linear mappings: 1 per output feature type
        self.transform = nn.ParameterDict()
        for m_out, d_out in self.f_out.structure:
            m_in = self.f_in.structure_dict[d_out]
            self.transform[str(d_out)] = nn.Parameter(torch.randn(m_out, m_in) / np.sqrt(m_in), requires_grad=learnable)

    def __repr__(self):
         return f"G1x1SE3(structure={self.f_out})"

    def forward(self, features, **kwargs):
        output = {}
        for k, v in features.items():
            if str(k) in self.transform.keys():
                output[k] = torch.matmul(self.transform[str(k)], v)
        return output

# initialize the function for generating the query that projects from f_in to f_mid_in
GMAB = {}
GMAB['q'] = G1x1SE3(f_in, f_mid_in)

q = GMAB['q'](h)

In [85]:
# Concatenation along channels and moments of features
# F: list where each element is an array with shape (m, 2*d + 1) for each feature degree d
# assume for now that H = 1 (single attention head)
def fiber2head(F, H, structure):
    # squeeze each array in the list into a m * (2*d + 1)-dimentional vector 
    # of all channels concatenated along the last dimention
    fibers = [F[f'{d}'].view(*F[f'{d}'].shape[:-2], H, -1) for d in structure.degrees]
    # concatenate across the last dimension of every array in the list to get a single vector
    fibers = torch.cat(fibers, -1)
    return fibers

n_heads = 1

G.ndata['q'] = fiber2head(q, 1, f_mid_in)

In [87]:
G.ndata['q'].shape

torch.Size([15, 1, 128])

### Key embedding

In [88]:
# Keys are same as TFN conv layers but no self interaction and head attention dim is added to comp
class GConvSE3Partial(nn.Module):
    """Graph SE(3)-equivariant node -> edge layer"""
    def __init__(self, f_in, f_out, edge_dim: int=0, x_ij=None):
        """SE(3)-equivariant partial convolution.

        A partial convolution computes the inner product between a kernel and
        each input channel, without summing over the result from each input
        channel. This unfolded structure makes it amenable to be used for
        computing the value-embeddings of the attention mechanism.

        Args:
            f_in: list of tuples [(multiplicities, type),...]
            f_out: list of tuples [(multiplicities, type),...]
        """
        super().__init__()
        self.f_out = f_out
        self.edge_dim = edge_dim

        # adding/concatinating relative position to feature vectors
        # 'cat' concatenates relative position & existing feature vector
        # 'add' adds it, but only if multiplicity > 1
        assert x_ij in [None, 'cat', 'add']
        self.x_ij = x_ij
        if x_ij == 'cat':
            self.f_in = Fiber.combine(f_in, Fiber(structure=[(1,1)]))
        else:
            self.f_in = f_in

        # Node -> edge weights
        self.kernel_unary = nn.ModuleDict()
        for (mi, di) in self.f_in.structure:
            for (mo, do) in self.f_out.structure:
                self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)

    def __repr__(self):
        return f'GConvSE3Partial(structure={self.f_out})'

    def udf_u_mul_e(self, d_out):
        """Compute the partial convolution for a single output feature type.

        This function is set up as a User Defined Function in DGL.

        Args:
            d_out: output feature type
        Returns:
            node -> edge function handle
        """
        def fnc(edges):
            # Neighbor -> center messages
            msg = 0
            for m_in, d_in in self.f_in.structure:
                # if type 1 and flag set, add relative position as feature
                if self.x_ij == 'cat' and d_in == 1:
                    # relative positions
                    rel = (edges.dst['x'] - edges.src['x']).view(-1, 3, 1)
                    m_ori = m_in - 1
                    if m_ori == 0:
                        # no type 1 input feature, just use relative position
                        src = rel
                    else:
                        # features of src node, shape [edges, m_in*(2l+1), 1]
                        src = edges.src[f'{d_in}'].view(-1, m_ori*(2*d_in+1), 1)
                        # add to feature vector
                        src = torch.cat([src, rel], dim=1)
                elif self.x_ij == 'add' and d_in == 1 and m_in > 1:
                    src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
                    rel = (edges.dst['x'] - edges.src['x']).view(-1, 3, 1)
                    src[..., :3, :1] = src[..., :3, :1] + rel
                else:
                    src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
                edge = edges.data[f'({d_in},{d_out})']
                msg = msg + torch.matmul(edge, src)
            msg = msg.view(msg.shape[0], -1, 2*d_out+1)

            return {f'out{d_out}': msg.view(msg.shape[0], -1, 2*d_out+1)}
        return fnc

    @profile
    def forward(self, h, G=None, r=None, basis=None, **kwargs):
        """Forward pass of the linear layer

        Args:
            h: dict of node-features
            G: minibatch of (homo)graphs
            r: inter-atomic distances
            basis: pre-computed Q * Y
        Returns:
            tensor with new features [B, n_points, n_features_out]
        """
        with G.local_scope():
            # Add node features to local graph scope
            for k, v in h.items():
                G.ndata[k] = v

            # Add edge features
            if 'w' in G.edata.keys():
                w = G.edata['w'] # shape: [#edges_in_batch, #bond_types]
                feat = torch.cat([w, r], -1)
            else:
                feat = torch.cat([r, ], -1)
            for (mi, di) in self.f_in.structure:
                for (mo, do) in self.f_out.structure:
                    etype = f'({di},{do})'
                    G.edata[etype] = self.kernel_unary[etype](feat, basis)

            # Perform message-passing for each output feature type
            for d in self.f_out.degrees:
                G.apply_edges(self.udf_u_mul_e(d))

            return {f'{d}': G.edata[f'out{d}'] for d in self.f_out.degrees}

In [92]:
# edge dim : used to determine the dim of input (edge_dim + 1) to the radial function
# if only radial distance, the edge_dim is 0
# x_ij: type-1 displacement vector used as edge feature (more on this later)
GMAB['k'] = GConvSE3Partial(f_in, f_mid_in, edge_dim=edge_dim, x_ij = 'cat')

k = GMAB['k'](h, G = G, r = r, basis = basis)

In [95]:
# Adding the result to edge feature
G.edata['k'] = fiber2head(k, n_heads, f_mid_in)

G.edata['k'].shape

torch.Size([210, 1, 128])