In [None]:
"""
    PageRank Helper Function
"""
def pagerank_reduce_func(nodes):
    msgs = torch.sum(nodes.mailbox['pagerank_pv'], dim=1)
    pv = (1 - DAMP) / N + DAMP * msgs
    return {'pagerank_pv' : pv}

In [None]:
"""
    PageRank Helper Function
"""
def pagerank_message_func(edges):
    return {'pagerank_pv' : edges.src['pagerank_pv'] / edges.src['pagerank_deg']}

In [None]:
"""
PageRank

Implements Pagerank features in bypartite GNN

Parameters
----------
g : DGL Heterograph 
    The Graph should contain two node types only.
user_label : string, optional
    Name of the user node
product_label: string, optional
    Name of the product node
edge_label: string, optional
    Name of the user to product edge type
rev_edge_label: string, optional
    Name of the product to user edge type
DAMP: float, optional
    Damp or decay factor. This corresponds to the probability of connections sinking at any giving point (nodes with no outgoing edges). 
    It prevents the sinked nodes from "absorbing" the PageRanks of those pages connected to the sinks. 
reverse: bool, optional
    Whether or not the PageRank algorithm should run on the reverse orientation (products to users)

Returns
    
-------
DGL Heterograph 
    The Graph with pagerank features included in its nodes ("pagerank_pv").
"""
def pagerank(g, user_label = 'user', product_label = 'product', edge_label = 'purchase', rev_edge_label = 'review', DAMP = 0.85, reverse = False):
   
    N = g.number_of_nodes()
    N_user = g.num_src_nodes(user_label)
    N_product = g.num_src_nodes(product_label)
    
    g.nodes[user_label].data['pagerank_pv'] = torch.ones(N_user) / N
    g.nodes[product_label].data['pagerank_pv'] = torch.ones(N_movie) / N
    g.nodes[user_label].data['pagerank_deg'] = g.out_degrees(g.nodes(user_label), etype=edge_label).float()
    g.nodes[product_label].data['pagerank_deg'] = g.out_degrees(g.nodes(product_label), etype=rev_edge_label).float()

    g.multi_update_all({edge_label: (pagerank_message_func, pagerank_reduce_func)},"sum")
    
    if(reverse):
        g.multi_update_all({rev_edge_label: (pagerank_message_func, pagerank_reduce_func)},"sum")
 
    return g
    