In [1]:
if (!require("pcalg")) install.packages("pcalg")
if (!require("graph")) install.packages("graph")
if (!require("ggm")) install.packages("ggm")
if (!require("Rgraphviz")) BiocManager::install("Rgraphviz")
if (!require("fastICA")) install.packages("fastICA")
if (!require("data.table")) install.packages("data.table")
if (!require("reticulate")) install.packages("reticulate")
if (!require("MASS")) install.packages("MASS")

library(pcalg)
library(graph)
library(ggm)
library(Rgraphviz)
library(fastICA)
library(data.table)
library(reticulate)
library(MASS)

set.seed(42)

load_tep_data <- function(data_path, n_vars) {
    all_data <- list()
   ## selected_vars <- c("XMV.11.", "XMEAS.17.", "XMEAS.20.", "XMV.10.", 
   ##                   "XMEAS.18.", "XMEAS.5.", "XMEAS.24.", "XMEAS.9.", 
   ##                   "XMEAS.21.", "XMEAS.8.", "XMEAS.39.", "XMEAS.1.",
   ##                   "XMEAS.37.", "XMEAS.6.", "XMEAS.14.")

    selected_vars <- c("XMV.11.", "XMV.5.", "XMV.4.", "XMEAS.33.", 
                      "XMV.8.", "XMV.2.", "XMEAS.10.", "XMEAS.17.", 
                      "XMEAS.32.", "XMEAS.18.", "XMEAS.27.", "XMEAS.34.",
                      "XMEAS.41.", "XMEAS.35.", "XMEAS.38.")
    
    selected_vars <- selected_vars[1:n_vars]
    
    d00 <- as.data.table(read.csv(file.path(data_path, "d00.csv")))
    d00$FaultBinary <- 0
    d00 <- d00[, c(selected_vars, "FaultBinary"), with=FALSE]
    all_data[[1]] <- d00
    
    for(i in 1:21) {
        fault_file <- file.path(data_path, sprintf("d%02d.csv", i))
        if(file.exists(fault_file)) {
            di <- as.data.table(read.csv(fault_file))
            di$FaultBinary <- 1
            di <- di[, c(selected_vars, "FaultBinary"), with=FALSE]
            all_data[[i+1]] <- di
        }
    }
    
    combined_data <- rbindlist(all_data)
    return(as.matrix(combined_data))
}

preprocess_data <- function(data_matrix) {
    scaled_data <- scale(data_matrix)
    scaled_data[is.na(scaled_data)] <- 0
    scaled_data[is.infinite(scaled_data)] <- 0
    scaled_data <- scaled_data + matrix(rnorm(prod(dim(scaled_data)), 0, 1e-10), 
                                      nrow=nrow(scaled_data))
    return(scaled_data)
}

create_balanced_dataset <- function(data_matrix, sample_size = 2000) {
    normal_idx <- which(data_matrix[, "FaultBinary"] == 0)
    fault_idx <- which(data_matrix[, "FaultBinary"] == 1)
    
    n_samples <- min(sample_size/2, length(normal_idx), length(fault_idx))
    sampled_normal <- sample(normal_idx, n_samples)
    sampled_fault <- sample(fault_idx, n_samples)
    
    balanced_data <- data_matrix[c(sampled_normal, sampled_fault), ]
    return(balanced_data)
}

perform_causal_discovery <- function(data, alpha = 0.01) {
    processed_data <- preprocess_data(data)
    n <- nrow(processed_data)
    var_names <- colnames(data)
    
    suffStat <- list(C = cor(processed_data), n = n)
    
    cat("\nPerforming FCI algorithm...\n")
    fci_result <- fci(suffStat, 
                     indepTest = gaussCItest,
                     alpha = alpha,
                     labels = var_names)
    
    cat("\nPerforming RFCI algorithm...\n")
    rfci_result <- rfci(suffStat, 
                       indepTest = gaussCItest,
                       alpha = alpha,
                       labels = var_names)
    
    cat("Performing PC algorithm...\n")
    pc_result <- pc(suffStat,
                   indepTest = gaussCItest,
                   alpha = alpha,
                   labels = var_names)
    
    cat("\nPerforming NOTEARS algorithm...\n")
    notears_result <- tryCatch({
        X <- scale(processed_data)
        p <- ncol(X)
        W <- matrix(0, p, p)
        lambda1 <- 0.1
        max_iter <- 100
        h_tol <- 1e-8
        rho_max <- 1e+16
        rho <- 1.0
        alpha <- 0.0
        
        for(iter in 1:max_iter) {
            grad <- t(X) %*% X %*% W - t(X) %*% X + lambda1 * sign(W)
            W <- W - grad / rho
            h <- sum(exp(W * W)) - p
            if(abs(h) <= h_tol) {
                break
            }
            alpha <- alpha + rho * h
            rho <- min(2 * rho, rho_max)
        }
        W[abs(W) < 0.1] <- 0
        W
    }, error = function(e) {
        warning("NOTEARS failed: ", e$message)
        matrix(0, ncol = ncol(processed_data), nrow = ncol(processed_data))
    })
    
    cat("\nPerforming LINGAM algorithm...\n")
    lingam_result <- tryCatch({
        X <- scale(processed_data)
        ica_result <- fastICA(X, n.comp = ncol(X))
        W <- ica_result$K %*% ica_result$W
        B <- solve(W)
        B_abs <- abs(B)
        threshold <- quantile(B_abs[B_abs > 0], 0.3)
        B[B_abs < threshold] <- 0
        diag(B) <- 0
        
        if(max(abs(B)) > 0) {
            B <- B / max(abs(B))
        }
        
        lingam_graph <- new("graphNEL", 
                           nodes = var_names,
                           edgemode = "directed")
        
        for(i in 1:nrow(B)) {
            for(j in 1:ncol(B)) {
                if(B[i,j] != 0) {
                    lingam_graph <- addEdge(var_names[j], 
                                          var_names[i], 
                                          lingam_graph,
                                          weight = B[i,j])
                }
            }
        }
        lingam_graph
    }, error = function(e) {
        warning("LINGAM failed: ", e$message)
        new("graphNEL", nodes = var_names, edgemode = "directed")
    })
    
    return(list(
        fci = fci_result,
        rfci = rfci_result,
        pc = pc_result,
        notears = notears_result,
        lingam = lingam_result
    ))
}

analyze_results <- function(results, data, save_path_base) {
    # Enhanced graph attributes
    graph_attrs <- list(
        node = list(
            fontsize = 16,
            width = 2.0,
            height = 2.0,
            fixedsize = TRUE,
            shape = "circle",
            style = "filled",
            fillcolor = "white"
        ),
        edge = list(
            len = 3.0,
            arrowsize = 1.0,
            color = "black"
        ),
        graph = list(
            rankdir = "TB",
            splines = "curved",
            overlap = "scalexy",
            margin = "0.5,0.5",
            nodesep = 0.75,
            ranksep = 1.5
        )
    )
    
    # Plot FCI
    cat("\nPlotting FCI graph...\n")
    fci_file <- sprintf("%s_fci_%dvars.pdf", save_path_base, ncol(data)-1)
    pdf(fci_file, width = 15, height = 15)
    par(mar = c(2, 2, 4, 2), oma = c(1, 1, 1, 1), cex = 1.2)
    plot(results$fci, attrs = graph_attrs)
    dev.off()
    
    # Plot RFCI
    cat("\nPlotting RFCI graph...\n")
    rfci_file <- sprintf("%s_rfci_%dvars.pdf", save_path_base, ncol(data)-1)
    pdf(rfci_file, width = 15, height = 15)
    par(mar = c(2, 2, 4, 2), oma = c(1, 1, 1, 1), cex = 1.2)
    plot(results$rfci, attrs = graph_attrs)
    dev.off()
    
    # Plot PC
    cat("Plotting PC graph...\n")
    pc_file <- sprintf("%s_pc_%dvars.pdf", save_path_base, ncol(data)-1)
    pdf(pc_file, width = 15, height = 15)
    par(mar = c(2, 2, 4, 2), oma = c(1, 1, 1, 1), cex = 1.2)
    plot(results$pc, 
         main = sprintf("PC Causal Graph - %d variables", ncol(data)-1),
         cex.main = 1.5,
         cex = 0.8)
    dev.off()
    
    # Plot NOTEARS
    cat("\nPlotting NOTEARS graph...\n")
    notears_file <- sprintf("%s_notears_%dvars.pdf", save_path_base, ncol(data)-1)
    pdf(notears_file, width = 15, height = 15)
    par(mar = c(2, 2, 4, 2), oma = c(1, 1, 1, 1), cex = 1.2)
    notears_graph <- new("graphNEL", 
                        nodes = colnames(data),
                        edgemode = "directed")
    
    W <- results$notears
    for(i in 1:nrow(W)) {
        for(j in 1:ncol(W)) {
            if(W[i,j] != 0) {
                notears_graph <- addEdge(colnames(data)[i], 
                                       colnames(data)[j], 
                                       notears_graph)
            }
        }
    }
    plot(notears_graph, attrs = graph_attrs)
    dev.off()
    
    # Plot LINGAM
    cat("\nPlotting LINGAM graph...\n")
    if(is(results$lingam, "graphNEL")) {
        lingam_file <- sprintf("%s_lingam_%dvars.pdf", save_path_base, ncol(data)-1)
        pdf(lingam_file, width = 15, height = 15)
        par(mar = c(2, 2, 4, 2), oma = c(1, 1, 1, 1), cex = 1.2)
        plot(results$lingam, attrs = graph_attrs)
        dev.off()
    }
    
    # Calculate summary statistics
    fci_adj <- as(results$fci@amat, "matrix")
    rfci_adj <- as(results$rfci@amat, "matrix")
    pc_adj <- as(results$pc@graph, "matrix")
    notears_adj <- results$notears
    lingam_adj <- as(results$lingam, "matrix")
    
    fault_idx <- which(colnames(rfci_adj) == "FaultBinary")
    fault_connections <- list(
        FCI = colnames(fci_adj)[which(fci_adj[, fault_idx] != 0 | fci_adj[fault_idx, ] != 0)],
        RFCI = colnames(rfci_adj)[which(rfci_adj[, fault_idx] != 0 | rfci_adj[fault_idx, ] != 0)],
        PC = colnames(pc_adj)[which(pc_adj[, fault_idx] != 0 | pc_adj[fault_idx, ] != 0)],
        NOTEARS = colnames(notears_adj)[which(notears_adj[, fault_idx] != 0 | notears_adj[fault_idx, ] != 0)],
        LINGAM = colnames(lingam_adj)[which(lingam_adj[, fault_idx] != 0 | lingam_adj[fault_idx, ] != 0)]
    )
    
    methods <- c("FCI", "RFCI", "PC", "NOTEARS", "LINGAM")
    edge_counts <- c(sum(fci_adj != 0)/2,
                    sum(rfci_adj != 0)/2,
                    sum(pc_adj != 0)/2,
                    sum(notears_adj != 0)/2,
                    sum(lingam_adj != 0)/2)
    
    edge_summary <- data.frame(
        Method = methods,
        Total_Edges = edge_counts,
        Fault_Connected_Variables = c(
            length(fault_connections$FCI),
            length(fault_connections$RFCI),
            length(fault_connections$PC),
            length(fault_connections$NOTEARS),
            length(fault_connections$LINGAM)
        )
    )
    
    # Save summary to CSV
    summary_file <- sprintf("%s_summary_%dvars.csv", save_path_base, ncol(data)-1)
    write.csv(edge_summary, summary_file)
    
    return(list(
        edge_summary = edge_summary,
        fault_connections = fault_connections
    ))
}

main <- function(data_path, save_path_base) {
    var_sets <- c(7, 10, 12, 15)
    results_all <- list()
    
    for(n_vars in var_sets) {
        cat(sprintf("\nProcessing analysis with %d variables...\n", n_vars))
        full_data <- load_tep_data(data_path, n_vars)
        balanced_data <- create_balanced_dataset(full_data)
        results <- perform_causal_discovery(balanced_data)
        comparison <- analyze_results(results, balanced_data, save_path_base)
        results_all[[sprintf("vars_%d", n_vars)]] <- list(
            causal_results = results,
            comparison = comparison
        )
    }
    
    return(results_all)
}

# Execute the analysis
results <- main(
    ## data_path = "Downloads/data_tep/",
    data_path = "Downloads/tep2py/data_tep/",
    save_path_base = "Downloads/shap"
)

Loading required package: pcalg

Loading required package: graph

Loading required package: BiocGenerics


Attaching package: ‘BiocGenerics’


The following objects are masked from ‘package:stats’:

    IQR, mad, sd, var, xtabs


The following objects are masked from ‘package:base’:

    anyDuplicated, aperm, append, as.data.frame, basename, cbind,
    colnames, dirname, do.call, duplicated, eval, evalq, Filter, Find,
    get, grep, grepl, intersect, is.unsorted, lapply, Map, mapply,
    match, mget, order, paste, pmax, pmax.int, pmin, pmin.int,
    Position, rank, rbind, Reduce, rownames, sapply, saveRDS, setdiff,
    table, tapply, union, unique, unsplit, which.max, which.min


Loading required package: ggm

Loading required package: Rgraphviz

Loading required package: grid

Loading required package: fastICA

Loading required package: data.table

Loading required package: reticulate

Loading required package: MASS




Processing analysis with 7 variables...

Performing FCI algorithm...

Performing RFCI algorithm...
Performing PC algorithm...

Performing NOTEARS algorithm...

Performing LINGAM algorithm...

Plotting FCI graph...

Plotting RFCI graph...
Plotting PC graph...

Plotting NOTEARS graph...

Plotting LINGAM graph...

Processing analysis with 10 variables...

Performing FCI algorithm...

Performing RFCI algorithm...
Performing PC algorithm...

Performing NOTEARS algorithm...

Performing LINGAM algorithm...

Plotting FCI graph...

Plotting RFCI graph...
Plotting PC graph...

Plotting NOTEARS graph...

Plotting LINGAM graph...

Processing analysis with 12 variables...

Performing FCI algorithm...

Performing RFCI algorithm...
Performing PC algorithm...

Performing NOTEARS algorithm...

Performing LINGAM algorithm...

Plotting FCI graph...

Plotting RFCI graph...
Plotting PC graph...

Plotting NOTEARS graph...

Plotting LINGAM graph...

Processing analysis with 15 variables...

Performing FCI a