In [None]:
import os
import numpy as np
import pandas as pd
import torch
# import torch.nn.functional as F
from pyspark.sql import functions as F
import torch_geometric
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import col


spark = SparkSession.builder \
    .appName("TransactionClassificationGNN") \
    .getOrCreate()

spark.conf.set("spark.sql.debug.maxToStringFields", "100")

RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)

class SparkGNN:
    def __init__(self, input_dir, num_features=166, num_classes=2):
        self.input_dir = input_dir
        self.num_features = num_features
        self.num_classes = num_classes
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    def load_data(self):
        features_path = os.path.join(self.input_dir, 'elliptic_txs_features.csv')
        classes_path = os.path.join(self.input_dir, 'elliptic_txs_classes.csv')
        edgelist_path = os.path.join(self.input_dir, 'elliptic_txs_edgelist.csv')
        
        features_df = spark.read.csv(features_path, header=False)
        classes_df = spark.read.csv(classes_path, header=True)
        edgelist_df = spark.read.csv(edgelist_path, header=True)
        
        features_cols = ['txId'] + [f'V{i}' for i in range(1, self.num_features + 1)]
        features_df = features_df.toDF(*features_cols)
        
        classes_df = classes_df.withColumn('class_mapped', 
            F.when(F.col('class') == '1', 'illicit')
            .when(F.col('class') == '2', 'licit')
            .otherwise('unknown')
        )
        
        return features_df, classes_df, edgelist_df
    
    def preprocess_data(self, features_df, classes_df):
        feature_cols = [f'V{i}' for i in range(1, self.num_features + 1)]
        assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')
        indexer = StringIndexer(inputCol='class_mapped', outputCol='label')
        pipeline = Pipeline(stages=[assembler, indexer])
        preprocessor = pipeline.fit(features_df.join(classes_df, 'txId'))
        processed_df = preprocessor.transform(features_df.join(classes_df, 'txId'))
        
        return processed_df

    
    def create_graph_data(self, processed_df, edgelist_df):
        tx_id_mapping = {row['txId']: idx for idx, row in enumerate(processed_df.collect())}
        
        node_features = torch.tensor(
            processed_df.select('features').rdd.map(lambda x: x['features'].toArray()).collect(), 
            dtype=torch.float
        )

        node_labels = torch.tensor(
            processed_df.select('label').rdd.map(lambda x: x['label']).collect(), 
            dtype=torch.long
        )
        
        edge_df = edgelist_df.filter(
            F.col('txId1').isin(list(tx_id_mapping.keys())) & 
            F.col('txId2').isin(list(tx_id_mapping.keys()))
        )
        
        edge_index_list = [
            (tx_id_mapping[row['txId1']], tx_id_mapping[row['txId2']]) 
            for row in edge_df.collect()
        ]
        
        edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
        
        data = Data(x=node_features, edge_index=edge_index, y=node_labels)
    
        num_nodes = data.num_nodes
        train_mask = torch.zeros(num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(num_nodes, dtype=torch.bool)
        
        known_mask = (data.y == 0) | (data.y == 1)
        known_indices = torch.nonzero(known_mask).squeeze()
        
        perm = torch.randperm(len(known_indices))
        known_indices = known_indices[perm]
        
        train_ratio, val_ratio = 0.8, 0.1
        train_size = int(train_ratio * len(known_indices))
        val_size = int(val_ratio * len(known_indices))
        
        train_indices = known_indices[:train_size]
        val_indices = known_indices[train_size:train_size+val_size]
        test_indices = known_indices[train_size+val_size:]
        
        train_mask[train_indices] = True
        val_mask[val_indices] = True
        test_mask[test_indices] = True
        
        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask
        
        return data.to(self.device)
    
    class GCN(torch.nn.Module):
        def __init__(self, num_features, num_classes):
            super().__init__()
            self.conv1 = GCNConv(num_features, 16)
            self.conv2 = GCNConv(16, num_classes)
        
        def forward(self, data):
            x, edge_index = data.x, data.edge_index
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = self.conv2(x, edge_index)
            return F.log_softmax(x, dim=1)
    
    def train(self, data, num_epochs=100):
        model = self.GCN(num_features=data.num_features, num_classes=self.num_classes).to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)
        criterion = torch.nn.CrossEntropyLoss()
        
        model.train()
        for epoch in range(num_epochs):
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()
            
            if epoch % 10 == 0:
                print(f'Epoch {epoch:03d}, Loss: {loss.item():.4f}')
        
        return model
    
    def evaluate(self, model, data):
        model.eval()
        with torch.no_grad():
            out = model(data)
            pred = out.argmax(dim=1)

            test_mask = data.test_mask
            y_true = data.y[test_mask].cpu().numpy()
            y_pred = pred[test_mask].cpu().numpy()
            
            precision = precision_score(y_true, y_pred, average='weighted')
            recall = recall_score(y_true, y_pred, average='weighted')
            f1 = f1_score(y_true, y_pred, average='weighted')
            
            print("\nClassification Report:")
            print(classification_report(y_true, y_pred, target_names=['illicit', 'licit']))
            
            return {
                'precision': precision,
                'recall': recall,
                'f1_score': f1
            }
    
    def run_pipeline(self):
        features_df, classes_df, edgelist_df = self.load_data()
        numeric_cols = [col_name for col_name in features_df.columns if col_name != 'txId']
        for column in numeric_cols:
            features_df = features_df.withColumn(column, col(column).cast(DoubleType()))
        
        processed_df = self.preprocess_data(features_df, classes_df)
        
        graph_data = self.create_graph_data(processed_df, edgelist_df)
        
        model = self.train(graph_data)
        
        metrics = self.evaluate(model, graph_data)
        
        return model, metrics

if __name__ == '__main__':
    input_directory = ''
    gnn_pipeline = SparkGNN(input_directory)
    model, metrics = gnn_pipeline.run_pipeline()
    print("\nFinal Metrics:")
    for metric, value in metrics.items():
        print(f"{metric.capitalize()}: {value:.4f}")

    spark.stop()

your 131072x1 screen size is bogus. expect trouble
24/12/03 13:41:10 WARN Utils: Your hostname, KrystalXPS resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
24/12/03 13:41:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/03 13:41:11 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/12/03 13:41:57 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 10:>                                                         (0 + 8) / 8]

: 