# Point cloud classification with Pointnet using Keras


## Introduction

Classification, detection and segmentation of unordered 3D point sets i.e. point clouds
is a core problem in computer vision. This example implements the seminal point cloud
deep learning paper [PointNet (Qi et al., 2017)](https://arxiv.org/abs/1612.00593). 


## Setup

If using colab first install trimesh with `!pip install trimesh`.


In [10]:
!pip install trimesh
import os
import glob
import trimesh
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt

tf.random.set_seed(1234)




## Load dataset

We use the ModelNet10 model dataset, the smaller 10 class version of the ModelNet40
dataset. First download the data:


In [11]:
DATA_DIR = tf.keras.utils.get_file(
    "modelnet.zip",
    "http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip",
    extract=True,
)
DATA_DIR = os.path.join(os.path.dirname(DATA_DIR), "ModelNet10")


We can use the `trimesh` package to read and visualize the `.off` mesh files.


In [12]:
# mesh = trimesh.load(os.path.join(DATA_DIR, "chair/train/chair_0001.off"))
# mesh.show()


To convert a mesh file to a point cloud we first need to sample points on the mesh
surface. `.sample()` performs a unifrom random sampling. Here we sample at 2048 locations
and visualize in `matplotlib`.


In [13]:
# points = mesh.sample(2048)

# fig = plt.figure(figsize=(5, 5))
# ax = fig.add_subplot(111, projection="3d")
# ax.scatter(points[:, 0], points[:, 1], points[:, 2])
# ax.set_axis_off()
# plt.show()


To generate a `tf.data.Dataset()` we need to first parse through the ModelNet data
folders. Each mesh is loaded and sampled into a point cloud before being added to a
standard python list and converted to a `numpy` array. We also store the current
enumerate index value as the object label and use a dictionary to recall this later.


In [14]:

def parse_dataset(num_points=2048):

    train_points = []
    train_labels = []
    test_points = []
    test_labels = []
    class_map = {}
    folders = glob.glob(os.path.join(DATA_DIR, "[!README]*"))

    for i, folder in enumerate(folders):
        print("processing class: {}".format(os.path.basename(folder)))
        # store folder name with ID so we can retrieve later
        class_map[i] = folder.split("/")[-1]
        # gather all files
        train_files = glob.glob(os.path.join(folder, "train/*"))
        test_files = glob.glob(os.path.join(folder, "test/*"))

        for f in train_files:
            train_points.append(trimesh.load(f).sample(num_points))
            train_labels.append(i)

        for f in test_files:
            test_points.append(trimesh.load(f).sample(num_points))
            test_labels.append(i)

    return (
        np.array(train_points),
        np.array(test_points),
        np.array(train_labels),
        np.array(test_labels),
        class_map,
    )



Set the number of points to sample and batch size and parse the dataset. This can take
~5minutes to complete.


In [15]:
NUM_POINTS = 2048
NUM_CLASSES = 10
BATCH_SIZE = 32

train_points, test_points, train_labels, test_labels, CLASS_MAP = parse_dataset(
    NUM_POINTS
)


processing class: chair
processing class: bed
processing class: dresser
processing class: sofa
processing class: bathtub
processing class: desk
processing class: toilet
processing class: night_stand
processing class: monitor
processing class: table


In [16]:
train_points.shape

(3991, 2048, 3)

In [17]:
def centralize_shapes(points, scale=.95):
   center = np.mean(points, axis=-1, keepdims=True)
  #  print(center.shape)
   diam = .5 * (np.amax(points, axis=(-2,-1), keepdims=True) - np.amin(points, axis=(-2,-1), keepdims=True))
  #  print(diam.shape)
   return scale* (points - center)/diam

In [18]:
train_points = centralize_shapes(train_points)
test_points = centralize_shapes(test_points)

In [19]:
## create kernels 
import scipy
from tqdm.notebook import trange

def create_features(points, epsilon, lambda_par, use_node_feats=False, k_vals=32, choose_from='all'):
    left_eigs = []
    right_eigs = []

    for i in trange(points.shape[0]):
      Cs = scipy.spatial.distance.cdist(points[i], points[i], "minkowski", p=1)
      Cs[Cs > epsilon] = 0
      if use_node_feats is False:
        Cs = (Cs + Cs.T)/2
        eigvals = np.real(np.linalg.eigvals(Cs)) # do exponential outside so that we can try more smoothening factors
        if choose_from =='left' :
          left_eigs.append(np.sort(eigvals)[: k_vals])
        elif choose_from == 'right' :
          right_eigs.append(np.sort(eigvals)[-k_vals :])
        else : 
          a = np.sort(eigvals)
          left_eigs.append(a[: k_vals]), right_eigs.append(a[-k_vals :])

      else :
        Cs1 = scipy.linalg.expm(lambda_par * Cs) # can be vectorized
        Cs1 = (
        Cs1 + Cs1.T
        ) / 2

        right_eigs.append(np.mean(np.einsum("ij, jk -> ik", Cs1, points[i]), axis=0))

    return left_eigs, right_eigs
    


In [20]:
epsilon = .1
lambda_par = .5

In [21]:
train_feats = create_features(train_points, epsilon=epsilon, lambda_par=lambda_par)


  0%|          | 0/3991 [00:00<?, ?it/s]

In [22]:
len(train_feats[0])

3991

In [23]:
test_feats = create_features(test_points, epsilon=epsilon, lambda_par=lambda_par)

  0%|          | 0/908 [00:00<?, ?it/s]

In [24]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report

In [25]:
## let us try a bunch of things 
eps = [.5, .1, -.5, -.1, -1.]
deeps = [10,15,20,25]
for e in eps : 
  train = np.exp(e*np.real(np.array(train_feats[0])))
  test = np.exp(e*np.real(np.array(test_feats[0])))
  for d in deeps :
    clf = RandomForestClassifier(max_depth=d, random_state=0)
    clf.fit(np.array(train), train_labels)
    preds = clf.predict(test)
    print(f"Epsilon is {e} and depth is {d}")
    print(classification_report(test_labels, preds))
    print("*"*40)

Epsilon is 0.5 and depth is 10
              precision    recall  f1-score   support

           0       0.35      0.56      0.43       100
           1       0.15      0.29      0.19       100
           2       0.43      0.42      0.42        86
           3       0.29      0.50      0.36       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.29      0.17      0.22       100
           7       0.35      0.26      0.30        86
           8       0.22      0.34      0.27       100
           9       0.05      0.01      0.02       100

    accuracy                           0.27       908
   macro avg       0.21      0.25      0.22       908
weighted avg       0.22      0.27      0.23       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.5 and depth is 15
              precision    recall  f1-score   support

           0       0.34      0.54      0.42       100
           1       0.13      0.24      0.17       100
           2       0.46      0.43      0.44        86
           3       0.29      0.46      0.35       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.31      0.22      0.26       100
           7       0.38      0.24      0.30        86
           8       0.21      0.33      0.26       100
           9       0.13      0.05      0.07       100

    accuracy                           0.27       908
   macro avg       0.22      0.25      0.23       908
weighted avg       0.23      0.27      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.5 and depth is 20
              precision    recall  f1-score   support

           0       0.34      0.55      0.42       100
           1       0.14      0.26      0.18       100
           2       0.44      0.41      0.42        86
           3       0.28      0.46      0.35       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.29      0.15      0.20       100
           7       0.37      0.27      0.31        86
           8       0.18      0.28      0.22       100
           9       0.08      0.03      0.04       100

    accuracy                           0.25       908
   macro avg       0.21      0.24      0.21       908
weighted avg       0.22      0.25      0.22       908

****************************************
Epsilon is 0.5 and depth is 25
              precision    recall  f1-score   support

           0       0.36      0.60      0.45       100
           1       0.15      

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.1 and depth is 15
              precision    recall  f1-score   support

           0       0.35      0.54      0.42       100
           1       0.14      0.25      0.18       100
           2       0.43      0.41      0.42        86
           3       0.29      0.47      0.36       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.32      0.24      0.27       100
           7       0.41      0.27      0.32        86
           8       0.21      0.32      0.25       100
           9       0.11      0.04      0.06       100

    accuracy                           0.27       908
   macro avg       0.23      0.25      0.23       908
weighted avg       0.24      0.27      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.1 and depth is 20
              precision    recall  f1-score   support

           0       0.34      0.55      0.42       100
           1       0.13      0.25      0.17       100
           2       0.48      0.43      0.45        86
           3       0.27      0.45      0.34       100
           4       0.00      0.00      0.00        50
           5       0.12      0.01      0.02        86
           6       0.27      0.14      0.18       100
           7       0.44      0.30      0.36        86
           8       0.21      0.32      0.25       100
           9       0.09      0.04      0.06       100

    accuracy                           0.26       908
   macro avg       0.24      0.25      0.23       908
weighted avg       0.24      0.26      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.1 and depth is 25
              precision    recall  f1-score   support

           0       0.34      0.58      0.43       100
           1       0.17      0.30      0.21       100
           2       0.44      0.41      0.42        86
           3       0.28      0.45      0.35       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.35      0.23      0.28       100
           7       0.38      0.26      0.31        86
           8       0.20      0.29      0.23       100
           9       0.07      0.03      0.04       100

    accuracy                           0.27       908
   macro avg       0.22      0.25      0.23       908
weighted avg       0.23      0.27      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.5 and depth is 10
              precision    recall  f1-score   support

           0       0.36      0.57      0.44       100
           1       0.15      0.31      0.20       100
           2       0.47      0.42      0.44        86
           3       0.29      0.53      0.38       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.32      0.18      0.23       100
           7       0.39      0.27      0.32        86
           8       0.23      0.34      0.27       100
           9       0.09      0.02      0.03       100

    accuracy                           0.28       908
   macro avg       0.23      0.26      0.23       908
weighted avg       0.24      0.28      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.5 and depth is 15
              precision    recall  f1-score   support

           0       0.36      0.57      0.44       100
           1       0.16      0.29      0.20       100
           2       0.44      0.43      0.43        86
           3       0.29      0.50      0.37       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.37      0.23      0.28       100
           7       0.37      0.23      0.29        86
           8       0.20      0.30      0.24       100
           9       0.12      0.05      0.07       100

    accuracy                           0.28       908
   macro avg       0.23      0.26      0.23       908
weighted avg       0.24      0.28      0.25       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.5 and depth is 20
              precision    recall  f1-score   support

           0       0.34      0.54      0.42       100
           1       0.15      0.25      0.19       100
           2       0.42      0.41      0.41        86
           3       0.32      0.55      0.40       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.26      0.18      0.21       100
           7       0.34      0.22      0.27        86
           8       0.23      0.35      0.27       100
           9       0.11      0.04      0.06       100

    accuracy                           0.27       908
   macro avg       0.22      0.25      0.22       908
weighted avg       0.23      0.27      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.5 and depth is 25
              precision    recall  f1-score   support

           0       0.33      0.55      0.41       100
           1       0.16      0.28      0.21       100
           2       0.40      0.38      0.39        86
           3       0.31      0.48      0.38       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.36      0.25      0.30       100
           7       0.33      0.22      0.27        86
           8       0.22      0.33      0.26       100
           9       0.10      0.04      0.06       100

    accuracy                           0.27       908
   macro avg       0.22      0.25      0.23       908
weighted avg       0.23      0.27      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.1 and depth is 10
              precision    recall  f1-score   support

           0       0.36      0.57      0.44       100
           1       0.15      0.30      0.20       100
           2       0.47      0.42      0.44        86
           3       0.29      0.53      0.38       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.29      0.17      0.22       100
           7       0.38      0.26      0.31        86
           8       0.23      0.34      0.27       100
           9       0.09      0.02      0.03       100

    accuracy                           0.28       908
   macro avg       0.23      0.26      0.23       908
weighted avg       0.24      0.28      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.1 and depth is 15
              precision    recall  f1-score   support

           0       0.38      0.58      0.46       100
           1       0.15      0.26      0.19       100
           2       0.43      0.42      0.42        86
           3       0.30      0.52      0.38       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.36      0.24      0.29       100
           7       0.34      0.21      0.26        86
           8       0.21      0.33      0.26       100
           9       0.14      0.05      0.07       100

    accuracy                           0.28       908
   macro avg       0.23      0.26      0.23       908
weighted avg       0.24      0.28      0.25       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.1 and depth is 20
              precision    recall  f1-score   support

           0       0.34      0.56      0.43       100
           1       0.15      0.25      0.19       100
           2       0.43      0.42      0.43        86
           3       0.30      0.52      0.38       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.31      0.21      0.25       100
           7       0.37      0.23      0.29        86
           8       0.21      0.32      0.26       100
           9       0.12      0.05      0.07       100

    accuracy                           0.27       908
   macro avg       0.23      0.26      0.23       908
weighted avg       0.24      0.27      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.1 and depth is 25
              precision    recall  f1-score   support

           0       0.34      0.57      0.43       100
           1       0.15      0.25      0.19       100
           2       0.42      0.42      0.42        86
           3       0.31      0.48      0.37       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.36      0.25      0.29       100
           7       0.36      0.23      0.28        86
           8       0.20      0.31      0.25       100
           9       0.07      0.03      0.04       100

    accuracy                           0.27       908
   macro avg       0.22      0.25      0.23       908
weighted avg       0.23      0.27      0.24       908

****************************************
Epsilon is -1.0 and depth is 10
              precision    recall  f1-score   support

           0       0.37      0.58      0.45       100
           1       0.15    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -1.0 and depth is 15
              precision    recall  f1-score   support

           0       0.36      0.56      0.44       100
           1       0.16      0.29      0.20       100
           2       0.44      0.43      0.43        86
           3       0.30      0.52      0.38       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.37      0.23      0.28       100
           7       0.37      0.23      0.29        86
           8       0.21      0.31      0.25       100
           9       0.13      0.05      0.07       100

    accuracy                           0.28       908
   macro avg       0.23      0.26      0.23       908
weighted avg       0.24      0.28      0.25       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -1.0 and depth is 20
              precision    recall  f1-score   support

           0       0.35      0.55      0.43       100
           1       0.14      0.24      0.18       100
           2       0.41      0.40      0.40        86
           3       0.31      0.53      0.39       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.28      0.19      0.22       100
           7       0.33      0.22      0.27        86
           8       0.22      0.34      0.27       100
           9       0.11      0.04      0.06       100

    accuracy                           0.27       908
   macro avg       0.21      0.25      0.22       908
weighted avg       0.22      0.27      0.23       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -1.0 and depth is 25
              precision    recall  f1-score   support

           0       0.34      0.56      0.42       100
           1       0.15      0.26      0.19       100
           2       0.40      0.38      0.39        86
           3       0.31      0.48      0.38       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.34      0.24      0.28       100
           7       0.33      0.22      0.26        86
           8       0.22      0.33      0.26       100
           9       0.09      0.04      0.06       100

    accuracy                           0.27       908
   macro avg       0.22      0.25      0.22       908
weighted avg       0.23      0.27      0.24       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [26]:
eps = [.5, .1, -.5, -.1, -1.]
deeps = [10,15,20,25]
for e in eps : 
  train = np.exp(e*np.real(np.array(train_feats[1])))
  test = np.exp(e*np.real(np.array(test_feats[1])))
  for d in deeps :
    clf = RandomForestClassifier(max_depth=d, random_state=0)
    clf.fit(np.array(train), train_labels)
    preds = clf.predict(test)
    print(f"Epsilon is {e} and depth is {d}")
    print(classification_report(test_labels, preds))
    print("*"*40)

Epsilon is 0.5 and depth is 10
              precision    recall  f1-score   support

           0       0.36      0.58      0.45       100
           1       0.18      0.31      0.23       100
           2       0.58      0.49      0.53        86
           3       0.33      0.55      0.41       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.40      0.41      0.41       100
           7       0.51      0.42      0.46        86
           8       0.25      0.35      0.29       100
           9       0.27      0.07      0.11       100

    accuracy                           0.34       908
   macro avg       0.29      0.32      0.29       908
weighted avg       0.30      0.34      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.5 and depth is 15
              precision    recall  f1-score   support

           0       0.38      0.60      0.46       100
           1       0.20      0.32      0.25       100
           2       0.58      0.49      0.53        86
           3       0.29      0.51      0.37       100
           4       0.00      0.00      0.00        50
           5       0.20      0.01      0.02        86
           6       0.42      0.41      0.42       100
           7       0.51      0.41      0.45        86
           8       0.29      0.38      0.33       100
           9       0.23      0.09      0.13       100

    accuracy                           0.34       908
   macro avg       0.31      0.32      0.30       908
weighted avg       0.32      0.34      0.31       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.5 and depth is 20
              precision    recall  f1-score   support

           0       0.34      0.54      0.42       100
           1       0.23      0.32      0.27       100
           2       0.56      0.47      0.51        86
           3       0.30      0.50      0.38       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.40      0.42      0.41       100
           7       0.49      0.42      0.45        86
           8       0.24      0.35      0.29       100
           9       0.33      0.14      0.20       100

    accuracy                           0.33       908
   macro avg       0.29      0.32      0.29       908
weighted avg       0.30      0.33      0.31       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.5 and depth is 25
              precision    recall  f1-score   support

           0       0.34      0.61      0.44       100
           1       0.21      0.34      0.26       100
           2       0.55      0.48      0.51        86
           3       0.34      0.49      0.40       100
           4       0.00      0.00      0.00        50
           5       0.14      0.01      0.02        86
           6       0.43      0.38      0.40       100
           7       0.49      0.41      0.44        86
           8       0.24      0.33      0.28       100
           9       0.22      0.09      0.13       100

    accuracy                           0.33       908
   macro avg       0.30      0.31      0.29       908
weighted avg       0.31      0.33      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.1 and depth is 10
              precision    recall  f1-score   support

           0       0.36      0.58      0.44       100
           1       0.18      0.32      0.23       100
           2       0.58      0.49      0.53        86
           3       0.33      0.55      0.41       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.41      0.39      0.40       100
           7       0.51      0.41      0.45        86
           8       0.24      0.34      0.28       100
           9       0.27      0.07      0.11       100

    accuracy                           0.33       908
   macro avg       0.29      0.31      0.29       908
weighted avg       0.30      0.33      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.1 and depth is 15
              precision    recall  f1-score   support

           0       0.37      0.59      0.46       100
           1       0.22      0.35      0.27       100
           2       0.58      0.49      0.53        86
           3       0.29      0.49      0.37       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.40      0.42      0.41       100
           7       0.51      0.40      0.44        86
           8       0.30      0.40      0.34       100
           9       0.26      0.10      0.14       100

    accuracy                           0.34       908
   macro avg       0.29      0.32      0.30       908
weighted avg       0.31      0.34      0.31       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.1 and depth is 20
              precision    recall  f1-score   support

           0       0.37      0.59      0.45       100
           1       0.22      0.31      0.26       100
           2       0.57      0.49      0.53        86
           3       0.32      0.50      0.39       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.40      0.43      0.41       100
           7       0.50      0.43      0.46        86
           8       0.24      0.37      0.29       100
           9       0.36      0.13      0.19       100

    accuracy                           0.34       908
   macro avg       0.30      0.32      0.30       908
weighted avg       0.31      0.34      0.31       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is 0.1 and depth is 25
              precision    recall  f1-score   support

           0       0.36      0.61      0.45       100
           1       0.21      0.33      0.26       100
           2       0.55      0.48      0.51        86
           3       0.33      0.50      0.40       100
           4       0.00      0.00      0.00        50
           5       0.14      0.01      0.02        86
           6       0.41      0.36      0.38       100
           7       0.49      0.41      0.44        86
           8       0.25      0.36      0.29       100
           9       0.22      0.09      0.13       100

    accuracy                           0.33       908
   macro avg       0.29      0.31      0.29       908
weighted avg       0.31      0.33      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.5 and depth is 10
              precision    recall  f1-score   support

           0       0.36      0.58      0.44       100
           1       0.20      0.37      0.26       100
           2       0.51      0.44      0.48        86
           3       0.29      0.49      0.37       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.41      0.38      0.40       100
           7       0.48      0.37      0.42        86
           8       0.25      0.35      0.29       100
           9       0.32      0.07      0.11       100

    accuracy                           0.32       908
   macro avg       0.28      0.31      0.28       908
weighted avg       0.30      0.32      0.29       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.5 and depth is 15
              precision    recall  f1-score   support

           0       0.36      0.60      0.45       100
           1       0.22      0.34      0.27       100
           2       0.58      0.44      0.50        86
           3       0.30      0.47      0.37       100
           4       0.00      0.00      0.00        50
           5       0.40      0.02      0.04        86
           6       0.40      0.42      0.41       100
           7       0.49      0.45      0.47        86
           8       0.27      0.37      0.31       100
           9       0.24      0.09      0.13       100

    accuracy                           0.34       908
   macro avg       0.33      0.32      0.30       908
weighted avg       0.34      0.34      0.31       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.5 and depth is 20
              precision    recall  f1-score   support

           0       0.33      0.57      0.42       100
           1       0.23      0.33      0.27       100
           2       0.49      0.45      0.47        86
           3       0.29      0.44      0.35       100
           4       0.00      0.00      0.00        50
           5       0.14      0.01      0.02        86
           6       0.38      0.38      0.38       100
           7       0.48      0.37      0.42        86
           8       0.27      0.39      0.32       100
           9       0.33      0.14      0.20       100

    accuracy                           0.33       908
   macro avg       0.29      0.31      0.28       908
weighted avg       0.31      0.33      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.5 and depth is 25
              precision    recall  f1-score   support

           0       0.35      0.57      0.43       100
           1       0.20      0.31      0.25       100
           2       0.57      0.48      0.52        86
           3       0.29      0.46      0.36       100
           4       0.00      0.00      0.00        50
           5       0.11      0.01      0.02        86
           6       0.37      0.37      0.37       100
           7       0.47      0.40      0.43        86
           8       0.26      0.36      0.30       100
           9       0.26      0.11      0.15       100

    accuracy                           0.32       908
   macro avg       0.29      0.31      0.28       908
weighted avg       0.30      0.32      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.1 and depth is 10
              precision    recall  f1-score   support

           0       0.36      0.58      0.45       100
           1       0.21      0.38      0.27       100
           2       0.52      0.44      0.48        86
           3       0.29      0.49      0.37       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.41      0.38      0.40       100
           7       0.47      0.37      0.42        86
           8       0.26      0.36      0.30       100
           9       0.33      0.08      0.13       100

    accuracy                           0.33       908
   macro avg       0.29      0.31      0.28       908
weighted avg       0.30      0.33      0.29       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.1 and depth is 15
              precision    recall  f1-score   support

           0       0.36      0.59      0.45       100
           1       0.23      0.37      0.28       100
           2       0.60      0.47      0.52        86
           3       0.30      0.48      0.37       100
           4       0.00      0.00      0.00        50
           5       0.50      0.03      0.07        86
           6       0.40      0.40      0.40       100
           7       0.49      0.43      0.46        86
           8       0.26      0.35      0.30       100
           9       0.24      0.10      0.14       100

    accuracy                           0.34       908
   macro avg       0.34      0.32      0.30       908
weighted avg       0.35      0.34      0.31       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.1 and depth is 20
              precision    recall  f1-score   support

           0       0.33      0.56      0.41       100
           1       0.22      0.34      0.26       100
           2       0.49      0.44      0.47        86
           3       0.29      0.44      0.35       100
           4       0.00      0.00      0.00        50
           5       0.17      0.01      0.02        86
           6       0.39      0.39      0.39       100
           7       0.46      0.37      0.41        86
           8       0.27      0.37      0.31       100
           9       0.32      0.12      0.18       100

    accuracy                           0.32       908
   macro avg       0.29      0.30      0.28       908
weighted avg       0.31      0.32      0.29       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -0.1 and depth is 25
              precision    recall  f1-score   support

           0       0.35      0.57      0.44       100
           1       0.19      0.29      0.23       100
           2       0.54      0.45      0.49        86
           3       0.30      0.48      0.37       100
           4       0.00      0.00      0.00        50
           5       0.17      0.01      0.02        86
           6       0.38      0.37      0.37       100
           7       0.48      0.41      0.44        86
           8       0.26      0.37      0.31       100
           9       0.26      0.11      0.15       100

    accuracy                           0.32       908
   macro avg       0.29      0.31      0.28       908
weighted avg       0.30      0.32      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -1.0 and depth is 10
              precision    recall  f1-score   support

           0       0.36      0.58      0.44       100
           1       0.20      0.37      0.26       100
           2       0.51      0.44      0.48        86
           3       0.29      0.49      0.37       100
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        86
           6       0.41      0.38      0.40       100
           7       0.48      0.37      0.42        86
           8       0.25      0.35      0.29       100
           9       0.32      0.07      0.11       100

    accuracy                           0.32       908
   macro avg       0.28      0.31      0.28       908
weighted avg       0.30      0.32      0.29       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -1.0 and depth is 15
              precision    recall  f1-score   support

           0       0.36      0.60      0.45       100
           1       0.22      0.34      0.27       100
           2       0.58      0.44      0.50        86
           3       0.30      0.47      0.37       100
           4       0.00      0.00      0.00        50
           5       0.40      0.02      0.04        86
           6       0.40      0.42      0.41       100
           7       0.49      0.45      0.47        86
           8       0.27      0.37      0.31       100
           9       0.24      0.09      0.13       100

    accuracy                           0.34       908
   macro avg       0.33      0.32      0.30       908
weighted avg       0.34      0.34      0.31       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -1.0 and depth is 20
              precision    recall  f1-score   support

           0       0.33      0.57      0.42       100
           1       0.23      0.33      0.27       100
           2       0.48      0.44      0.46        86
           3       0.29      0.44      0.35       100
           4       0.00      0.00      0.00        50
           5       0.14      0.01      0.02        86
           6       0.39      0.39      0.39       100
           7       0.47      0.37      0.42        86
           8       0.27      0.40      0.33       100
           9       0.34      0.14      0.20       100

    accuracy                           0.33       908
   macro avg       0.30      0.31      0.29       908
weighted avg       0.31      0.33      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epsilon is -1.0 and depth is 25
              precision    recall  f1-score   support

           0       0.35      0.57      0.43       100
           1       0.20      0.30      0.24       100
           2       0.57      0.48      0.52        86
           3       0.29      0.46      0.36       100
           4       0.00      0.00      0.00        50
           5       0.11      0.01      0.02        86
           6       0.37      0.37      0.37       100
           7       0.48      0.41      0.44        86
           8       0.26      0.37      0.30       100
           9       0.26      0.11      0.15       100

    accuracy                           0.32       908
   macro avg       0.29      0.31      0.28       908
weighted avg       0.30      0.32      0.30       908

****************************************


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [27]:
import scipy
import numpy as np
import matplotlib.pylab as pl
from mpl_toolkits.mplot3d import Axes3D  # noqa
from scipy import linalg 
from typing import List, Dict, Tuple, Callable
import random
from scipy.sparse import csr_matrix
from scipy.special import softmax
import time
from tqdm.notebook import trange

def random_projection_creator(num_random_features: int, 
                              dim: int, 
                              scaling=0, 
                              ortho=True) -> np.ndarray:
    seed = 0
    np.random.seed(seed)
    if not ortho:
        final_matrix = np.random.normal(size=(num_random_features, dim))
    else:
        nb_full_blocks = int(num_random_features / dim)
        block_list = []
        for _ in range(nb_full_blocks):
            unstructured_block = np.random.normal(size=(dim, dim))
            q, _ = np.linalg.qr(unstructured_block)
            q = np.transpose(q)
            block_list.append(q)
        remaining_rows = num_random_features - nb_full_blocks * dim
        if remaining_rows > 0:
            unstructured_block = np.random.normal(size=(dim, dim))
            q, _ = np.linalg.qr(unstructured_block)
            q = np.transpose(q)
            block_list.append(q[0:remaining_rows])
        final_matrix = np.vstack(block_list)

    if scaling == 0:
        multiplier = np.linalg.norm(np.random.normal(size=(num_random_features, dim)), axis=1)
    elif scaling == 1:
        multiplier = np.sqrt(float(dim)) * np.ones((num_random_features))
    else:
        raise ValueError('Scaling must be one of {0, 1}. Was %s' % scaling)

    return np.matmul(np.diag(multiplier), final_matrix)


def fourier_transform(input_projection: np.ndarray, epsilon: float, norm_type='L1'):
    if norm_type == 'L1':
        return np.prod(np.sin(2.0 * epsilon* input_projection) / input_projection)
    elif norm_type == 'L2':
        pass
    else:
        pass 

def density_function(input_projection: np.ndarray) -> np.ndarray:
    """
    density function of the probabilistic distribution applied by the below 
    random_projection_creator to construct projections
    
    both the input and output are 1d numpy arrays
    """
    dim = len(input_projection)
    length = linalg.norm(input_projection)
    return (1.0 / np.power(2.0 * np.pi, dim / 2.0)) * np.exp(-length**2 / 2.0)


def construct_random_features(positions: np.ndarray, 
                              random_projection_creator: Callable, 
                              density_function: Callable, 
                              num_rand_features: int, 
                              fourier_transform: Callable, 
                              epsilon: float):
    """
    this function is used by graph diffusion GFIntegrator
    
    parameter definitions are the same as in graph_diffusion_gf_integrator
    """
    dim = len(positions[0])
    projection_matrix = random_projection_creator(num_rand_features, dim)
    projected_positions = np.einsum('md,nd->nm', projection_matrix, positions)
    exp_projected_positions = np.exp(2.0 * np.pi * 1j * projected_positions)
    ft_with_eps = lambda x: fourier_transform(x, epsilon)
    fts = np.apply_along_axis(ft_with_eps, 1, projection_matrix)
    dens = np.apply_along_axis(density_function, 1, projection_matrix)
    renormalizers = fts / dens
    rfs = np.einsum('nm,m->nm',exp_projected_positions, renormalizers)
    return (1.0 / np.sqrt(num_rand_features)) * rfs

class DFGFIntegrator(object):
    
    def __init__(self, positions: np.ndarray, 
                 epsilon: float, 
                 lambda_par: float, 
                 num_rand_features: int, 
                 dim: int,
                 random_projection_creator: Callable, 
                 density_function: Callable, 
                 fourier_transform: Callable):

        self._positions = positions
        self._a_matrix = construct_random_features(positions, 
                                                   random_projection_creator, 
                                                   density_function,
                                                   num_rand_features,
                                                   fourier_transform,
                                                   epsilon)          
        self._bt_matrix = construct_random_features(-positions, 
                                                    random_projection_creator,
                                                    density_function, 
                                                    num_rand_features,
                                                    fourier_transform,
                                                    epsilon)
        self._bt_matrix = np.transpose(self._bt_matrix)
        bta = np.matmul(self._bt_matrix, self._a_matrix)
        self._invbta_matrix = linalg.inv(bta)     
        self._expbta_matrix = linalg.expm(lambda_par * bta)
        self._expbta_matrix -= np.identity(num_rand_features)  
         
    def integrate_graph_field(self, field: np.ndarray):
        res = np.einsum('mn,n...->m...',self._bt_matrix, field)
        res = np.einsum('am,m...->a...', self._invbta_matrix, res)
        res = np.einsum('am,m...->a...', self._expbta_matrix, res)
        res = np.einsum('nm,m...->n...', self._a_matrix, res)
        return np.real(field + res)

In [28]:
def create_rfd_feats(points, choose_from='all', k_vals=32, epsilon=.1, lambda_par=-.1):
    left_eigs = []
    right_eigs = []
    print(points.shape)
    for i in trange(points.shape[0]):
      integrator = DFGFIntegrator(
                points[i],
                epsilon,
                lambda_par,
                16,
                3,
                random_projection_creator,
                density_function,
                fourier_transform,
            ) 
      eigvals = np.real(np.linalg.eigvals(integrator.integrate_graph_field(np.eye(points[i].shape[0]))))
      if choose_from =='left' :
        left_eigs.append(np.sort(eigvals)[: k_vals]) #k_vals means a subset of the number of eigenvalues of the kernel matrix
      elif choose_from == 'right' :
        right_eigs.append(np.sort(eigvals)[-k_vals :])
      else : 
        a = np.sort(eigvals)
        left_eigs.append(a[: k_vals]), right_eigs.append(a[-k_vals :])

    return left_eigs, right_eigs



In [29]:
test_feats = create_rfd_feats(test_points)

(908, 2048, 3)


  0%|          | 0/908 [00:00<?, ?it/s]

In [30]:
train_feats = create_rfd_feats(train_points)

(3991, 2048, 3)


  0%|          | 0/3991 [00:00<?, ?it/s]

In [31]:
deeps = [10,15,20,25]
for d in deeps :
    clf = RandomForestClassifier(max_depth=d, random_state=0)
    clf.fit(np.array(train_feats[0]), train_labels)
    preds = clf.predict(test_feats[0])
    print(f"depth is {d}")
    print(classification_report(test_labels, preds))
    print("*"*40)

depth is 10
              precision    recall  f1-score   support

           0       0.69      0.90      0.78       100
           1       0.75      0.86      0.80       100
           2       0.73      0.64      0.68        86
           3       0.54      0.91      0.68       100
           4       1.00      0.20      0.33        50
           5       0.41      0.31      0.36        86
           6       0.81      0.71      0.76       100
           7       0.64      0.65      0.65        86
           8       0.75      0.76      0.75       100
           9       0.52      0.35      0.42       100

    accuracy                           0.66       908
   macro avg       0.68      0.63      0.62       908
weighted avg       0.67      0.66      0.64       908

****************************************
depth is 15
              precision    recall  f1-score   support

           0       0.71      0.89      0.79       100
           1       0.81      0.84      0.82       100
           2 