In [None]:

## Loading libraries 
install.packages("pROC")
suppressPackageStartupMessages({
  library(dplyr)
  library(data.table)
  library(tidyr)
  library(pROC)
  library(ggplot2)
})

## Step 1: importing data
df_pheno <- fread('t2d.pheno')
df_afr <- read.table('unrel.txt', col.names = 'IID')

df_afr_pheno <- df_pheno %>% 
  filter(IID %in% df_afr$IID) %>%
  as.data.table()

## Step 2: Define helper function for score file processing
process_pgs_method <- function(method_name, file_list) {
  return(file_list)
}

## Step 3: Define Nagelkerke R² function using your provided formula
NagelkerkeR2 <- function(rr) {
  n <- nrow(rr$model)
  R2 <- (1 - exp((rr$dev - rr$null)/n))/(1 - exp(-rr$null/n))
  RVAL <- list(N = n, R2 = R2)
  return(RVAL)
}

## Step 4: Define all score files

score_files <- list(
  PT_Clump = list(
    T2D_Suzuki = "genopred/output/ukb/pgs/TRANS/ptclump/T2D_SUZUKI/ukb-T2D_SUZUKI-TRANS.profiles"
  ),
  MegaPRS = list(
    T2D_Suzuki = "genopred/output/ukb/pgs/TRANS/megaprs/T2D_SUZUKI/ukb-T2D_SUZUKI-TRANS.profiles"
  ),
  QuickPRS = list(
    T2D_Suzuki = "genopred/output/ukb/pgs/TRANS/quickprs/T2D_SUZUKI/ukb-T2D_SUZUKI-TRANS.profiles"
  ),
  Lassosum = list(
    T2D_Suzuki = "genopred/output/ukb/pgs/TRANS/lassosum/T2D_SUZUKI/ukb-T2D_SUZUKI-TRANS.profiles"
  ),
  DBSLMM = list(
    T2D_Suzuki = "genopred/output/ukb/pgs/TRANS/dbslmm/T2D_SUZUKI/ukb-T2D_SUZUKI-TRANS.profiles"
  ),
  PRS_CS = list(
    T2D_Suzuki = "genopred/output/ukb/pgs/TRANS/prscs/T2D_SUZUKI/ukb-T2D_SUZUKI-TRANS.profiles"
  )
)


## Step 5: Function to process one score file using your Nagelkerke's R²
process_score_file <- function(score_path, method_name, gwas_source) {
  cat("Processing:", method_name, "-", gwas_source, "\n")
  
  # Check if file exists
  if (!file.exists(score_path)) {
    cat("File not found:", score_path, "\n")
    return(NULL)
  }
  
  # Read scores
  scores <- fread(score_path)
  
  # Check if required columns exist
  if (!"IID" %in% names(scores)) {
    cat("IID column not found in:", score_path, "\n")
    return(NULL)
  }
  
  # Merge with phenotypes
  setkey(scores, IID)
  setkey(df_afr_pheno, IID)
  df <- scores[df_afr_pheno, nomatch = 0]
  
  # Check if merge was successful
  if (nrow(df) == 0) {
    cat("No overlapping IIDs found for:", method_name, "-", gwas_source, "\n")
    return(NULL)
  }
  
  # Identify PRS columns
  prs_cols <- grep("^T2D_", names(df), value = TRUE)
  
  if (length(prs_cols) == 0) {
    cat("No PRS columns found in:", score_path, "\n")
    return(NULL)
  }
  
  # Compute Nagelkerke's R² for each PRS using your function
  r2_results <- lapply(prs_cols, function(col) {
    tryCatch({
      # Use backticks to handle special characters in column names
      base_formula <- as.formula("t2d_cc ~ age + genetic_sex + PC1 + PC2 + PC3 + PC4 + PC5 + PC6 + PC7 + PC8 + PC9 + PC10")
      full_formula <- as.formula(paste("t2d_cc ~ `", col, "` + age + genetic_sex + PC1 + PC2 + PC3 + PC4 + PC5 + PC6 + PC7 + PC8 + PC9 + PC10", sep = ""))
      
      # Fit models
      base_model <- glm(base_formula, data = df, family = "binomial")
      full_model <- glm(full_formula, data = df, family = "binomial")
      
      # Calculate Nagelkerke's R² using your function
      r2_base <- NagelkerkeR2(base_model)
      r2_full <- NagelkerkeR2(full_model)
      
      # Calculate incremental R² (PRS contribution)
      incremental_r2 <- r2_full$R2 - r2_base$R2
      
      data.frame(
        PRS = col,
        Base_R2 = r2_base$R2,
        Full_R2 = r2_full$R2,
        Incremental_R2 = incremental_r2,
        N = r2_full$N
      )
    }, error = function(e) {
      cat("Error processing PRS column:", col, "| Error:", e$message, "\n")
      return(NULL)
    })
  })
  
  # Remove NULL results
  r2_results <- r2_results[!sapply(r2_results, is.null)]
  
  if (length(r2_results) == 0) {
    cat("No valid R² results for:", method_name, "-", gwas_source, "\n")
    return(NULL)
  }
  
  r2_df <- do.call(rbind, r2_results)
  r2_df <- r2_df[order(-r2_df$Incremental_R2), ]
  
  # Select best PRS parameter (based on incremental R²)
  best_prs_col <- r2_df$PRS[1]
  
  # Add raw best PRS and standardized version
  df$BEST_PRS_raw <- df[[best_prs_col]]
  df$BEST_PRS_std <- as.numeric(scale(df$BEST_PRS_raw))
  
  return(list(
    method = method_name,
    gwas_source = gwas_source,
    r2_table = r2_df,
    best_param = best_prs_col,
    best_incremental_r2 = max(r2_df$Incremental_R2),
    best_full_r2 = r2_df$Full_R2[1],
    base_r2 = r2_df$Base_R2[1],
    n_observations = r2_df$N[1],
    df = df
  ))
}

## Step 6: Process all methods and GWAS sources
all_results <- list()

for (method in names(score_files)) {
  method_results <- list()
  
  for (gwas_source in names(score_files[[method]])) {
    score_path <- score_files[[method]][[gwas_source]]
    
    result <- process_score_file(score_path, method, gwas_source)
    
    if (!is.null(result)) {
      method_results[[gwas_source]] <- result
      cat(" ", method, "-", gwas_source, "| Best PRS:", result$best_param, 
          "| Incremental R²:", round(result$best_incremental_r2, 6), 
          "| Full R²:", round(result$best_full_r2, 4), 
          "| N:", result$n_observations, "\n")
    }
  }
  
  if (length(method_results) > 0) {
    all_results[[method]] <- method_results
  }
}

## Step 7: Compare best incremental R² across all methods
comparison_df <- data.frame()

for (method in names(all_results)) {
  for (gwas_source in names(all_results[[method]])) {
    result <- all_results[[method]][[gwas_source]]
    comparison_df <- rbind(comparison_df, data.frame(
      Method = method,
      GWAS_Source = gwas_source,
      Best_PRS = result$best_param,
      Base_R2 = result$base_r2,
      Full_R2 = result$best_full_r2,
      Incremental_R2 = result$best_incremental_r2,
      N = result$n_observations
    ))
  }
}

# Check if we have any results
if (nrow(comparison_df) == 0) {
  stop("No valid results were generated. Please check your file paths and data.")
}

# Order by best incremental R²
comparison_df <- comparison_df[order(-comparison_df$Incremental_R2), ]

cat("\n", strrep("=", 70), "\n", sep = "")
cat("FINAL RANKING OF ALL METHODS (by Incremental Nagelkerke's R²)\n")
cat("Using your provided Nagelkerke R² formula\n")
cat(strrep("=", 70), "\n", sep = "")
print(comparison_df, digits = 6)
cat(strrep("=", 70), "\n", sep = "")

# Best overall method
best_overall <- comparison_df[1, ]
cat("\n BEST OVERALL METHOD:", best_overall$Method, "-", best_overall$GWAS_Source, "\n")
cat("Incremental R²:", round(best_overall$Incremental_R2, 6), "\n")
cat("Full Model R²:", round(best_overall$Full_R2, 4), "\n")
cat("Base Model R²:", round(best_overall$Base_R2, 4), "\n")
cat("Sample Size:", best_overall$N, "\n")
cat("Best PRS parameter:", best_overall$Best_PRS, "\n")

## Step 6: Save results to CSV
write.csv(comparison_df, "prs_output/method_comparison_r2.csv", row.names = FALSE)
cat("\n Results saved to: method_comparison_r2.csv\n")

## Step 7: Create visualization
# Create a combined label for plotting
comparison_df$Method_Label <- paste(comparison_df$Method)

# Create the plot for incremental R²
p1 <- ggplot(comparison_df, aes(x = reorder(Method_Label, Incremental_R2), y = Incremental_R2)) +
  geom_col(fill = "#451bcf", alpha = 0.8) +
  geom_text(aes(label = round(Incremental_R2, 6)), hjust = -0.2, size = 3) +
  coord_flip() +
  labs(
    title = "Comparison of PGS Methods Performance",
    subtitle = paste("Nagelkerke R² | Best Method:", best_overall$Method),
    x = "Method",
    y = "Incremental Nagelkerke's R²"
  ) +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5),
        plot.subtitle = element_text(hjust = 0.5, size = 10),
        axis.text.y = element_text(size = 10))

# Save the plot
ggsave("prs_output/method_comparison_incremental_r2.png", p1, width = 12, height = 8, dpi = 300)
cat(" Visualisation saved to: method_comparison_incremental_r2.png\n")

# Display the plot
print(p1)

## Step 8: Verify the function works
cat("\n", strrep("-", 70), "\n", sep = "")
cat("VERIFICATION: Testing your Nagelkerke R² function\n")
cat(strrep("-", 70), "\n", sep = "")



## ===============================
## Step 9: AUC (ROC) for best PRS
## ===============================
# Ensure pROC is available
# if (!requireNamespace("pROC", quietly = TRUE)) {
  # stop("Package 'pROC' is required for AUC/ROC. Please install it with install.packages('pROC').")
#}
#library(pROC)

auc_df <- data.frame()
roc_curves <- list()

for (i in seq_len(nrow(comparison_df))) {
  meth <- comparison_df$Method[i]
  src  <- comparison_df$GWAS_Source[i]
  res  <- all_results[[meth]][[src]]
  df   <- res$df
  
  # PRS-only ROC/AUC
  roc_prs <- roc(df$t2d_cc, df$BEST_PRS_std, quiet = TRUE)
  auc_prs <- as.numeric(auc(roc_prs))
  
  # PRS + covariates ROC/AUC
  base_terms  <- c("age", "genetic_sex", paste0("PC", 1:10))
  full_formula <- reformulate(c("BEST_PRS_std", base_terms), response = "t2d_cc")
  full_fit    <- glm(full_formula, data = df, family = binomial())
  pred_full   <- predict(full_fit, type = "response")
  roc_adj     <- roc(df$t2d_cc, pred_full, quiet = TRUE)
  auc_adj     <- as.numeric(auc(roc_adj))
  
  auc_df <- rbind(
    auc_df,
    data.frame(
      Method       = meth,
      GWAS_Source  = src,
      AUC_PRS      = auc_prs,
      AUC_Adjusted = auc_adj,
      N            = nrow(df),
      stringsAsFactors = FALSE
    )
  )
  
  key <- paste(meth, src, sep = " - ")
  roc_curves[[key]] <- list(prs = roc_prs, adj = roc_adj)
}

# Save summary
write.csv(auc_df, "prs_output/best_prs_auc_summary.csv", row.names = FALSE)
cat("AUC summary saved to: best_prs_auc_summary.csv\n")

# Long format for plotting
#library(tidyr)
auc_df$Method_Label <- paste(auc_df$Method, auc_df$GWAS_Source, sep = " - ")
auc_long <- pivot_longer(
  auc_df,
  cols = c(AUC_PRS, AUC_Adjusted),
  names_to = "Metric",
  values_to = "AUC"
)

# Bar chart comparing AUCs
p_auc <- ggplot(auc_long, aes(x = reorder(Method_Label, AUC), y = AUC, fill = Metric)) +
  geom_col(position = position_dodge(width = 0.7), width = 0.6) +
  coord_flip() +
  scale_fill_manual(values = c(AUC_PRS = "#4C78A8", AUC_Adjusted = "#F58518")) +
  labs(
    title = "AUC by Method (Best PRS)",
    subtitle = "Blue: PRS-only | Orange: PRS + covariates",
    x = "Method + GWAS Source",
    y = "AUC"
  ) +
  theme_minimal(base_size = 11) +
  theme(
    plot.title = element_text(hjust = 0.5),
    plot.subtitle = element_text(hjust = 0.5)
  )

ggsave("prs_output/best_prs_auc_barchart.png", p_auc, width = 11, height = 8, dpi = 300)
cat("AUC bar chart saved to: best_prs_auc_barchart.png\n")

# Optional: Overlay ROC curves for top 5 by adjusted AUC
top_k <- 6
top_methods <- auc_df[order(-auc_df$AUC_Adjusted), ][1:min(top_k, nrow(auc_df)), "Method_Label"]

# Build a combined ROC coordinate dataframe
roc_plot_df <- data.frame()
for (lbl in top_methods) {
  r <- roc_curves[[lbl]]$adj
  coords <- as.data.frame(pROC::coords(r, "all", ret = c("specificity", "sensitivity")))
  coords$FPR <- 1 - coords$specificity
  coords$TPR <- coords$sensitivity
  coords$Method_Label <- lbl
  roc_plot_df <- rbind(roc_plot_df, coords[, c("FPR", "TPR", "Method_Label")])
}

p_roc <- ggplot(roc_plot_df, aes(x = FPR, y = TPR, color = Method_Label)) +
  geom_line(size = 1) +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "grey60") +
  labs(
    title = "ROC Curves (PRS + Covariates)",
    subtitle = paste("Top", min(top_k, nrow(auc_df)), "methods by adjusted AUC"),
    x = "False Positive Rate",
    y = "True Positive Rate",
    color = "Method"
  ) +
  theme_minimal(base_size = 11) +
  theme(plot.title = element_text(hjust = 0.5), plot.subtitle = element_text(hjust = 0.5))

ggsave("prs_output/roc_curves_top_methods.png", p_roc, width = 10, height = 8, dpi = 300)
cat("ROC curves saved to: roc_curves_top_methods.png\n")



## ==========================================
## Step 10: Correlation heatmap of best PRS
## ==========================================
# Collect BEST_PRS_std per method/source with IID
prs_pairs <- list()
for (method in names(all_results)) {
  for (src in names(all_results[[method]])) {
    res <- all_results[[method]][[src]]
    label <- paste(method, src, sep = " - ")
    dt <- res$df[, .(IID, BEST_PRS_std)]
    setnames(dt, "BEST_PRS_std", label)
    prs_pairs[[label]] <- dt
  }
}

labels <- names(prs_pairs)
k <- length(labels)
corr_mat <- matrix(NA_real_, nrow = k, ncol = k, dimnames = list(labels, labels))

# Pairwise correlation using merged IID overlap
for (i in seq_len(k)) {
  for (j in seq_len(k)) {
    merged <- merge(prs_pairs[[labels[i]]], prs_pairs[[labels[j]]], by = "IID", all = FALSE)
    # merged has columns: IID, label_i, label_j
    if (nrow(merged) >= 10) {
      corr_mat[i, j] <- suppressWarnings(cor(merged[[2]], merged[[3]], use = "pairwise.complete.obs"))
    } else {
      corr_mat[i, j] <- NA_real_
    }
  }
}

# Melt the matrix without extra packages
melted <- as.data.frame(as.table(corr_mat))
names(melted) <- c("X", "Y", "Correlation")

p_heat <- ggplot(melted, aes(x = X, y = Y, fill = Correlation)) +
  geom_tile(color = "white") +
  scale_fill_gradient2(low = "#00204D", mid = "#FFFFE0", high = "#f55d42",na.value = "grey80",
    midpoint = 0, limits = c(-1, 1)
  ) +
  geom_text(aes(label = ifelse(is.na(Correlation), "", sprintf("%.2f", Correlation))), size = 3) +
  labs(
    title = "Correlation Heatmap of Best PRS across Methods",
    x = NULL, y = NULL
  ) +
  theme_minimal(base_size = 11) +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1),
    plot.title  = element_text(hjust = 0.5)
  )
p_heat
ggsave("prs_output/best_prs_correlation_heatmap.png", p_heat, width = 11, height = 9, dpi = 300)
