In [2]:
import arxiv
import time
from tqdm import tqdm

# Set up the search query for AI-related papers
search_query = 'cat:cs.AI OR cat:cs.LG OR cat:cs.CL'  # AI, Machine Learning, and Computational Linguistics categories

# Create a client with appropriate parameters
client = arxiv.Client(
    page_size = 100,  # Number of results per request
    delay_seconds = 3,  # Delay between requests to be respectful to arXiv's servers
    num_retries = 3
)

# Set up the search parameters
search = arxiv.Search(
    query = search_query,
    max_results = 1000,
    sort_by = arxiv.SortCriterion.SubmittedDate
)

# Create a list to store the papers
papers = []

# Fetch the papers with a progress bar
print("Fetching papers from arXiv...")
try:
    for paper in tqdm(client.results(search), total=1000):
        papers.append({
            'title': paper.title,
            'authors': [author.name for author in paper.authors],
            'summary': paper.summary,
            'published': paper.published,
            'pdf_url': paper.pdf_url,
            'arxiv_id': paper.entry_id.split('/')[-1]
        })
        
        if len(papers) >= 1000:
            break
            
except Exception as e:
    print(f"An error occurred: {e}")

print(f"\nSuccessfully retrieved {len(papers)} papers")

Fetching papers from arXiv...


100%|█████████▉| 999/1000 [00:36<00:00, 27.64it/s]


Successfully retrieved 1000 papers





In [4]:
# Convert datetime to string before saving
for paper in papers:
    paper['published'] = paper['published'].isoformat()

import json
with open('arxiv_papers.json', 'w', encoding='utf-8') as f:
    json.dump(papers, f, ensure_ascii=False, indent=2)

In [1]:
import time
from tqdm import tqdm
import json
import networkx as nx
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity


class PaperGraph:
    def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
        self.G = nx.DiGraph()

        # Determine device: GPU or CPU
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")

        # Load models
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)

        # Possible relationship candidates
        self.relationship_candidates = [
            "extends",
            "challenges",
            "refines",
            "not connected"
        ]

    def _analyze_relationship(self, paper1_summary, paper2_summary):
        """Use LLM to analyze the relationship between two papers"""
        prompt = f"""Analyze the relationship between these two research papers:

Paper A: {paper1_summary}

Paper B: {paper2_summary}

Determine if Paper A extends and builds upon, challenges view, refines or is not connected to Paper B. 
Respond with exactly one word: 'extends', 'challenges', 'refines', or 'unrelated'.
"""

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=1,
                temperature=0.1,
                do_sample=False
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Map response to relationship type
        if 'extends' in response.lower():
            return 'extends'
        elif 'challenges' in response.lower():
            return 'challenges'
        elif 'refines' in response.lower():
            return 'refines'
        else:
            return 'not connected'

    def add_papers(self, papers):
        """Add papers as nodes and compute their embeddings"""
        print("Computing paper embeddings...")

        for paper in tqdm(papers):
            text = f"{paper['title']} {paper['summary']}"
            embedding = self.encoder.encode(text)
            
            self.G.add_node(paper['arxiv_id'],
                           title=paper['title'],
                           authors=paper['authors'],
                           summary=paper['summary'],
                           embedding=embedding)

    def find_relationships(self, similarity_threshold=0.6):
        """Find relationships between papers using embeddings and LLM"""
        relationships = []
        paper_ids = list(self.G.nodes())
        
        print("Analyzing paper relationships...")
        for i, paper1_id in enumerate(tqdm(paper_ids[:-1])):
            embedding1 = self.G.nodes[paper1_id]['embedding']
            
            for paper2_id in paper_ids[i+1:]:
                embedding2 = self.G.nodes[paper2_id]['embedding']
                
                # First filter by similarity
                similarity = cosine_similarity([embedding1], [embedding2])[0][0]
                
                if similarity > similarity_threshold:
                    # Use LLM to determine relationship
                    relationship = self._analyze_relationship(
                        self.G.nodes[paper1_id]['summary'],
                        self.G.nodes[paper2_id]['summary']
                    )
                    if relationship != 'not connected':
                        relationships.append((paper1_id, paper2_id, relationship))
        
        return relationships

    def add_relationships(self, relationships):
        """Add edges to the graph based on found relationships"""
        for source, target, rel_type in relationships:
            self.G.add_edge(source, target, relationship=rel_type)

    def visualize(self):
        """Visualize the paper relationship graph"""
        plt.figure(figsize=(15, 15))
        
        # Create layout
        pos = nx.spring_layout(self.G)
        
        # Draw nodes
        nx.draw_networkx_nodes(self.G, pos, node_size=100)
        
        # Create color map for relationships
        rel_to_color = {
            'extends': 'green',
            'challenges': 'red',
            'refines': 'black'
        }
        
        # Draw edges with different colors
        for rel_type, color in rel_to_color.items():
            edges = [(u, v) for (u, v, d) in self.G.edges(data=True) 
                    if d['relationship'] == rel_type]
            nx.draw_networkx_edges(self.G, pos, edgelist=edges, 
                                 edge_color=color, 
                                 arrows=True, arrowsize=10)
        
        # Add legend
        legend_elements = [plt.Line2D([0], [0], color=c, label=l) 
                         for l, c in rel_to_color.items()]
        plt.legend(handles=legend_elements)
        
        plt.title("Paper Relationship Graph")
        plt.axis('off')
        plt.show()

        # Print relationship statistics
        print("\nRelationship Statistics:")
        for rel_type in rel_to_color.keys():
            count = sum(1 for _, _, d in self.G.edges(data=True) 
                       if d['relationship'] == rel_type)
            print(f"{rel_type}: {count} relationships")

    def save_graph(self, filename='paper_graph.pkl'):
        """Save the graph with all its attributes"""
        # Convert numpy arrays to lists for serialization
        for node in self.G.nodes():
            self.G.nodes[node]['embedding'] = self.G.nodes[node]['embedding'].tolist()
        
        # Save the graph
        import pickle
        with open(filename, 'wb') as f:
            pickle.dump(self.G, f)
        print(f"Graph saved to {filename}")

    @classmethod
    def load_graph(cls, filename='paper_graph.pkl'):
        """Load a saved graph"""
        import pickle
        graph = cls()
        
        with open(filename, 'rb') as f:
            graph.G = pickle.load(f)
        
        # Convert lists back to numpy arrays
        for node in graph.G.nodes():
            graph.G.nodes[node]['embedding'] = np.array(graph.G.nodes[node]['embedding'])
        
        print(f"Graph loaded from {filename}")
        return graph

    def export_for_gnn(self, filename='graph_data.npz'):
        """Export graph data in a format suitable for GNN training"""
        # Get node features (embeddings)
        node_ids = list(self.G.nodes())
        node_features = np.array([self.G.nodes[node]['embedding'] for node in node_ids])
        
        # Create edge index and edge attributes
        edges = list(self.G.edges(data=True))
        edge_index = np.array([[node_ids.index(src), node_ids.index(dst)] 
                              for src, dst, _ in edges]).T
        
        # Convert edge types to integers
        relationship_to_int = {
            'supports': 1,
            'contradicts': 2,
            'refines': 3,
            'not connected': 0
        }
        edge_attr = np.array([relationship_to_int[d['relationship']] 
                             for _, _, d in edges])
        
        # Save the data
        np.savez(filename,
                 node_features=node_features,
                 edge_index=edge_index,
                 edge_attr=edge_attr,
                 node_ids=node_ids)
        print(f"Graph data exported to {filename}")


  from .autonotebook import tqdm as notebook_tqdm


Loading NLP models...


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Computing paper embeddings...


100%|██████████| 1000/1000 [00:46<00:00, 21.41it/s]


Analyzing paper relationships...


 15%|█▌        | 151/999 [41:16<3:51:46, 16.40s/it] 


KeyboardInterrupt: 