In [None]:
!pip install plotly networkx --upgrade

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import streamlit as st
import networkx as nx
import plotly

import warnings; warnings.simplefilter('ignore')

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()
session.use_database('TEMP')
session.use_schema('PUBLIC')


In [None]:
-- use accountadmin to set up roles below
USE ROLE ACCOUNTADMIN;

-- Create a consumer role for users of the GDS application
CREATE ROLE IF NOT EXISTS gds_role;
GRANT USAGE, OPERATE ON WAREHOUSE WH_XS TO ROLE GDS_ROLE;
GRANT APPLICATION ROLE NEO4J_GRAPH_ANALYTICS.app_user TO ROLE gds_role;

-- Create a consumer role for administrators of the GDS application
CREATE ROLE IF NOT EXISTS gds_admin_role;
GRANT APPLICATION ROLE NEO4J_GRAPH_ANALYTICS.app_admin TO ROLE gds_admin_role; 
 
GRANT ALL ON DATABASE TEMP TO APPLICATION NEO4J_GRAPH_ANALYTICS;
GRANT ALL ON SCHEMA PUBLIC TO APPLICATION NEO4J_GRAPH_ANALYTICS;
GRANT USAGE ON DATABASE TEMP TO role gds_role;

GRANT USAGE ON SCHEMA TEMP.PUBLIC TO APPLICATION NEO4J_GRAPH_ANALYTICS;
GRANT USAGE ON SCHEMA TEMP.PUBLIC TO role gds_role;

CREATE OR REPLACE TABLE temp.public.P2P_AGG_TRANSACTIONS (
	SOURCENODEID NUMBER(38,0),
	TARGETNODEID NUMBER(38,0),
	TOTAL_AMOUNT FLOAT
) AS
SELECT sourceNodeId, targetNodeId, SUM(transaction_amount) AS total_amount
FROM temp.public.P2P_TRANSACTIONS
GROUP BY sourceNodeId, targetNodeId;

GRANT ALL ON ALL TABLES IN SCHEMA TEMP.PUBLIC TO APPLICATION NEO4J_GRAPH_ANALYTICS;
GRANT ALL ON ALL TABLES IN SCHEMA TEMP.PUBLIC TO role gds_role;
GRANT ALL ON SCHEMA TEMP.PUBLIC to role gds_role;
GRANT ALL ON ALL TABLES IN SCHEMA TEMP.PUBLIC to role gds_role;
GRANT ALL PRIVILEGES ON FUTURE TABLES IN SCHEMA TEMP.PUBLIC TO role gds_role;
GRANT ROLE gds_role to role gds_admin_role;
GRANT ROLE gds_admin_role to role sysadmin;




In [None]:
CALL NEO4J_GRAPH_ANALYTICS.GDS.CREATE_SESSION('CPU_X64_XS');

In [None]:
SELECT * FROM P2P_USERS
limit 10;

In [None]:
SELECT * FROM P2P_W_SHARED_CARD
limit 10;

In [None]:
SELECT NEO4J_GRAPH_ANALYTICS.gds.graph_drop('entity_linking_graph', { 'failIfMissing': false });

In [None]:
import networkx as nx
import plotly.graph_objects as go


def size_scale(lst, bounds=(5, 10)):
    """Scales a list of values to a given range (bounds)."""
    if not lst:
        return []
    mx, mn = max(lst), min(lst)
    d = mx - mn if mx != mn else 1  # Avoid division by zero
    return [(bounds[1] - bounds[0]) * ((i - mn) / d) + bounds[0] for i in lst]

def make_graph_from_wcc_ids(wcc_ids, scale_prop="CENTRALITY"):
    """Creates a directed graph for the given WCC community."""
    global e_df  
    user_df = session.table('P2P_USERS').to_pandas()
    e_df = session.table('P2P_W_SHARED_CARD').to_pandas()
    n_df = user_df[user_df.WCC_ID.isin(wcc_ids)]

    if scale_prop not in n_df.columns or 'FRAUD_TRANSFER_FLAG' not in n_df.columns:
        raise KeyError(f"Missing required columns in user_df. Available columns: {n_df.columns.tolist()}")

    user_ids = n_df.NODEID.tolist()
    e_df_filtered = e_df[(e_df.SOURCENODEID.isin(user_ids)) & (e_df.TARGETNODEID.isin(user_ids))]

    G = nx.from_pandas_edgelist(e_df_filtered, source='SOURCENODEID', target='TARGETNODEID', create_using=nx.DiGraph())

    # Assign attributes to nodes
    attributes = dict(zip(n_df.NODEID, n_df[['NODEID', scale_prop, 'FRAUD_TRANSFER_FLAG']].to_dict(orient="records")))
    nx.set_node_attributes(G, attributes)

    return G

def plot_graph(G, title="Subgraph", scale_prop="CENTRALITY"):
    """Plots the graph using Plotly with color-coded fraud nodes."""
    pos = nx.spring_layout(G)

    # Edge traces
    edge_x, edge_y = [], []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')

    # Node traces
    node_x, node_y, node_size, node_color, node_text = [], [], [], [], []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(str(G.nodes[node]))

        # Get centrality value for scaling
        centrality = round(G.nodes[node].get(scale_prop, 1), 3)
        node_size.append(centrality)

        # Color nodes based on fraud flag
        fraud_flag = G.nodes[node].get('FRAUD_TRANSFER_FLAG', 0)
        node_color.append('red' if fraud_flag == 1 else 'blue')

    node_size = size_scale(node_size, (10, 30))

    node_trace = go.Scatter(
        x=node_x, y=node_y, mode='markers', hoverinfo='text',
        marker=dict(
            size=node_size,
            color=node_color,
            line=dict(width=2, color='DarkSlateGrey')
        ),
        text=node_text
    )

    fig = go.Figure(
        data=[edge_trace, node_trace],
        layout=go.Layout(
            title=title, titlefont_size=16,
            showlegend=False, hovermode='closest',
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
        )
    )
    
    return fig

def plot_pr_graph_from_wcc_ids(wcc_ids):
    """Plots the graph for a given WCC community with fraud highlighting."""
    G = make_graph_from_wcc_ids(wcc_ids, "CENTRALITY")
    return plot_graph(G, title=f"WCC Community: {wcc_ids}", scale_prop="CENTRALITY")



In [None]:
#plot_pr_graph_from_wcc_ids([433])

In [None]:
-- Create a projection of P2P users linked via shared cards
SELECT neo4j_graph_analytics.gds.graph_project(
               'entity_linking_graph',
           { 'nodeTable': 'temp.public.p2p_users',
               'relationshipTable': 'temp.public.p2p_w_shared_card'});

In [None]:
-- calculate weakly connected components (WCC)
SELECT neo4j_graph_analytics.gds.wcc('entity_linking_graph', {'mutateProperty': 'wcc_id'});

In [None]:
-- Write  to table
SELECT neo4j_graph_analytics.gds.write_nodeproperties('entity_linking_graph',
           {'nodeProperties': ['wcc_id'], 'table': 'temp.public.P2P_COMPONENTS'}
);

In [None]:
SELECT NEO4J_GRAPH_ANALYTICS.gds.page_rank('entity_linking_graph', 
        {'mutateProperty': 'score'}
);

In [None]:
SELECT NEO4J_GRAPH_ANALYTICS.gds.write_nodeproperties('entity_linking_graph',
           {'nodeProperties': ['score'], 'table': 'temp.public.shared_card_transaction_pagerank'});

In [None]:
ALTER TABLE p2p_users
ADD COLUMN IF NOT EXISTS centrality float;
UPDATE p2p_users
SET centrality = ROUND(shared_card_transaction_pagerank.score, 3)
FROM shared_card_transaction_pagerank
WHERE p2p_users.nodeId = shared_card_transaction_pagerank.nodeId;

ALTER TABLE p2p_users
ADD COLUMN IF NOT EXISTS wcc_id INT; 
UPDATE p2p_users
SET wcc_id = p2p_components.wcc_id
FROM p2p_components
WHERE p2p_users.nodeId = p2p_components.nodeId;

In [None]:
SELECT * FROM p2p_users 
ORDER BY wcc_id
limit 10;

In [None]:
CREATE OR REPLACE VIEW resolved_p2p_users AS
SELECT p2p_components.wcc_id,
       count(*) AS user_count,
       TO_NUMBER(SUM(CASE WHEN p2p_users.fraud_transfer_flag > 0 THEN 1 ELSE 0 END)) AS fraud_flags,
       ARRAY_AGG(p2p_users.nodeId) AS user_ids
       FROM p2p_users JOIN p2p_components ON p2p_users.nodeId = p2p_components.nodeId
GROUP BY p2p_components.wcc_id ORDER BY fraud_flags DESC;


In [None]:
select * from resolved_p2p_users 
ORDER BY fraud_flags DESC;

In [None]:
SELECT * FROM P2P_USERS 
LIMIT 10;

In [None]:
plot_pr_graph_from_wcc_ids([4016])

In [None]:
import streamlit as st

st.subheader("Explore Graph Communities")
resolved_p2p_users = session.table('resolved_p2p_users').to_pandas()


selected_comm_id = st.selectbox(
   "Select a Community",
   resolved_p2p_users['WCC_ID'].tolist(),
   placeholder="Select WCC ID...",
)

if selected_comm_id:
    comm_ids = resolved_p2p_users[resolved_p2p_users['WCC_ID'] == selected_comm_id]['WCC_ID'].tolist()
    st.plotly_chart(plot_pr_graph_from_wcc_ids(comm_ids))


In [None]:
-- stop session
CALL neo4j_graph_analytics.gds.stop_session();