In [None]:
import requests
import json
import re
import io
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from pathlib import Path

datapath = Path("../data")

pd.set_option('display.max_rows', 50)

modpath = Path("../scripts")
sys.path.append(os.path.abspath(modpath))

In [None]:
def make_volcano(tab, lfc=0, FDR=0.05):
    sig = tab[tab["FDR"]<FDR]
    sns.scatterplot(x=tab["logFC"],y=-np.log10(tab["FDR"]), edgecolor=None, color="grey")
    sns.scatterplot(x=sig["logFC"],y=-np.log10(sig["FDR"]), edgecolor=None)
    plt.ylabel("-log10 FDR")
    plt.axhline(-np.log10(FDR),ls="--",color="red")
    if lfc > 0:
        plt.axvline(lfc,ls="--",color="red")
        plt.axvline(-lfc,ls="--",color="red")
    plt.title(f"DEGs: {len(sig)}")

In [None]:
metafile = None

#dffile = "../data/GSETB2/LNPL/LNPL.csv" # Control vs Latent
#dffile = "../data/GSETB3/LTAT/LTAT.csv" # Latent vs Active
#dffile = "../data/GSETB/LWPL/LWPL.csv" # Control vs Active

#dffile = "../data/breast_lumab/LUMAB/LUMAB.csv" # Luminal A vs Luminal B

dffile = "../data/GSEPN/GIPF/GIPF.csv" # Control vs IPF
metafile = "../data/GSEPN/GIPF/GIPF.meta.csv"

df = pd.read_csv(dffile, index_col=0)
display(df.head())
print(df.shape)

if metafile:
    meta = pd.read_csv(metafile, index_col=0)
    display(meta.head())

In [None]:
from DEA import run_dea

design = metafile if metafile else "unpaired"
print("Design:", design)
lfc = 0
FDR = 0.05

edgerqlf_kwargs = {"filter_expr": True, "cols_to_keep": ["logFC","logCPM","FDR"], "lfc": lfc, "design": design,
                   "check_gof": False, "verbose": False}
edgerlrt_kwargs = {"filter_expr": True, "cols_to_keep": ["logFC","logCPM","FDR"], "lfc": lfc, "design": design,
                   "check_gof": False, "verbose": False}
deseq2_kwargs = {"cols_to_keep": ["logFC","logCPM","FDR"],"lfc": lfc, "design": design}

outfile_dea_truth = "../data/test/truth.qlf.csv"
run_dea(df, outfile_dea_truth, "edgerqlf", True, **edgerqlf_kwargs)

# outfile_dea_truth = "../data/test/truth.lrt.csv"
# run_dea(df, outfile_dea_truth, "edgerlrt", True, **edgerlrt_kwargs)

#outfile_dea_truth = "../data/test/truth.wald.csv"
#run_dea(df, outfile_dea_truth, "deseq2", True, **deseq2_kwargs)

In [None]:
tab_truth = pd.read_csv(outfile_dea_truth, index_col=0)

make_volcano(tab_truth, lfc=lfc, FDR=FDR)

## Subsamples

In [None]:
from misc import paired_replicate_sampler

N = 12

df_sub = paired_replicate_sampler(df, N)[0]
print(df_sub.shape)

if design not in ["paired", "unpaired"]:
    meta_sub = meta.loc[df_sub.columns]
    design_sub = "../data/test/design.csv"
    meta_sub.to_csv(design_sub)
else:
    design_sub = design

print(design_sub)

In [None]:
from DEA import run_dea


edgerqlf_kwargs = {"filter_expr": True, "cols_to_keep": ["logFC","logCPM","FDR"], "lfc": lfc, "design": design_sub,
                   "check_gof": False, "verbose": False}
edgerlrt_kwargs = {"filter_expr": True, "cols_to_keep": ["logFC","logCPM","FDR"], "lfc": lfc, "design": design_sub,
                   "check_gof": False, "verbose": False}
deseq2_kwargs = {"cols_to_keep": ["logFC","logCPM","FDR"],"lfc": lfc, "design": design_sub}

outfile_dea = "../data/test/test.qlf.csv"
run_dea(df_sub, outfile_dea, "edgerqlf", True, **edgerqlf_kwargs)

# outfile_dea = "../data/test/test.lrt.csv"
# run_dea(df_sub, outfile_dea, "edgerlrt", True, **edgerlrt_kwargs)

# outfile_dea = "../data/test/test.wald.csv"
# run_dea(df_sub, outfile_dea, "deseq2", True, **deseq2_kwargs)

In [None]:
tab = pd.read_csv(outfile_dea, index_col=0)
tab.head()

In [None]:
make_volcano(tab, lfc=lfc, FDR=FDR)

In [None]:
from process import sklearn_metrics

common = tab_truth.index.intersection(tab.index)
true = tab_truth.loc[common]["FDR"]<FDR
pred = tab.loc[common]["FDR"]<FDR

mcc, prec, rec = sklearn_metrics(true, pred)

TP = true & pred
FP = ~true & pred
TN = ~true & ~pred
FN = true & ~pred
assert TP.sum() + FP.sum() + TN.sum() + FN.sum() == len(common)

print(f"MCC: {mcc:>10.2f}")
print(f"Precision: {prec:.2f}")
print(f"Recall: {rec:>7.2f}")
print("===============")
print(f"True: {true.sum():>9}")
print(f"Pred: {pred.sum():>9}")
print("===============")
print(f"TP: {TP.sum():>11}")
print(f"FP: {FP.sum():>11}")
print(f"TN: {TN.sum():>11}")
print(f"FN: {FN.sum():>11}")