# GNN Exploration — Holley Vehicle Fitment Recommendations

1. Load all nodes/edges from BigQuery
2. Graph statistics: degree distributions, sparsity, connected components
3. Cold-start analysis: % users with 0 interaction edges vs 100% with vehicle edges
4. Prototype simple GCN on user→product bipartite subgraph
5. Visualize sample subgraphs (cold/warm/hot users)

In [None]:
import sys
sys.path.insert(0, '..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from collections import Counter

from src.gnn.data_loader import GNNDataLoader
from src.gnn.graph_builder import HolleyGraphBuilder

## 1. Load Data

In [None]:
loader = GNNDataLoader()
data = loader.load_all()

print("Table sizes:")
for k, v in data.items():
    print(f"  {k}: {len(v):,} rows")

In [None]:
users = data['user_nodes']
products = data['product_nodes']
vehicles = data['vehicle_nodes']
edges_up = data['edges_user_product']
edges_pv = data['edges_product_vehicle']
edges_uv = data['edges_user_vehicle']
edges_pp = data['edges_product_product']

print(f"Users: {len(users):,}")
print(f"Products: {len(products):,}")
print(f"Vehicles: {len(vehicles):,}")
print(f"Edges user→product: {len(edges_up):,}")
print(f"Edges product→vehicle: {len(edges_pv):,}")
print(f"Edges user→vehicle: {len(edges_uv):,}")
print(f"Edges product↔product: {len(edges_pp):,}")

## 2. Graph Statistics

In [None]:
# Engagement tier distribution
tier_counts = users['engagement_tier'].value_counts()
print("Engagement tiers:")
for tier, count in tier_counts.items():
    print(f"  {tier}: {count:,} ({count/len(users)*100:.1f}%)")

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# User→product degree distribution
user_degrees = edges_up.groupby('user_id').size()
axes[0].hist(user_degrees, bins=50, edgecolor='black')
axes[0].set_title('User→Product Degree Distribution')
axes[0].set_xlabel('Degree')
axes[0].set_yscale('log')

# Product→vehicle degree (fitment breadth)
pv_degrees = edges_pv.groupby('sku').size()
axes[1].hist(pv_degrees, bins=50, edgecolor='black')
axes[1].set_title('Product→Vehicle Degree (Fitment Breadth)')
axes[1].set_xlabel('Degree')
axes[1].set_yscale('log')

# Co-purchase edge weight distribution
axes[2].hist(edges_pp['weight'], bins=50, edgecolor='black')
axes[2].set_title('Co-Purchase Edge Weight Distribution')
axes[2].set_xlabel('Weight (log(1+count))')

plt.tight_layout()
plt.show()

## 3. Cold-Start Analysis

In [None]:
# Users with interaction edges vs vehicle edges
users_with_interactions = set(edges_up['user_id'].unique())
users_with_vehicles = set(edges_uv['user_id'].unique())
all_users = set(users['user_id'].unique())

cold_users = all_users - users_with_interactions

print(f"Total users: {len(all_users):,}")
print(f"Users with interaction edges: {len(users_with_interactions):,} ({len(users_with_interactions)/len(all_users)*100:.1f}%)")
print(f"Users with vehicle edges: {len(users_with_vehicles):,} ({len(users_with_vehicles)/len(all_users)*100:.1f}%)")
print(f"Cold users (no interactions): {len(cold_users):,} ({len(cold_users)/len(all_users)*100:.1f}%)")
print(f"\nKey insight: {len(cold_users)/len(all_users)*100:.0f}% of users are cold-start for interactions")
print(f"But {len(users_with_vehicles)/len(all_users)*100:.0f}% have vehicle edges → GNN can propagate info via vehicle→product→user")

## 4. Build PyG Graph & Prototype Training

In [None]:
builder = HolleyGraphBuilder()
hetero = builder.build(data)
print(hetero)
print(f"\nNum part types: {builder.num_part_types}")

In [None]:
# Quick training test (small subset or full)
from src.gnn.model import HolleyGAT
from src.gnn.trainer import GNNTrainer, TrainConfig

model = HolleyGAT(
    num_users=hetero['user'].num_nodes,
    num_products=hetero['product'].num_nodes,
    num_vehicles=hetero['vehicle'].num_nodes,
    num_part_types=builder.num_part_types,
)

config = TrainConfig(epochs=10, patience=5)
trainer = GNNTrainer(model, hetero, config, device='cpu')
history = trainer.train()

plt.figure(figsize=(8, 4))
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('GNN Training Loss')
plt.show()

## 5. Visualize Sample Subgraphs

In [None]:
def visualize_user_subgraph(user_id, data, builder, ax, title):
    """Visualize a user's local neighborhood."""
    G = nx.Graph()
    G.add_node(f"U:{user_id}", color='blue', size=300)
    
    # User→product edges
    up = data['edges_user_product']
    user_products = up[up['user_id'] == user_id]['sku'].tolist()
    for sku in user_products[:10]:  # limit for readability
        G.add_node(f"P:{sku}", color='green', size=100)
        G.add_edge(f"U:{user_id}", f"P:{sku}")
    
    # User→vehicle edge
    uv = data['edges_user_vehicle']
    user_vehicles = uv[uv['user_id'] == user_id]['vehicle_id'].tolist()
    for vid in user_vehicles:
        G.add_node(f"V:{vid}", color='red', size=200)
        G.add_edge(f"U:{user_id}", f"V:{vid}")
        
        # Vehicle→product edges (show a few)
        pv = data['edges_product_vehicle']
        vehicle_products = pv[pv['vehicle_id'] == vid]['sku'].tolist()[:5]
        for sku in vehicle_products:
            if f"P:{sku}" not in G:
                G.add_node(f"P:{sku}", color='lightgreen', size=80)
            G.add_edge(f"V:{vid}", f"P:{sku}")
    
    colors = [G.nodes[n].get('color', 'gray') for n in G.nodes]
    sizes = [G.nodes[n].get('size', 100) for n in G.nodes]
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, ax=ax, node_color=colors, node_size=sizes,
            with_labels=False, edge_color='gray', alpha=0.7)
    ax.set_title(title)


# Pick one user per tier
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

for i, tier in enumerate(['cold', 'warm', 'hot']):
    tier_users = users[users['engagement_tier'] == tier]['user_id']
    if len(tier_users) > 0:
        sample_user = tier_users.iloc[0]
        visualize_user_subgraph(sample_user, data, builder, axes[i], f"{tier.title()} User")

plt.suptitle('Sample User Subgraphs by Engagement Tier\nBlue=User, Red=Vehicle, Green=Product', y=1.02)
plt.tight_layout()
plt.show()