In [8]:
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import random

# Load the filtered datasets
classes_df = pd.read_csv('../data/filtered/filtered_classes.csv')
edgelist_df = pd.read_csv('../data/filtered/filtered_edgelist.csv')
features_df = pd.read_csv('../data/filtered/filtered_features.csv', header=None)

# Rename columns for features_df
features_df.columns = ['txId'] + [f'feature_{i}' for i in range(1, features_df.shape[1])]

# Create a directed graph
G = nx.DiGraph()

# Add nodes with features and labels
for _, row in features_df.iterrows():
    tx_id = row['txId']
    features = row[1:].tolist()
    tx_class = classes_df.loc[classes_df['txId'] == tx_id, 'class'].values[0]
    G.add_node(tx_id, features=features, label=tx_class)

# Add edges
for _, row in edgelist_df.iterrows():
    G.add_edge(row['txId1'], row['txId2'])

# Subsample nodes for visualization
num_nodes = 5000  # Number of nodes to visualize
sampled_nodes = random.sample(G.nodes(), num_nodes)
H = G.subgraph(sampled_nodes).copy()

# Create positions for the nodes using spring_layout
pos = nx.spring_layout(H)

# Extract node and edge information
edge_x = []
edge_y = []
for edge in H.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

node_x = []
node_y = []
node_color = []
for node in H.nodes():
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)
    node_color.append(H.nodes[node]['label'])

# Create the edge trace
edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.5, color='#888'),
    hoverinfo='none',
    mode='lines')

# Create the node trace
node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers',
    hoverinfo='text',
    marker=dict(
        showscale=True,
        colorscale='YlGnBu',
        color=[],
        size=10,
        colorbar=dict(
            thickness=15,
            title='Node Connections',
            xanchor='left',
            titleside='right'
        ),
    ))

# Add node labels
node_text = []
for node in H.nodes():
    node_text.append(f"TxId: {node}<br>Class: {H.nodes[node]['label']}")
node_trace.text = node_text
node_trace.marker.color = node_color

# Create the figure
fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    title='Bitcoin Transaction Network (Subsampled)',
                    titlefont_size=16,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    annotations=[dict(
                        text="Subsampled network visualization of Bitcoin transactions",
                        showarrow=False,
                        xref="paper", yref="paper",
                        x=0.005, y=-0.002 )],
                    xaxis=dict(showgrid=False, zeroline=False),
                    yaxis=dict(showgrid=False, zeroline=False))
                )

fig.show()


In [4]:
len(G.nodes)

46564