<a href="https://colab.research.google.com/github/selgebali/Colabs/blob/main/PIDnetwork_graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PID Network Graph Visualization Script

## Overview
This script generates a **Persistent Identifier (PID) Network Graph**, visualizing the relationships and connections between entities like publications, datasets, software, people, organizations, and funders. Data is retrieved from the **DataCite GraphQL API**, processed into structured formats, and displayed as an interactive network graph using Plotly and NetworkX.

## Features
- **Data Fetching:** Retrieves entity and relationship data from the DataCite GraphQL API.
- **Node and Edge Construction:** Processes nodes and edges with attributes such as counts and weights.
- **Visualization:** Generates customizable, interactive graphs with annotated nodes and curved edges.
- **Output:** Saves the graph as an interactive HTML file.

## Prerequisites
### Required Libraries
Ensure the following Python libraries are installed:

- `requests`: For API requests.
- `pandas`: For data manipulation.
- `plotly`: For interactive graph visualization.
- `networkx`: For graph construction and layout calculations.
- `numpy`: For numerical calculations.

Install the dependencies using pip:
```bash
pip install requests pandas plotly networkx numpy
```

## Script Details
### 1. **Fetch Data**
The script sends a GraphQL query to the **DataCite GraphQL API** to fetch entity counts and relationships.

#### GraphQL Query
```graphql
{
  publications {
    totalCount
    datasetConnectionCount
    softwareConnectionCount
    personConnectionCount
    organizationConnectionCount
    funderConnectionCount
  }
  datasets {
    totalCount
    softwareConnectionCount
    personConnectionCount
    organizationConnectionCount
    funderConnectionCount
  }
  softwares {
    totalCount
    personConnectionCount
    organizationConnectionCount
    funderConnectionCount
  }
  people(query: "*") {
    totalCount
    organizationConnectionCount
  }
  organizations {
    totalCount
  }
  funders {
    totalCount
  }
}
```

### 2. **Construct DataFrames**
- **Nodes DataFrame**: Contains entity counts for publications, datasets, software, people, organizations, and funders.
- **Edges DataFrame**: Captures relationships and connection counts between entities.

### 3. **Build Graph**
A **NetworkX graph** is built using:
- **Nodes:** Attributes include `size` (entity count) and `color` (based on type).
- **Edges:** Attributes include `weight` (connection count) and curvature for better visualization.

### 4. **Visualize Graph**
- The graph is visualized using **Plotly**, with options for:
  - Curved edges to represent relationships.
  - Scaled node sizes based on entity counts.
  - Edge annotations to display connection counts.

### 5. **Save and Display**
- The graph is displayed interactively and saved as an HTML file (`graph.html`).

## Example Workflow

### Fetch Data
```python
url = "https://api.datacite.org/graphql"
response = requests.post(url, json={'query': query})
data = response.json()
```

### Process Data
```python
# Create Nodes and Edges DataFrames
nodes = pd.DataFrame([...])  # Populate nodes with counts
edges = pd.DataFrame([...])  # Populate edges with connections
```

### Build NetworkX Graph
```python
G = nx.Graph()

# Add nodes and edges with attributes
for node in nodes.itertuples():
    G.add_node(node.id, label=node.id, size=int(node.count), color=color_map.get(node.id, '#7f7f7f'))

for edge in edges.itertuples():
    G.add_edge(edge.source, edge.target, weight=int(edge.count))
```

### Visualize Graph
```python
fig = go.Figure(data=edge_traces + [node_trace], layout=go.Layout(...))
fig.show()
```

### Save Graph
```python
pio.write_html(fig, file='graph.html', auto_open=True)
```

## Customization
- **Layouts:** Experiment with `spring`, `circular`, or manually defined layouts for nodes.
- **Colors:** Modify `custom_colors` to update node types.
- **Edge Styles:** Adjust edge curvature and annotations for better visibility.

## Output
- **Interactive Graph:** Displays relationships between entities interactively.
- **HTML File:** Saves the graph as `graph.html` for easy sharing.

## Debugging and Logs
- The script includes status and error logs for GraphQL requests and graph construction.

## Contributing
Contributions are welcome! Submit issues or pull requests to enhance functionality or address bugs.

In [None]:
# Install necessary libraries
!pip install requests plotly pandas networkx

# Import necessary libraries
import requests
import pandas as pd
import plotly.graph_objects as go
import networkx as nx




In [None]:
import pandas as pd
import requests

# Define the GraphQL query
query = '''
{
  publications {
    totalCount
    datasetConnectionCount
    softwareConnectionCount
    personConnectionCount
    organizationConnectionCount
    funderConnectionCount
  }
  datasets {
    totalCount
    softwareConnectionCount
    personConnectionCount
    organizationConnectionCount
    funderConnectionCount
  }
  softwares {
    totalCount
    personConnectionCount
    organizationConnectionCount
    funderConnectionCount
  }
  people(query: "*") {
    totalCount
    organizationConnectionCount
  }
  organizations {
    totalCount
  }
  funders {
    totalCount
  }
}
'''

# Send the request to the DataCite GraphQL API
url = "https://api.datacite.org/graphql"
response = requests.post(url, json={'query': query})

# Check for errors in the response
if response.status_code != 200:
    print("Error:", response.status_code, response.text)
else:
    data = response.json()
    if "errors" in data:
        print("GraphQL Errors:", data["errors"])
    else:
        # Safely extract data
        publications = {'id': 'Publication', 'count': data.get('data', {}).get('publications', {}).get('totalCount', 0)}
        datasets = {'id': 'Dataset', 'count': data.get('data', {}).get('datasets', {}).get('totalCount', 0)}
        softwares = {'id': 'Software', 'count': data.get('data', {}).get('softwares', {}).get('totalCount', 0)}
        people = {'id': 'Person', 'count': data.get('data', {}).get('people', {}).get('totalCount', 0)}
        organizations = {'id': 'Organization', 'count': data.get('data', {}).get('organizations', {}).get('totalCount', 0)}
        funders = {'id': 'Funder', 'count': data.get('data', {}).get('funders', {}).get('totalCount', 0)}

        # Create DataFrames for nodes and edges
        nodes = pd.DataFrame([publications, datasets, softwares, people, organizations, funders])

        edges = pd.DataFrame([
            {'source': 'Publication', 'target': 'Dataset', 'count': data.get('data', {}).get('publications', {}).get('datasetConnectionCount', 0)},
            {'source': 'Publication', 'target': 'Software', 'count': data.get('data', {}).get('publications', {}).get('softwareConnectionCount', 0)},
            {'source': 'Publication', 'target': 'Person', 'count': data.get('data', {}).get('publications', {}).get('personConnectionCount', 0)},
            {'source': 'Publication', 'target': 'Organization', 'count': data.get('data', {}).get('publications', {}).get('organizationConnectionCount', 0)},
            {'source': 'Publication', 'target': 'Funder', 'count': data.get('data', {}).get('publications', {}).get('funderConnectionCount', 0)},
            {'source': 'Dataset', 'target': 'Software', 'count': data.get('data', {}).get('datasets', {}).get('softwareConnectionCount', 0)},
            {'source': 'Dataset', 'target': 'Person', 'count': data.get('data', {}).get('datasets', {}).get('personConnectionCount', 0)},
            {'source': 'Dataset', 'target': 'Organization', 'count': data.get('data', {}).get('datasets', {}).get('organizationConnectionCount', 0)},
            {'source': 'Dataset', 'target': 'Funder', 'count': data.get('data', {}).get('datasets', {}).get('funderConnectionCount', 0)},
            {'source': 'Software', 'target': 'Person', 'count': data.get('data', {}).get('softwares', {}).get('personConnectionCount', 0)},
            {'source': 'Software', 'target': 'Organization', 'count': data.get('data', {}).get('softwares', {}).get('organizationConnectionCount', 0)},
            {'source': 'Software', 'target': 'Funder', 'count': data.get('data', {}).get('softwares', {}).get('funderConnectionCount', 0)},
            {'source': 'Person', 'target': 'Organization', 'count': data.get('data', {}).get('people', {}).get('organizationConnectionCount', 0)}
        ])

        # Display DataFrames
        print(nodes)
        print(edges)

             id     count
0   Publication  24997797
1       Dataset  32547012
2      Software    599789
3        Person  22303204
4  Organization    109806
5        Funder     44529
         source        target     count
0   Publication       Dataset  27731533
1   Publication      Software     50907
2   Publication        Person   7911740
3   Publication  Organization    331244
4   Publication        Funder    161510
5       Dataset      Software     17443
6       Dataset        Person   2653970
7       Dataset  Organization    318532
8       Dataset        Funder    383383
9      Software        Person    172097
10     Software  Organization     23791
11     Software        Funder     29590
12       Person  Organization    131616


In [None]:
print(edges.head())

        source        target     count
0  Publication       Dataset  27731533
1  Publication      Software     50907
2  Publication        Person   7911740
3  Publication  Organization    331244
4  Publication        Funder    161510


In [None]:
print(nodes.head())

             id     count
0   Publication  24997797
1       Dataset  32547012
2      Software    599789
3        Person  22303204
4  Organization    109806


In [None]:
# @title
# Define your custom color palette
custom_colors = ['#243B54', '#00B1E2', '#5B88B9', '#46BCAB', '#90D7CD', '#BC2B66']

# Map node types to colors using your custom palette
node_types = ['Publication', 'Dataset', 'Software', 'Person', 'Organization', 'Funder']
color_map = dict(zip(node_types, custom_colors))

# Create a NetworkX graph
G = nx.Graph()

# Add nodes with attributes
for node in nodes.itertuples():
    G.add_node(node.id, label=node.id, size=int(node.count), color=color_map.get(node.id, '#7f7f7f'))

# Add edges with attributes
for edge in edges.itertuples():
    G.add_edge(edge.source, edge.target, weight=int(edge.count))

# Start with spring_layout
pos = nx.spring_layout(G, k=2, iterations=300)

# Manually define positions for each node
pos = {
    'Publication': (0.5, 3.2),
    'Dataset': (1, 5),
    'Software': (1.4, 2.5),
    'Person': (0.5, -0.2),
    'Organization': (0.9, -0.6),
    'Funder': (1.2, -0.50)
}


# Extract node positions, labels, and properties
node_x, node_y = zip(*[pos[node[0]] for node in G.nodes(data=True)])
node_text = [f"{node[0]} ({node[1]['size']})" for node in G.nodes(data=True)]
node_hovertext = [f"{node[0]}<br>Count: {node[1]['size']}" for node in G.nodes(data=True)]
node_size = [node[1]['size'] for node in G.nodes(data=True)]
node_color = [node[1]['color'] for node in G.nodes(data=True)]

# Normalize node sizes for better distinction
max_size = max(node_size)

node_size = [((size / max_size) * 60) + 150 for size in node_size]  # Scale node size between 30 and 90

node_text = [f"{node[0]}<br>{node[1]['size']}" for node in G.nodes(data=True)]  # Display "Funder" and "5" on two lines

# Create node traces with text labels displaying counts
node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode='markers+text',  # Display text (node labels + counts) inside the nodes
    text=node_text,  # Display both the label and count directly inside nodes
    hovertext=node_hovertext,  # Hover info
    hoverinfo='text',  # Ensure hover info is set to 'text'
    marker=dict(
        showscale=False,
        color=node_color,
        size=node_size,
        line_width=2,
        opacity=1.0),
    textposition='middle center',  # Position the text in the center of the nodes
    textfont=dict(
        size=18,  # Set the font size
        family='Arial',  # Set the font family
        color='White',  # Set the font color
        weight='bold'  # Set font weight to bold
    )
)

# Extract edge positions and hover text
edge_traces = []
edge_weights = [e[2]['weight'] for e in G.edges(data=True)]
max_weight = max(edge_weights)
min_weight = min(edge_weights)
edge_hovertext = []  # Hover text for edges

for edge in G.edges(data=True):
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    weight = edge[2]['weight']

    # Adjust scaling factor for edge thickness
    if max_weight != min_weight:
        width = ((weight - min_weight) / (max_weight - min_weight)) * 20 + 12  # Scale between 2 and 22
    else:
        width = 2  # Default width if all weights are the same

    # Use a more visible color for edges
    edge_color = '#888'

    # Add hovertext for edges with connection counts
    edge_hovertext.append(f"{edge[0]} - {edge[1]}<br>Connections: {weight}")

    # Create edge trace
    edge_traces.append(go.Scatter(
    x=[x0, x1],
    y=[y0, y1],
    line=dict(width=width, color=edge_color),
    hoverinfo='text',  # Ensure hover info is set to 'text'
    hovertext=edge_hovertext[-1],  # Correctly add hover text
    mode='lines'
))
###
import numpy as np

# Optionally: Add edge text annotations (display counts on edges)
edge_annotations = []
for edge in G.edges(data=True):
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    mid_x = (x0 + x1) / 2
    mid_y = (y0 + y1) / 2
    weight = edge[2]['weight']

    # Add text annotation at midpoint of edge, rotated along the edge
    edge_annotations.append(dict(
        x=mid_x,
        y=mid_y,
        text=f"{weight}",
        showarrow=False,
        font=dict(size=16, weight='bold'),
        xanchor='center',
        yanchor='middle',
        #textangle=text_angle,  # Set the text angle based on edge direction
        align="center"
    ))


# Create the figure
fig = go.Figure(data=edge_traces + [node_trace],
             layout=go.Layout(
                title='<br>PID Graph: Number of Nodes and Connections',
                titlefont=dict(size=20),
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                annotations=edge_annotations,  # Add edge annotations for edge weights
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                width=1500,  # Set figure width
                height=1000,  # Set figure height
                plot_bgcolor='white',  # Set plot background color to white
                paper_bgcolor='white'  # Set overall figure background color to white
                ))

# Show the figure
fig.show()


In [None]:
import gspread
import pandas as pd
import plotly.graph_objs as go
import networkx as nx
import numpy as np

# Custom color palette
custom_colors = ['#243B54', '#00B1E2', '#5B88B9', '#46BCAB', '#90D7CD', '#BC2B66']

# Node types and colors
node_types = ['Publication', 'Dataset', 'Software', 'Person', 'Organization', 'Funder']
color_map = dict(zip(node_types, custom_colors))

# Create a NetworkX graph
G = nx.Graph()

# Add nodes with attributes (assuming 'nodes' is a DataFrame)
for node in nodes.itertuples():
    G.add_node(node.id, label=node.id, size=int(node.count), color=color_map.get(node.id, '#7f7f7f'))

# Add edges with attributes (assuming 'edges' is a DataFrame)
for edge in edges.itertuples():
    G.add_edge(edge.source, edge.target, weight=int(edge.count))

# Define node positions manually
pos = {
    'Publication': (0.5, 3.2),
    'Dataset': (1, 5),
    'Software': (1.4, 2.5),
    'Person': (0.5, -0.2),
    'Organization': (0.9, -0.6),
    'Funder': (1.2, -0.50)
}

# Extract node properties
node_x, node_y = zip(*[pos[node[0]] for node in G.nodes(data=True)])
node_size = [G.nodes[node]['size'] for node in G.nodes]
node_color = [G.nodes[node]['color'] for node in G.nodes]
node_text = [f"{node} <br>({G.nodes[node]['size']})" for node in G.nodes]
node_hovertext = [f"{node}<br>Count: {G.nodes[node]['size']}" for node in G.nodes]

# Normalize node sizes
max_size = max(node_size)
node_size = [((size / max_size) * 60) + 150 for size in node_size]

# Create node traces with text labels
node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode='markers+text',
    text=node_text,
    hovertext=node_hovertext,
    hoverinfo='text',
    marker=dict(
        showscale=False,
        color=node_color,
        size=node_size,
        line_width=2,
        opacity=1.0
    ),
    textposition='middle center',
    textfont=dict(
        size=18,
        family='Arial',
        color='White',
        weight='bold'
    )
)

# Create edge traces with curved lines based on edge weight
edge_traces = []
edge_hovertext = []  # Hover text for edges
edge_weights = [e[2]['weight'] for e in G.edges(data=True)]
max_weight = max(edge_weights)
min_weight = min(edge_weights)

for edge in G.edges(data=True):
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    weight = edge[2]['weight']

    # Adjust scaling factor for edge thickness
    if max_weight != min_weight:
        width = ((weight - min_weight) / (max_weight - min_weight)) * 20 + 12  # Scale between 12 and 32
    else:
        width = 2  # Default width if all weights are the same

    # Generate slight curve for edges (using quadratic Bezier curve approximation)
    # Generate a more prominent curve for edges (using quadratic Bezier curve approximation)
    t = np.linspace(0, 1, 100)
    x_mid = (x0 + x1) / 2  # Midpoint for the x-axis
    y_mid = (y0 + y1) / 2 + 0.9  # Increase curve height by using a larger value (e.g., 0.6 instead of 0.2)
    x_values = (1 - t) ** 2 * x0 + 2 * (1 - t) * t * x_mid + t ** 2 * x1
    y_values = (1 - t) ** 2 * y0 + 2 * (1 - t) * t * y_mid + t ** 2 * y1
    # Use a more visible color for edges
    edge_color = '#888'

    # Create edge trace
    edge_trace = go.Scatter(
        x=x_values,
        y=y_values,
        line=dict(width=width, color=edge_color),
        hoverinfo='text',
        hovertext=f"{edge[0]} - {edge[1]}<br>Connections: {weight}",
        mode='lines'
    )
    edge_traces.append(edge_trace)

# Optionally: Add edge text annotations (display counts on edges)
edge_annotations = []
for edge in G.edges(data=True):
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    mid_x = (x0 + x1) / 2
    mid_y = (y0 + y1) / 2
    weight = edge[2]['weight']

    # Add text annotation at midpoint of edge
    edge_annotations.append(dict(
        x=mid_x, y=mid_y, text=f"{weight}",
        showarrow=False, font=dict(size=16, weight='bold'), align="center"
    ))

# Create the figure
fig = go.Figure(data=edge_traces + [node_trace],
                layout=go.Layout(
                    title=dict(
                        text='PID Graph: Number of Nodes and Connections',
                        font=dict(size=20),
                        x=0.5, y=0.98,  # Center-align title
                        xanchor='center', yanchor='top'
                    ),
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20, l=5, r=5, t=40),
                    annotations=edge_annotations,  # Add edge annotations for weights
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    width=1500, height=1000,
                    plot_bgcolor='white',
                    paper_bgcolor='white'
                ))

# Show the figure
fig.show()

In [None]:
import gspread
import pandas as pd
import plotly.graph_objs as go
import networkx as nx
import numpy as np

# Custom color palette
custom_colors = ['#243B54', '#00B1E2', '#5B88B9', '#46BCAB', '#90D7CD', '#BC2B66']

# Node types and colors
node_types = ['Publication', 'Dataset', 'Software', 'Person', 'Organization', 'Funder']
color_map = dict(zip(node_types, custom_colors))

# Create a NetworkX graph
G = nx.Graph()

# Add nodes with attributes (assuming 'nodes' is a DataFrame)
for node in nodes.itertuples():
    G.add_node(node.id, label=node.id, size=int(node.count), color=color_map.get(node.id, '#7f7f7f'))

# Add edges with attributes (assuming 'edges' is a DataFrame)
for edge in edges.itertuples():
    G.add_edge(edge.source, edge.target, weight=int(edge.count))

# Generate circular layout positions
pos = nx.circular_layout(G)
scale_factor = 0.6  # Reduce the radius of the circular layout
pos = {node: (coords[0] * scale_factor, coords[1] * scale_factor) for node, coords in pos.items()}

# Extract node properties
node_x, node_y = zip(*[pos[node] for node in G.nodes])
node_size = [G.nodes[node]['size'] for node in G.nodes]
node_color = [G.nodes[node]['color'] for node in G.nodes]
node_text = [f"{node} <br>({G.nodes[node]['size']})" for node in G.nodes]
node_hovertext = [f"{node}<br>Count: {G.nodes[node]['size']}" for node in G.nodes]

# Normalize node sizes
max_size = max(node_size)
node_size = [((size / max_size) * 60) + 150 for size in node_size]

# Create node traces with text labels
node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode='markers+text',
    text=node_text,
    hovertext=node_hovertext,
    hoverinfo='text',
    marker=dict(
        showscale=False,
        color=node_color,
        size=node_size,
        line_width=2,
        opacity=1.0
    ),
    textposition='middle center',
    textfont=dict(
        size=18,
        family='Arial',
        color='White',
        weight='bold'
    )
)

# Create edge traces with curved lines based on edge weight
edge_traces = []
edge_weights = [e[2]['weight'] for e in G.edges(data=True)]
max_weight = max(edge_weights)
min_weight = min(edge_weights)

for edge in G.edges(data=True):
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    weight = edge[2]['weight']

    # Adjust scaling factor for edge thickness
    if max_weight != min_weight:
        width = ((weight - min_weight) / (max_weight - min_weight)) * 20 + 15  # Scale between 12 and 32
    else:
        width = 2  # Default width if all weights are the same

    # Generate slight curve for edges (using quadratic Bezier curve approximation)
    t = np.linspace(0, 1, 100)
    x_mid = (x0 + x1) / 2
    y_mid = (y0 + y1) / 2 + 0.2  # Adjust curve height as needed
    x_values = (1 - t) ** 2 * x0 + 2 * (1 - t) * t * x_mid + t ** 2 * x1
    y_values = (1 - t) ** 2 * y0 + 2 * (1 - t) * t * y_mid + t ** 2 * y1
    edge_color = '#888'

    # Create edge trace
    edge_trace = go.Scatter(
        x=x_values,
        y=y_values,
        line=dict(width=width, color=edge_color),
        hoverinfo='text',
        hovertext=f"{edge[0]} - {edge[1]}<br>Connections: {weight}",
        mode='lines'
    )
    edge_traces.append(edge_trace)

# Create edge text annotations with offsets for better legibility
edge_annotations = []
for edge in G.edges(data=True):
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]

    # Calculate midpoint for text placement
    x_mid = (x0 + x1) / 2
    y_mid = (y0 + y1) / 2

    # Offset text perpendicular to the edge for better legibility
    offset_x = (y1 - y0) * 0.1  # Small perpendicular offset
    offset_y = (x0 - x1) * 0.1  # Small perpendicular offset

    weight = edge[2]['weight']

    # Add text annotation with offset
    edge_annotations.append(dict(
        x=x_mid + offset_x,
        y=y_mid + offset_y,
        text=f"{weight}",
        showarrow=False,
        font=dict(size=16, color='black', family='Arial', weight='bold'),
        align="center",
        bgcolor="rgba(255, 255, 255, 0.7)"  # Semi-transparent background
    ))

# Create the figure
fig = go.Figure(data=edge_traces + [node_trace],
                layout=go.Layout(
                    title=dict(
                        text='PID Graph: Number of Nodes and Connections',
                        font=dict(size=20),
                        x=0.5, y=0.98,  # Center-align title
                        xanchor='center', yanchor='top'
                    ),
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20, l=5, r=5, t=40),
                    annotations=edge_annotations,  # Add edge annotations for weights
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    width=1500, height=1000,
                    plot_bgcolor='white',
                    paper_bgcolor='white'
                ))

# Show the figure
fig.show()

In [None]:
import plotly.io as pio

pio.write_html(fig, file='graph.html', auto_open=True)

In [None]:
import gspread
import pandas as pd
import plotly.graph_objs as go
import networkx as nx
import numpy as np

# Custom color palette
custom_colors = ['#243B54', '#00B1E2', '#5B88B9', '#46BCAB', '#90D7CD', '#BC2B66']

# Node types and colors
node_types = ['Publication', 'Dataset', 'Software', 'Person', 'Organization', 'Funder']
color_map = dict(zip(node_types, custom_colors))

# Create a NetworkX graph
G = nx.Graph()

# Add nodes with attributes (assuming 'nodes' is a DataFrame)
for node in nodes.itertuples():
    G.add_node(node.id, label=node.id, size=int(node.count), color=color_map.get(node.id, '#7f7f7f'))

# Add edges with attributes (assuming 'edges' is a DataFrame)
for edge in edges.itertuples():
    G.add_edge(edge.source, edge.target, weight=int(edge.count))

# Define arc positions: nodes are placed along the x-axis at equal intervals
node_order = list(G.nodes())  # Get node order
n_nodes = len(node_order)
x_positions = np.linspace(0, 1, n_nodes)  # Equally spaced positions between 0 and 1
y_positions = [0] * n_nodes  # All nodes are placed along y=0

# Assign positions to each node in the graph
pos = {node: (x, y) for node, x, y in zip(node_order, x_positions, y_positions)}

# Extract node properties
node_x, node_y = zip(*[pos[node] for node in G.nodes])
node_size = [G.nodes[node]['size'] for node in G.nodes]
node_color = [G.nodes[node]['color'] for node in G.nodes]
node_text = [f"{node} <br>({G.nodes[node]['size']})" for node in G.nodes]
node_hovertext = [f"{node}<br>Count: {G.nodes[node]['size']}" for node in G.nodes]

# Normalize node sizes
max_size = max(node_size)
node_size = [((size / max_size) * 60) + 150 for size in node_size]

# Create node traces with text labels
node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode='markers+text',
    text=node_text,
    hovertext=node_hovertext,
    hoverinfo='text',
    marker=dict(
        showscale=False,
        color=node_color,
        size=node_size,
        line_width=2,
        opacity=1.0
    ),
    textposition='middle center',
    textfont=dict(
        size=18,
        family='Arial',
        color='White',
        weight='bold'
    )
)

# Create edge traces with arc curvature based on edge weight
edge_traces = []
edge_annotations = []  # Annotations and edge traces are combined
edge_weights = [e[2]['weight'] for e in G.edges(data=True)]
max_weight = max(edge_weights)
min_weight = min(edge_weights)

for edge in G.edges(data=True):
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    weight = edge[2]['weight']

    # Adjust scaling factor for edge thickness
    if max_weight != min_weight:
        width = ((weight - min_weight) / (max_weight - min_weight)) * 20 + 8  # Scale between 2 and 12 based on weights
    else:
        width = 2  # Default width if all weights are the same

    # Create arc curves for edges, curve height depends on node distance
    t = np.linspace(0, 1, 100)
    x_mid = (x0 + x1) / 2  # Midpoint for the x-axis
    curve_height = abs(x1 - x0) / 2  # Height of the arc proportional to distance between nodes
    y_mid = y0 + curve_height  # Midpoint for the y-axis, curve height added

    x_values = (1 - t) ** 2 * x0 + 2 * (1 - t) * t * x_mid + t ** 2 * x1
    y_values = (1 - t) ** 2 * y0 + 2 * (1 - t) * t * y_mid + t ** 2 * y1

    # Use a more visible color for edges
    edge_color = '#888'

    # Create edge trace
    edge_trace = go.Scatter(
        x=x_values,
        y=y_values,
        line=dict(width=width, color=edge_color, dash='solid'),  # Change 'dash' to other styles like 'dot', 'dashdot'
        hoverinfo='text',
        hovertext=f"{edge[0]} - {edge[1]}<br>Connections: {weight}",
        mode='lines'
    )
    edge_traces.append(edge_trace)

    # Find the highest point on the curve (maximum y-value)
    max_y_index = np.argmax(y_values)  # Index of the highest point along the curve
    top_x = x_values[max_y_index]  # X position at the highest point
    top_y = y_values[max_y_index]  # Y position at the highest point

    # Add text annotation at the highest point of the curve
    edge_annotations.append(dict(
        x=top_x, y=top_y, text=f"{weight}",
        showarrow=False,
        font=dict(size=22, weight='bold'),
        align="center"
    ))

# Create the figure
fig = go.Figure(data=edge_traces + [node_trace],
                layout=go.Layout(
                    title=dict(
                        text='PID Graph: Number of Nodes and Connections',
                        font=dict(size=20),
                        x=0.5, y=0.98,  # Center-align title
                        xanchor='center', yanchor='top'
                    ),
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20, l=5, r=5, t=40),
                    annotations=edge_annotations,  # Add edge annotations for weights
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    width=2000, height=900,
                    plot_bgcolor='white',
                    paper_bgcolor='white'
                ))

# Show the figure
fig.show()

In [None]:
import plotly.io as pio

pio.write_html(fig, file='graph.html', auto_open=True)