# Triton Meets ArangoDB: Amazon Product Recommendation Application using GraphSage Algorithm

In this notebook, we will build an Amazon Product Recommendation application by leveraging three technologies at a time i.e. Graph Machine Learning, Nvidia's Triton inference server and ArangoDB. 

In [1]:
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv
import os.path as osp
import pandas as pd
import numpy as np
import collections
from pandas.core.common import flatten
# importing obg datatset
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from pandas.core.common import flatten
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(rc={'figure.figsize':(14.7,4.27)})
sns.set_theme(style="ticks")
import collections
from scipy.special import softmax
import umap
from arango import ArangoClient
import oasis
import pprint
import tritongrpcclient
import tritongrpcclient.model_config_pb2 as mc
import tritonhttpclient
from tritonclientutils import triton_to_np_dtype
from tritonclientutils import InferenceServerException
from scipy.special import softmax
import mplcursors



In [2]:
root = "/home/sachin/Desktop/arangoml/datasets"


In [3]:
dataset = PygNodePropPredDataset('ogbn-products', root)

In [4]:
# getting train val test split idx based on sales ranking
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-products')
data = dataset[0]

In [5]:
# storing train, validation and test node indices
train_idx = split_idx['train']
valid_idx = split_idx['valid']
test_idx = split_idx['test']


In [6]:
# test node indexes
test_idx

tensor([ 235938,  235939,  235940,  ..., 2449026, 2449027, 2449028])

In [7]:
# neighborhood sampling of test nodes
test_loader = NeighborSampler(data.edge_index, node_idx=test_idx,
                              sizes=[15, 10, 5], batch_size=1,
                              shuffle=False, num_workers=12)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
# node feature matrix
x = data.x.to(device)

In [10]:
# labels
y = data.y.to(device)

In [11]:
# creating adjs for performing a trace on the GraphSage model
# will contain only edge_idx and size attributes

max_nodes = 1000
def create_triton_input(dummy_n_ids, dummy_adjs):
    edge_list_0 = []
    edge_list_1 = []
    edge_list_2 = []
    edge_adjs = []
    for idx, e_idx in enumerate(dummy_adjs[0]):
        if idx == 0:
            edge_list_0.append(e_idx[0])
            #edge_list_0.append(e_idx[1])
            edge_list_0.append(torch.tensor(np.asarray(e_idx[2])))
        elif idx == 1:
            edge_list_1.append(e_idx[0])
            #edge_list_1.append(e_idx[1])
            edge_list_1.append(torch.tensor(np.asarray(e_idx[2])))
        else:
            edge_list_2.append(e_idx[0])
            #edge_list_2.append(e_idx[1])
            edge_list_2.append(torch.tensor(np.asarray(e_idx[2])))
    
    # creating edge indexes
    edge_index_0 = edge_list_0[0]
    edge_index_0 = edge_index_0.to(device)
    edge_size_0 = edge_list_0[1]
    edge_size_0 = edge_size_0.to(device)

    edge_index_1 = edge_list_1[0]
    edge_index_1 = edge_index_1.to(device)
    edge_size_1 = edge_list_1[1]
    edge_size_1 = edge_size_1.to(device)

    edge_index_2 = edge_list_2[0]
    edge_index_2 = edge_index_2.to(device)
    edge_size_2 = edge_list_2[1]
    edge_size_2 = edge_size_2.to(device)
    
    
    # add padding to node feature matrix
    dummy_x = x[dummy_n_ids[0]]
    total_nodes = dummy_x.size(0)
    nodes_padded = max_nodes - total_nodes
    dummy_x_pad = F.pad(input=dummy_x, pad=(0, 0, 0, nodes_padded), mode='constant', value=0)
    dummy_x_pad = dummy_x_pad.to(device)
    
    return dummy_x_pad, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2

In [12]:
# setting inputs and output names as same in config file
input_name = ['input__0', 'input__1', 'input__2', 'input__3', 'input__4', 'input__5', 'input__6']
output_name = ['output__0', 'output__1', 'output__2']
VERBOSE = False
from tritonclient.utils import *

## Client-Side Script to Interact with Triton Inference Server

The run_inference function computes the node embeddings of a given node at three different layers of trained GraphSage model and returns the same. This function requires 7 inputs:

{ node_matrix: Padded node feature matrix consiting of nodes involved in the computation      graph

  edge_index_0: adjacency list for all the edges involved at the Hop-3 (layer-3)
  
  edge_size_0 : shape of the bipartite graph at Hop-3

  edge_index_1: adjacency list for all the edges involved at the Hop-2 (layer-2)
  
  edge_size_1 : shape of the bipartite graph at Hop-2
  
  edge_index_2: adjacency list for all the edges involved at the Hop-1 (layer-1)
  
  edge_size_2 : shape of the bipartite graph at Hop-1
}

Note: Neighborhood sampler returns ajacency list in reversed order

In [13]:
def run_inference(node_matrix, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2, model_name='graph_embeddings', url='127.0.0.1:8000', model_version='1'):
    triton_client = tritonhttpclient.InferenceServerClient(
        url=url, verbose=VERBOSE)
    model_metadata = triton_client.get_model_metadata(
        model_name=model_name, model_version=model_version)
    model_config = triton_client.get_model_config(
        model_name=model_name, model_version=model_version)
    # I have restricted the input sequence length to 256

    input_node_matrix = node_matrix
    input_node_matrix  = np.array(input_node_matrix.cpu(), dtype=np.float32)
    
    # edges_indexes and sizes
    ed_index_0 = np.array(edge_index_0.cpu(), dtype=np.int64)
    ed_index_1 = np.array(edge_index_1.cpu(), dtype=np.int64)
    ed_index_2 = np.array(edge_index_2.cpu(), dtype=np.int64)
    
    ed_size_0 = np.array(edge_size_0.cpu(), dtype=np.int64)
    ed_size_1 = np.array(edge_size_1.cpu(), dtype=np.int64)
    ed_size_2 = np.array(edge_size_2.cpu(), dtype=np.int64)
    


    input0 = tritonhttpclient.InferInput(input_name[0], (1000,  100), 'FP32')
    input0.set_data_from_numpy(input_node_matrix, binary_data=False)
    
    #layer-1
    input1 = tritonhttpclient.InferInput(input_name[1], ed_index_0.shape, 'INT64')
    input1.set_data_from_numpy(ed_index_0, binary_data=False)
    #size
    input2 = tritonhttpclient.InferInput(input_name[2], (2,), 'INT64')
    input2.set_data_from_numpy(ed_size_0, binary_data=False)
    
    #layer-2
    input3 = tritonhttpclient.InferInput(input_name[3], ed_index_1.shape, 'INT64')
    input3.set_data_from_numpy(ed_index_1, binary_data=False)
    #size
    input4 = tritonhttpclient.InferInput(input_name[4], (2,), 'INT64')
    input4.set_data_from_numpy(ed_size_1, binary_data=False)
    
    #layer-3
    input5 = tritonhttpclient.InferInput(input_name[5], ed_index_2.shape, 'INT64')
    input5.set_data_from_numpy(ed_index_2, binary_data=False)
    #size
    input6 = tritonhttpclient.InferInput(input_name[6], (2,), 'INT64')
    input6.set_data_from_numpy(ed_size_2, binary_data=False)
    
    output0 = tritonhttpclient.InferRequestedOutput(output_name[0],  binary_data=False)
    output1 = tritonhttpclient.InferRequestedOutput(output_name[1],  binary_data=False)
    output2 = tritonhttpclient.InferRequestedOutput(output_name[2],  binary_data=False)
    
    response = triton_client.infer(model_name, model_version=model_version, inputs=[input0, input1, input2, input3, input4, input5, input6], outputs=[output0, output1, output2])
    # layer-1 embeddings
    embeddings_layer_1 = response.as_numpy('output__0')
    # layer-2 embeddings
    embeddings_layer_2 = response.as_numpy('output__1')
    # # layer-3 embeddings
    embeddings_layer_3 = response.as_numpy('output__2')
    return embeddings_layer_1, embeddings_layer_2, embeddings_layer_3

In [14]:
# load integer to real product category label mapping
df = pd.read_csv('/home/sachin/Desktop/arangoml/datasets/ogbn_products/mapping/labelidx2productcategory.csv.gz')

In [15]:
label_idx, prod_cat = df.iloc[: ,0].values, df.iloc[: ,1].values
label_mapping = dict(zip(label_idx, prod_cat))

For the demonstration purpose we will use first 5000 test nodes for the Inference


In [16]:
# selecting test nodes and its adjacency matrix

layer_3_embs = []
layer_2_embs = []
for idx, (batch_size, n_id, adjs) in enumerate(test_loader):
        print("idx:", idx)
        dummy_n_ids = []
        dummy_adjs = []
        dummy_n_ids.append(n_id)
        dummy_adjs.append(adjs)
        
        if len(dummy_n_ids[0]) == 1:
            print("Found Disconnected Node in the graph at index:", idx)
            layer_3_embs.append("Disconnected Node")
        elif idx == 5000:
            break
        else:
            # creating triton input
            dummy_x_pad, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2 = create_triton_input(dummy_n_ids, dummy_adjs)
            # generating node embeddings for test node from Triton Server
            emb1, emb2, emb3 = run_inference(dummy_x_pad, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2)
            layer_3_embs.append(emb3[0])

        

idx: 0
idx: 1
idx: 2
idx: 3
idx: 4
idx: 5
idx: 6
idx: 7
idx: 8
idx: 9
idx: 10
idx: 11
idx: 12
idx: 13
idx: 14
idx: 15
idx: 16
idx: 17
idx: 18
idx: 19
idx: 20
idx: 21
idx: 22
idx: 23
idx: 24
idx: 25
idx: 26
idx: 27
idx: 28
idx: 29
idx: 30
idx: 31
idx: 32
idx: 33
idx: 34
idx: 35
Found Disconnected Node in the graph at index: 35
idx: 36
idx: 37
idx: 38
idx: 39
idx: 40
idx: 41
idx: 42
idx: 43
idx: 44
idx: 45
idx: 46
idx: 47
idx: 48
idx: 49
idx: 50
idx: 51
idx: 52
idx: 53
idx: 54
idx: 55
idx: 56
idx: 57
idx: 58
idx: 59
idx: 60
idx: 61
idx: 62
idx: 63
idx: 64
idx: 65
idx: 66
idx: 67
idx: 68
idx: 69
idx: 70
idx: 71
idx: 72
idx: 73
idx: 74
idx: 75
idx: 76
idx: 77
idx: 78
idx: 79
idx: 80
idx: 81
idx: 82
idx: 83
idx: 84
idx: 85
idx: 86
idx: 87
idx: 88
idx: 89
idx: 90
idx: 91
idx: 92
idx: 93
idx: 94
idx: 95
idx: 96
idx: 97
idx: 98
idx: 99
idx: 100
idx: 101
idx: 102
idx: 103
idx: 104
idx: 105
idx: 106
idx: 107
idx: 108
idx: 109
idx: 110
idx: 111
idx: 112
idx: 113
idx: 114
idx: 115
idx: 116
idx: 11

idx: 827
idx: 828
idx: 829
idx: 830
Found Disconnected Node in the graph at index: 830
idx: 831
idx: 832
idx: 833
idx: 834
idx: 835
idx: 836
idx: 837
idx: 838
idx: 839
idx: 840
idx: 841
idx: 842
idx: 843
idx: 844
idx: 845
idx: 846
idx: 847
Found Disconnected Node in the graph at index: 847
idx: 848
idx: 849
idx: 850
idx: 851
idx: 852
idx: 853
idx: 854
idx: 855
idx: 856
idx: 857
idx: 858
idx: 859
idx: 860
idx: 861
idx: 862
Found Disconnected Node in the graph at index: 862
idx: 863
Found Disconnected Node in the graph at index: 863
idx: 864
idx: 865
idx: 866
idx: 867
idx: 868
idx: 869
idx: 870
idx: 871
idx: 872
idx: 873
idx: 874
idx: 875
idx: 876
idx: 877
idx: 878
idx: 879
idx: 880
idx: 881
idx: 882
idx: 883
idx: 884
idx: 885
idx: 886
idx: 887
idx: 888
idx: 889
idx: 890
idx: 891
idx: 892
idx: 893
idx: 894
idx: 895
idx: 896
idx: 897
idx: 898
idx: 899
idx: 900
idx: 901
idx: 902
idx: 903
idx: 904
idx: 905
idx: 906
idx: 907
idx: 908
idx: 909
idx: 910
idx: 911
idx: 912
idx: 913
idx: 914
idx:

idx: 1592
idx: 1593
idx: 1594
idx: 1595
Found Disconnected Node in the graph at index: 1595
idx: 1596
idx: 1597
idx: 1598
idx: 1599
idx: 1600
idx: 1601
idx: 1602
idx: 1603
idx: 1604
idx: 1605
idx: 1606
idx: 1607
idx: 1608
idx: 1609
idx: 1610
idx: 1611
idx: 1612
idx: 1613
idx: 1614
idx: 1615
idx: 1616
idx: 1617
idx: 1618
idx: 1619
idx: 1620
idx: 1621
idx: 1622
idx: 1623
idx: 1624
idx: 1625
idx: 1626
idx: 1627
idx: 1628
idx: 1629
idx: 1630
idx: 1631
idx: 1632
idx: 1633
idx: 1634
idx: 1635
idx: 1636
idx: 1637
idx: 1638
idx: 1639
idx: 1640
idx: 1641
idx: 1642
idx: 1643
idx: 1644
idx: 1645
idx: 1646
idx: 1647
idx: 1648
idx: 1649
idx: 1650
idx: 1651
idx: 1652
idx: 1653
idx: 1654
idx: 1655
idx: 1656
idx: 1657
idx: 1658
idx: 1659
idx: 1660
idx: 1661
idx: 1662
idx: 1663
idx: 1664
idx: 1665
idx: 1666
idx: 1667
idx: 1668
idx: 1669
idx: 1670
Found Disconnected Node in the graph at index: 1670
idx: 1671
idx: 1672
idx: 1673
idx: 1674
idx: 1675
idx: 1676
idx: 1677
idx: 1678
idx: 1679
idx: 1680
idx: 1

idx: 2338
Found Disconnected Node in the graph at index: 2338
idx: 2339
idx: 2340
idx: 2341
idx: 2342
idx: 2343
idx: 2344
idx: 2345
idx: 2346
idx: 2347
idx: 2348
idx: 2349
idx: 2350
idx: 2351
idx: 2352
idx: 2353
idx: 2354
idx: 2355
idx: 2356
idx: 2357
idx: 2358
idx: 2359
idx: 2360
idx: 2361
idx: 2362
idx: 2363
idx: 2364
idx: 2365
idx: 2366
idx: 2367
idx: 2368
idx: 2369
idx: 2370
idx: 2371
idx: 2372
idx: 2373
idx: 2374
idx: 2375
idx: 2376
Found Disconnected Node in the graph at index: 2376
idx: 2377
idx: 2378
idx: 2379
idx: 2380
idx: 2381
idx: 2382
idx: 2383
idx: 2384
idx: 2385
idx: 2386
idx: 2387
idx: 2388
idx: 2389
idx: 2390
idx: 2391
idx: 2392
idx: 2393
idx: 2394
idx: 2395
idx: 2396
idx: 2397
idx: 2398
idx: 2399
idx: 2400
idx: 2401
idx: 2402
idx: 2403
idx: 2404
idx: 2405
idx: 2406
idx: 2407
idx: 2408
idx: 2409
idx: 2410
idx: 2411
idx: 2412
idx: 2413
idx: 2414
idx: 2415
idx: 2416
idx: 2417
idx: 2418
idx: 2419
idx: 2420
idx: 2421
idx: 2422
idx: 2423
idx: 2424
idx: 2425
idx: 2426
idx: 2

idx: 3121
idx: 3122
idx: 3123
Found Disconnected Node in the graph at index: 3123
idx: 3124
idx: 3125
idx: 3126
idx: 3127
idx: 3128
idx: 3129
idx: 3130
idx: 3131
idx: 3132
idx: 3133
idx: 3134
idx: 3135
idx: 3136
idx: 3137
idx: 3138
idx: 3139
idx: 3140
idx: 3141
idx: 3142
idx: 3143
idx: 3144
idx: 3145
idx: 3146
idx: 3147
idx: 3148
idx: 3149
idx: 3150
idx: 3151
idx: 3152
idx: 3153
idx: 3154
idx: 3155
idx: 3156
idx: 3157
idx: 3158
idx: 3159
idx: 3160
idx: 3161
idx: 3162
idx: 3163
idx: 3164
idx: 3165
idx: 3166
idx: 3167
idx: 3168
idx: 3169
idx: 3170
idx: 3171
idx: 3172
idx: 3173
idx: 3174
idx: 3175
idx: 3176
idx: 3177
idx: 3178
idx: 3179
Found Disconnected Node in the graph at index: 3179
idx: 3180
idx: 3181
idx: 3182
idx: 3183
idx: 3184
idx: 3185
idx: 3186
idx: 3187
idx: 3188
idx: 3189
idx: 3190
idx: 3191
idx: 3192
idx: 3193
idx: 3194
idx: 3195
idx: 3196
idx: 3197
idx: 3198
idx: 3199
idx: 3200
idx: 3201
idx: 3202
idx: 3203
idx: 3204
idx: 3205
idx: 3206
idx: 3207
idx: 3208
idx: 3209
idx: 3

idx: 3864
idx: 3865
idx: 3866
idx: 3867
idx: 3868
idx: 3869
idx: 3870
idx: 3871
idx: 3872
idx: 3873
idx: 3874
idx: 3875
idx: 3876
idx: 3877
idx: 3878
idx: 3879
idx: 3880
idx: 3881
idx: 3882
idx: 3883
idx: 3884
idx: 3885
idx: 3886
idx: 3887
idx: 3888
idx: 3889
idx: 3890
idx: 3891
idx: 3892
idx: 3893
idx: 3894
idx: 3895
idx: 3896
idx: 3897
idx: 3898
idx: 3899
idx: 3900
idx: 3901
idx: 3902
idx: 3903
idx: 3904
idx: 3905
idx: 3906
idx: 3907
idx: 3908
idx: 3909
idx: 3910
idx: 3911
idx: 3912
idx: 3913
idx: 3914
Found Disconnected Node in the graph at index: 3914
idx: 3915
idx: 3916
idx: 3917
idx: 3918
idx: 3919
idx: 3920
idx: 3921
idx: 3922
idx: 3923
Found Disconnected Node in the graph at index: 3923
idx: 3924
idx: 3925
idx: 3926
idx: 3927
idx: 3928
idx: 3929
idx: 3930
idx: 3931
idx: 3932
idx: 3933
idx: 3934
idx: 3935
idx: 3936
idx: 3937
idx: 3938
idx: 3939
idx: 3940
idx: 3941
idx: 3942
idx: 3943
idx: 3944
idx: 3945
idx: 3946
idx: 3947
idx: 3948
Found Disconnected Node in the graph at index:

idx: 4602
idx: 4603
idx: 4604
idx: 4605
idx: 4606
idx: 4607
idx: 4608
idx: 4609
Found Disconnected Node in the graph at index: 4609
idx: 4610
idx: 4611
idx: 4612
idx: 4613
idx: 4614
idx: 4615
idx: 4616
idx: 4617
idx: 4618
idx: 4619
idx: 4620
idx: 4621
idx: 4622
idx: 4623
idx: 4624
idx: 4625
idx: 4626
idx: 4627
idx: 4628
idx: 4629
idx: 4630
idx: 4631
idx: 4632
idx: 4633
idx: 4634
idx: 4635
idx: 4636
idx: 4637
idx: 4638
idx: 4639
idx: 4640
idx: 4641
idx: 4642
idx: 4643
idx: 4644
idx: 4645
idx: 4646
idx: 4647
idx: 4648
idx: 4649
idx: 4650
idx: 4651
idx: 4652
idx: 4653
idx: 4654
idx: 4655
idx: 4656
idx: 4657
idx: 4658
idx: 4659
idx: 4660
idx: 4661
idx: 4662
idx: 4663
idx: 4664
idx: 4665
idx: 4666
idx: 4667
idx: 4668
idx: 4669
idx: 4670
idx: 4671
idx: 4672
idx: 4673
idx: 4674
idx: 4675
idx: 4676
idx: 4677
idx: 4678
idx: 4679
idx: 4680
idx: 4681
idx: 4682
idx: 4683
idx: 4684
idx: 4685
idx: 4686
idx: 4687
idx: 4688
idx: 4689
idx: 4690
idx: 4691
idx: 4692
idx: 4693
idx: 4694
idx: 4695
idx: 469

In [61]:
# connectiong to arangodb
# Initialize the ArangoDB client.
client = ArangoClient("http://127.0.0.1:8529")

In [62]:
## Connect to the database
#amazon_db = oasis.connect_python_arango(login)
amazon_db = client.db('_system', username='root', password='amritsar')

In [63]:
test_idx_lb = 235938
test_idx_mb = test_idx_lb + len(layer_3_embs)
test_idx_ub = 2449028

In [50]:
# load dataset
#! ./arangorestore -c none --create-collection true --server.endpoint "tcp://127.0.0.1:8529" --server.username "root" --server.database "_system" --server.password "amritsar" --default-replication-factor 3  --input-directory "./ogbn-product_dataset"

In [67]:
batch = []
BATCH_SIZE = 250
batch_idx = 1
index = 0
# collection in which we will store are inference results
product_collection = amazon_db["Products"]

In the below cell we will update the amazon product recommendation graph (obgn-products) dataset stored inside the ArangoDB with the node embeddings and their corresponding product category predictions for the 5000 test nodes.

In [68]:
for idx in range(test_idx_lb, test_idx_mb):
    update_doc = {}
    product_id = "Products/" + str(idx)
    update_doc["_id"] = product_id
    if layer_3_embs[index] == "Disconnected Node":
        update_doc["predicted_node_embeddings"] = layer_3_embs[index]
        update_doc["predicted_product"] = str(-1)
    else:
        update_doc["predicted_node_embeddings"] = layer_3_embs[index].tolist()
        update_doc["predicted_product"] = str(label_mapping[np.argmax(layer_3_embs[index], axis=-1)])
    batch.append(update_doc)
    last_record = (idx == (test_idx_mb - 1))
    index +=1
    
    if index % BATCH_SIZE == 0:
        print("Inserting batch %d" % (batch_idx))
        batch_idx += 1
        product_collection.update_many(batch)
        batch = []   
    if last_record and len(batch) > 0:
        print("Inserting batch the last batch!")
        product_collection.update_many(batch)


Inserting batch 1
Inserting batch 2
Inserting batch 3
Inserting batch 4
Inserting batch 5
Inserting batch 6
Inserting batch 7
Inserting batch 8
Inserting batch 9
Inserting batch 10
Inserting batch 11
Inserting batch 12
Inserting batch 13
Inserting batch 14
Inserting batch 15
Inserting batch 16
Inserting batch 17
Inserting batch 18
Inserting batch 19
Inserting batch 20


## Amazon Product Recommendation with AQL 
Products which can be bought together with a query product

In [78]:
# product ids for demo 235940, 240930
cursor = amazon_db.aql.execute(
"""
  FOR p in Products
    FILTER p._id == "Products/236435"
    RETURN { "predicted_node_embeddings": p.predicted_node_embeddings, "product_cat": p.product_cat }
""")

# Iterate through the result cursor
for doc in cursor:
  print(doc)

{'predicted_node_embeddings': [8.11427116394043, 1.2199492454528809, -3.4377615451812744, -2.5239882469177246, -3.2495641708374023, 5.2369794845581055, -1.8155913352966309, -8.527416229248047, -4.056517601013184, -1.9635710716247559, 0.19732767343521118, -2.9490034580230713, -1.4104312658309937, -9.393174171447754, -11.7564058303833, -6.5331902503967285, -2.5783517360687256, -4.2324137687683105, 3.3259811401367188, 0.030963152647018433, -4.7573771476745605, 3.299964427947998, -12.81488037109375, -0.5681211948394775, -5.670832633972168, 1.744966745376587, 1.7140165567398071, -10.516529083251953, -7.2565412521362305, -21.127893447875977, -6.552777290344238, -10.345855712890625, -8.504383087158203, -11.93044662475586, -14.149519920349121, -14.586654663085938, -6.9217448234558105, -7.073517799377441, -14.168991088867188, -11.582767486572266, -11.457265853881836, -16.460329055786133, -12.390432357788086, -12.25572395324707, -12.473586082458496, -12.95902156829834, -12.110896110534668], 'pro

In [79]:
cursor = amazon_db.aql.execute(
"""
LET descr_emb = (
  FOR p in Products
    FILTER p._id == "Products/236435"
    FOR j in RANGE(0, 46)
      RETURN TO_NUMBER(NTH(p.predicted_node_embeddings,j))
)

LET descr_mag = (
  SQRT(SUM(
    FOR i IN RANGE(0, 47)
      RETURN POW(TO_NUMBER(NTH(descr_emb, i)), 2)
  ))
)

LET dau = (

    FOR v in Products
    FILTER HAS(v, "predicted_node_embeddings")

    LET v_mag = (SQRT(SUM(
      FOR k IN RANGE(0, 47)
        RETURN POW(TO_NUMBER(NTH(v.predicted_node_embeddings, k)), 2)
    )))

    LET numerator = (SUM(
      FOR i in RANGE(0,46)
          RETURN TO_NUMBER(NTH(descr_emb, i)) * TO_NUMBER(NTH(v.predicted_node_embeddings, i))
    ))

    LET cos_sim = (numerator)/(descr_mag * v_mag)

    RETURN {"product": v._id, "product_cat": v.product_cat, "cos_sim": cos_sim}

    )

FOR du in dau
    SORT du.cos_sim DESC
    LIMIT 5000
    RETURN {"product_cat": du.product_cat, "cos_sim": du.cos_sim} 
""")


In [80]:
# Iterate through the result cursor
for doc in cursor:
    print(doc)

{'product_cat': 'Home & Kitchen', 'cos_sim': 1}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9915181195954686}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9903672509511374}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9898367275336624}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9896782150608073}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9882337658793278}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9880318458111959}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9872162913201172}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9859483691799504}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9858877208046664}
{'product_cat': 'Patio, Lawn & Garden', 'cos_sim': 0.9853313674545409}
{'product_cat': 'Patio, Lawn & Garden', 'cos_sim': 0.9851588622389956}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.9850942093242058}
{'product_cat': 'Home & Kitchen', 'cos_sim': 0.984943823143214}
{'product_cat': 'Patio, Lawn & Garden', 'cos_sim': 0.9838077985619076}
{'product_cat': 'Home & K

Here we're using the cosine similarity to retrieve the product which can be bought together with a query product. The cosine similarity is calculated between the node embeddings of a query product and all the other 4999 products:
$$
 \frac{
  \sum\limits_{i=1}^{n}{a_i b_i}
  }{
      \sqrt{\sum\limits_{j=1}^{n}{a_j^2}}
      \sqrt{\sum\limits_{k=1}^{n}{b_k^2}}
  }
$$



Once we calculate the cosine similarities, we can then SORT the products and return the highly likely bought together product with a query product!

## Note: 
1)For the demo purpose I am using 5000 test nodes but definitely if we go for more nodes results can be improved.

2) Also there is a lot of room for the improvement of the accuracy of model which can be achieved using hyperparameter tuning for e.g setting different number of search depths or playing with size of hidden layers like we have used 256 in our experiment. Another interesting thing to experiment would be using different neighborhood sampling techniques like random walk. 

Therefore, I can left this as a HomeWork for you !!