In [None]:
import urllib.request
from LinkPredictor import EnhancedLinkPredictor
import networkx as nx
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')
import os
import urllib.request
import random

In [None]:
# Function to load classical citation networks
def load_citation_dataset(dataset_name):
    """
    Load a classical citation network dataset

    Args:
        dataset_name: Name of the dataset ('cora', 'citeseer', 'arxiv', or 'web-stanford')

    Returns:
        NetworkX DiGraph object
    """

    G = nx.DiGraph()
    display_print = True

    try:
        if dataset_name.lower() == 'cora':
            print("Loading Cora dataset...")
            data_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz'
            local_file = 'cora.tgz'

            # Download dataset if not exists
            if not os.path.exists('cora'):
                print(f"Downloading Cora dataset...")
                urllib.request.urlretrieve(data_url, local_file)

                # Extract the file
                import tarfile
                with tarfile.open(local_file, 'r:gz') as tar:
                    tar.extractall()

                # Clean up
                os.remove(local_file)

            # Read the cora.cites file
            with open('cora/cora.cites', 'r') as f:
                for line in f:
                    citing, cited = map(int, line.strip().split('\t'))
                    G.add_edge(citing, cited)

            # Add paper metadata
            with open('cora/cora.content', 'r') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    paper_id = int(parts[0])
                    features = parts[1:-1]
                    label = parts[-1]

                    if paper_id not in G:
                        G.add_node(paper_id)

                    G.nodes[paper_id]['label'] = label
                    # Add a simple year attribute for recency calculation (using a random year between 2000-2015)
                    G.nodes[paper_id]['year'] = random.randint(2000, 2015)


        elif dataset_name.lower() == 'arxiv':
            print("Loading arXiv HEP-TH citation network...")
            dataset_url = 'https://snap.stanford.edu/data/cit-HepTh.txt.gz'
            local_file = 'cit-HepTh.txt.gz'
            extracted_file = 'cit-HepTh.txt'

            # Download dataset if not exists
            if not os.path.exists(extracted_file):
                print(f"Downloading arXiv HEP-TH citation network...")
                urllib.request.urlretrieve(dataset_url, local_file)

                # Extract the file
                import gzip
                import shutil
                with gzip.open(local_file, 'rb') as f_in:
                    with open(extracted_file, 'wb') as f_out:
                        shutil.copyfileobj(f_in, f_out)

                # Clean up
                os.remove(local_file)

            # Read the citation network
            with open(extracted_file, 'r') as f:
                for line in f:
                    if line.startswith('#'):
                        continue
                    parts = line.strip().split('\t')
                    if len(parts) == 2:
                        citing, cited = parts
                        G.add_edge(citing, cited)

            # Add random years for papers (1992-2003 as per the dataset description)
            for node in G.nodes():
                G.nodes[node]['year'] = random.randint(1992, 2003)

        elif dataset_name.lower() == 'pubmed':
            print("Loading PubMed dataset...")
            data_url = 'https://linqs-data.soe.ucsc.edu/public/Pubmed-Diabetes.tgz'
            local_file = 'Pubmed-Diabetes.tgz'

            # Download dataset if not exists
            if not os.path.exists('Pubmed-Diabetes'):
                print(f"Downloading PubMed dataset...")
                urllib.request.urlretrieve(data_url, local_file)

                # Extract the file
                import tarfile
                with tarfile.open(local_file, 'r:gz') as tar:
                    tar.extractall()

                # Clean up
                os.remove(local_file)

            # Reading the 'data/Pubmed-Diabetes.DIRECTED.cites.tab' file
            cite_file = 'Pubmed-Diabetes/data/Pubmed-Diabetes.DIRECTED.cites.tab'
            with open(cite_file, 'r') as f:
                # Skip header line
                next(f)
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        # Format is sometimes "paper:<number>\tpaper:<number>"
                        source = parts[0].replace('paper:', '')
                        target = parts[1].replace('paper:', '')
                        G.add_edge(source, target)

            # Try to add node attributes if available
            try:
                node_file = 'Pubmed-Diabetes/data/Pubmed-Diabetes.NODE.paper.tab'
                with open(node_file, 'r') as f:
                    # Skip header line
                    next(f)
                    for line in f:
                        parts = line.strip().split('\t')
                        if len(parts) >= 3:
                            paper_id = parts[0].replace('paper:', '')
                            if paper_id in G or paper_id in G.nodes():
                                # Extract year if available, otherwise use a random year
                                year = random.randint(1990, 2010)
                                label = parts[1] if len(parts) > 1 else "unknown"
                                G.nodes[paper_id]['year'] = year
                                G.nodes[paper_id]['label'] = label
            except Exception as e:
                print(f"Could not load node attributes for PubMed: {e}")
                # Add random years
                for node in G.nodes():
                    G.nodes[node]['year'] = random.randint(1990, 2010)
                    
        else:
            print(f"Unknown dataset '{dataset_name}', please choose from: cora, citeseer, arxiv, web-stanford, or pubmed")
            display_print = False
            return None

        if display_print:
            print(f"Loaded {dataset_name} dataset")

        return G

    except Exception as e:
        print(f"Error loading dataset {dataset_name}: {e}")
        print("Please check your internet connection or try again later.")
        return None

In [None]:
#Configure plt for better visualization
plt.style.use('seaborn-v0_8-darkgrid')
# Suppress warnings
warnings.filterwarnings('ignore')
# List of datasets to analyze
datasets = ['cora', 'pubmed', 'arxiv']  # You can add other datasets as needed

# Store the results for each dataset
all_results = {}

for dataset_name in datasets:
    print(f"\n{'='*50}")
    print(f"Analyzing {dataset_name.upper()} dataset")
    print(f"{'='*50}")
    
    # Load the dataset
    citation_graph = load_citation_dataset(dataset_name)
    
    # Initialize the link predictor
    predictor = EnhancedLinkPredictor(citation_graph)
    
    # Evaluate the link prediction with title and basic stats before table
    results = predictor.evaluate(num_test_nodes=400, test_ratio=0.3)
    
    # Store results
    all_results[dataset_name] = results