In [None]:
"""
assumptions:
1. bipartite graph with 2 types of nodes
2. there are no interactions from user -> user or item -> item, so it is strictly bipartite
3. if there exists an edge (example, 'rates') from user to item, then there exists 'rev-rates' from item to user
4. the above is required because: the best way to find n-hop neighbors is using in_edges() and out_edges()
   both of these functions require the edge type to be specified.  when we move from 1-hop neighbor to 2-hop
   neighbor, the edge type will change (for example, 1st hop: user -> item, and 2nd hop: item -> user)
5. print statements can be removed at a later point. i've left them in for the understanding of the rest of the team
"""
def clustering_coefficient(graph):
    for ntype in graph.ntypes:
        clus_coeff_list = []
        print("ntype: ", ntype)
        node_ids_list = graph.nodes[ntype][0]['node_ID'].tolist()
        outgoing_etypes = [etype_tup[1] for etype_tup in graph.canonical_etypes if etype_tup[0]==ntype]
        incoming_etypes = [etype_tup[1] for etype_tup in graph.canonical_etypes if etype_tup[2]==ntype]

        print("node IDs: ", node_ids_list)
        print("outgoing edge types: ", outgoing_etypes)
        
        for node_id in node_ids_list:
            # find the first degree neighborhood for this node
            clus_coeff_vals = []
            first_degree_neighbors = set()
            for out_edge_type in outgoing_etypes:
                output = graph.out_edges(node_id, etype=out_edge_type, form='uv')
                print("output", output, len(output[0]))
                # isolated node
                if len(output[0]) == 0:
                    print("breaking out")
                    break
                for i in output[1]:
                    first_degree_neighbors.add(int(i))
            print("first degree neighbors for node ", node_id, "are ", set(first_degree_neighbors))

            # find the second degree neighborhood
            sec_degree_neighbors = set()
            # below wont run for isolated nodes because first_degree_neighbors is empty
            for neigh_node_id in first_degree_neighbors:
                for out_edge_type in incoming_etypes:
                    output = graph.out_edges(neigh_node_id, etype=out_edge_type, form='uv')
                    print(output)
                    if len(output[0]) <2:
                        print("breaking out")
                        break
                    for i in output[1]:
                        sec_degree_neighbors.add(int(i))
                print("second degree neighbors for node ", neigh_node_id, "are ", set(sec_degree_neighbors))
            
            # find the third degree neighborhood
            third_degree_neighbors = set()
            # below wont run for isolated nodes because sec_degree_neighbors is empty
            for neigh_node_id in sec_degree_neighbors:
                for out_edge_type in outgoing_etypes:
                    output = graph.out_edges(neigh_node_id, etype=out_edge_type, form='uv')
                    print(output)
                    if len(output[0]) <2:
                        print("breaking out")
                        break
                    for i in output[1]:
                        third_degree_neighbors.add(int(i))
                print("third degree neighbors for node ", neigh_node_id, "are ", set(sec_degree_neighbors))
            
            # clustering coeff defined as specified in networkx for a bipartite graph
            # https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.bipartite.cluster.clustering.html
            union = sec_degree_neighbors.union(third_degree_neighbors)
            intersection = sec_degree_neighbors.intersection(third_degree_neighbors)
            
            if len(union)>0:
                c_ = len(intersection)/len(union)
            else:
                c_ = 0
            if len(sec_degree_neighbors)>0:
                clus_coeff = c_/len(sec_degree_neighbors)
            else:
                clus_coeff = 0
                
            print("clusering coeff for node id: ", node_id," of node type: ", ntype, " is: ", clus_coeff)
            clus_coeff_list.append(clus_coeff)
            
        clus_coeff_tensor = [torch.FloatTensor([c]) for c in clus_coeff_list]
        graph.nodes[ntype].data['clustering_coeff'] = torch.stack(clus_coeff_tensor, axis=0)

    return graph