In [1]:
import numpy as np
import pandas as pd

import torch

from tqdm.notebook import tqdm
from collections import defaultdict

In [2]:
ref_df = pd.read_parquet('/kaggle/input/dataset-with-embeddings/reference_edges.parquet')

In [3]:
ref_dict = {}
for i, row in tqdm(ref_df.iterrows(), total = ref_df.shape[0]):
#     print(row)
    ref_dict[row['id']] = row['refers']
#     break
    
# ref_dict

  0%|          | 0/57017 [00:00<?, ?it/s]

In [9]:
connected_components = defaultdict(set)


def dfs(node):
    """
    The key is understanding the recursion
    The recursive assumption is:
        After calling `dfs(node)`, the `connected_components` dict contains all the connected as keys,
        and the values are *the same set* that contains all the connected nodes.
    """
    global connected_components, ref_dict
    if node not in connected_components:
        # this is important, so neighbors won't try to traverse current node
        connected_components[node] = set()
        if node in ref_dict:
            for next_ in ref_dict[node]:
                dfs(next_)
                # according the recursive assumption, connected_component of `next_` is also the one of `node`
                connected_components[node] = connected_components[next_]

        # all that's left is add the current node
        connected_components[node].add(node)

for node_ in tqdm(ref_dict):
    dfs(node_)


# get all connected components and convert to tuples, so they are hashable
connected_comp_as_tuples = map(tuple, connected_components.values())

# use ``set`` to make the list of connected components distinct (without repetition)
unique_components = list(set(connected_comp_as_tuples))

  0%|          | 0/57017 [00:00<?, ?it/s]

In [17]:
ref_df[ref_df['id'] == '1711.06420']

Unnamed: 0,id,refers


In [13]:
all_papers = []
for i in range(len(unique_components)):
    all_papers.extend(unique_components[i])

In [16]:
len(all_papers)

74786

In [10]:
paper_id = '1711.0642'

for i in range(len(unique_components)):
    if paper_id in unique_components[i]:
        print(unique_components[i])
        break

In [33]:
cit_df = pd.read_parquet('/kaggle/input/dataset-with-embeddings/citation_edges.parquet')

In [35]:
cit_dict = {}
for i, row in tqdm(cit_df.iterrows(), total = cit_df.shape[0]):
#     print(row)
    cit_dict[row['id']] = row['cited_by']
#     break
    
# ref_dict

  0%|          | 0/42049 [00:00<?, ?it/s]

In [36]:
connected_components = defaultdict(set)


def dfs(node):
    """
    The key is understanding the recursion
    The recursive assumption is:
        After calling `dfs(node)`, the `connected_components` dict contains all the connected as keys,
        and the values are *the same set* that contains all the connected nodes.
    """
    global connected_components, cit_dict
    if node not in connected_components:
        # this is important, so neighbors won't try to traverse current node
        connected_components[node] = set()
        if node in cit_dict:
            for next_ in cit_dict[node]:
                dfs(next_)
                # according the recursive assumption, connected_component of `next_` is also the one of `node`
                connected_components[node] = connected_components[next_]

        # all that's left is add the current node
        connected_components[node].add(node)

for node_ in tqdm(cit_dict):
    dfs(node_)


# get all connected components and convert to tuples, so they are hashable
connected_comp_as_tuples = map(tuple, connected_components.values())

# use ``set`` to make the list of connected components distinct (without repetition)
unique_components = set(connected_comp_as_tuples)

  0%|          | 0/42049 [00:00<?, ?it/s]

In [37]:
len(unique_components)

33115

In [None]:
class Node:
    def __init__(self, node_id, authors=None, title=None, categories=None, abstract=None, update_date=None, abstract_embedding=None):
        self.id = node_id
        self.authors = self.parse_authors(authors)
        self.title = self.remove_newlines(title)
        self.categories = categories if categories else []
        self.abstract = self.remove_newlines(abstract)
        self.update_date = update_date
        self.abstract_embedding = abstract_embedding
        
        self.degree = 0
        self.num_citations = 0
        self.num_references = 0

    # function to store authors in list format
    def parse_authors(self, authors_string):
        if authors_string:
            authors_list = []
            for author in authors_string.split(" and "):
                authors_list.extend(author.split(", "))
            return authors_list
        else:
            return []

    # def remove_newlines(self, text):
    #     if text:
    #         return text.replace("\n", "")
    #     else:
    #         return None
    
    # function to replace newline characters ("\n") with an empty string
    def remove_newlines(self, text):
        return text.replace("\n", " ")
    
    def update_degrees(self, degree, citations, references):
        self.degree = degree
        self.num_citations = citations
        self.num_references = references
    
class Graph:
    def __init__(self):
        self.nodes = {}
        self.undirected_edges = {}
        self.citation_edges = {}
        self.reference_edges = {}
        
        self.node_degrees = {}
        self.num_citations = {}
        self.num_references = {}
        
        self.v_to_i = {}

    # function to add nodes in the Graph data structure
    def add_node(self, node_id, authors=None, title=None, categories=None, abstract=None, update_date=None):
        if node_id not in self.nodes:
            self.nodes[node_id] = Node(node_id, authors, title, categories, abstract, update_date)

    # function to add edges in the Graph data structure
    def add_edge(self, node1, node2):
        if node1 in self.nodes and node2 in self.nodes:
            # conditions to add nodes to the edge lists of themselves and the neighbour node
            if node1 not in self.undirected_edges:
                self.undirected_edges[node1] = []
                self.node_degrees[node1] = 0
            if node2 not in self.undirected_edges:
                self.undirected_edges[node2] = []
                self.node_degrees[node2] = 0
            if node2 not in self.undirected_edges[node1]:
                self.undirected_edges[node1].append(node2)
                self.node_degrees[node1] += 1
            if node1 not in self.undirected_edges[node2]:
                self.undirected_edges[node2].append(node1)
                self.node_degrees[node2] += 1
                
            # when node1 references node2 i.e. referring
            if node1 not in self.reference_edges:
                self.reference_edges[node1] = []
                self.num_references[node1] = 0
            if node2 not in self.reference_edges[node1]:
                self.reference_edges[node1].append(node2)
                self.num_references[node1] += 1
            
                
            # when node2 is referred by node1 i.e. citation
            if node2 not in self.citation_edges:
                self.citation_edges[node2] = []
                self.num_citations[node2] = 0
            if node1 not in self.citation_edges[node2]:
                self.citation_edges[node2].append(node1)
                self.num_citations[node2] += 1
                
    def update_degrees(self):
        for node_id, node_obj in tqdm(self.nodes.items(), total=len(list(self.nodes.keys()))):
            try:
                node_obj.update_degrees(degree=self.node_degrees[node_id], citations=self.num_citations.get(node_id, 0), references=self.num_references.get(node_id, 0))
            except KeyError:
                continue

    # function to return list of nodes present in graph
    def get_nodes(self):
        return list(self.nodes.keys())
    
    def get_adjacency_matrix(self):
        self.v_to_i = {}
        for i, node_id in enumerate(list(self.nodes.keys())):
            self.v_to_i[node_id] = i
        
        matrix = np.zeros((len(self.nodes.keys()), len(self.nodes.keys())), dtype=np.float16)
#         edge_list = self.get_edge_list()
        
        for node_id in list(self.reference_edges.keys()):
            adjacent_nodes = self.reference_edges[node_id]
            
            for adj_node in adjacent_nodes:
                matrix[self.v_to_i[node_id], self.v_to_i[adj_node]] = 1
#                 matrix[self.v_to_i[adj_node], self.v_to_i[node_id]] = 1

        gc.collect()
        
        return matrix
    
    def get_adjacency_list(self):
        return self.undirected_edges, self.citation_edges, self.reference_edges
    
    def remove_isolated_nodes(self):
        node_ids = list(self.nodes.keys())
        for id in tqdm(node_ids):
            if self.nodes[id].degree == 0:
                del self.nodes[id]

    # function to return list of edges present in graph
    def get_edge_list(self):
        edge_list = []
        for node, connected_nodes in tqdm(self.undirected_edges.items(), total=len(list(self.undirected_edges.keys()))):
            for connected_node in connected_nodes:
                edge_list.append((node, connected_node))
        return edge_list


In [None]:
graph = Graph()

for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    node_id = row['id']
    authors = row['authors']
    title = row['title']
    categories = row['categories']
    abstract = row['abstract']
    update_date = row['update_date']

    graph.add_node(node_id, authors=authors, title=title, categories=categories, abstract=abstract, update_date=update_date)
    
gc.collect()