In [1]:
# Install the required packages correctly
import sys
!{sys.executable} -m pip install python-arango networkx pandas gdown
!{sys.executable} -m pip install langchain langchain-openai langchain-community openai

Collecting python-arango
  Downloading python_arango-8.1.5-py3-none-any.whl.metadata (8.2 kB)
Collecting requests_toolbelt (from python-arango)
  Downloading requests_toolbelt-1.0.0-py2.py3-none-any.whl.metadata (14 kB)
Collecting importlib_metadata>=4.7.1 (from python-arango)
  Downloading importlib_metadata-8.6.1-py3-none-any.whl.metadata (4.7 kB)
Downloading python_arango-8.1.5-py3-none-any.whl (114 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.6/114.6 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading importlib_metadata-8.6.1-py3-none-any.whl (26 kB)
Downloading requests_toolbelt-1.0.0-py2.py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: importlib_metadata, requests_toolbelt, python-arango
  Attempting uninstall: importlib_metadata
    Found existing installation: importlib-metadata 4.6.4
    Uninstalling importlib-m

In [2]:
import sys
import subprocess
import importlib

import networkx as nx
from arango import ArangoClient
# Update these langchain imports
from langchain_openai import OpenAI  # Or use this
# Alternatively, you might need to use:
# from langchain_community.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import pandas as pd
import json
import gdown

In [3]:
def install_and_import(package):
    try:
        importlib.import_module(package)
        print(f"✅ {package} is already installed")
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        print(f"✅ Successfully installed {package}")

# Install required packages
required_packages = ["python-arango", "networkx", "langchain", "gdown", "pandas", "openai"]
for package in required_packages:
    install_and_import(package)

# Note about cugraph
print("Note: cugraph requires CUDA. If you don't have a GPU, we'll use NetworkX instead.")
try:
    install_and_import("cugraph")
    use_cugraph = True
except:
    print("⚠️ Could not install cugraph. Using NetworkX for all graph operations.")
    use_cugraph = False

# Import required libraries after installation
import networkx as nx
from arango import ArangoClient
from langchain.llms import OpenAI
from langchain import PromptTemplate, LLMChain
import pandas as pd
import json
import gdown

print("All dependencies imported successfully!")


Installing python-arango...
✅ Successfully installed python-arango
✅ networkx is already installed
✅ langchain is already installed
✅ gdown is already installed
✅ pandas is already installed
✅ openai is already installed
Note: cugraph requires CUDA. If you don't have a GPU, we'll use NetworkX instead.
Installing cugraph...
⚠️ Could not install cugraph. Using NetworkX for all graph operations.
All dependencies imported successfully!


In [4]:
# Step 1: Setup & Dependencies
import sys
import subprocess
import importlib

# Function to install and import a package
def install_and_import(package):
    try:
        importlib.import_module(package)
        print(f"✅ {package} is already installed")
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        print(f"✅ Successfully installed {package}")

# Install required packages
required_packages = ["python-arango", "networkx", "pandas", "gdown", "requests", "tqdm"]
# Update langchain packages to use the newer structure
langchain_packages = ["langchain", "langchain-openai", "langchain-community"]

for package in required_packages:
    install_and_import(package)
for package in langchain_packages:
    install_and_import(package)

# Import required libraries after installation
import networkx as nx
from arango import ArangoClient
# Updated imports for langchain
from langchain_openai import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import pandas as pd
import json
import gdown
import requests
import os
from tqdm import tqdm
import gzip
import time

print("All dependencies imported successfully!")

Installing python-arango...
✅ Successfully installed python-arango
✅ networkx is already installed
✅ pandas is already installed
✅ gdown is already installed
✅ requests is already installed
✅ tqdm is already installed
✅ langchain is already installed
Installing langchain-openai...
✅ Successfully installed langchain-openai
Installing langchain-community...
✅ Successfully installed langchain-community
All dependencies imported successfully!


In [5]:
import os
import requests
import gzip
from tqdm import tqdm
import pandas as pd
import json

print("\nSetting up Amazon SNAP dataset downloads...")

def download_file(url, filename):
    """Download a file with progress bar"""
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        block_size = 1024

        with open(filename, 'wb') as file, tqdm(
                desc=filename,
                total=total_size,
                unit='iB',
                unit_scale=True,
                unit_divisor=1024,
            ) as bar:
            for data in response.iter_content(block_size):
                size = file.write(data)
                bar.update(size)

        print(f"✅ Downloaded {filename}")
        return filename
    except requests.exceptions.RequestException as e:
        print(f"⚠️ Error downloading file: {e}")
        return None
    except Exception as e:
        print(f"⚠️ An unexpected error occurred during download: {e}")
        return None

def parse_amazon_metadata(gz_file):
    """Parse Amazon metadata from gzipped file using json.loads, handling multi-line JSON objects."""
    print(f"Parsing metadata from {gz_file}...")
    products = []
    current_product_lines = []  # Accumulate lines for the current product

    try:
        with gzip.open(gz_file, 'rt', encoding='utf-8') as f:
            for i, line in enumerate(tqdm(f, desc="Reading lines")):
                line = line.strip()
                if line == "":  # Empty line signals end of a product entry
                    if current_product_lines:  # If we have lines accumulated
                        try:
                            # Join accumulated lines and parse as JSON
                            product_json = json.loads("".join(current_product_lines))
                            products.append(product_json)
                        except json.JSONDecodeError as e:
                            print(f"⚠️ JSONDecodeError: {e} on lines ending at {i + 1}:")
                            #for problematic_line in current_product_lines: #show all the problematic lines.
                                #print(problematic_line)
                        current_product_lines = []  # Reset for the next product
                else:
                    current_product_lines.append(line)  # Accumulate non-empty lines

            # Handle any remaining lines after the loop (last product)
            if current_product_lines:
                try:
                    product_json = json.loads("".join(current_product_lines))
                    products.append(product_json)
                except json.JSONDecodeError as e:
                    print(f"⚠️ JSONDecodeError: {e} at end of file:")
                    #for problematic_line in current_product_lines:
                        #print(problematic_line)

        print(f"✅ Parsed {len(products)} products")
        df = pd.DataFrame(products)

        if 'ASIN' in df.columns:
            df['ASIN'] = df['ASIN'].astype(str)
        else:
            print("⚠️ Warning: 'ASIN' column not found in metadata.")
            df['ASIN'] = ''

        csv_file = gz_file.replace('.gz', '.csv')
        df.to_csv(csv_file, index=False)
        print(f"✅ Saved to {csv_file}")
        return df

    except Exception as e:
        print(f"⚠️ Error parsing metadata file: {e}")
        return None


def parse_amazon_copurchase(gz_file, max_edges=None):
    """Parse Amazon co-purchasing network, with optional edge limit."""
    print(f"Parsing co-purchase network from {gz_file}...")
    edges = []
    try:
        with gzip.open(gz_file, 'rt', encoding='latin1') as f:
            for i, line in enumerate(tqdm(f, desc="Reading edges")):
                if not line.startswith('#'):
                    source, target = line.strip().split()
                    edges.append((source, target))
                    if max_edges is not None and i >= max_edges:
                        break

        print(f"✅ Parsed {len(edges)} co-purchase edges")
        df = pd.DataFrame(edges, columns=['source', 'target'])
        csv_file = gz_file.replace('.gz', '.csv')
        df.to_csv(csv_file, index=False)
        print(f"✅ Saved to {csv_file}")
        return df

    except Exception as e:
        print(f"⚠️ Error parsing co-purchase network: {e}")
        return None


# SNAP Amazon Dataset URLs - now a dictionary with dataset names as keys
amazon_datasets = {
    "metadata": {
        "url": "http://snap.stanford.edu/data/amazon/productGraph/metadata.json.gz",
        "type": "metadata"
    },
    "amazon0302": {
        "url": "http://snap.stanford.edu/data/amazon0302.txt.gz",
        "type": "copurchase"
    },
    "amazon0312": {
        "url": "http://snap.stanford.edu/data/amazon0312.txt.gz",
         "type": "copurchase"
    },
    "amazon0505": {
        "url": "http://snap.stanford.edu/data/amazon0505.txt.gz",
        "type": "copurchase"
    },
    "amazon0601": {
        "url": "http://snap.stanford.edu/data/amazon0601.txt.gz",
        "type": "copurchase"
    },
    # "books": { # Removed books and electronics
    #     "url": "http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Books.csv",
    #     "type": "ratings_csv" #different parsing logic needed
    # },
    # "electronics": {
    #     "url": "http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Electronics.csv",
    #      "type": "ratings_csv"
    # }
}

# Create data directory
data_dir = "amazon_data"
os.makedirs(data_dir, exist_ok=True)

# Download and process datasets
print("\nDownloading and processing Amazon SNAP datasets...")
datasets = {} # This will hold all datasets (DataFrames)

for dataset_name, dataset_info in amazon_datasets.items():
    print(f"Processing dataset: {dataset_name}")
    gz_file = os.path.join(data_dir, os.path.basename(dataset_info["url"]))
    csv_file = gz_file.replace('.gz', '.csv')
    datasets[dataset_name] = {"data": None, "graph": None} # Initialize entry

    if os.path.exists(csv_file):
        print(f"✅ Using existing CSV file for {dataset_name}.")
        try:
            datasets[dataset_name]["data"] = pd.read_csv(csv_file)
        except pd.errors.EmptyDataError:
            print(f"⚠️ CSV file for {dataset_name} is empty. Reprocessing.")
            if not os.path.exists(gz_file):
                download_file(dataset_info["url"], gz_file)
            if dataset_info["type"] == "metadata":
                df = parse_amazon_metadata(gz_file)
            elif dataset_info["type"] == "copurchase":
                df = parse_amazon_copurchase(gz_file)  # No max_edges here
            else:
                print(f"⚠️ Unknown dataset type: {dataset_info['type']}")
                df = None

            if df is not None:
                datasets[dataset_name]["data"] = df

    else:
        print(f"CSV file for {dataset_name} not found. Downloading and parsing...")
        if not os.path.exists(gz_file):
            download_file(dataset_info["url"], gz_file)
        if dataset_info["type"] == "metadata":
            df = parse_amazon_metadata(gz_file)
        elif dataset_info["type"] == "copurchase":
            df = parse_amazon_copurchase(gz_file)  # No max_edges here
        else:
            print(f"⚠️ Unknown dataset type: {dataset_info['type']}")
            df = None

        if df is not None:
            datasets[dataset_name]["data"] = df
        else:
             datasets[dataset_name]["data"] = pd.DataFrame()


Setting up Amazon SNAP dataset downloads...

Downloading and processing Amazon SNAP datasets...
Processing dataset: metadata
CSV file for metadata not found. Downloading and parsing...


amazon_data/metadata.json.gz: 100%|██████████| 3.13G/3.13G [03:45<00:00, 14.9MiB/s]


✅ Downloaded amazon_data/metadata.json.gz
Parsing metadata from amazon_data/metadata.json.gz...


Reading lines: 9430088it [01:17, 121536.65it/s]


⚠️ JSONDecodeError: Expecting property name enclosed in double quotes: line 1 column 2 (char 1) at end of file:
✅ Parsed 0 products
✅ Saved to amazon_data/metadata.json.csv
Processing dataset: amazon0302
CSV file for amazon0302 not found. Downloading and parsing...


amazon_data/amazon0302.txt.gz: 100%|██████████| 4.45M/4.45M [00:01<00:00, 3.04MiB/s]


✅ Downloaded amazon_data/amazon0302.txt.gz
Parsing co-purchase network from amazon_data/amazon0302.txt.gz...


Reading edges: 1234881it [00:00, 1356632.71it/s]


✅ Parsed 1234877 co-purchase edges
✅ Saved to amazon_data/amazon0302.txt.csv
Processing dataset: amazon0312
CSV file for amazon0312 not found. Downloading and parsing...


amazon_data/amazon0312.txt.gz: 100%|██████████| 10.8M/10.8M [00:01<00:00, 8.21MiB/s]


✅ Downloaded amazon_data/amazon0312.txt.gz
Parsing co-purchase network from amazon_data/amazon0312.txt.gz...


Reading edges: 3200444it [00:02, 1346056.55it/s]


✅ Parsed 3200440 co-purchase edges
✅ Saved to amazon_data/amazon0312.txt.csv
Processing dataset: amazon0505
CSV file for amazon0505 not found. Downloading and parsing...


amazon_data/amazon0505.txt.gz: 100%|██████████| 11.2M/11.2M [00:01<00:00, 6.39MiB/s]


✅ Downloaded amazon_data/amazon0505.txt.gz
Parsing co-purchase network from amazon_data/amazon0505.txt.gz...


Reading edges: 3356828it [00:02, 1335174.04it/s]


✅ Parsed 3356824 co-purchase edges
✅ Saved to amazon_data/amazon0505.txt.csv
Processing dataset: amazon0601
CSV file for amazon0601 not found. Downloading and parsing...


amazon_data/amazon0601.txt.gz: 100%|██████████| 11.3M/11.3M [00:01<00:00, 6.76MiB/s]


✅ Downloaded amazon_data/amazon0601.txt.gz
Parsing co-purchase network from amazon_data/amazon0601.txt.gz...


Reading edges: 3387392it [00:02, 1307101.53it/s]


✅ Parsed 3387388 co-purchase edges
✅ Saved to amazon_data/amazon0601.txt.csv


In [8]:
print("\nPreparing Amazon product graph...")

def prepare_amazon_graph(copurchase_df, metadata_df=None):
    """Transform Amazon dataset into a graph structure."""
    G = nx.DiGraph()

    # Add edges from co-purchase data. Handle the case where copurchase_df is None or empty.
    if copurchase_df is not None and not copurchase_df.empty:
        for _, row in tqdm(copurchase_df.iterrows(), total=len(copurchase_df), desc="Adding edges"):
            try:
                G.add_edge(str(row['source']), str(row['target']))
            except KeyError as e:
                print(f"⚠️ KeyError: {e}.  Missing 'source' or 'target' in row: {row}")
                # Decide how to handle.  Skip?  Log? For now, skip.
                continue
    else:
        print("⚠️ No co-purchase data available.  Graph will have no edges.")

    # Add node attributes from metadata if available and valid.
    if metadata_df is not None and 'ASIN' in metadata_df.columns and not metadata_df.empty:
        print("Adding product metadata to nodes...")
        for _, row in tqdm(metadata_df.iterrows(), total=len(metadata_df), desc="Adding metadata"):
            asin = str(row['ASIN'])  # Ensure ASIN is a string
            if asin: #check that asin is not empty
              if asin in G:  # Only add metadata if the node exists in the graph
                  for col in metadata_df.columns:
                      if col != 'ASIN' and pd.notna(row[col]):
                          G.nodes[asin][col] = row[col]
              #else:  # Optional: Verbose logging for missing nodes.
              #    print(f"Node {asin} not found in graph. Skipping metadata.")
            else:
                print("Skipping metadata entry with empty or missing ASIN.")
    elif metadata_df is not None and metadata_df.empty:
        print("Metadata DataFrame is empty, no attributes added.")
    else:
      print("Metadata is not available so no attributes will be added.")


    print(f"✅ Created Amazon product graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
    return G

# --- Usage (Modified for Multiple Datasets) ---

# Select the dataset you want to use to create the graph.
# For this example, let's use 'amazon0601'.  You'll make this selectable in Gradio later.
dataset_name = "amazon0601"

if dataset_name in datasets and datasets[dataset_name]["data"] is not None:
    copurchase_df = datasets[dataset_name]["data"]
    # Only use metadata if it's available *and* we're using amazon0601 (or another copurchase dataset).
    #  The metadata is *not* tied to a specific copurchase dataset date.
    metadata_df = datasets.get("metadata", {}).get("data") if "metadata" in datasets else None

    amazon_graph = prepare_amazon_graph(copurchase_df, metadata_df)

    # Store the graph in the datasets dictionary for later use.
    datasets[dataset_name]["graph"] = amazon_graph
else:
    print(f"⚠️ Error: Could not create graph. Dataset '{dataset_name}' is missing or has no data.")
    amazon_graph = None  # Set to None to indicate failure.


Preparing Amazon product graph...


Adding edges: 100%|██████████| 3387388/3387388 [02:14<00:00, 25127.75it/s]


Metadata DataFrame is empty, no attributes added.
✅ Created Amazon product graph with 403394 nodes and 3387388 edges


In [10]:
from typing import Optional  # Import Optional

print("\nPerforming basic graph analysis...")

def analyze_graph(G: Optional[nx.DiGraph]) -> Optional[dict]:
    """Perform basic analysis on the graph, handling None input."""
    if G is None:
        print("⚠️ Input graph is None. Cannot perform analysis.")
        return None  # Or return a dictionary with default values

    analysis = {}

    try:
        # Basic statistics
        analysis["num_nodes"] = G.number_of_nodes()
        analysis["num_edges"] = G.number_of_edges()

        # Compute degree statistics
        print("Computing degree statistics...")
        degrees = [d for n, d in G.degree()]
        analysis["avg_degree"] = sum(degrees) / len(degrees) if degrees else 0.0  # Handle empty graph
        analysis["max_degree"] = max(degrees) if degrees else 0

        # Identify top nodes by degree
        print("Finding most connected products...")
        degree_dict = dict(G.degree())
        top_nodes = sorted(degree_dict.items(), key=lambda x: x[1], reverse=True)[:10]
        analysis["top_nodes_by_degree"] = top_nodes

        # Extract largest connected component
        print("Finding largest connected component...")
        try:
            largest_cc = max(nx.weakly_connected_components(G), key=len)
            analysis["largest_cc_size"] = len(largest_cc)
            analysis["largest_cc_percentage"] = (len(largest_cc) / G.number_of_nodes()) * 100
        except ValueError: #raised if G is empty
            analysis["largest_cc_size"] = 0
            analysis["largest_cc_percentage"] = 0.0

        # Sample a small subgraph for visualization
        print("Creating sample subgraph for detailed analysis...")
        if top_nodes: #check if top_nodes is not empty
            seed_node = top_nodes[0][0]
            sample_nodes = set([seed_node])
            frontier = set([seed_node])

            while len(sample_nodes) < 100 and frontier:
                new_frontier = set()
                for node in frontier:
                    neighbors = set(G.neighbors(node))
                    new_nodes = neighbors - sample_nodes
                    sample_nodes.update(list(new_nodes)[:5])
                    new_frontier.update(list(new_nodes)[:5])
                    if len(sample_nodes) >= 100:
                        break
                frontier = new_frontier

            sample_subgraph = G.subgraph(sample_nodes)
            analysis["sample_subgraph"] = sample_subgraph #Keep for visualization
            analysis["sample_subgraph_size"] = sample_subgraph.number_of_nodes()
        else:
            analysis["sample_subgraph"] = None
            analysis["sample_subgraph_size"] = 0


        print(f"✅ Completed basic graph analysis")
        return analysis

    except Exception as e:
        print(f"⚠️ Error during graph analysis: {e}")
        return None  # Or a dictionary with default/error values


# Run the analysis (check if amazon_graph exists)
if amazon_graph is not None:
    graph_analysis = analyze_graph(amazon_graph)

    # Print some findings (only if analysis was successful)
    if graph_analysis:
        print("\nAmazon Product Network Analysis Results:")
        print(f"Total products (nodes): {graph_analysis['num_nodes']:,}")
        print(f"Total co-purchase links (edges): {graph_analysis['num_edges']:,}")
        print(f"Average connections per product: {graph_analysis['avg_degree']:.2f}")
        print(f"Maximum connections for a product: {graph_analysis['max_degree']}")
        print(f"Largest connected component contains {graph_analysis['largest_cc_percentage']:.2f}% of products")

        print("\nTop 10 most connected products (potential influencers):")
        for i, (node, degree) in enumerate(graph_analysis['top_nodes_by_degree'], 1):
            print(f"{i}. Product {node}: {degree} connections")

else:
    print("⚠️ Graph analysis skipped because the graph could not be created.")


Performing basic graph analysis...
Computing degree statistics...
Finding most connected products...
Finding largest connected component...
Creating sample subgraph for detailed analysis...
✅ Completed basic graph analysis

Amazon Product Network Analysis Results:
Total products (nodes): 403,394
Total co-purchase links (edges): 3,387,388
Average connections per product: 16.79
Maximum connections for a product: 2761
Largest connected component contains 99.99% of products

Top 10 most connected products (potential influencers):
1. Product 1041: 2761 connections
2. Product 45: 2497 connections
3. Product 50: 2291 connections
4. Product 529: 1522 connections
5. Product 783: 1184 connections
6. Product 10030: 866 connections
7. Product 89: 813 connections
8. Product 1862: 796 connections
9. Product 12245: 731 connections
10. Product 52: 719 connections


In [23]:
!pip install langchain_cohere streamlit

Collecting streamlit
  Downloading streamlit-1.42.2-py2.py3-none-any.whl.metadata (8.9 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gitpython!=3.1.19,<4,>=3.0.7 (from streamlit)
  Downloading GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.19,<4,>=3.0.7->streamlit)
  Downloading gitdb-4.0.12-py3-none-any.whl.metadata (1.2 kB)
Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.19,<4,>=3.0.7->streamlit)
  Downloading smmap-5.0.2-py3-none-any.whl.metadata (4.3 kB)
Downloading streamlit-1.42.2-py2.py3-none-any.whl (9.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[

In [11]:
from typing import Optional  # Import Optional

print("\nPerforming community detection...")

def detect_communities(G: Optional[nx.DiGraph], graph_analysis: Optional[dict] = None, max_nodes: int = 5000) -> Optional[dict]:
    """Detect communities, handling None graph and large graphs."""

    if G is None:
        print("⚠️ Input graph is None. Cannot detect communities.")
        return None

    try:
        if G.number_of_nodes() > max_nodes:
            print(f"Graph is large ({G.number_of_nodes()} nodes), sampling for community detection...")
            if graph_analysis and "sample_subgraph" in graph_analysis and graph_analysis["sample_subgraph"] is not None:
                subgraph = graph_analysis["sample_subgraph"]
            else:
                # Fallback: random sample if no sample_subgraph is available
                print("Sampling nodes for community detection (no previous sample)...")
                sampled_nodes = list(G.nodes())[:max_nodes]
                subgraph = G.subgraph(sampled_nodes)
        else:
            subgraph = G

        undirected_G = subgraph.to_undirected()  # Create undirected version

        try:
            import community as community_louvain  # Import here
            print("Running Louvain community detection...")
            partition = community_louvain.best_partition(undirected_G)
            communities = {}
            for node, community_id in partition.items():
                communities.setdefault(community_id, []).append(node)
            sorted_communities = sorted(communities.items(), key=lambda x: len(x[1]), reverse=True)

            return {
                "algorithm": "louvain",
                "num_communities": len(communities),
                "community_sizes": [len(comm) for _, comm in sorted_communities[:10]],
                "top_communities": sorted_communities[:5],
                "node_communities": partition
            }

        except ImportError:
            print("⚠️ python-louvain (community) not installed. Falling back to connected components.")
            components = list(nx.connected_components(undirected_G))
            sorted_components = sorted(components, key=len, reverse=True)
            return {
                "algorithm": "connected_components",
                "num_communities": len(components),
                "community_sizes": [len(comp) for comp in sorted_components[:10]],
                "top_communities": [(i, list(comp)) for i, comp in enumerate(sorted_components[:5])],
                "node_communities": {node: i for i, comp in enumerate(components) for node in comp}
            }
    except Exception as e:
        print(f"⚠️ Error during community detection: {e}")
        return None  # Or a dictionary with default/error values


# Try to install the community detection library
try:
    install_package("python-louvain")  # Use the installation function.
except:
    print("Could not install python-louvain. Will use connected components instead.")


# Run community detection (check if amazon_graph exists)
if amazon_graph is not None:
    community_analysis = detect_communities(amazon_graph, graph_analysis) #now we send the graph analysis

    # Print community findings (only if analysis was successful)
    if community_analysis:
      print("\nCommunity Detection Results:")
      print(f"Algorithm used: {community_analysis['algorithm']}")
      print(f"Number of communities/clusters detected: {community_analysis['num_communities']}")
      print(f"Top 5 community sizes: {community_analysis['community_sizes'][:5]}")
else:
    print("⚠️ Community detection skipped because the graph could not be created.")


Performing community detection...
Could not install python-louvain. Will use connected components instead.
Graph is large (403394 nodes), sampling for community detection...
⚠️ python-louvain (community) not installed. Falling back to connected components.

Community Detection Results:
Algorithm used: connected_components
Number of communities/clusters detected: 1
Top 5 community sizes: [104]


In [17]:
from typing import Optional
import time
from arango import ArangoClient, ServerConnectionError
import networkx as nx
from tqdm import tqdm

print("\nChecking ArangoDB connection...")

def setup_arangodb(retries=5, delay=5) -> Optional[ArangoClient.db]:
    """Setup ArangoDB connection safely with retries."""
    for attempt in range(retries):
        try:
            client = ArangoClient(hosts=ARANGO_HOST)
            # Try connecting with user credentials first
            try:
                # First check if the database exists in the system database
                sys_db = client.db('_system', username=ARANGO_USERNAME, password=ARANGO_PASSWORD)

                if not sys_db.has_database(ARANGO_DB):
                    print(f"Creating database: {ARANGO_DB}")
                    sys_db.create_database(ARANGO_DB)
                    print(f"✅ Created database: {ARANGO_DB}")

                # Now connect to the specific database
                db = client.db(ARANGO_DB, username=ARANGO_USERNAME, password=ARANGO_PASSWORD)
                print(f"✅ Connected to ArangoDB: {ARANGO_DB}")
                return db
            except Exception as e:
                print(f"⚠️ Error while checking/creating database: {e}")
                return None

        except ServerConnectionError as e:
            print(f"Attempt {attempt + 1}/{retries} failed: {e}")
            if attempt < retries - 1:
                print(f"Retrying in {delay} seconds...")
                time.sleep(delay)
        except Exception as e:  # Catch other potential exceptions
            print(f"⚠️ Error connecting to ArangoDB: {e}")
            return None

    print(f"❌ Error connecting to ArangoDB after {retries} retries.")
    return None


def persist_amazon_graph(db: ArangoClient.db, graph: nx.DiGraph, graph_name: str) -> bool:
    """Save Amazon graph into ArangoDB with a specified graph name."""
    if db is None or graph is None:
        print("⚠️ Cannot persist graph: Database connection or graph is None.")
        return False

    try:
        # Verify we can perform operations on the database
        try:
            collections = db.collections()
            print(f"Database contains {len(collections)} collections")
        except Exception as e:
            print(f"⚠️ Cannot access database collections: {e}")
            return False

        # Use ArangoDB's graph module
        if db.has_graph(graph_name):
            arango_graph = db.graph(graph_name)
            print(f"Using existing graph: {graph_name}")
        else:
            arango_graph = db.create_graph(graph_name)
            print(f"Created new graph: {graph_name}")

        nodes_collection_name = f"{graph_name}_products"  # Unique collection names
        edges_collection_name = f"{graph_name}_copurchase"

        # Vertex Collection
        if not arango_graph.has_vertex_collection(nodes_collection_name):
            products = arango_graph.create_vertex_collection(nodes_collection_name)
            print(f"✅ Created vertex collection: {nodes_collection_name}")
        else:
            products = arango_graph.vertex_collection(nodes_collection_name)
            print(f"Using existing vertex collection: {nodes_collection_name}")

        # Edge Definition
        if not arango_graph.has_edge_definition(edges_collection_name):
            copurchase = arango_graph.create_edge_definition(
                edge_collection=edges_collection_name,
                from_vertex_collections=[nodes_collection_name],
                to_vertex_collections=[nodes_collection_name]
            )
            print(f"✅ Created edge definition: {edges_collection_name}")
        else:
            copurchase = arango_graph.edge_collection(edges_collection_name)
            print(f"Using existing edge definition: {edges_collection_name}")

        # Insert nodes in batches
        batch_size = 1000
        nodes_list = list(graph.nodes(data=True))
        print(f"Inserting {len(nodes_list)} nodes in batches of {batch_size}...")
        for i in tqdm(range(0, len(nodes_list), batch_size), desc="Inserting nodes"):
            batch = nodes_list[i:i + batch_size]
            nodes_batch = [
                {"_key": str(node).replace('/', '_'), **data} for node, data in batch
            ]
            products.insert_many(nodes_batch, overwrite_mode='update')

        # Insert edges in batches
        edges_list = list(graph.edges(data=True))
        print(f"Inserting {len(edges_list)} edges in batches of {batch_size}...")
        for i in tqdm(range(0, len(edges_list), batch_size), desc="Inserting edges"):
            batch = edges_list[i:i + batch_size]
            edges_batch = [
                {
                    "_from": f"{nodes_collection_name}/{str(source).replace('/', '_')}",
                    "_to": f"{nodes_collection_name}/{str(target).replace('/', '_')}",
                    **data,  # Include any edge attributes
                }
                for source, target, data in batch
            ]
            copurchase.insert_many(edges_batch, overwrite_mode='update')

        print(f"✅ Successfully persisted Amazon graph '{graph_name}' to ArangoDB")
        return True

    except Exception as e:
        print(f"⚠️ Error persisting graph: {e}")
        return False

# --- Main Execution ---
# Define these variables before running
# amazon_graph = ...  # Your NetworkX graph
# dataset_name = "amazon0601"  # Your dataset name

# Get database connection
db = setup_arangodb()

if db is not None:
    # We have a database connection, now try to add the graph if it exists
    if 'amazon_graph' in globals() and amazon_graph is not None:
        success = persist_amazon_graph(db, amazon_graph, dataset_name)
        if success:
            print(f"✅ ArangoDB persistence complete for dataset: {dataset_name}")
        else:
            print(f"❌ ArangoDB persistence failed for dataset: {dataset_name}")
    else:
        print("❌ ArangoDB persistence skipped: No graph to persist")
else:
    print("❌ ArangoDB persistence skipped: No database connection.")


Checking ArangoDB connection...
Creating database: arnav
✅ Created database: arnav
✅ Connected to ArangoDB: arnav
Database contains 8 collections
Created new graph: amazon0601
✅ Created vertex collection: amazon0601_products
✅ Created edge definition: amazon0601_copurchase
Inserting 403394 nodes in batches of 1000...


Inserting nodes: 100%|██████████| 404/404 [00:10<00:00, 39.79it/s]


Inserting 3387388 edges in batches of 1000...


Inserting edges: 100%|██████████| 3388/3388 [05:00<00:00, 11.29it/s]

✅ Successfully persisted Amazon graph 'amazon0601' to ArangoDB
✅ ArangoDB persistence complete for dataset: amazon0601





In [20]:
from typing import Optional

print("\nSetting up LangChain for graph insights with Cohere...")
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
#from langchain.llms import Cohere  # Deprecated
from langchain_cohere import ChatCohere  # Use ChatCohere
import os

def setup_langchain_cohere(graph_analysis: Optional[dict], community_analysis: Optional[dict]) -> Optional[LLMChain]:
    """Sets up LangChain with the Cohere API, handling None inputs."""

    if not COHERE_API_KEY:
        print("⚠️ COHERE_API_KEY not set. Cannot initialize LLM.")
        return None

    if graph_analysis is None or community_analysis is None:
        print("⚠️ Graph or community analysis is None.  Cannot set up LLMChain.")
        return None

    try:
        print("✅ Using Cohere")
        query_template = PromptTemplate(
            template="""
Based on the network analysis:

Graph has {num_nodes} nodes and {num_edges} edges.
Average degree: {avg_degree:.2f}
Max degree: {max_degree}
Top Products by Degree: {top_nodes_by_degree}
Communities detected: {num_communities}
Community Sizes: {community_sizes}

Query: {query}

Answer:
""",
            input_variables=["query", "num_nodes", "num_edges", "avg_degree", "max_degree",
                             "num_communities", "community_sizes", "top_nodes_by_degree"]
        )
        # Use ChatCohere with the API key
        cohere_llm = ChatCohere(cohere_api_key=COHERE_API_KEY, model="command")
        return LLMChain(llm=cohere_llm, prompt=query_template)

    except Exception as e:
        print(f"⚠️ Error setting up LangChain: {e}")
        return None


def agentic_query(query: str, llm_chain: Optional[LLMChain], graph_analysis: Optional[dict], community_analysis: Optional[dict]) -> str:
    """Processes queries about the graph using the LLM chain, handling None inputs."""

    if llm_chain is None:
        return "LLM chain not available.  Check Cohere API key and setup."

    if graph_analysis is None or community_analysis is None:
        return "Graph or community analysis is missing. Cannot answer query."

    try:
         # Prepare the input dictionary, handling potential missing attributes, and avoiding
         # passing the complete sample_subgraph to the prompt
        input_data = {
            "query": query,
            "num_nodes": graph_analysis.get("num_nodes", 0),  # Use .get() with defaults
            "num_edges": graph_analysis.get("num_edges", 0),
            "avg_degree": graph_analysis.get("avg_degree", 0.0),
            "max_degree": graph_analysis.get("max_degree", 0),
            "top_nodes_by_degree": graph_analysis.get("top_nodes_by_degree", []),
            "num_communities": community_analysis.get("num_communities", 0) if community_analysis else 0, #check that community analysis is not None
            "community_sizes": community_analysis.get("community_sizes", []) if community_analysis else [],
        }


        result = llm_chain.invoke(input_data)  # Use .invoke()
        return result['text']

    except Exception as e:
        print(f"Error during LLM query: {e}")
        return f"An error occurred during the query: {e}"

# --- Example Usage (with checks for None) ---

if graph_analysis is not None and community_analysis is not None:
    try:
        llm_chain = setup_langchain_cohere(graph_analysis, community_analysis)
        if llm_chain:
            queries = [
                "What is the most influential product?",
                "How many communities are there?",
                "What is the structure of the network?",
                "Give me a product recommendation."
            ]
            for query in queries:
                response = agentic_query(query, llm_chain, graph_analysis, community_analysis)
                print(f"Query: {query}\nAnswer: {response}\n")
    except ValueError as e:
        print(f"Error: {e}")
elif graph_analysis is None:
    print("⚠️ Graph analysis failed, skipping agentic queries.")
elif community_analysis is None:
    print("⚠️ Community analysis failed, skipping agentic queries.")


Setting up LangChain for graph insights with Cohere...
✅ Using Cohere


  return LLMChain(llm=cohere_llm, prompt=query_template)


Query: What is the most influential product?
Answer: Of the products in the top ten list, '1041' has the highest average degree at 2761. This means it is the most popular and influential product, since it is connected to the most edges in the graph. It is likely to be an important product in the market, generating a high level of interest and conversation among customers. It is generating a lot of buzz, and should be the center of attention in any marketing plan. 

I can go into more detail on what this figure means for your business, or any other network analysis insights you would like explained in clearer terms. I hope this helps clarify what the maximum degree product means in this case! 

Let me know if you have any other immediate marketing insights you would like to discuss pertaining to this data, or if you would like any of the other findings elaborated on.

Query: How many communities are there?
Answer: There is only one community detected in the analysis. The community size 

In [19]:
!pip install langchain_cohere
import os
os.environ["COHERE_API_KEY"] = "WcXnR3lxNWGwnoJmI2hq8CnCmPfAr8fRbFFacCsT"
!pip install cohere

Collecting langchain_cohere
  Downloading langchain_cohere-0.4.2-py3-none-any.whl.metadata (6.6 kB)
Collecting cohere<6.0,>=5.12.0 (from langchain_cohere)
  Downloading cohere-5.13.12-py3-none-any.whl.metadata (3.4 kB)
Collecting types-pyyaml<7.0.0.0,>=6.0.12.20240917 (from langchain_cohere)
  Downloading types_PyYAML-6.0.12.20241230-py3-none-any.whl.metadata (1.8 kB)
Collecting fastavro<2.0.0,>=1.9.4 (from cohere<6.0,>=5.12.0->langchain_cohere)
  Downloading fastavro-1.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting types-requests<3.0.0,>=2.0.0 (from cohere<6.0,>=5.12.0->langchain_cohere)
  Downloading types_requests-2.32.0.20241016-py3-none-any.whl.metadata (1.9 kB)
Downloading langchain_cohere-0.4.2-py3-none-any.whl (42 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cohere-5.13.12-py3-none-any.whl (252 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [21]:
import time
import os
from langchain import PromptTemplate, LLMChain
from langchain.llms import Cohere

def setup_langchain_cohere(graph_analysis, community_analysis):
    """Sets up LangChain with the Cohere API."""
    try:
        if 'COHERE_API_KEY' in os.environ:
            print("✅ Using Cohere")
            query_template = PromptTemplate(
                template="""
            Based on the network analysis:

            Graph has {num_nodes} nodes and {num_edges} edges
            Average degree: {avg_degree:.2f}
            Max degree: {max_degree}
            Communities detected: {num_communities}
            Community Sizes: {community_sizes}
            Top Products by Degree: {top_nodes_by_degree}

            Query: {query}

            Answer:
            """,
                input_variables=["query", "num_nodes", "num_edges", "avg_degree", "max_degree", "num_communities", "community_sizes", "top_nodes_by_degree"]
            )
            cohere_llm = Cohere()
            return LLMChain(llm=cohere_llm, prompt=query_template)

        else:
            raise ValueError("COHERE_API_KEY environment variable not set.")

    except Exception as e:
        print(f"⚠️ Error setting up LangChain: {e}")
        return None

def agentic_query(query, llm_chain, graph_analysis, community_analysis):
    """Processes queries about the graph using the LLM chain."""
    if llm_chain is None:
        return "LLM chain not available. Please check setup."
    return llm_chain.run({"query": query, **graph_analysis, **community_analysis})

# Example graph and community analysis data (replace with your actual data)
graph_analysis = {
    "num_nodes": 100,
    "num_edges": 200,
    "avg_degree": 2.5,
    "max_degree": 10,
    "top_nodes_by_degree": [(1, 10), (2, 9)],
}
community_analysis = {
    "num_communities": 5,
    "community_sizes": [20, 15, 12, 34, 19]
}

try:
  #setup langchain with Cohere
  llm_chain = setup_langchain_cohere(graph_analysis, community_analysis)

  print("\nTesting agentic queries on Amazon graph...")
  examples = [
      "What are the most influential products in the Amazon network?",
      "What insights can we gain from the community structure?",
      "How can this graph be used for product recommendations?",
      "What does the network structure tell us about Amazon's marketplace?"
  ]

  for example in examples:
      print(f"\nQuery: {example}")
      time.sleep(1)
      result = agentic_query(example, llm_chain, graph_analysis, community_analysis) #fixed by providing the other three variables
      print(f"Result: {result}")

  print("\n✅ Amazon SNAP Graph Analysis complete!")

except ValueError as e:
    print(f"Error: {e}")

✅ Using Cohere

Testing agentic queries on Amazon graph...

Query: What are the most influential products in the Amazon network?


  cohere_llm = Cohere()
  return llm_chain.run({"query": query, **graph_analysis, **community_analysis})


Result: The products with the highest degrees are Product 1 with a degree of 10 and Product 2 with a degree of 9. These products likely have significant influence over the network due to their high connectivity and popularity among customers. This could be because of the ratings the products have, their affordability, their uniqueness, or even just a result of the wide range of options available in this specific category on Amazon. 

It is important to note that the concept of influence is slightly subjective, and ultimate profitability can only be measured by the effectiveness of converting viewing customers into purchasing ones. 

Query: What insights can we gain from the community structure?
Result: The community structure of a graph can provide several insights into its behavior and properties. 

1. **Identifying Groups** - Community structure can identify distinct groups or clusters within the graph. In the given network analysis, there are five detected communities, suggesting th

In [22]:
import networkx as nx
import os
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_cohere import ChatCohere

In [24]:
import sys
import subprocess
import importlib
import networkx as nx
from arango import ArangoClient
from langchain_openai import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import pandas as pd
import json
import gdown
import requests
from tqdm import tqdm
import gzip
import time
import streamlit as st
import matplotlib.pyplot as plt  # Import matplotlib for visualizations

# ----------------------------------------------------------------------------
# 1. Dependency Installation (Using Jupyter-Friendly Method)
# ----------------------------------------------------------------------------

def install_and_import(package):
    installed = True #boolean to store whether the required packages are installed

    try:
        importlib.import_module(package)
        print(f"✅ {package} is already installed")
    except ImportError:
        try:
            print(f"Installing {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])
            print(f"✅ Successfully installed {package}")
            importlib.import_module(package) # Check if it can be imported after installation
        except Exception as e:
            print(f"⚠️ Error installing {package}: {e}")
            print("Skipping this package.")
            installed = False

    return installed

#required packages
required_packages = ["python-arango", "networkx", "pandas", "gdown", "requests", "tqdm", "streamlit", "matplotlib"]
langchain_packages = ["langchain", "langchain-openai", "langchain-community"]
all_packages = required_packages + langchain_packages
installed_all_packages = True
missing_packages = []

for package in all_packages:
    try:
        importlib.import_module(package)
        print(f"✅ {package} is already installed")
    except ImportError:
        missing_packages.append(package)
        installed_all_packages = False # We don't know the results yet.

if not installed_all_packages:
    print("Installing all missing packages in one go...")
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install"] + missing_packages)
        print("✅ Successfully installed all missing packages")
        for package in missing_packages:
            try:
                 importlib.import_module(package)  # Double check if it can be imported now.
            except:
                 print(f"Failed to import {package}, there might be a dependency error.")
                 installed_all_packages = False

    except Exception as e:
        print(f"⚠️ Error installing packages: {e}")
        print("Skipping graph analysis and web interface setup.")
        installed_all_packages = False # Something failed to install so we return false.


✅ networkx is already installed
✅ pandas is already installed
✅ gdown is already installed
✅ requests is already installed
✅ tqdm is already installed
✅ streamlit is already installed
✅ matplotlib is already installed
✅ langchain is already installed
Installing all missing packages in one go...
✅ Successfully installed all missing packages
Failed to import python-arango, there might be a dependency error.
Failed to import langchain-openai, there might be a dependency error.
Failed to import langchain-community, there might be a dependency error.


In [25]:
def main():
    import streamlit as st
    import networkx as nx
    import matplotlib.pyplot as plt
    import pandas as pd
    import time
    import os
    from langchain import PromptTemplate, LLMChain
    from langchain.llms import Cohere

    try: # To catch missing variables if any setup issues occur
      # Load your datasets and process the graph as previously done
      amazon_graph = nx.DiGraph([(i, (i+1) % 100) for i in range(100)])
    except Exception as e:
      print(f"Error creating test graph: {e}")
      exit() # Exit due to essential setup failure

    from itertools import islice # for graph sampling

    # Functions (Copied from previous responses, please note these are just examples)
    def analyze_graph(G):
        """Analyzes a graph and returns various metrics."""
        analysis = {}
        analysis["num_nodes"] = G.number_of_nodes()
        analysis["num_edges"] = G.number_of_edges()
        degrees = [d for n, d in G.degree()]
        analysis["avg_degree"] = sum(degrees) / len(degrees) if degrees else 0
        analysis["max_degree"] = max(degrees) if degrees else 0
        degree_dict = dict(G.degree())
        top_nodes = sorted(degree_dict.items(), key=lambda x: x[1], reverse=True)[:10]
        analysis["top_nodes_by_degree"] = top_nodes

        # Find largest weakly connected component
        connected_components = list(nx.weakly_connected_components(G))
        if connected_components:
            largest_cc = max(connected_components, key=len)
            analysis["largest_cc_size"] = len(largest_cc)
            analysis["largest_cc_percentage"] = len(largest_cc) / G.number_of_nodes() * 100
        else:
            analysis["largest_cc_size"] = 0
            analysis["largest_cc_percentage"] = 0

        # Sample a small subgraph for visualization and detailed analysis
        if top_nodes:
            seed_node = top_nodes[0][0]
            sample_nodes = set([seed_node])
            frontier = set([seed_node])
            while len(sample_nodes) < 100 and frontier:
                new_frontier = set()
                for node in frontier:
                    neighbors = set(G.neighbors(node))
                    new_nodes = neighbors - sample_nodes
                    sample_nodes.update(list(new_nodes)[:5])
                    new_frontier.update(list(new_nodes)[:5])
                    if len(sample_nodes) >= 100:
                        break
                frontier = new_frontier
            sample_subgraph = G.subgraph(sample_nodes)
            analysis["sample_subgraph_nodes"] = list(sample_nodes)  # Store nodes instead of subgraph
            analysis["sample_subgraph_size"] = sample_subgraph.number_of_nodes()
        else:
            analysis["sample_subgraph_nodes"] = []
            analysis["sample_subgraph_size"] = 0

        return analysis

    def detect_communities(G, graph_analysis=None, max_nodes=5000):
        """Detects communities within the graph using Louvain or connected components."""
        # Handle large graphs by sampling
        try: #this could give an error, if no data is there for example
            if G.number_of_nodes() > max_nodes:
                print(f"Graph is large ({G.number_of_nodes()} nodes), sampling {max_nodes} nodes for community detection...")
                if graph_analysis and "sample_subgraph_nodes" in graph_analysis:
                    subgraph = G.subgraph(graph_analysis["sample_subgraph_nodes"])
                else:
                    # Sample nodes if no sample subgraph is available
                    subgraph = G.subgraph(list(G.nodes())[:max_nodes])
            else:
                subgraph = G

            # Convert to undirected for community detection
            undirected_G = subgraph.to_undirected()

            try:
                # Try using Louvain algorithm
                import community as community_louvain
                partition = community_louvain.best_partition(undirected_G)
                communities = {}
                for node, community_id in partition.items():
                    if community_id not in communities:
                        communities[community_id] = []
                    communities[community_id].append(node)
                sorted_communities = sorted(communities.items(), key=lambda x: len(x[1]), reverse=True)
                return {
                    "algorithm": "louvain",
                    "num_communities": len(communities),
                    "community_sizes": [len(comm) for _, comm in sorted_communities[:10]],
                    "top_communities": sorted_communities[:5],
                    "node_communities": partition,
                }
            except ImportError:
                print("Louvain algorithm not available, using connected components instead...")
                # Fallback to connected components
                components = list(nx.connected_components(undirected_G))
                sorted_components = sorted(components, key=len, reverse=True)
                return {
                    "algorithm": "connected_components",
                    "num_communities": len(components),
                    "community_sizes": [len(comp) for comp in sorted_components[:10]],
                    "top_communities": [(i, list(comp)) for i, comp in enumerate(sorted_components[:5])],
                    "node_communities": {node: i for i, comp in enumerate(components) for node in comp}
                }
        except:
            print("Can't do community setup, maybe there is no information in this set?")
            return{}

    def setup_langchain_cohere(graph_analysis, community_analysis):
        """Sets up LangChain with the Cohere API."""
        try:
            # Check if Cohere API key is set
            cohere_api_key = os.environ.get('COHERE_API_KEY')
            if cohere_api_key:
                print("✅ Using Cohere")
                query_template = PromptTemplate(
                    template="""
                Based on the network analysis:
                Graph has {num_nodes} nodes and {num_edges} edges
                Average degree: {avg_degree:.2f}
                Max degree: {max_degree}
                Communities detected: {num_communities}
                Community Sizes: {community_sizes}
                Top Products by Degree: {top_nodes_by_degree}
                Query: {query}
                Answer:
                """,
                    input_variables=["query", "num_nodes", "num_edges", "avg_degree", "max_degree",
                                    "num_communities", "community_sizes", "top_nodes_by_degree"]
                )
                cohere_llm = Cohere(cohere_api_key=cohere_api_key)
                return LLMChain(llm=cohere_llm, prompt=query_template)
            else:
                # Raise error if API key is not set
                raise ValueError("COHERE_API_KEY environment variable not set.")
        except Exception as e:
            print(f"⚠️ Error setting up LangChain: {e}")
            return None

    def agentic_query(query, llm_chain, graph_analysis, community_analysis):
        """Processes queries about the graph using the LLM chain."""
        if llm_chain is None:
            return "LLM chain not available. Please check setup."

        # Create a new dictionary with all parameters
        params = {
            "query": query
        }
        # Add graph analysis parameters
        for key, value in graph_analysis.items():
            if isinstance(value, (str, int, float, list, tuple, dict)) and key != "sample_subgraph_nodes":
                params[key] = value

        # Add community analysis parameters
        for key, value in community_analysis.items():
            if isinstance(value, (str, int, float, list, tuple, dict)) and key != "node_communities":
                params[key] = value

        # Run the chain with the prepared parameters
        return llm_chain.run(**params)

    graph_analysis = analyze_graph(amazon_graph)
    community_analysis = detect_communities(amazon_graph, graph_analysis)
    print(graph_analysis)
    print(community_analysis)
    # ----------------------------------------------------------------------------
    # 4. Streamlit Application
    # ----------------------------------------------------------------------------

    st.title("Amazon Product Network Analysis")

    st.sidebar.header("Graph Statistics")
    st.sidebar.write(f"Total Products (Nodes): {graph_analysis.get('num_nodes'):,}")
    st.sidebar.write(f"Total Co-Purchase Links (Edges): {graph_analysis.get('num_edges'):,}")
    st.sidebar.write(f"Average Connections per Product: {graph_analysis.get('avg_degree'):.2f}")
    st.sidebar.write(f"Maximum Connections for a Product: {graph_analysis.get('max_degree')}")
    st.sidebar.write(f"Largest Connected Component: {graph_analysis.get('largest_cc_percentage'):.2f}%")

    st.sidebar.header("Community Statistics")
    st.sidebar.write(f"Number of Communities: {community_analysis.get('num_communities', 'N/A')}")
    st.sidebar.write("Top 5 Community Sizes:")
    if "community_sizes" in community_analysis:
        for i, size in enumerate(community_analysis['community_sizes'][:5]):
            st.sidebar.write(f"{i+1}: " + str(size))

    # Visualization - using matplotlib for simplicity
    st.header("Graph Visualization")
    st.write("Displaying a sample subgraph for visualization")
    if 'sample_subgraph' in graph_analysis and graph_analysis['sample_subgraph']:
        fig, ax = plt.subplots()
        nx.draw(graph_analysis["sample_subgraph"], with_labels=True, ax=ax)
        st.pyplot(fig)  # st.pyplot for matplotlib plots. If using plotly or other library you will use different command.
    else:
        st.write("No sample subgraph available.")

            # LLM-powered Insights Section
    st.header("LLM-Powered Insights")

    llm_chain = setup_langchain_cohere(graph_analysis, community_analysis)

    if llm_chain:
        query = st.text_input("Enter your query about the Amazon network:")
        if query:
            result = agentic_query(query, llm_chain, graph_analysis, community_analysis)
            st.write("LLM Answer:", result)
    else:
        st.error("Failed to set up the LLM Chain. Check your API key and settings.")
if __name__ == "__main__":
    main()

2025-02-26 06:25:04.208 
  command:

    streamlit run /usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py [ARGUMENTS]


Louvain algorithm not available, using connected components instead...
{'num_nodes': 100, 'num_edges': 100, 'avg_degree': 2.0, 'max_degree': 2, 'top_nodes_by_degree': [(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 2), (7, 2), (8, 2), (9, 2)], 'largest_cc_size': 100, 'largest_cc_percentage': 100.0, 'sample_subgraph_nodes': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], 'sample_subgraph_size': 100}
{'algorithm': 'connected_components', 'num_communities': 1, 'community_sizes': [100], 'top_communities': [(0, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 

2025-02-26 06:25:04.330 Session state does not function when running a script without `streamlit run`


In [27]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.18.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.8-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.7.2 (from gradio)
  Downloading gradio_client-1.7.2-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.9.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.meta

In [48]:
!pip install pyArango

Collecting pyArango
  Downloading pyArango-2.1.1.tar.gz (51 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/51.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting future (from pyArango)
  Downloading future-1.0.0-py3-none-any.whl.metadata (4.0 kB)
Collecting datetime (from pyArango)
  Downloading DateTime-5.5-py3-none-any.whl.metadata (33 kB)
Collecting zope.interface (from datetime->pyArango)
  Downloading zope.interface-7.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading DateTime-5.5-py3-none-any.whl (52 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/52.6 kB[0m 

In [51]:
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import io
import base64
import os
import requests
import gzip
from tqdm import tqdm
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

@dataclass
class GraphAnalysis:
    num_nodes: int
    num_edges: int
    avg_degree: float
    max_degree: int
    top_nodes_by_degree: List[Tuple[str, int]]
    largest_cc_size: int
    largest_cc_percentage: float
    sample_subgraph: Optional[nx.DiGraph]
    sample_subgraph_size: int
    image: str

def visualize_graph(graph: Optional[nx.DiGraph]) -> str:
    """Visualizes the graph (or a sample) and returns a base64 encoded image."""
    if graph is None:
        return ""

    # Always sample a subgraph for visualization to keep it simple and fast
    if graph.number_of_nodes() > 100:
        top_nodes = sorted(graph.degree(), key=lambda x: x[1], reverse=True)[:10]  #Top 10
        seed_node = top_nodes[0][0]
        sample_nodes = {seed_node}
        frontier = {seed_node}
        while len(sample_nodes) < 100 and frontier:
            new_frontier = set()
            for node in frontier:
                neighbors = set(graph.neighbors(node))
                new_nodes = (neighbors - sample_nodes)
                selected_nodes = list(new_nodes)[:5]  # Limit to 5 neighbors
                sample_nodes.update(selected_nodes)
                new_frontier.update(selected_nodes)
                if len(sample_nodes) >= 100:
                    break
            frontier = new_frontier
        graph = graph.subgraph(list(sample_nodes))


    plt.figure(figsize=(12, 6))  # Adjust figure size as needed
    nx.draw(graph, with_labels=True, font_weight='bold', node_size=400, font_size=9, alpha=0.7) #Keep small
    plt.title("Generated Graph (Sample)")  # Clarify it's a sample
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    return base64.b64encode(buf.getvalue()).decode('utf-8')

def analyze_graph(graph: Optional[nx.DiGraph]) -> Optional[GraphAnalysis]:
    """Analyzes the graph and returns metrics, including a sampled image."""
    if graph is None:
        return None

    analysis = {
        "num_nodes": graph.number_of_nodes(),
        "num_edges": graph.number_of_edges(),
    }
    degrees = [d for _, d in graph.degree()]
    analysis["avg_degree"] = sum(degrees) / len(degrees) if degrees else 0.0
    analysis["max_degree"] = max(degrees) if degrees else 0
    top_nodes = sorted(graph.degree(), key=lambda x: x[1], reverse=True)[:10] #Top 10
    analysis["top_nodes_by_degree"] = top_nodes

    connected_components = list(nx.weakly_connected_components(graph))
    if connected_components:
        largest_cc = max(connected_components, key=len)
        analysis["largest_cc_size"] = len(largest_cc)
        analysis["largest_cc_percentage"] = (len(largest_cc) / graph.number_of_nodes()) * 100
    else:
        analysis["largest_cc_size"] = 0
        analysis["largest_cc_percentage"] = 0.0

    analysis["sample_subgraph"] = None  # We'll *always* create the sample now
    analysis["sample_subgraph_size"] = 0
    analysis["image"] = visualize_graph(graph)  # Get sampled image

    return GraphAnalysis(**analysis)


def download_file(url: str, filename: str) -> Optional[str]:
    """Downloads a file with a progress bar."""
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        with open(filename, 'wb') as file, tqdm(
            desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024
        ) as bar:
            for data in response.iter_content(1024):
                file.write(data)
                bar.update(len(data))
        print(f"✅ Downloaded {filename}")
        return filename
    except requests.exceptions.RequestException as e:
        print(f"⚠️ Error downloading {filename}: {e}")
        return None



def parse_amazon_copurchase(gz_file: str) -> Optional[pd.DataFrame]:
    """Parse Amazon co-purchasing network data."""
    print(f"Parsing co-purchase network from {gz_file}...")
    edges = []
    try:
        with gzip.open(gz_file, 'rt', encoding='latin1') as f:
            for line in tqdm(f, desc="Reading edges"):
                if not line.startswith('#'):
                    source, target = line.strip().split()
                    edges.append((source, target))
        print(f"✅ Parsed {len(edges)} co-purchase edges")
        df = pd.DataFrame(edges, columns=['source', 'target'])
        csv_file = gz_file.replace('.gz', '.csv')
        df.to_csv(csv_file, index=False)
        print(f"✅ Saved to {csv_file}")
        return df
    except Exception as e:
        print(f"⚠️ Error parsing co-purchase data: {e}")
        return None

def load_graph(copurchase_df: Optional[pd.DataFrame]) -> Optional[nx.DiGraph]:
    """Loads the graph from DataFrames."""
    if copurchase_df is None:
        print("⚠️ Copurchase DataFrame is None.")
        return None
    try:
        graph = nx.DiGraph()
        with tqdm(total=len(copurchase_df), desc="Adding edges") as pbar:
            for _, row in copurchase_df.iterrows():
                graph.add_edge(str(row['source']), str(row['target']))
                pbar.update(1)
        print(f"✅ Created graph: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")
        return graph
    except Exception as e:
        print(f"⚠️ Error loading graph: {e}")
        return None


def process_data(max_nodes_to_display: int = 1000) -> Dict[str, str]:
    """Downloads, parses, and analyzes the graph, returning results for Gradio."""

    amazon_datasets = {
        "copurchase": "http://snap.stanford.edu/data/amazon0601.txt.gz",
    }
    data_dir = "amazon_data"
    os.makedirs(data_dir, exist_ok=True)


    copurchase_file = os.path.join(data_dir, "amazon0601.txt.gz")
    copurchase_csv_file = copurchase_file.replace('.gz', '.csv')
    if os.path.exists(copurchase_csv_file):
        print("Using existing copurchase CSV.")
        copurchase_df = pd.read_csv(copurchase_csv_file)
    else:
        if not os.path.exists(copurchase_file):
            download_file(amazon_datasets["copurchase"], copurchase_file)
        copurchase_df = parse_amazon_copurchase(copurchase_file)
        if copurchase_df is None:
            return {
                "graph_summary": "Error: Could not load co-purchase data.",
                "graph_visualization": "",
                "status": "Data loading error."
            }

    amazon_graph = load_graph(copurchase_df)
    if amazon_graph is None:
        return {
            "graph_summary": "Error: Could not create graph.",
            "graph_visualization": "",
            "status": "Graph creation error."
        }

    graph_analysis = analyze_graph(amazon_graph)
    if graph_analysis is None:
        return {
            "graph_summary": "Error: Graph analysis failed.",
            "graph_visualization": "",
            "status": "Graph analysis error."
        }

    # Create a concise summary for the text output
    summary = (
        f"The graph has {graph_analysis.num_nodes} nodes and {graph_analysis.num_edges} edges.\n"
        f"Average degree: {graph_analysis.avg_degree:.2f}, Max degree: {graph_analysis.max_degree}.\n"
        f"Largest connected component size: {graph_analysis.largest_cc_size} "
        f"({graph_analysis.largest_cc_percentage:.2f}% of nodes).\n"
        f"Top nodes by degree: {graph_analysis.top_nodes_by_degree[:5]}" # Top 5

    )

    # Limit the graph displayed
    if graph_analysis.num_nodes > max_nodes_to_display:
      summary += f"\n\nDisplaying a sample of up to {max_nodes_to_display} nodes."

    return {
        "graph_summary": summary,
        "graph_visualization": graph_analysis.image,  # Always a sampled/limited image
        "status": "Graph analysis complete!",
    }

# --- Gradio Interface Setup ---
inputs = [
    gr.Slider(minimum=100, maximum=10000, value=1000, step=100, label="Max Nodes to Display", key="max_nodes_to_display")
]
outputs = [
    gr.Textbox(label="Graph Summary", key="graph_summary"),
    gr.HTML(label="Graph Visualization", key="graph_visualization"),
    gr.Textbox(label="Status", key="status"),
]

iface = gr.Interface(
    fn=process_data,
    inputs=inputs,
    outputs=outputs,
    title="Amazon Graph Analysis (Simplified)",
    description="Analyzes the Amazon product co-purchasing network and displays a simplified graph.",
    allow_flagging="never",  # Prevent flagging, since we don't have user input
)

if __name__ == "__main__":
    iface.launch(debug=False)



Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://7ef8c0f1f722fadf75.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


In [35]:
from google.colab import drive
drive.mount('/content/drive')

Adding edges:  51%|█████     | 1713262/3387388 [00:54<00:47, 35002.28it/s]

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [36]:
import os

data_dir = "/content/drive/MyDrive/amazon_data"  # If using Drive
# data_dir = "/content/amazon_data"  # If NOT using Drive (ephemeral)
os.makedirs(data_dir, exist_ok=True)

Adding edges:  55%|█████▌    | 1867907/3387388 [00:58<00:43, 34567.47it/s]

In [44]:
from google.colab import userdata

COHERE_API_KEY = 'WcXnR3lxNWGwnoJmI2hq8CnCmPfAr8fRbFFacCsT'
ARANGO_HOST = 'https://a40b6d186a3a.arangodb.cloud:8529'
ARANGO_USERNAME = 'root'
ARANGO_PASSWORD = '2eM5Wd4NRTrcnHQt3yfM'
ARANGO_DB = 'arnav'

In [38]:
!pip install gradio langchain_cohere python-arango

Adding edges:  63%|██████▎   | 2126661/3387388 [01:06<00:36, 34128.10it/s]



Adding edges:  63%|██████▎   | 2133566/3387388 [01:06<00:36, 33976.90it/s]



Adding edges:  64%|██████▍   | 2161399/3387388 [01:07<00:35, 34142.34it/s]

In [50]:
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import io
import base64
import os
import requests
import gzip
from tqdm import tqdm
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
import pandas as pd
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_cohere import ChatCohere
import subprocess
import sys
import importlib
from arango import ArangoClient
import time
from arango.exceptions import ServerConnectionError, DatabaseCreateError
import json

# --- Helper Function for Installation ---
def install_package(package_name):
    """Installs a package using pip, handling errors and using --break-system-packages."""
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package_name, "--break-system-packages"])
        print(f"✅ Successfully installed {package_name}")
        return True
    except subprocess.CalledProcessError:
        print(f"❌ Failed to install {package_name}.  You may need to install it manually.")
        print("   Consider using a virtual environment or 'apt install python3-<package-name>' if available.")
        return False
    except Exception as e:
        print(f"❌ Unexpected error installing {package_name}: {e}")
        return False

# --- Package Installation ---
required_packages = ["gradio", "networkx", "matplotlib", "requests", "tqdm", "pandas",
                     "langchain", "langchain-cohere", "python-dotenv", "python-arango",
                     "community"]  # Corrected name
for package in required_packages:
    try:
        importlib.import_module(package)
        print(f"✅ {package} is already installed")
    except ImportError:
        print(f"Installing {package}...")
        if not install_package(package):
            print(f"Skipping {package} due to installation failure.")

# --- Load API Key from .env ---
from dotenv import load_dotenv
load_dotenv()  # OR use Colab secrets (recommended)



if not COHERE_API_KEY:
    print("⚠️ WARNING: COHERE_API_KEY not found in environment variables.  LLM features will be disabled.")

# --- Data Classes ---
@dataclass
class GraphAnalysis:
    num_nodes: int
    num_edges: int
    avg_degree: float
    max_degree: int
    top_nodes_by_degree: List[Tuple[str, int]]
    largest_cc_size: int
    largest_cc_percentage: float
    sample_subgraph: Optional[nx.DiGraph] = None  # Default to None
    sample_subgraph_size: int = 0
    image: str = ""

@dataclass
class CommunityAnalysis:
  algorithm: str
  num_communities: int
  community_sizes: List[int]
  top_communities: List[Tuple[int, List[str]]]
  node_communities: Dict[str, int]

# --- ArangoDB Setup and Loading Functions ---

def setup_arangodb(retries=5, delay=5):
    """Sets up the ArangoDB connection and creates the database/collections if needed."""
    for attempt in range(retries):
        try:
            client = ArangoClient(hosts=ARANGO_HOST)
            db = None

            # Try connecting with user credentials first
            try:
                db = client.db(ARANGO_DB, username=ARANGO_USERNAME, password=ARANGO_PASSWORD)
                print(f"✅ Connected to ArangoDB: {ARANGO_DB}")
            except Exception:
                # Fallback to _system as root
                print(f"⚠️ Initial connection to '{ARANGO_DB}' failed.  Trying _system...")
                sys_db = client.db('_system', username=ARANGO_USERNAME, password=ARANGO_PASSWORD)
                if not sys_db.has_database(ARANGO_DB):
                    sys_db.create_database(ARANGO_DB)
                    print(f"✅ Created database: {ARANGO_DB}")
                db = client.db(ARANGO_DB, username=ARANGO_USERNAME, password=ARANGO_PASSWORD) #connect

            # Check for graph and create if it does not exist.
            if db is not None:
                if not db.has_graph(ARANGO_GRAPH_NAME):
                    db.create_graph(ARANGO_GRAPH_NAME)
                    print(f"✅ Created graph: {ARANGO_GRAPH_NAME}")
                return db

            return db

        except ServerConnectionError as e:
            print(f"Attempt {attempt + 1}/{retries} failed: {e}")
            if attempt < retries - 1:
                print(f"Retrying in {delay} seconds...")
                time.sleep(delay)
        except Exception as e:
            print(f"❌ Error connecting to/setting up ArangoDB: {e}")
            return None

    print(f"❌ Error connecting to ArangoDB after multiple retries.")
    return None

def persist_graph_to_arangodb(db, graph: nx.DiGraph, graph_name: str):
    """Persists the NetworkX graph to ArangoDB with a given graph name."""

    if db is None or graph is None:
        print("⚠️ Cannot persist graph: Database connection or graph is None.")
        return False

    try:
        if db.has_graph(graph_name):
            arango_graph = db.graph(graph_name)
        else:
            arango_graph = db.create_graph(graph_name)
            print(f"Using/Created ArangoDB graph: {graph_name}")

        nodes_collection_name = f"{graph_name}_products"
        edges_collection_name = f"{graph_name}_copurchase"

        # Vertex Collection
        if not arango_graph.has_vertex_collection(nodes_collection_name):
            products = arango_graph.create_vertex_collection(nodes_collection_name)
            print(f"✅ Created vertex collection: {nodes_collection_name}")
        else:
            products = arango_graph.vertex_collection(nodes_collection_name)

        # Edge Definition
        if not arango_graph.has_edge_definition(edges_collection_name):
            copurchase = arango_graph.create_edge_definition(
                edge_collection=edges_collection_name,
                from_vertex_collections=[nodes_collection_name],
                to_vertex_collections=[nodes_collection_name]
            )
            print(f"✅ Created edge definition: {edges_collection_name}")
        else:
            copurchase = arango_graph.edge_collection(edges_collection_name)

        # Insert nodes in batches
        batch_size = 1000
        nodes_list = list(graph.nodes(data=True))
        print(f"Inserting {len(nodes_list)} nodes in batches of {batch_size}...")
        for i in tqdm(range(0, len(nodes_list), batch_size), desc="Inserting nodes"):
            batch = nodes_list[i:i + batch_size]
            nodes_batch = [
                {"_key": str(node), **data} for node, data in batch
            ]
            products.insert_many(nodes_batch, overwrite_mode='update')

        # Insert edges in batches
        edges_list = list(graph.edges(data=True))
        print(f"Inserting {len(edges_list)} edges in batches of {batch_size}...")
        for i in tqdm(range(0, len(edges_list), batch_size), desc="Inserting edges"):
            batch = edges_list[i:i + batch_size]
            edges_batch = [
                {
                    "_from": f"{nodes_collection_name}/{str(source).replace('/', '_')}",
                    "_to": f"{nodes_collection_name}/{str(target).replace('/', '_')}",
                    **data,  # Include any edge attributes
                }
                for source, target, data in batch
            ]
            copurchase.insert_many(edges_batch, overwrite_mode='update')

        print(f"✅ Graph '{graph_name}' persisted to ArangoDB.")
        return True

    except Exception as e:
        print(f"❌ Error persisting graph '{graph_name}' to ArangoDB: {e}")
        return False

def load_graph_from_arangodb(db, graph_name: str) -> Optional[nx.DiGraph]:
    """Loads a graph from ArangoDB into NetworkX."""
    if db is None:
        print("⚠️ Cannot load graph: Database connection is None.")
        return None

    try:
        graph = nx.DiGraph(db.graph(graph_name))
        print(f"✅ Graph '{graph_name}' loaded from ArangoDB: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")
        return graph
    except Exception as e:
        print(f"❌ Error loading graph '{graph_name}' from ArangoDB: {e}")
        return None

# --- Data Loading and Processing Functions ---
def download_file(url, filename):
    """Downloads a file with a progress bar (no changes needed)."""
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        with open(filename, 'wb') as file, tqdm(
            desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024
        ) as bar:
            for data in response.iter_content(1024):
                file.write(data)
                bar.update(len(data))
        print(f"✅ Downloaded {filename}")
        return filename
    except requests.exceptions.RequestException as e:
        print(f"⚠️ Error downloading {filename}: {e}")
        return None
def parse_amazon_metadata(gz_file):
    """Parse Amazon metadata (using json.loads per line)."""
    print(f"Parsing metadata from {gz_file}...")
    products = []
    try:
        with gzip.open(gz_file, 'rt', encoding='utf-8') as f:
            for i, line in enumerate(tqdm(f, desc="Reading lines")):
                line = line.strip()
                if line:
                    try:
                        product = json.loads(line)
                        products.append(product)
                    except json.JSONDecodeError as e:
                        print(f"⚠️ JSONDecodeError: {e} on line {i + 1}: {line}")
                        continue

        print(f"✅ Parsed {len(products)} products")
        df = pd.DataFrame(products)

        if 'ASIN' in df.columns:
             df['ASIN'] = df['ASIN'].astype(str)
        else:
            print("⚠️ Warning: 'ASIN' column not found in metadata.")
            df['ASIN'] = '' #or other default

        csv_file = gz_file.replace('.gz', '.csv')
        df.to_csv(csv_file, index=False)
        print(f"✅ Saved to {csv_file}")
        return df

    except Exception as e:
        print(f"⚠️ Error parsing metadata file: {e}")
        return None
def parse_amazon_copurchase(gz_file, max_edges=None):
    """Parse co-purchase data (with optional edge limit)."""
    print(f"Parsing co-purchase network from {gz_file}...")
    edges = []
    try:
        with gzip.open(gz_file, 'rt', encoding='latin1') as f:
            for i, line in enumerate(tqdm(f, desc="Reading edges")):
                if not line.startswith('#'):
                    source, target = line.strip().split()
                    edges.append((source, target))
                    if max_edges is not None and i >= max_edges:
                        break

        print(f"✅ Parsed {len(edges)} co-purchase edges")
        df = pd.DataFrame(edges, columns=['source', 'target'])
        csv_file = gz_file.replace('.gz', '.csv')
        df.to_csv(csv_file, index=False)
        print(f"✅ Saved to {csv_file}")
        return df

    except Exception as e:
        print(f"⚠️ Error parsing co-purchase network: {e}")
        return None

def load_graph(copurchase_df: Optional[pd.DataFrame]) -> Optional[nx.DiGraph]:
    """Loads the graph from DataFrames."""
    if copurchase_df is None:
        print("⚠️ Copurchase DataFrame is None.")
        return None
    try:
        graph = nx.DiGraph()
        with tqdm(total=len(copurchase_df), desc="Adding edges") as pbar:
            for _, row in copurchase_df.iterrows():
                graph.add_edge(str(row['source']), str(row['target']))
                pbar.update(1)
        print(f"✅ Created graph: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")
        return graph
    except Exception as e:
        print(f"⚠️ Error loading graph: {e}")
        return None

# --- Graph Analysis and Community Detection ---
def visualize_graph(graph: Optional[nx.DiGraph]) -> str:
    """Visualizes the graph (or a sample) and returns a base64 encoded image."""
    if graph is None:
        return ""

    # Always sample a subgraph for visualization to keep it simple and fast
    if graph.number_of_nodes() > 100:
        top_nodes = sorted(graph.degree(), key=lambda x: x[1], reverse=True)[:10]  #Top 10
        if not top_nodes: # Handle empty graph
            return ""
        seed_node = top_nodes[0][0]
        sample_nodes = {seed_node}
        frontier = {seed_node}
        while len(sample_nodes) < 100 and frontier:
            new_frontier = set()
            for node in frontier:
                neighbors = set(graph.neighbors(node))
                new_nodes = (neighbors - sample_nodes)
                selected_nodes = list(new_nodes)[:5]  # Limit to 5 neighbors
                sample_nodes.update(selected_nodes)
                new_frontier.update(selected_nodes)
                if len(sample_nodes) >= 100:
                    break
            frontier = new_frontier
        graph = graph.subgraph(list(sample_nodes))


    plt.figure(figsize=(12, 6))  # Adjust figure size as needed
    nx.draw(graph, with_labels=True, font_weight='bold', node_size=400, font_size=9, alpha=0.7) #Keep small
    plt.title("Generated Graph (Sample)")  # Clarify it's a sample
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    return base64.b64encode(buf.getvalue()).decode('utf-8')

def analyze_graph(graph: Optional[nx.DiGraph]) -> Optional[GraphAnalysis]:
    """Analyzes the graph and returns metrics, including a sampled image."""
    if graph is None:
        return None

    try:
        analysis = {
            "num_nodes": graph.number_of_nodes(),
            "num_edges": graph.number_of_edges(),
        }
        degrees = [d for _, d in graph.degree()]
        analysis["avg_degree"] = sum(degrees) / len(degrees) if degrees else 0.0
        analysis["max_degree"] = max(degrees) if degrees else 0
        top_nodes = sorted(graph.degree(), key=lambda x: x[1], reverse=True)[:10] #Top 10
        analysis["top_nodes_by_degree"] = top_nodes

        connected_components = list(nx.weakly_connected_components(graph))
        if connected_components:
            largest_cc = max(connected_components, key=len)
            analysis["largest_cc_size"] = len(largest_cc)
            analysis["largest_cc_percentage"] = (len(largest_cc) / graph.number_of_nodes()) * 100
        else:
            analysis["largest_cc_size"] = 0
            analysis["largest_cc_percentage"] = 0.0

        analysis["sample_subgraph"] = None  # We'll *always* create the sample now
        analysis["sample_subgraph_size"] = 0
        analysis["image"] = visualize_graph(graph)  # Get sampled image

        return GraphAnalysis(**analysis)

    except Exception as e:
        print(f"Error during graph analysis: {e}")
        return None

def detect_communities(G: nx.DiGraph, graph_analysis: Optional[GraphAnalysis] = None, max_nodes: int = 5000) -> Optional[CommunityAnalysis]:
    """Detects communities using Louvain (if available) or falls back to connected components."""

    if G is None:
        print("⚠️ Input graph is None. Cannot detect communities.")
        return None
    #handle exceptions:
    try:
        if G.number_of_nodes() > max_nodes:
            print(f"Graph is large ({G.number_of_nodes()} nodes), sampling for community detection...")
            if graph_analysis and graph_analysis.sample_subgraph:
                subgraph = graph_analysis.sample_subgraph
            else:
                # Basic random sampling if we don't have a sample_subgraph
                sampled_nodes = list(G.nodes())[:max_nodes]
                subgraph = G.subgraph(sampled_nodes)
        else:
            subgraph = G

        undirected_G = subgraph.to_undirected()

        try:
            # Corrected import and function call
            import community as community_louvain
            partition = community_louvain.best_partition(undirected_G)
            communities: Dict[int, List[str]] = {}
            for node, community_id in partition.items():
                if community_id not in communities:
                    communities[community_id] = []
                communities[community_id].append(node)
            sorted_communities = sorted(communities.items(), key=lambda x: len(x[1]), reverse=True)

            return CommunityAnalysis(
                algorithm="louvain",
                num_communities=len(communities),
                community_sizes=[len(comm) for _, comm in sorted_communities[:10]],
                top_communities=sorted_communities[:5],
                node_communities=partition
            )

        except ImportError:
            print("⚠️ python-louvain (community) not installed. Falling back to connected components.")
            components = list(nx.connected_components(undirected_G))
            sorted_components = sorted(components, key=len, reverse=True)
            return CommunityAnalysis(
                algorithm="connected_components",
                num_communities=len(components),
                community_sizes=[len(comp) for comp in sorted_components[:10]],
                top_communities=[(i, list(comp)) for i, comp in enumerate(sorted_components[:5])],
                node_communities={node: i for i, comp in enumerate(components) for node in comp}
            )
    except Exception as e:
        print(f"Error during community detection: {e}")
        return None

# --- LangChain Setup ---

def setup_langchain_cohere(graph_analysis: Optional[GraphAnalysis], community_analysis: Optional[CommunityAnalysis]) -> Optional[LLMChain]:
    """Sets up the LangChain LLMChain with Cohere, handling missing API key."""

    if not COHERE_API_KEY:
        print("⚠️ COHERE_API_KEY not set.  Cannot initialize LLM.")
        return None

    if graph_analysis is None or community_analysis is None:
        print("⚠️ Graph or community analysis is None. Cannot set up LLMChain.")
        return None
    #handle exceptions
    try:
        print("Setting up LangChain with Cohere...")
        query_template = PromptTemplate(
            template="""
Based on the network analysis:

Graph has {num_nodes} nodes and {num_edges} edges.
Average degree: {avg_degree:.2f}
Max degree: {max_degree}
Top Products by Degree: {top_nodes_by_degree}
Communities detected: {num_communities}
Community Sizes: {community_sizes}

Query: {query}

Answer:
""",
            input_variables=["query", "num_nodes", "num_edges", "avg_degree", "max_degree",
                             "num_communities", "community_sizes", "top_nodes_by_degree"]
        )

        cohere_llm = ChatCohere(cohere_api_key=COHERE_API_KEY, model="command") # Use ChatCohere
        return LLMChain(llm=cohere_llm, prompt=query_template)

    except Exception as e:
        print(f"⚠️ Error setting up LangChain: {e}")
        return None



def agentic_query(query: str, llm_chain: Optional[LLMChain], graph_analysis: Optional[GraphAnalysis], community_analysis: Optional[CommunityAnalysis]) -> str:
    """Processes queries using the LLM chain."""

    if llm_chain is None:
        return "LLM chain not available.  Check Cohere API key and setup."

    if graph_analysis is None or community_analysis is None:
        return "Graph or community analysis is missing. Cannot answer query."

    try:
        # Prepare the input dictionary, handling potential missing attributes
        input_data = {
            "query": query,
            "num_nodes": graph_analysis.num_nodes,
            "num_edges": graph_analysis.num_edges,
            "avg_degree": graph_analysis.avg_degree,
            "max_degree": graph_analysis.max_degree,
            "top_nodes_by_degree": graph_analysis.top_nodes_by_degree,
            "num_communities": community_analysis.num_communities,
            "community_sizes": community_analysis.community_sizes

        }

        result = llm_chain.invoke(input_data) # Use .invoke() for newer LangChain
        return result['text'] # Access the 'text' key for the response

    except Exception as e:
        print(f"Error during LLM query: {e}")
        return f"An error occurred during the query: {e}"
def process_data(max_nodes_to_display: int, dataset_name: str, user_query: str, persist_to_db: bool = False) -> Dict[str, str]:

  # --- 1. Data Loading (Prioritize ArangoDB) ---
    ARANGO_GRAPH_NAME = dataset_name
    data_dir = "amazon_data"
    os.makedirs(data_dir, exist_ok=True)

    db = setup_arangodb()  # Connect to ArangoDB *first*
    amazon_graph = None

    if db:
        amazon_graph = load_graph_from_arangodb(db, dataset_name)  # Load directly from DB
        if amazon_graph:
            print(f"✅ Loaded graph '{dataset_name}' from ArangoDB")

    if not amazon_graph:  # If loading from ArangoDB failed, load from file
        print(f"Loading or downloading data for dataset: {dataset_name}...")
        if dataset_name == "metadata":
            gz_file = os.path.join(data_dir, "metadata.json.gz")
            csv_file = os.path.join(data_dir, "metadata.csv")
            if not os.path.exists(gz_file):
                download_file(amazon_datasets[dataset_name]["url"], gz_file)
            df = parse_amazon_metadata(gz_file)  # No max_edges for metadata
        elif dataset_name in amazon_datasets and amazon_datasets[dataset_name]["type"] == "copurchase":
            gz_file = os.path.join(data_dir, os.path.basename(amazon_datasets[dataset_name]["url"]))
            csv_file = gz_file.replace('.gz', '.csv')
            if not os.path.exists(gz_file):
                download_file(amazon_datasets[dataset_name]["url"], gz_file)
            df = parse_amazon_copurchase(gz_file, max_edges=100000)  # Limit edges for testing
        else:
            print(f"⚠️ Unknown dataset: {dataset_name}")
            return {
                "graph_summary": f"Error: Unknown dataset '{dataset_name}'.",
                "graph_visualization": "",
                "llm_response": "",
                "status": "Data loading error."
            }
        if df is not None:
            datasets[dataset_name] = {"data": df, "graph": None}
            if dataset_name != "metadata":
                amazon_graph = load_graph(df)
                if amazon_graph is not None:
                    datasets[dataset_name]["graph"] = amazon_graph
        else:
            datasets[dataset_name] = {"data": pd.DataFrame(), "graph": None}
            print(f"⚠️ Could not parse copurchase data from gz file. Co-purchase data will be unavailable.")
            return {
            "graph_summary": f"Error: Could not load data for '{dataset_name}'.",
            "graph_visualization": "",
            "llm_response": "",
            "status": "Data loading error."
            }
    # --- 2. ArangoDB Persistence (Conditional) ---
    if persist_to_db and db and amazon_graph:
        persist_graph_to_arangodb(db, amazon_graph, dataset_name)


    # --- 3. Graph Analysis ---
    graph_analysis = analyze_graph(amazon_graph)
    if graph_analysis is None:
        return {
            "graph_summary": "Error: Graph analysis failed.",
            "graph_visualization": "",
            "llm_response": "",
            "status": "Graph analysis error."
        }
    community_analysis = detect_communities(amazon_graph, graph_analysis)
    if community_analysis is None:
        return {
            "graph_summary": "Error: Community analysis failed.",
            "graph_visualization": graph_analysis.image if graph_analysis else "",
            "llm_response": "",
            "status": "Community detection error."
        }

    # --- 4. LLM Interaction ---
    llm_chain = setup_langchain_cohere(graph_analysis, community_analysis)
    llm_response = agentic_query(user_query, llm_chain, graph_analysis, community_analysis) if llm_chain and user_query else "Enter a query to get insights from the LLM."

    summary = (
        f"The graph has {graph_analysis.num_nodes} nodes and {graph_analysis.num_edges} edges.\n"
        f"Average degree: {graph_analysis.avg_degree:.2f}, Max degree: {graph_analysis.max_degree}.\n"
        f"Largest connected component size: {graph_analysis.largest_cc_size} "
        f"({graph_analysis.largest_cc_percentage:.2f}% of nodes).\n"
        f"Top nodes by degree: {graph_analysis.top_nodes_by_degree[:5]}"
    )

    if graph_analysis.num_nodes > max_nodes_to_display:
        summary += f"\n\nDisplaying a sample of up to {max_nodes_to_display} nodes."


    return {
        "graph_summary": summary,
        "graph_visualization": graph_analysis.image,
        "llm_response": llm_response,
        "status": "Graph analysis complete!",
    }
# --- Dataset Definitions ---
amazon_datasets = {
    "metadata": {
        "url": "http://snap.stanford.edu/data/amazon/productGraph/metadata.json.gz",
        "type": "metadata"
    },
    "amazon0302": {
        "url": "http://snap.stanford.edu/data/amazon0302.txt.gz",
        "type": "copurchase"
    },
    "amazon0312": {
        "url": "http://snap.stanford.edu/data/amazon0312.txt.gz",
        "type": "copurchase"
    },
    "amazon0505": {
        "url": "http://snap.stanford.edu/data/amazon0505.txt.gz",
        "type": "copurchase"
    },
    "amazon0601": {
        "url": "http://snap.stanford.edu/data/amazon0601.txt.gz",
        "type": "copurchase"
    },
}

# --- Gradio Interface Setup ---

if __name__ == "__main__":
    # Initialize datasets dictionary outside the main function
    datasets = {name: {"data": None, "graph": None} for name in amazon_datasets}

    with gr.Blocks() as iface:
      with gr.Row():
        with gr.Column():
          max_nodes_to_display = gr.Slider(minimum=100, maximum=10000, value=1000, step=100, label="Max Nodes to Display")
          dataset_dropdown = gr.Dropdown(choices=list(amazon_datasets.keys()), label="Select Dataset", value="amazon0601")
          user_query = gr.Textbox(label="Ask a question about the graph:")
          persist_to_db_checkbox = gr.Checkbox(label="Persist graph to ArangoDB", value=False)
          run_button = gr.Button("Analyze and Query")
        with gr.Column():
          graph_summary = gr.Textbox(label="Graph Summary")
          graph_visualization = gr.Image(label="Graph Visualization")
          llm_response = gr.Textbox(label="LLM Response")
          status = gr.Textbox(label="Status")

      run_button.click(
          fn=process_data,
          inputs=[max_nodes_to_display, dataset_dropdown, user_query, persist_to_db_checkbox],
          outputs=[graph_summary, graph_visualization, llm_response, status]
      )
    iface.launch(debug=True)

✅ gradio is already installed
✅ networkx is already installed
✅ matplotlib is already installed
✅ requests is already installed
✅ tqdm is already installed
✅ pandas is already installed
✅ langchain is already installed
Installing langchain-cohere...
✅ Successfully installed langchain-cohere
Installing python-dotenv...
✅ Successfully installed python-dotenv
Installing python-arango...
✅ Successfully installed python-arango
✅ community is already installed
Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://603f17e28829461900.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Fa

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7861 <> https://a6d30a5cb643c68471.gradio.live
Killing tunnel 127.0.0.1:7862 <> https://603f17e28829461900.gradio.live


In [54]:
!pip install -q streamlit pyngrok

In [53]:
import streamlit as st
import networkx as nx
import matplotlib.pyplot as plt
import io
import base64
import os
import requests
import gzip
from tqdm import tqdm
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
import pandas as pd
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_cohere import ChatCohere
import subprocess
import sys
import importlib
from arango import ArangoClient
import time
import json
from dotenv import load_dotenv

# --- Helper Function for Installation ---
def install_package(package_name):
    """Installs a package using pip."""
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package_name, "--break-system-packages"])
        print(f"✅ Successfully installed {package_name}")
        return True
    except subprocess.CalledProcessError:
        print(f"❌ Failed to install {package_name}.  You may need to install it manually.")
        return False
    except Exception as e:
        print(f"❌ Unexpected error installing {package_name}: {e}")
        return False

# --- Package Installation ---
required_packages = ["streamlit", "networkx", "matplotlib", "requests", "tqdm", "pandas",
                     "langchain", "langchain-cohere", "python-dotenv", "python-arango",
                     "community"]
for package in required_packages:
    try:
        importlib.import_module(package)
        print(f"✅ {package} is already installed")
    except ImportError:
        print(f"Installing {package}...")
        install_package(package)


if not COHERE_API_KEY:
    st.warning("COHERE_API_KEY not found. LLM features will be disabled.")

# --- Data Classes ---
@dataclass
class GraphAnalysis:
    num_nodes: int
    num_edges: int
    avg_degree: float
    max_degree: int
    top_nodes_by_degree: List[Tuple[str, int]]
    largest_cc_size: int
    largest_cc_percentage: float
    sample_subgraph: Optional[nx.DiGraph] = None
    image: str = ""

@dataclass
class CommunityAnalysis:
    algorithm: str
    num_communities: int
    community_sizes: List[int]
    top_communities: List[Tuple[int, List[str]]]
    node_communities: Dict[str, int]

# --- ArangoDB Setup and Loading Functions ---
def setup_arangodb(retries=5, delay=5):
    """Sets up ArangoDB connection and creates database/collections."""
    for attempt in range(retries):
        try:
            client = ArangoClient(hosts=ARANGO_HOST)
            db = None
            try:
                db = client.db(ARANGO_DB, username=ARANGO_USERNAME, password=ARANGO_PASSWORD)
                st.success(f"Connected to ArangoDB: {ARANGO_DB}")
            except Exception:
                st.warning(f"Initial connection to '{ARANGO_DB}' failed.  Trying _system...")
                sys_db = client.db('_system', username=ARANGO_USERNAME, password=ARANGO_PASSWORD)
                if not sys_db.has_database(ARANGO_DB):
                    sys_db.create_database(ARANGO_DB)
                    st.success(f"Created database: {ARANGO_DB}")
                db = client.db(ARANGO_DB, username=ARANGO_USERNAME, password=ARANGO_PASSWORD)

            if db and not db.has_graph(ARANGO_GRAPH_NAME):
                db.create_graph(ARANGO_GRAPH_NAME)
                st.success(f"Created graph: {ARANGO_GRAPH_NAME}")
            return db

        except ServerConnectionError as e:
            st.error(f"Attempt {attempt + 1}/{retries} failed: {e}")
            if attempt < retries - 1:
                time.sleep(delay)
        except Exception as e:
            st.error(f"Error connecting to/setting up ArangoDB: {e}")
            return None
    return None

def persist_graph_to_arangodb(db, graph: nx.DiGraph, graph_name: str):
    """Persists the NetworkX graph to ArangoDB."""
    if not db or not graph:
        st.warning("Cannot persist graph: Database/graph is None.")
        return False

    try:
        arango_graph = db.graph(graph_name) if db.has_graph(graph_name) else db.create_graph(graph_name)
        nodes_collection_name = f"{graph_name}_products"
        edges_collection_name = f"{graph_name}_copurchase"

        products = arango_graph.create_vertex_collection(nodes_collection_name) if not arango_graph.has_vertex_collection(nodes_collection_name) else arango_graph.vertex_collection(nodes_collection_name)
        copurchase = arango_graph.create_edge_definition(
            edge_collection=edges_collection_name,
            from_vertex_collections=[nodes_collection_name],
            to_vertex_collections=[nodes_collection_name]
        ) if not arango_graph.has_edge_definition(edges_collection_name) else arango_graph.edge_collection(edges_collection_name)

        batch_size = 1000
        nodes_list = list(graph.nodes(data=True))
        with st.spinner("Inserting nodes..."):
          for i in tqdm(range(0, len(nodes_list), batch_size), desc="Inserting nodes"):
            batch = nodes_list[i:i + batch_size]
            products.insert_many([{"_key": str(node), **data} for node, data in batch], overwrite_mode='update')

        edges_list = list(graph.edges(data=True))
        with st.spinner("Inserting edges..."):
          for i in tqdm(range(0, len(edges_list), batch_size), desc="Inserting edges"):
            batch = edges_list[i:i + batch_size]
            edges_batch = [
                {"_from": f"{nodes_collection_name}/{str(source)}", "_to": f"{nodes_collection_name}/{str(target)}", **data}
                for source, target, data in batch
            ]
            copurchase.insert_many(edges_batch, overwrite_mode='update')
        return True
    except Exception as e:
        st.error(f"Error persisting graph: {e}")
        return False

def load_graph_from_arangodb(db, graph_name: str) -> Optional[nx.DiGraph]:
    """Loads a graph from ArangoDB."""
    if not db:
        st.warning("Cannot load graph: Database is None.")
        return None
    try:
        graph = nx.DiGraph(db.graph(graph_name))
        st.success(f"Graph '{graph_name}' loaded: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")
        return graph
    except Exception as e:
        st.error(f"Error loading graph: {e}")
        return None

# --- Data Loading and Processing ---
def download_file(url, filename):
    """Downloads a file with a progress bar."""
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        with open(filename, 'wb') as file, st.spinner(f"Downloading {filename}..."):
            with tqdm(desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024) as bar:
                for data in response.iter_content(1024):
                    file.write(data)
                    bar.update(len(data))
        return filename
    except requests.exceptions.RequestException as e:
        st.error(f"Error downloading: {e}")
        return None

def parse_amazon_metadata(gz_file):
    """Parses Amazon metadata from a gzipped JSON file."""
    products = []
    try:
        with gzip.open(gz_file, 'rt', encoding='utf-8') as f:
          with st.spinner("Reading Metadata..."):
            for line in tqdm(f, desc="Reading metadata"):
                if line.strip():
                    try:
                        products.append(json.loads(line.strip()))
                    except json.JSONDecodeError as e:
                        st.warning(f"JSONDecodeError: {e} in line: {line.strip()}")
        df = pd.DataFrame(products)
        df['ASIN'] = df['ASIN'].astype(str)
        return df.to_csv(gz_file.replace('.gz', '.csv'), index=False)
    except Exception as e:
        st.error(f"Error parsing metadata: {e}")
        return None

def parse_amazon_copurchase(gz_file, max_edges=None):
    """Parses Amazon co-purchase data."""
    edges = []
    try:
        with gzip.open(gz_file, 'rt', encoding='latin1') as f:
            with st.spinner("Reading Copurchase data..."):
              for i, line in enumerate(tqdm(f, desc="Reading copurchase")):
                if not line.startswith('#'):
                    source, target = line.strip().split()
                    edges.append((source, target))
                    if max_edges and i >= max_edges:
                        break
        df = pd.DataFrame(edges, columns=['source', 'target'])
        return df.to_csv(gz_file.replace('.gz', '.csv'), index=False)
    except Exception as e:
        st.error(f"Error parsing copurchase: {e}")
        return None

def load_graph(copurchase_file: str) -> Optional[nx.DiGraph]:
    """Loads graph from copurchase CSV."""
    try:
        df = pd.read_csv(copurchase_file)
        graph = nx.DiGraph()
        with st.spinner("Creating Graph..."):
          for _, row in tqdm(df.iterrows(), total=len(df), desc="Adding edges"):
            graph.add_edge(str(row['source']), str(row['target']))
        st.success(f"Graph created: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")
        return graph
    except Exception as e:
        st.error(f"Error loading graph: {e}")
        return None

# --- Graph Analysis and Community Detection ---
def visualize_graph(graph: Optional[nx.DiGraph]) -> str:
    """Visualizes a sample of the graph."""
    if not graph: return ""
    if graph.number_of_nodes() > 100:
        top_nodes = sorted(graph.degree(), key=lambda x: x[1], reverse=True)[:10]
        if not top_nodes: return ""
        seed_node = top_nodes[0][0]
        sample_nodes, frontier = {seed_node}, {seed_node}
        while len(sample_nodes) < 100 and frontier:
            new_frontier = set()
            for node in frontier:
                neighbors = set(graph.neighbors(node)) - sample_nodes
                selected = list(neighbors)[:5]
                sample_nodes.update(selected)
                new_frontier.update(selected)
            frontier = new_frontier
        graph = graph.subgraph(list(sample_nodes))

    plt.figure(figsize=(12, 6))
    nx.draw(graph, with_labels=True, node_size=400, font_size=9, alpha=0.7)
    plt.title("Graph Sample")
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    return base64.b64encode(buf.getvalue()).decode('utf-8')

def analyze_graph(graph: Optional[nx.DiGraph]) -> Optional[GraphAnalysis]:
    """Analyzes the graph."""
    if not graph: return None
    try:
        degrees = [d for _, d in graph.degree()]
        analysis = {
            "num_nodes": graph.number_of_nodes(),
            "num_edges": graph.number_of_edges(),
            "avg_degree": sum(degrees) / len(degrees) if degrees else 0.0,
            "max_degree": max(degrees) if degrees else 0,
            "top_nodes_by_degree": sorted(graph.degree(), key=lambda x: x[1], reverse=True)[:10],
            "largest_cc_size": 0,
            "largest_cc_percentage": 0.0,
            "image": visualize_graph(graph)
        }
        if cc := list(nx.weakly_connected_components(graph)):
            largest_cc = max(cc, key=len)
            analysis["largest_cc_size"] = len(largest_cc)
            analysis["largest_cc_percentage"] = (len(largest_cc) / graph.number_of_nodes()) * 100
        return GraphAnalysis(**analysis)
    except Exception as e:
        st.error(f"Error during graph analysis: {e}")
        return None

def detect_communities(G: nx.DiGraph, max_nodes: int = 5000) -> Optional[CommunityAnalysis]:
    """Detects communities (Louvain or connected components)."""
    if not G: return None
    try:
        subgraph = G.subgraph(list(G.nodes())[:max_nodes]) if G.number_of_nodes() > max_nodes else G
        undirected_G = subgraph.to_undirected()

        try:
            import community as community_louvain
            partition = community_louvain.best_partition(undirected_G)
            communities = {}
            for node, cid in partition.items():
                communities.setdefault(cid, []).append(node)
            sorted_communities = sorted(communities.items(), key=lambda x: len(x[1]), reverse=True)
            return CommunityAnalysis(
                algorithm="louvain", num_communities=len(communities),
                community_sizes=[len(c) for _, c in sorted_communities[:10]],
                top_communities=sorted_communities[:5], node_communities=partition
            )
        except ImportError:
            st.warning("python-louvain not installed. Using connected components.")
            components = list(nx.connected_components(undirected_G))
            sorted_components = sorted(components, key=len, reverse=True)
            return CommunityAnalysis(
                algorithm="connected_components", num_communities=len(components),
                community_sizes=[len(c) for c in sorted_components[:10]],
                top_communities=[(i, list(c)) for i, c in enumerate(sorted_components[:5])],
                node_communities={n: i for i, comp in enumerate(components) for n in comp}
            )
    except Exception as e:
        st.error(f"Error during community detection: {e}")
        return None

# --- LangChain Setup ---
def setup_langchain_cohere(graph_analysis: Optional[GraphAnalysis], community_analysis: Optional[CommunityAnalysis]) -> Optional[LLMChain]:
    """Sets up LangChain with Cohere."""
    if not COHERE_API_KEY or not graph_analysis or not community_analysis:
        st.warning("Cannot initialize LLM: Missing key or analysis.")
        return None
    try:
        template = """Based on the network analysis:
Graph: {num_nodes} nodes, {num_edges} edges. Avg degree: {avg_degree:.2f}, Max: {max_degree}.
Top nodes: {top_nodes_by_degree}. Communities: {num_communities}, Sizes: {community_sizes}.
Query: {query}  Answer:"""
        prompt = PromptTemplate(template=template, input_variables=["query", "num_nodes", "num_edges", "avg_degree",
                                                                "max_degree", "num_communities", "community_sizes",
                                                                "top_nodes_by_degree"])
        return LLMChain(llm=ChatCohere(cohere_api_key=COHERE_API_KEY, model="command"), prompt=prompt)
    except Exception as e:
        st.error(f"Error setting up LangChain: {e}")
        return None

def agentic_query(query: str, llm_chain: Optional[LLMChain], graph_analysis: Optional[GraphAnalysis], community_analysis: Optional[CommunityAnalysis]) -> str:
    """Queries the LLM."""
    if not llm_chain: return "LLM not available."
    if not graph_analysis or not community_analysis: return "Analysis missing."
    try:
        input_data = {
            "query": query, "num_nodes": graph_analysis.num_nodes, "num_edges": graph_analysis.num_edges,
            "avg_degree": graph_analysis.avg_degree, "max_degree": graph_analysis.max_degree,
            "top_nodes_by_degree": graph_analysis.top_nodes_by_degree,
            "num_communities": community_analysis.num_communities, "community_sizes": community_analysis.community_sizes
        }
        return llm_chain.invoke(input_data)['text']
    except Exception as e:
        return f"Error during query: {e}"

# --- Data Processing and Streamlit Interface ---
def process_data(max_nodes_to_display: int, dataset_name: str, user_query: str, persist_to_db: bool) -> Dict[str, str]:
    """Main data processing function."""
    data_dir = "amazon_data"
    os.makedirs(data_dir, exist_ok=True)
    db = setup_arangodb()
    amazon_graph = load_graph_from_arangodb(db, dataset_name) if db else None

    if not amazon_graph:
        if dataset_name == "metadata":
            gz_file = os.path.join(data_dir, "metadata.json.gz")
            csv_file = os.path.join(data_dir, "metadata.csv")
            download_file(amazon_datasets[dataset_name]["url"], gz_file) if not os.path.exists(gz_file) else None
            parse_amazon_metadata(gz_file)
        elif dataset_name in amazon_datasets:
            gz_file = os.path.join(data_dir, os.path.basename(amazon_datasets[dataset_name]["url"]))
            csv_file = gz_file.replace('.gz', '.csv')
            download_file(amazon_datasets[dataset_name]["url"], gz_file) if not os.path.exists(gz_file) else None
            parse_amazon_copurchase(gz_file, max_edges=100000)
            amazon_graph = load_graph(csv_file)
        else:
            return {"graph_summary": f"Error: Unknown dataset '{dataset_name}'.", "image": "", "llm_response": "", "status": "Data loading error."}
    if persist_to_db and db and amazon_graph:
        persist_graph_to_arangodb(db, amazon_graph, dataset_name)

    graph_analysis = analyze_graph(amazon_graph)
    if not graph_analysis:
        return {"graph_summary": "Error: Graph analysis failed.", "image": "", "llm_response": "", "status": "Analysis error."}
    community_analysis = detect_communities(amazon_graph, graph_analysis)
    if not community_analysis:
         return {"graph_summary": "Error: Community analysis failed.", "image": graph_analysis.image, "llm_response": "", "status": "Analysis error."}

    llm_chain = setup_langchain_cohere(graph_analysis, community_analysis)
    llm_response = agentic_query(user_query, llm_chain, graph_analysis, community_analysis) if llm_chain and user_query else "Enter a query."
    summary = (f"Nodes: {graph_analysis.num_nodes}, Edges: {graph_analysis.num_edges}.  Avg degree: {graph_analysis.avg_degree:.2f}, "
               f"Max degree: {graph_analysis.max_degree}. Largest CC: {graph_analysis.largest_cc_size} ({graph_analysis.largest_cc_percentage:.2f}%). "
               f"Top nodes: {graph_analysis.top_nodes_by_degree[:5]}")
    if graph_analysis.num_nodes > max_nodes_to_display:
        summary += f"\nDisplaying sample (up to {max_nodes_to_display} nodes)."

    return {"graph_summary": summary, "image": graph_analysis.image, "llm_response": llm_response, "status": "Complete!"}

amazon_datasets = {
    "metadata": {"url": "http://snap.stanford.edu/data/amazon/productGraph/metadata.json.gz", "type": "metadata"},
    "amazon0302": {"url": "http://snap.stanford.edu/data/amazon0302.txt.gz", "type": "copurchase"},
    "amazon0312": {"url": "http://snap.stanford.edu/data/amazon0312.txt.gz", "type": "copurchase"},
    "amazon0505": {"url": "http://snap.stanford.edu/data/amazon0505.txt.gz", "type": "copurchase"},
    "amazon0601": {"url": "http://snap.stanford.edu/data/amazon0601.txt.gz", "type": "copurchase"},
}

if __name__ == "__main__":
    st.set_page_config(layout="wide")  # Use wide layout
    st.title("Amazon Co-purchase Network Analysis")

    with st.sidebar:
        st.header("Settings")
        max_nodes_to_display = st.slider("Max Nodes to Display", 100, 10000, value=1000, step=100)
        dataset_name = st.selectbox("Select Dataset", list(amazon_datasets.keys()), index=list(amazon_datasets.keys()).index("amazon0601"))  # Correct default selection
        user_query = st.text_input("Ask a question about the graph:")
        persist_to_db = st.checkbox("Persist graph to ArangoDB", value=False)
        if st.button("Analyze"):
            with st.spinner("Analyzing..."):
                results = process_data(max_nodes_to_display, dataset_name, user_query, persist_to_db)
            st.session_state['results'] = results  # Store results in session state

    # Display results (if available)
    if 'results' in st.session_state:
        results = st.session_state['results']
        col1, col2 = st.columns(2)  # Two columns for better layout

        with col1:
            st.subheader("Graph Summary")
            st.text(results['graph_summary'])

            st.subheader("LLM Response")
            st.text(results['llm_response'])
            st.write(f"Status: {results['status']}")


        with col2:
            st.subheader("Graph Visualization")
            st.image("data:image/png;base64," + results['image'])

✅ streamlit is already installed
✅ networkx is already installed
✅ matplotlib is already installed
✅ requests is already installed
✅ tqdm is already installed
✅ pandas is already installed
✅ langchain is already installed
Installing langchain-cohere...
✅ Successfully installed langchain-cohere
Installing python-dotenv...
✅ Successfully installed python-dotenv
Installing python-arango...




✅ Successfully installed python-arango
✅ community is already installed
