# Code to replicate analysis

## Chapters
### 1. Miscellaneous preparation steps
1. Install non-default modules and upgrade modules
1. Mount Google Drive
1. Set paths (point to google drive folder to work in)
1. Download files (initial set up, skip if continuing)
1. Import modules, initialise paths

### 2. Generating data for analysis
1. Differential Gene Expression (DESeq2)
1. Reconstructing the Gene Regulatory Network

    1. Supp. Fig. 8A: Responsiveness of DEGs across experiments
    1. Reconstruction
    1. Supp. Fig. 9: Optimising parameters for filtering GRN
    1. Finalising the gene regulatory network; Fig 3A, visualised in cytoscape
    1. TF-only GRN: Fig 4A (visualised in Cytoscape) and Supp.Fig 11
1. Extract Arabidopsis GO and TFs

### 3. Analysis and plotting
1. Figure 1 & Supp. Fig 1: Measurements and Student's t-test
1. Supp. Fig. 2: QC of RNA-seq data
1. Supp. Fig. 3: Volcano plots (DESeq2)
1. Supp. Fig. 4: Comparison of DEGs between two controls
1. Supp. Fig. 5: Upset plot (up- and down-regulated)
1. Supp. Fig. 6, 7: Intersection of DEGs across single and cross stresses (up- and down-regulated)
1. Figure 2: Summary of DEGs in Marchantia and inter stress comparisons
1. Supp. Fig. 10: Expression of GRN TFs across experiments (clustered)
1. Figure 3B: Expression of GRN TFs across experiments
1. Figure 3C: Specific expression in GRN TFs
1. Supp. Fig. 12: Influence of TFs
1. Supp. Fig. 13: Robustly responding second-level MapMan bins across the 7 abiotic stresses
1. Figure 4B, C, D; Supp. Fig. 14, 15: Bipartite networks for robustly expressed TFs and biological processes
1. Figure 5: Annotation of Ath orthologs with evidence from literature
1. Figure 6A, B: Effects of combined stress in terms of significant L2FC
1. Figure 6C: Classification of stress interactions
1. Figure 6D: Linear regression of all experiments
1. Figure 7A-H: Linear regression by stress


# 1. Miscellaneous preparation steps

### 1.1 Install non-default modules and upgrade modules

In [None]:
# install non-default colab modules
# Restart runtime after installation and skip to next step
!pip install upsetplot
!pip install matplotlib --upgrade

### 1.2 Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
!rm -rf /content/sample_data

In [None]:
#@title 1.3 Set path {display-mode: "form"}

#@markdown Enter the path of the directory you want to work in.

drive_path = '/content/gdrive/My Drive/' #@param {type: 'string'}

### 1.4 Download files (first time only)

In [None]:
# Downloads necessary files to perform analyses [only need to be done once]
# https://gist.github.com/iamtekeste/3cdfd0366ebfd2c0d805 download raw files directtly from Google Drive
!wget --no-check-certificate -r "https://drive.google.com/uc?id=1cbKgWbEWtstl_2_rb06tI_D-vseprPnT&export=download" -O marchantia_stress.zip

dir_path = drive_path + 'marchantia_stress/'
dir_path_safe = dir_path.replace(' ', '\ ')
!mkdir $dir_path_safe
!unzip marchantia_stress.zip -d $dir_path_safe

###1.5 Import modules, set paths

In [None]:
# import modules
import os
import string
%load_ext rpy2.ipython
import pandas as pd
import math
from matplotlib_venn import venn2
from matplotlib import pyplot as plt
from ast import literal_eval
import seaborn as sns
from collections import Counter
import random
from statsmodels.stats.multitest import multipletests
import numpy as np
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

dir_path = drive_path + 'marchantia_stress/'
dir_path_safe = dir_path.replace(' ', '\ ')

# 2. Generating data for analysis

### 2.1 Differential Gene Expression

In [None]:
# Making necessary directories for outputs
mpo_path = dir_path + "prep_files/mpo/deseq/"
mpo_path_safe = dir_path_safe + "prep_files/mpo/deseq/"
if not os.path.exists(mpo_path):
    !mkdir -p $mpo_path_safe
    print("Directories made: " + mpo_path.replace('\\', ''))

In [None]:
# Installing DESeq2
%%R
if (!requireNamespace("BiocManager", quietly = TRUE))
    install.packages("BiocManager")

BiocManager::install("DESeq2", ask = FALSE)

In [None]:
# To pull python variables
%R -i dir_path
%Rget dir_path

%R -i dir_path_safe
%Rget dir_path_safe

%R -i mpo_path
%Rget mpo_path

In [None]:
# adapted from DESeq2_stressonly_phase1n2.R
%%R
# DESeq2 (Marchantia)
library('DESeq2')
library('RColorBrewer')

sink(paste0(mpo_path, "phase1n2_sum.txt"), type="output")

raw_counts <- read.table(file = paste0(dir_path, 'prep_files/all_stress_raw.tsv'), sep = '\t', header = TRUE)
raw_counts <- data.frame(raw_counts, row.names = 1)

stresses <- c("controlH2", "controlD2", "H", "C", "M", "S", "L", "D", "N",
              "HS", "HM", "HN", "CS", "CM", "CN", "SM", "ML", "NL", "MN",
              "SD", "MD", "ND", "HD", "CD", "CL", "LS", "SN")
colData = read.csv(paste0(dir_path, 'summary_files/all_stress.txt'), sep = '\t', row.names=1, header = FALSE)
names(colData) <- c('condition')

dds = DESeqDataSetFromMatrix(countData=raw_counts,
                             colData=colData,
                             design=~condition)
dds = DESeq(dds)

y = 2
for (x in 1:2){
  for (i in y:length(stresses)){
    if (stresses[i] != stresses[x]){
      res = results(dds, contrast=c("condition", stresses[i], stresses[x]))
      res = res[order(res$pvalue),]
      resSig = subset(res, res$padj < 0.05 & abs(res$log2FoldChange) > 1)
      resSig = resSig[ order(resSig$padj), ]
      print(paste(stresses[i], 'vs', stresses[x]))
      summary(res)
      summary(resSig)
      write.table(as.data.frame(res), file=paste(mpo_path, stresses[i], stresses[x], '_res.tsv', sep = ''),
                  quote=FALSE, sep='\t', col.names = NA)
      write.table(as.data.frame(resSig), file=paste(mpo_path, stresses[i], stresses[x], '_resSig.tsv', sep = ''),
                  quote=FALSE, sep='\t', col.names = NA)
    }
  }
  y = y + 1
}

### 2.2 Reconstructing the Gene Regulatory Network

#### 2.2.1 Supp. Fig. 8A: Responsiveness of DEGs across experiments

In [None]:
"""
DEG count across ALL stresses
Separated by TFs and non-TFs
"""

import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt

genelist = "G:/My Drive/Projects/Marchantia_2019/phase1n2/deseq/resSig_compiled.txt"
tflist = "G:/My Drive/Projects/Marchantia_2019/grn/TF_collate/PlantTFDB_Mpov5r1_prediction_plusTFDB.txt"

genedf = pd.read_csv(genelist, sep="\t") # 82982 DGEs
tflist = [x.split("\t")[0] for x in open(tflist, "r").readlines()] # 397 TFs
genedf['type'] = genedf["gene"].apply(lambda x: "TF" if x in tflist else "notTF")

sum(genedf.type == "TF") # 1341 occurrences
len(set(genedf.gene)) # 12442 genes /19421 genes
len(set(genedf.gene[genedf.type == "TF"])) # 252 unique TFs / 397 TFs

genedf.to_csv("G:/My Drive/Projects/Marchantia_2019/grn/resSig_compiled_wType.txt", sep = "\t", index = False)

condcount = Counter(genedf.gene)
cdict = {"genecount" : [], "tfcount": [], "genes" : [], "TFs" : []}

for i in range(max(condcount.values())):
	genecount = [x for x in condcount if condcount[x] > i]
	TFcount = [x for x in genecount if x in tflist]
	cdict["genes"].append(genecount)
	cdict["TFs"].append(TFcount)
	cdict["genecount"].append(len(genecount))
	cdict["tfcount"].append(len(TFcount))

dfcount = pd.DataFrame(cdict)
dfcount.index.names = ["InMoreThan"]
dfcount.to_csv("G:/My Drive/Projects/Marchantia_2019/grn/count.txt", sep = "\t")
dfcount.plot.bar(y=["genecount", "tfcount"], use_index = True, logy = True)
plt.savefig(dir_path + 'figures/' + 'FigS8A.png', dpi=600)

#### 2.2.2 Reconstruction

In [None]:
%%R
# Matrix trimming
# Aims:
#       1) Trim matrix to only contain genes that are differentially
#          expressed in > 5 conditions
# Outputs:
#       1) Trimmed TPM matrix

library(readr)

ofile <- file.path(dir_path, "prep_files/all_stress_mt5.tsv")
ifile <- file.path(dir_path, "prep_files/all_stress.tsv")
cfile <- file.path(dir_path, "prep_files/count.txt")

tpm_mat <- as.matrix(read.table(ifile, header=TRUE, sep = "\t",
                                row.names = 1, as.is=TRUE))
mat_col_ref <- read_tsv(file.path(paste0(dir_path, 'summary_files/all_stress.txt')), col_names = FALSE)
# https://cran.r-project.org/web/packages/comprehenr/vignettes/Introduction.html
nname <- paste(mat_col_ref$X2, "_", to_vec(for(x in str_split(mat_col_ref$X1, "_")) x[[2]]), sep = "")
names(nname) <- mat_col_ref$X1

for (i in 1 : length(nname)){
  new <- nname[i]
  old <- names(nname[i])
  colnames(tpm_mat)[colnames(tpm_mat) == old] <- new
}

count_mat <- as.matrix(read.table(cfile, header=TRUE, sep = "\t",
                                row.names = 1, as.is=TRUE))

valid_str <- gsub("\\[|\\]",'', count_mat[6, 3]) #remove '[ and ]'
valid_genes <- unlist(strsplit(valid_str, ", ")) #split on ', '
pruned_mat <- tpm_mat[valid_genes,]
pruned_mat_dum <- rbind(pruned_mat, replicate(length(colnames(pruned_mat)), 0))

write.table(pruned_mat_dum, ofile, sep = "\t", col.names=NA)

## Note! Remove dummy lines in matrix manually. Updated matrix provided for download.

In [None]:
! cd dir_path + "/GRN_code"

In [None]:
%%R
# GRN Part 3 
# Started: 17 February 2022
# Aims:
#       1) Generate the linear regression models for gene-TF and TF-TF
#       2) Regularisations: LASSO/ elastic net
#       3) Intermediate outputs per response variable
#       4) Final output as a GRN with [goi_name, TF_name, coefficients(beta), gene_name (of TFs), relvar, s, lambda, cv]

# Dependencies:
# 1) dircreater.r
# 2) wrap_elnetv4_Mpo.r
# 3) TFelnetv6_Mpo.r
# note: cd to /home/qiaowen/Marchantia_2019/scripts/GRN_code/ to access before running code to access the dependencies

# Working paths (ChuckNorris)
ddir <- file.path(dir_path, "GRN_code/deps")
wdir <- file.path(dir_path, "prep_files/Mpo_GRN_models")
mat_path <- file.path(dir_path, "prep_files/all_stress_mt5_nodum.tsv")
tf_path <- file.path(dir_path, "prep_files/PlantTFDB_Mpov5r1_prediction_plusTFDB.txt")

## Script to create the elastic net derived grn
source('deps/wrap_elnetv3_Mpo.r')
#library(edgeR)
library(comprehenr)
set.seed(2019)
options(stringsAsFactors = FALSE)

cmode_name <- "elnet"

extract_matrix <- function(mat, x) {
  x_names <- to_vec(for(z in colnames(mat)) if(grepl(x, z, fixed = TRUE) & !grepl('control', z, fixed = TRUE)) z)
  new_matrix <- subset(mat, select = x_names)
  return(new_matrix)
}

#stresses <- c("L", "D", "H", "C", "S","M", "N", "all") # uncomment and change K=3 in TFelnetv6_Mpo.r for 3 fold cross-validation. 
stresses <- c("all") # K=5 in TFelnetv6_Mpo.r for 5 fold cross-validation

for (mat_type in stresses){
  print(paste("Building models for subset: ", mat_type))
  # load gene expression matrix and normalise (log transform followed by z transform)
  mat <- as.matrix(read.table(mat_path, header=TRUE, sep = "\t", row.names = 1, as.is=TRUE))
  resdir <- file.path(wdir, cmode_name, mat_type)
  
  if (mat_type != "all"){
    mat <- extract_matrix(mat, mat_type)
  }
  
  # kick out genes that are completely '0' across all conditions (aftifact of subsetting)
  mat <- mat[rowSums(mat) != 0,]
  
  ###
  # Checks if matrix contains zeros
  if (sum(mat>0) > 0){
    print("Warning: '0' present in gene expression matrix.")
    minval <- min(mat[mat > 0])
    print(paste("Minimum expression value:", minval))
    print("Replacing zeros with 1e-12")
    mat[mat == 0] <- 0.000000000001
  }
  
  # Log transforms matrix
  log_mat <- log(mat)
  #sum(colSums(log_mat == -Inf))
  gene_dat <- t(scale(t(log_mat)))
  
  #import TF annotation - names should be the same as rownames of the read data
  TF <- read.delim(tf_path, header = FALSE)
  tfs <- TF[, 1]
  tfs <- tfs[tfs %in% rownames(gene_dat)] # grab only TFs that are in current dataset
  #print(length(tfs))
  
  #Do elastic net regression analysis for each TF using all other TFs as predictors 
  #THIS STEP TAKES VERY LONG - RUN ON SERVER OR wITH MORE THEN 8 THREADS
  elnet_res <- wrap_elnet(gene_dat, resdir=resdir, thrsh=0, tfs=tfs, parallel=64, cmode=cmode_name)
  elnet_all <- elnet_res[,c('Gene.ID', 'predicted', 'rel.coeff')]
  colnames(elnet_all) <- c('from', 'to', 'weight' )
  elnet_all <- elnet_all[order(abs(elnet_all$weight), decreasing = TRUE),]
  save(elnet_all, file=file.path(resdir,'elnet_all.obj'))
  write.table(elnet_all, file.path(resdir, 'elnet_all.txt'), sep = "\t", col.names = NA)
}

In [None]:
!cd ~

#### 2.2.3 Supp. Fig 9: Optimising parameters for filtering GRN

In [None]:
%%R
# Finding the ideal cutoff to apply to the networks
# Aims:
#       1) General distribution of relative coefficient
#       2) Number of models per R^2 (Fig S9A)
#       3) Max coeff. per R^2
#       4) Number of nodes and edges per coeff cutoff

library(ggplot2)
library(RColorBrewer)

wdir <- file.path(dir_path, "prep_files/Mpo_GRN_models")
resname <- "elnet.txt"
stresses <- c("L", "D", "H", "C", "S","M", "N", "all_cv5")
cutoffs <- c(0.5, 0.625, 0.75)

# load networks
read_mat <- function(mat_path){
  as.data.frame(read.table(mat_path, header=TRUE, sep = "\t", row.names = 1, as.is=TRUE))
}

elnet_res <- lapply(stresses, function(x){read_mat(paste(wdir, "elnet", x, resname, sep="/"))})
names(elnet_res) <- stresses

# 1) General relative coefficient distribution
for (cf in cutoffs){
  # apply relvar cutoff
  elnet_cut <- lapply(elnet_res, function(x){subset(x, relvar > cf)})
  names(elnet_cut) <- stresses
  #multiplot
  pdf(file.path(dir_path, "prep_files", paste(cf, "_distribution_hist", ".pdf", sep="")), width = 8.5)
  par(mfrow=c(3,3))
  for (s in stresses){
    hist(elnet_cut[[s]][,"rel.coeff"], xlab = "Relative coefficient", main = paste(cf, s, "relative coefficient distribution"))
  }
  dev.off()
}

# 2) Number of models per R^2
cutoffs <- c(5:10)/10
cf_count <- matrix(0, nrow = length(cutoffs), ncol = length(stresses))
rownames(cf_count) <- cutoffs
colnames(cf_count) <- stresses
for (cf in cutoffs){
  # apply relvar cutoff
  elnet_cut <- lapply(elnet_res, function(x){subset(x, relvar > cf)})
  names(elnet_cut) <- stresses
  for (s in stresses){
    cf_count[as.character(cf), s] = length(unlist((unique(elnet_cut[[s]]["predicted"]))))
  }
}
pdf(file.path(dir_path, "figures", "FigS9A.pdf"), width = 7.5, height = 6)
matplot(cf_count, type = "o", lwd = 2, col = brewer.pal(length(stresses), "Dark2"),
        main = "Number of models at various R^2 cutoffs", xaxt = "n",
        lty = 1, xlab="R^2 cutoff", pch=16) # 
# Add X-axis
axis(side=1,at=1:nrow(cf_count),labels=cutoffs)
legend("topright", legend = stresses, col=brewer.pal(length(stresses), "Dark2"), lty = 1, pch=16, lwd=2)
#xlab("R^2 cutoff")
dev.off()

# 3) Max coeff. per R^2 (absolute)

pdf(file.path(dir_path, "prep_files", "coeff_R2_lm.pdf"))
par(mfrow=c(3,3))
for (s in stresses){
  unique_genes <- unlist(unique(elnet_res[[s]]["predicted"]))
  R2_val <- sapply(unique_genes, function(x){unlist(unique(subset(elnet_res[[s]], predicted == x, relvar)))})
  max_relcoeff <- sapply(unique_genes, function(x){max(abs(subset(elnet_res[[s]], predicted == x, rel.coeff)))})
  plot(R2_val, max_relcoeff, main = paste("Network", s), xlab = "Maximum relative coefficient", ylab = "R^2 value", cex = 0.5)
  lm_obj <- lm(max_relcoeff ~ R2_val)
  abline(lm_obj, col = "red")
  mtext(paste("R^2:", round(summary(lm_obj)$r.squared, 2)), 3, adj = 0.02, line = -1, cex = 0.5)
}
dev.off()

# 4) Number of nodes and edges per coeff cutoff (global)
get_cutoffs <- function(cf){
  quantiles <- seq(0.1, 0.9, 0.1)
  node_count <- matrix(0, nrow = length(quantiles), ncol = length(stresses))
  edge_count <- matrix(0, nrow = length(quantiles), ncol = length(stresses))
  colnames(node_count) <- stresses
  rownames(node_count) <- quantiles
  colnames(edge_count) <- stresses
  rownames(edge_count) <- quantiles
  
  # apply relvar cutoff
  elnet_cut <- lapply(elnet_res, function(x){subset(x, relvar > cf)})
  names(elnet_cut) <- stresses
  # get quantiles
  for (s in stresses){
    nw_quantile <- quantile(unlist(abs(elnet_cut[[s]][, "rel.coeff"])), quantiles)
    for (q in names(nw_quantile)){
      quantile_cut <- subset(elnet_cut[[s]], abs(rel.coeff) >= nw_quantile[q])
      nodes <- length(unique(append(unique(quantile_cut[, "predicted"]), unique(quantile_cut[, "Gene.ID"]))))
      edges <- nrow(quantile_cut)
      rname <- as.character(as.numeric(substr(q, 1, nchar(q)-1))/100)
      node_count[rname, s] <- nodes
      edge_count[rname, s] <- edges
    }
  }
  # plotting (Node)
  matplot(node_count, type = "l", lwd = 2, col = brewer.pal(length(stresses), "Dark2"),
          main = paste("Number of nodes, R^2:", cf), xaxt = "n",
          lty = 1:length(stresses), xlab="coefficient quantiles", ylab = "Number of nodes")
  # Add X-axis
  axis(side=1,at=1:nrow(node_count),labels=names(nw_quantile))
  legend("bottomright", inset=c(0.01 ,0.1), legend = stresses, col=brewer.pal(length(stresses), "Dark2"),
         lty = 1:length(stresses), cex = 0.8, pt.lwd = 2)

  # plotting (Edges)
  matplot(edge_count, type = "l", lwd = 2, col = brewer.pal(length(stresses), "Dark2"),
          main = paste("Number of edges, R^2:", cf), xaxt = "n",
          lty = 1:length(stresses), xlab="coefficient quantiles", ylab = "Number of edges")
  # Add X-axis
  axis(side=1,at=1:nrow(edge_count),labels=names(nw_quantile))
  legend("topright", legend = stresses, col=brewer.pal(length(stresses), "Dark2"),
         lty = 1:length(stresses), cex = 0.8, pt.lwd = 2)
}

pdf(file.path(dir_path, "prep_files", "network_count.pdf"), width = 8, height = 11)
for (cf in cutoffs){
  par(mfrow=c(2,1))
  get_cutoffs(cf)
}
dev.off()

In [None]:
"""
Fig S9B -- Part 1
Check overlap of AGRIS regulatory network and elastic net
"""
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

# Initialise Orthofinder output genes: orthogroup
OF_path = dir_path + "/prep_files" + '/Orthogroups.txt'

gene_OF = {}

with open(OF_path, 'r') as OF_file:
	for line in OF_file:
		content = line.strip("\n").split(": ")
		og = content[0]
		gene_list = content[1].split(" ")
		for gene in gene_list:
			gene_OF[gene] = og

# initialise agris
agris_path = dir_path + 'prep_files/AtRegNet_confirmed.txt'
agris_out_path = dir_path + 'prep_files/AtRegNet_confirmed_OG.txt'
agris_og_nodup = dir_path + 'prep_files/AtRegNet_OG_nw.txt'
# TFLocus, TargetLocus: 1, 4
agris_df = pd.read_csv(agris_path, sep="\t", header=0, index_col=None)
agris_df["TFOG"] = [gene_OF[x] if x in gene_OF else pd.NA for x in agris_df.TFLocus]
agris_df["TargetOG"] = [gene_OF[x] if x in gene_OF else pd.NA for x in agris_df.TargetLocus]
agris_df.dropna(inplace=True, subset=["TargetOG", "TFOG"])
agris_df.to_csv(agris_out_path, index=False, sep="\t")

no_dup = agris_df.drop_duplicates(subset=["TargetOG", "TFOG"])[["TargetOG", "TFOG"]]
no_dup.to_csv(agris_og_nodup, index=False, sep="\t")

no_dup_og = list(set(no_dup.TargetOG.to_list() + no_dup.TFOG.to_list()))

# initialise elnet
elnet_dir = dir_path + "/prep_files/Mpo_GRN_models/elnet/"
stresses = ["L", "D", "H", "C", "S","M", "N", "all_cv5"]
elnet = "elnet.txt"
quantile = [i/10 for i in range(0,10)] # change


for i in stresses:
	# add OG to main network file
	elnet_file = pd.read_csv(elnet_dir + i + "/" + elnet, header=0, sep="\t", index_col=0)
	elnet_file["predicted_GO"] = [gene_OF[x] for x in elnet_file.predicted]
	elnet_file["Gene_GO"] = [gene_OF[x] for x in elnet_file["Gene.ID"]]
	elnet_file.to_csv(elnet_dir + i + "/elnet_GO.txt", index=False, sep="\t")
	
	# GO terms are found in the agris netreg network
	elnet_in_agris = elnet_file[(elnet_file.predicted_GO.isin(no_dup_og)) & (elnet_file.Gene_GO.isin(no_dup_og))]
	elnet_in_agris.to_csv(elnet_dir + i + "/elnet_in_agris.txt", index=False, sep="\t")
	
	# remove duplicated interactions
	elnet_nodup = elnet_in_agris.drop_duplicates(subset=["predicted_GO", "Gene_GO"])[["predicted_GO", "Gene_GO"]]
	#elnet_nodup.to_csv(elnet_dir + i + "/elnet_in_agris_GOnw.txt", index=False, sep="\t")
		
# The actual comparison

no_dup["target_TF"] = no_dup.TargetOG + "_" + no_dup.TFOG
nw_real_ji = pd.DataFrame(0, columns=quantile, index=stresses+["union"])

for j in quantile:
	union_interactions = []
	for i in stresses:
		elnet_no_dup = pd.read_csv(elnet_dir + i + "/elnet_in_agris.txt", header=0, index_col=None, sep="\t")
		# R^2 > 0.8 cutoff
		elnet_no_dup = elnet_no_dup[elnet_no_dup.relvar > 0.8]
		elnet_no_dup["target_TF"] = elnet_no_dup.predicted_GO + "_" + elnet_no_dup.Gene_GO
		# Get coefficient cutoff for corresponding quantile
		cutoff = abs(elnet_no_dup["rel.coeff"]).quantile(j)
		elnet_no_dup = elnet_no_dup[abs(elnet_no_dup["rel.coeff"]) >= cutoff] # change
		if i != stresses[-1]:
			union_interactions.extend(elnet_no_dup.target_TF)
		# Calculations
		intersect = list(set(no_dup.target_TF.to_list()) & set(elnet_no_dup.target_TF.to_list()))
		union = list(set(no_dup.target_TF.to_list()) | set(elnet_no_dup.target_TF.to_list()))
		nw_real_ji.loc[i, j] = len(intersect)/len(union)

	# for union network filtered at R^2 0.8 and corresponding quantiles
	intersect = list(set(no_dup.target_TF.to_list()) & set(union_interactions))
	union = list(set(no_dup.target_TF.to_list()) | set(union_interactions))
	nw_real_ji.loc["union", j] = len(intersect)/len(union)

# normal jaccard (GRN&AGRIS/GRN)
#nw_real_ji.plot(title="Overlap of AGRIS with GRN at various coefficient cutoffs (quantile)",
#						  xlabel = "Networks", ylabel = "Jaccard Index", style='.-').legend(bbox_to_anchor=(1, 1))
#plt.xticks(ticks = [i for i in range(9)] ,labels = ["Light", "Dark", "Heat", "Cold", "Salt", "Mannitol", "Nitrogen deficiency", "All", "Union"], rotation = 'vertical')
#plt.savefig(dir_path + 'figures/AGRIS_overlap_chart_nJI_quantile0.png', dpi = 600, bbox_inches='tight')
nw_real_ji.to_csv(dir_path + 'prep_files/all_quantile_nJI_quantile0.txt', sep="\t")

# Just some summary stats
best_quantile_nJI = nw_real_ji.idxmax(axis=1)
Counter(nw_real_ji.idxmax(axis=1))
best_quantile_nJI.to_csv(dir_path + 'prep_files/best_quantile_nJI_quantile_0.txt', sep="\t", header=None)

# Proceeding with custom cutoffs
best_union_interactions_nJI = []
best_quantile_overlap_nJI = nw_real_ji.iloc[:-1,:].max(axis=1).to_list()
union_df_nJI = []

for i in stresses:
	elnet_no_dup = pd.read_csv(elnet_dir + i + "/elnet_in_agris.txt", header=0, index_col=None, sep="\t")
	# R^2 > 0.8 cutoff
	elnet_no_dup = elnet_no_dup[elnet_no_dup.relvar > 0.8]
	elnet_no_dup["target_TF"] = elnet_no_dup.predicted_GO + "_" + elnet_no_dup.Gene_GO
	# Get coefficient cutoff for corresponding quantile
	cutoff = abs(elnet_no_dup["rel.coeff"]).quantile(best_quantile_nJI[i])
	elnet_no_dup = elnet_no_dup[abs(elnet_no_dup["rel.coeff"]) >= cutoff]
	elnet_no_dup["nw_source"] = [i for x in range(len(elnet_no_dup))]
	if i != stresses[-1]:
		best_union_interactions_nJI.extend(elnet_no_dup.target_TF)
		union_df_nJI.append(elnet_no_dup)
intersect = list(set(no_dup.target_TF.to_list()) & set(best_union_interactions_nJI))
best_quantile_overlap_nJI.append(len(intersect)/len(set(best_union_interactions_nJI) | set(no_dup.target_TF.to_list()))) # ["L", "D", "H", "C", "S","M", "N", "all_cv5", "union"]

new_networks = ["L", "D", "H", "C", "S","M", "N", "all_cv5", "union"]
with open(dir_path + 'prep_files/overlap_best_quantile_0.txt', "w+") as bf:
	bf.write("\t" + "\t".join(new_networks) + "\n")
	bf.write("quantile\t" + "\t".join([str(x) for x in best_quantile[:-1]]) + "\tna\n")
	bf.write("custom_ratio\t" + "\t".join([str(x) for x in best_quantile_overlap]) + "\n")
	bf.write("quantile\t" + "\t".join([str(x) for x in best_quantile_nJI[:-1]]) + "\tna\n")
	bf.write("JI_ratio\t" + "\t".join([str(x) for x in best_quantile_overlap_nJI]) + "\n")

# normal JI ratio
union_all_nJI = pd.concat(union_df_nJI)
union_all_nJI.to_csv(elnet_dir + '/union_nw_nJI_quantile0.txt', sep="\t", index=False)

In [None]:
"""
Supp. Fig. S9B -- Part 2
"""

import pandas as pd
import random
import seaborn as sns
from collections import defaultdict
import matplotlib.pyplot as plt

agris_path = dir_path + 'prep_files/AtRegNet_confirmed_OG.txt'
elnet_dir = dir_path + 'prep_files/Mpo_GRN_models/elnet/'
elnet_name = '/elnet_in_agris.txt'
union_path = elnet_dir + 'union_nw_quantile0.txt'
union_path_nJI = elnet_dir + 'union_nw_nJI_quantile0.txt'

# initialise agris
agris = pd.read_csv(agris_path, header=0, index_col=None, sep="\t")
agris_OG_nodup = agris.drop_duplicates(subset=["TFOG", "TargetOG"])[["TargetOG", "TFOG"]]

agris_TFOG = agris_OG_nodup.TFOG.to_list()
agris_TargetOG = agris_OG_nodup.TargetOG.to_list()
agris_targetTF_OG = agris_OG_nodup.TargetOG +"_" + agris_OG_nodup.TFOG.to_list()


new_networks = ["L", "D", "H", "C", "S","M", "N", "all_cv5", "union"]
union_networks = ["union2", "union3", "union4", "union5", "union6", "union7"]
all_networks = ["L", "D", "H", "C", "S","M", "N", "all_cv5", "union1", "union2", "union3", "union4", "union5", "union6", "union7"]
quantile_path = dir_path + 'prep_files/overlap_best_quantile_0.txt'
with open(quantile_path, "r") as q:
	qcon = q.readlines()
	best_quantile = qcon[1].strip("\n").split("\t")[1:-1]
	expected_ratio = [float(x) for x in qcon[2].strip("\n").split("\t")[1:]]
	nJI_quantile = qcon[3].strip("\n").split("\t")[1:-1]
	expected_JI = [float(x) for x in qcon[4].strip("\n").split("\t")[1:]]

def get_grn_OG(nw_path, nw, index, cut_ref):
	elnet_no_dup = pd.read_csv(nw_path, header=0, index_col=None, sep="\t")
	if nw != "union":
		# R^2 > 0.8 cutoff
		elnet_no_dup = elnet_no_dup[elnet_no_dup.relvar > 0.8]
		elnet_no_dup["target_TF"] = elnet_no_dup.predicted_GO + "_" + elnet_no_dup.Gene_GO
		# Get coefficient cutoff for corresponding quantile
		cutoff = abs(elnet_no_dup["rel.coeff"]).quantile(float(cut_ref[index]))
		elnet_no_dup = elnet_no_dup[abs(elnet_no_dup["rel.coeff"]) >= cutoff]
	return(list(set(elnet_no_dup.target_TF)))

def get_grn_int(nw_path, nw, index, cut_ref):
	elnet_no_dup = pd.read_csv(nw_path, header=0, index_col=None, sep="\t")
	if nw != "union":
		# R^2 > 0.8 cutoff
		elnet_no_dup = elnet_no_dup[elnet_no_dup.relvar > 0.8]
		cutoff = abs(elnet_no_dup["rel.coeff"]).quantile(float(cut_ref[index]))
		elnet_no_dup = elnet_no_dup[abs(elnet_no_dup["rel.coeff"]) >= cutoff]
		gene_TF = elnet_no_dup["predicted"] + "_" + elnet_no_dup["Gene.ID"]
	else:
		unique_gene_pairs_df = elnet_no_dup.drop_duplicates(subset=["predicted", "Gene.ID"])[["predicted", "Gene.ID"]]
		gene_TF = unique_gene_pairs_df["predicted"] + "_" + unique_gene_pairs_df["Gene.ID"]
	return(gene_TF.to_list())

def get_ucut_OG(nw_path, nw):
	# for union netork, presence in X network
	elnet_no_dup = pd.read_csv(nw_path, header=0, index_col=None, sep="\t")
	# Get coefficient cutoff for corresponding quantile 
	cutoff = int(nw[-1])
	unique_gene_pairs_df = elnet_no_dup.drop_duplicates(subset=["predicted", "Gene.ID"])[["predicted", "Gene.ID"]]
	unique_gene_pairs = [[row["predicted"], row["Gene.ID"]] for idx, row in unique_gene_pairs_df.iterrows()]
	OGpairs = []
	for pair in unique_gene_pairs:
		subset = elnet_no_dup[(elnet_no_dup["predicted"] == pair[0]) & (elnet_no_dup["Gene.ID"] == pair[1])]
		if len(subset) >= cutoff:
			OGpairs.append(subset.target_TF.to_list()[0])
	return(list(set(OGpairs)))

def get_ucut_int(nw_path, nw):
	elnet_no_dup = pd.read_csv(nw_path, header=0, index_col=None, sep="\t")
	cutoff = int(nw[-1])
	unique_gene_pairs_df = elnet_no_dup.drop_duplicates(subset=["predicted", "Gene.ID"])[["predicted", "Gene.ID"]] #target_TF
	unique_gene_pairs = [[row["predicted"], row["Gene.ID"]] for idx, row in unique_gene_pairs_df.iterrows()]
	intpairs = []
	for pair in unique_gene_pairs:
		subset = elnet_no_dup[(elnet_no_dup["predicted"] == pair[0]) & (elnet_no_dup["Gene.ID"] == pair[1])]
		if len(subset) >= cutoff:
			intpairs.append(pair[0] + "_" + pair[1])
	return(intpairs)


# Permutation test
nw_OG_int_nJI = []
nw_int_nJI = []
for idx, network in enumerate(new_networks):
	if network == "union":
		fpath = union_path_nJI
	else:
		fpath = elnet_dir + network + elnet_name
	nw_OG_int_nJI.append(get_grn_OG(fpath, network, idx, nJI_quantile))
	nw_int_nJI.append(get_grn_int(fpath, network, idx, nJI_quantile))

for idx, network in enumerate(union_networks):
	fpath = union_path_nJI
	OG_pairs = get_ucut_OG(fpath, network)
	nw_OG_int_nJI.append(OG_pairs)
	nw_int_nJI.append(get_ucut_int(fpath, network))
	expected_JI.append(len(set(agris_targetTF_OG) & set(OG_pairs))/len(set(OG_pairs) | set(agris_targetTF_OG)))

dicto2 = defaultdict(list)
for i in range(1000):
	random.shuffle(agris_TFOG)
	int_pairs = [x + '_' + agris_TFOG[j] for j, x in enumerate(agris_TargetOG)]
	for k, nw in enumerate(all_networks):
		intersection = len(set(int_pairs) & set(nw_OG_int_nJI[k]))
		dicto2["network"].append(nw)
		dicto2["ratio"].append(intersection/len(set(nw_OG_int_nJI[k]) | set(int_pairs)))
ratio_coll_nJI = pd.DataFrame.from_dict(dicto2)
pval_nJI = [sum(ratio_coll_nJI[ratio_coll_nJI.network == x].ratio >= expected_JI[i])/1000 for i, x in enumerate(all_networks)]
sig_nw = [x for i, x in enumerate(all_networks) if pval_nJI[i] < 0.05]
sig_height = [ratio_coll_nJI[ratio_coll_nJI.network == all_networks[i]].ratio.max() + 0.001 for i, x in enumerate(pval_nJI) if x < 0.05]

# Plot result of permutuation test
fig, ax = plt.subplots()
sns.violinplot(x = "network", y = "ratio", data = ratio_coll_nJI, ax=ax, color="silver", scale='width')
ax.set_xticklabels(["Light", "Dark", "Heat", "Cold", "Salt", "Mannitol", "Nitrogen deficiency", "All", "Union (1)", "Union (2)", "Union (3)", "Union (4)", "Union (5)", "Union (6)", "Union (7)"], rotation=90)
ax.set_xlabel("Network")
ax.set_ylabel("Jaccard Index")
ax.plot(all_networks, expected_JI, 'ko')
ax.plot(sig_nw, sig_height, 'k*')
plt.savefig(dir_path + 'figures/FigS9B.png', dpi = 600, bbox_inches='tight')

#### 2.2.4 Finalising the gene regulatory network; Fig 3A, visualised in Cytoscape

In [None]:
# Get union network (including those not in AGRIS), cutoff R2 > 0.8
elnet_name = '/elnet.txt'

u_stresses = ["L", "D", "H", "C", "S","M", "N"]
 
def read_nw(nwpath, nw):
	nwdf = pd.read_csv(nwpath, header=0, index_col=0, sep="\t")
	nwdf_r2cf = nwdf[nwdf.relvar > 0.8]
	nwdf_r2cf["network"] = nw
	return(nwdf_r2cf)

nw_dfs = [read_nw(elnet_dir + x + elnet_name, x) for x in u_stresses]
# concatenate union
union_raw = pd.concat(nw_dfs)
union_raw.to_csv(elnet_dir + "union_0.8_ignoreAGRIS_full.txt", index=False, sep="\t")

In [None]:
# Aim: To retrieve top TF (highest absolute coefficient) for each gene among all TFs
# Fig 3A, visualised in Cytoscape
# get union network (including not in AGRIS)


# comment next 2 lines if running for the first time
union_raw_p = elnet_dir + 'union_0.8_ignoreAGRIS_full.txt'
union_raw = pd.read_csv(union_raw_p, header=0, sep="\t")

unique_targets = list(set(union_raw.predicted.to_list())) # 5878
top_int_df = pd.DataFrame(columns = ["predicted", "Gene.ID", "coeffs", "present_in", "top_coeff", "top_stress", "TF_stat"])
len_ut = len(unique_targets)

def get_top_TF(target):
	subset = union_raw[union_raw.predicted == target]
	sorted_subset = subset.reindex(subset["rel.coeff"].abs().sort_values(ascending=False).index)
	top_TF = sorted_subset.iloc[0,:]["Gene.ID"]
	top_coeff = sorted_subset.iloc[0,:]["rel.coeff"]
	top_nw = sorted_subset.iloc[0,:]["network"]
	TF_subset = sorted_subset[sorted_subset["Gene.ID"] == top_TF]
	nw_list = TF_subset.network.to_list()
	coeff_list = TF_subset["rel.coeff"].to_list()
	TF_stat_list = ["P" if x > 0 else "N" for x in coeff_list]
	stat_sum = set(TF_stat_list)
	if len(stat_sum) == 1:
		if TF_stat_list[0] == "P":
			TF_stat = "Activator"
		elif TF_stat_list[0] == "N":
			TF_stat = "Repressor"
	else:
		TF_stat = "Ambiguous"
	top_int_df.loc[unique_targets.index(target)] = [target, top_TF, coeff_list, nw_list, top_coeff, top_nw, TF_stat]

for i, target in enumerate(unique_targets):
	get_top_TF(target)
	if i % 1000 == 0:
		print("Processed: " + str(i) + "/" + str(len_ut))

top_int_df.to_csv(elnet_dir + "union_0.8_ignoreAGRIS_topTF.txt", index=False, sep="\t")

from collections import Counter
import matplotlib.pyplot as plt
stat_count = Counter(top_int_df.TF_stat.to_list())
plt.pie(stat_count.values(), labels = stat_count.keys(), autopct = lambda x: '{:.0f}'.format(x*sum(stat_count.values())/100))
nw_count = Counter(top_int_df.top_stress.to_list()) # Counter({'M': 536, 'C': 816, 'S': 961, 'D': 750, 'H': 816, 'N': 1086, 'L': 913})
plt.pie(nw_count.values(), labels = nw_count.keys(), autopct = lambda x: '{:.0f}'.format(x*sum(nw_count.values())/100))

#### 2.2.5. TF-only GRN, Fig 4A (visualised in cytoscape) and Supp. Fig 11

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Paths
elnet_path = elnet_dir + '/union_0.8_ignoreAGRIS_full.txt'
tf_path = dir_path + 'prep_files/PlantTFDB_Mpov5r1_prediction_plusTFDB.txt'

elnet_df = pd.read_csv(elnet_path, header=0, sep="\t")
tf_list = [x.split("\t")[0] for x in open(tf_path, "r").readlines()]

elnet_tf_df = elnet_df[(elnet_df.predicted.isin(tf_list)) & (elnet_df["Gene.ID"].isin(tf_list))] # network containing TFs only

def overall_spec(df_row):
    # retrieve status from specificity file
	status = ["UP" if x > 0 else "DOWN" for x in list(df_row) if x != 0]
	status = list(set(status))
	if status == ["NC"]:
		return "NS"
	elif status == ["UP"]:
		return "UP"
	elif status == ["DOWN"]:
		return "DOWN"
	else:
		return "MIXED"

# specificity file, spec > 0.7	
desc_path = elnet_dir + 'TF_Scond_specificity_07.txt'
desc_file = pd.read_csv(desc_path, header=0, sep="\t")
desc_file.columns = ["desc"] + desc_file.columns.to_list()[1:]
desc_file["gene"] = desc_file.desc.apply(lambda x: x.split(" ")[0])
desc_file["overall_spec"] = desc_file.iloc[:,1:8].apply(lambda x: overall_spec(x), axis=1)
# annotation file for network when viewed in cytoscape
desc_file.to_csv(dir_path + 'prep_files/network_anno.txt', index=False, sep="\t")

# for each gene, get the top DEG status in stress X containing network
elnet_tf_df["predicted_stat"] = elnet_tf_df.apply(lambda x: desc_file[desc_file.gene == x.predicted].overall_spec.values[0] if len(desc_file[desc_file.gene == x.predicted].overall_spec.values)!= 0 else "NS", axis=1)
elnet_tf_df["geneid_stat"] = elnet_tf_df.apply(lambda x: desc_file[desc_file.gene == x["Gene.ID"]].overall_spec.values[0] if len(desc_file[desc_file.gene == x["Gene.ID"]].overall_spec.values)!= 0 else "NS", axis=1)

def check_outcome(pstat, gstat, relcoef):
	if pstat == "NS" and gstat == "NS":
			return "unexpected"
	elif pstat == "NS" or gstat == "NS":
			return "unexpected"
	elif pstat == "MIXED" and gstat == "MIXED":
			return "unexpected"
	elif pstat == "MIXED" or gstat == "MIXED":
			return "unexpected"
	elif len(set([pstat, gstat])) == 1:
		if relcoef > 0:
			return "expected"
		else:
			return "unexpected"
	else:
		if relcoef < 0:
			return "expected"
		else:
			return "unexpected"
elnet_tf_df["grn_tally"] = elnet_tf_df.apply(lambda x: check_outcome(x.predicted_stat, x.geneid_stat, x["rel.coeff"]), axis=1)

# filter for unique target-TF pairs with top relative coefficient
gene_pairs = [list(x) for i, x in elnet_tf_df[["predicted", "Gene.ID"]].drop_duplicates().iterrows()]
elnet_tf_unique = pd.DataFrame(columns=elnet_tf_df.columns.to_list())
for g in gene_pairs:
	sub = elnet_tf_df[(elnet_tf_df.predicted == g[0]) & (elnet_tf_df["Gene.ID"] == g[1])]
	elnet_tf_unique.loc[len(elnet_tf_unique)] = sub.loc[abs(sub["rel.coeff"]).idxmax()].to_list()

def plot_stats(df, ytype, start=0, end=55):
    # plots number of expected and unexpected edges for each coefficient cutoff 0 - 0.54, step 0.01
	cutoff_count = [df[abs(df["rel.coeff"]) > i/100].grn_tally.value_counts() for i in range(start, end)]
	cutoff_ratio = [x.expected/sum(x) for x in cutoff_count]
	if ytype == "edges":
		cutoff_size = [sum(x) for x in cutoff_count]
		ylab = "Number of edges"
	elif ytype == "nodes":
		cutoff_size = [len(list(set(x.predicted.to_list() + x["Gene.ID"].to_list()))) for x in [df[abs(df["rel.coeff"]) > i/100] for i in range(start, end)]]
		ylab = "Number of nodes"
	
	fig, ax1 = plt.subplots()
	color = 'royalblue'
	ax1.set_xlabel('Absolute relative coefficient cutoff')
	ax1.set_ylabel('Ratio of expected versus total interactions', color=color)
	ax1.plot([i/100 for i in range(start,end)], cutoff_ratio, color=color)
	ax1.tick_params(axis='y', labelcolor=color)
	
	ax2 = ax1.twinx()
	color = 'firebrick'
	ax2.set_ylabel(ylab, color=color)  # we already labellled the x-label in ax1
	ax2.plot([i/100 for i in range(start,end)], cutoff_size, color=color)
	ax2.tick_params(axis='y', labelcolor=color)
	
	fig.tight_layout()
	plt.savefig(dir_path + 'figures/' + 'FigS11_{}.png'.format(ytype), dpi=600)
	plt.show()
	return[[i/100 for i in range(start, end)], cutoff_ratio, cutoff_size]

node_details = plot_stats(elnet_tf_unique, "nodes")
edge_details = plot_stats(elnet_tf_unique, "edges")

# Find optimal cutoff and output network to file
optimal_cutoff = node_details[0][node_details[1].index(max(node_details[1]))] #0.22
filtered_elnet = elnet_tf_unique[abs(elnet_tf_unique["rel.coeff"]) > optimal_cutoff]
filtered_elnet["abs_relcoeff"] = filtered_elnet["rel.coeff"].apply(abs)
filtered_elnet.to_csv(dir_path + 'prep_files/network_cf22_spec.txt', index=False, sep="\t")

### 2.3 Extract Arabidopsis GO and TFs

In [None]:
"""
Flatten GO file to 1 line per gene
"""

import pandas as pd

GOpath = dir_path + 'prep_files/ATH_GO_GOSLIM.txt'
TFpath = dir_path + 'prep_files/Ath_TF_list.txt'
opath = dir_path + 'prep_files/'

GO_df = pd.read_csv(GOpath, skiprows = 4, header=None, sep="\t")
GO_df.columns = ['locus_name', 'TAIR_accession', 'object_name',
				 'relationship_type', 'GO_term', 'GO_ID', 'TAIR_Keyword_ID',
				 'Aspect', 'GOslim_term', 'Evidence_code', 'Evidence_description',
				 'Evidence_with', 'Reference','Annotator', 'Date_annotated']
TF_df = pd.read_csv(TFpath, header=0, sep="\t")
TF_genes = list(set(TF_df.Gene_ID))
all_genes = GO_df.locus_name.unique().tolist()

# define codes relating to experimental evidence and high throughput experiments
exp_codes = ['EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP']
exp_codes_htp = ['HTP', 'HDA', 'HMP', 'HGI', 'HEP']
GO_df_EXP = GO_df[GO_df.Evidence_code.isin(exp_codes)]
EXP_genes = list(GO_df_EXP.locus_name.unique())
EXP_TF = list(set(EXP_genes) & set(TF_genes))
GO_df_HTP = GO_df[GO_df.Evidence_code.isin(exp_codes_htp)]
HTP_genes = [x for x in list(GO_df_HTP.locus_name.unique()) if x not in EXP_TF]
HTP_TF = list(set(HTP_genes) & set(TF_genes))
EXP_HTP_TF = list(set(EXP_TF) | set(HTP_TF))

dicto = {'locus': [], 'MolFunc': [], 'BioProc': [], 'type': []}
for i, gene in enumerate(TF_genes):
	if i+1 % 100 == 0:
		print("{}/1000 genes".format(i+1))
	MolF = list(GO_df[(GO_df.locus_name == gene) & (GO_df.Aspect == "F")].GO_term.unique())
	BioP = list(GO_df[(GO_df.locus_name == gene) & (GO_df.Aspect == "P")].GO_term.unique())
	dicto['locus'].append(gene)
	dicto['MolFunc'].append(MolF)
	dicto['BioProc'].append(BioP)
	
	if gene in EXP_TF:
		dicto['type'].append('EXP')
	elif gene in HTP_TF:
		dicto['type'].append('HTP')
	else:
		dicto['type'].append('OTHERS')
		
EXP_TF_aspect_df = pd.DataFrame.from_dict(dicto)
EXP_TF_aspect_df.to_csv(opath + 'Ath_TFall_GOanno.txt',
						sep="\t", index=False)

In [None]:
"""
Map Mpo differentially expressed TFs to Ath
"""

# get differentially expressed Mpo TF
mpo_tf_path = dir_path + "prep_files/PlantTFDB_Mpov5r1_prediction_plusTFDB.txt"
mpo_exp_path = dir_path + "prep_files/Mpo_GRN_models/all_stress_mt5_nodum.tsv"

mpo_tf_list = pd.read_csv(mpo_tf_path, sep="\t").iloc[:,0]
mpo_exp_tf_list = pd.read_csv(mpo_exp_path, header=0, sep="\t").iloc[:,0]

mpo_nw_tf = list(set(mpo_tf_list) & set(mpo_exp_tf_list))

# orthogroup file
OF_path = dir_path + 'prep_files/Orthogroups.txt'
gene_OF = {}

with open(OF_path, 'r') as OF_file:
	for line in OF_file:
		content = line.strip("\n").split(": ")
		og = content[0]
		gene_list = content[1].split(" ")
		for gene in gene_list:
			gene_OF[gene] = og

# Ath TF GO file (flat)
ath_go_path = dir_path + 'prep_files/Ath_TFall_GOanno.txt'
ath_go_df = pd.read_csv(ath_go_path, header=0, index_col=0, sep="\t")
ath_go_df["OG"] = [gene_OF[x] for x in ath_go_df.index.to_list()]

# conversion
mpo_nw_tf_og = [gene_OF[x] for x in mpo_nw_tf]

unique_og = list(set(mpo_nw_tf_og))
ath_genes = {og: ath_go_df[ath_go_df.OG == og].index.to_list() for og in unique_og}
ath_desc = {og: [list(x)[:-1] for i, x in ath_go_df[ath_go_df.OG == og].iterrows()] for og in unique_og}

compilation = pd.DataFrame.from_dict({"Mpo_genes": mpo_nw_tf, "OG": mpo_nw_tf_og})
compilation["Ath_genes"] = compilation.OG.apply(lambda x: ath_genes[x])
compilation["Ath_desc"] = compilation.OG.apply(lambda x: ath_desc[x])
compilation_nona_expOnly = compilation[["EXP" in [z for y in x for z in y] for x in compilation.Ath_desc]]
compilation_nona_expOnly.to_csv(dir_path + 'prep_files/MpoDEGTF_AthAllTF_noNA_exp.txt', index=False, sep="\t")

# 3. Analysis and plotting

### Figure 1 & Supp. Fig 1: Measurements and Student's t-test

In [None]:
# adated from measurements_forsupp.py
wdir = dir_path + 'prep_files/'
odir = dir_path + 'figures/'
odir_safe = dir_path_safe + 'figures/'
if not os.path.exists(odir):
    !mkdir $odir_safe
infile = 'phase1n2_measurements_nooutliers.txt'

measurements = pd.read_csv(wdir + infile, sep='\t')
# Single letter to full single stress description
singled = {'C':'Cold',
		   'H':'Heat',
		   'S':'Salt',
		   'M':'Mannitol',
		   'L':'Light',
		   'D':'Dark',
		   'N':'Nitrogen'
		   }

# For plotting all controls
areatype = ['Parea', 'Earea']
titletype = ['15', '21']
for i in range(0,2):
	area = areatype[i]
	title = titletype[i]
	control_m = measurements[measurements.Stress == 'Control'][["Batch", area]].groupby('Batch', sort = False).mean()
	control_s = measurements[measurements.Stress == 'Control'][["Batch", area]].groupby('Batch', sort = False).std()
	control_m.plot.bar(yerr=[list(control_s[area]), list(control_s[area])[::-1]], legend=False, title='Control (Day '+ title + ')', capsize=4)
	plt.savefig(odir + 'Control_Day' + title + '.png', dpi = 600, bbox_inches='tight')
	plt.show()

# df with only single stress measurements
ss_meas = measurements[(measurements.Condition != 'None') & (measurements.Condition != 'mixed')]
# df with only crossed stress measurements
cs_meas = measurements[measurements.Condition == 'mixed']
# df with only controls
control_meas = measurements[measurements.Stress == 'Control']

controltype = ['Stress']
controltitle = ['_merged']

xaxislabel = {'Heat': 'Temperature (\u00B0C)',
			  'Cold': 'Temperature (\u00B0C)',
			  'Mannitol': 'Mannitol (mM)',
			  'Salt': 'NaCl (mM)',
			  'Light': 'Light intensity (\u03bcEm\u207b\u00b2s\u207b\u00b9)',
			  'Dark': 'Days',
			  'Nitrogen': 'KNO\u2083 (%)'}

# Supp. Fig. 1
# t-test (control as b, following test, a)
tout = open(wdir + 'ttest.txt', 'w+') 
from scipy import stats as st

for ss in list(ss_meas.Stress.unique()):
	for i, c in enumerate(controltype):
		control_batches = ss_meas[ss_meas.Stress == ss].Batch.unique()
		control_mean = control_meas[control_meas.Batch.isin(control_batches)].groupby(c, sort = False).mean()
		control_std = control_meas[control_meas.Batch.isin(control_batches)].groupby(c, sort = False).std()
		stress_mean = ss_meas[ss_meas.Stress == ss].groupby('Condition', sort=False).mean()
		stress_std = ss_meas[ss_meas.Stress == ss].groupby('Condition', sort=False).std()
		
		if c == 'Stress': # t-test
			control_df = control_meas[control_meas.Batch.isin(control_batches)][['Parea', 'Earea']]
			stress_conds = ss_meas[ss_meas.Stress == ss].Condition.unique()
			for k, a in enumerate(areatype):
				for scond in stress_conds:
					stress_df = ss_meas[(ss_meas.Stress == ss) & (ss_meas.Condition == scond)]
					tstat, pval = st.ttest_ind(stress_df[a], control_df[a])
					tout.write(('\t').join(['Day '+ titletype[k], ss + '_' + scond, 'control_merged', str(tstat), str(pval)]) + "\n")
		
		labels = control_mean.index.to_list() + stress_mean.index.to_list()
		for j, a in enumerate(areatype): # Day 15 or 21 area
			coll_mean = list(control_mean[a]) + list(stress_mean[a])
			coll_std = list(control_std[a]) + list(stress_std[a])
			plt.bar(labels, coll_mean, yerr = coll_std, capsize=4)
			plt.title(ss + ' (Day ' + titletype[j] + ')')
			plt.xlabel(xaxislabel[ss]) 
			plt.ylabel('Area (mm\u00b2)')
			plt.savefig(odir + ss + '_Day' + titletype[j] + controltitle[i] + '.png', dpi = 600, bbox_inches='tight')
			plt.show()

# cross_stress plot
for i, c in enumerate(controltype):
	cs_control_batches = cs_meas.Batch.unique()
	cs_control_mean = control_meas[control_meas.Batch.isin(cs_control_batches)].groupby(c, sort = False).mean()
	cs_control_std = control_meas[control_meas.Batch.isin(cs_control_batches)].groupby(c, sort = False).std()
	cs_stress_mean = cs_meas.groupby('Stress', sort=False).mean()
	cs_stress_std = cs_meas.groupby('Stress', sort=False).std()
	cs_labels = cs_control_mean.index.to_list() + cs_stress_mean.index.to_list()
	
	if c == 'Stress': #t-test
		control_df = control_meas[control_meas.Batch.isin(cs_control_batches)][['Parea', 'Earea']]
		for k, a in enumerate(areatype):
			for ss in list(cs_meas.Stress.unique()):
				stress_df = cs_meas[(cs_meas.Stress == ss)]
				tstat, pval = st.ttest_ind(stress_df[a], control_df[a])
				tout.write(('\t').join(['Day '+ titletype[k], ss + '_' + 'mixed', 'control_merged', str(tstat), str(pval)]) + "\n")
					
	for j, a in enumerate(areatype): # Day 15 or 21 area
			coll_mean = list(cs_control_mean[a]) + list(cs_stress_mean[a])
			coll_std = list(cs_control_std[a]) + list(cs_stress_std[a])
			plt.bar(cs_labels, coll_mean, yerr = coll_std, capsize=4)
			plt.title('Cross stress (Day ' + titletype[j] + ')')
			plt.xticks(rotation=90)
			plt.xlabel('Experiment') 
			plt.ylabel('Area (mm\u00b2)')
			plt.savefig(odir + 'Cross_stress_Day' + titletype[j] + controltitle[i] + '.png', dpi = 600, bbox_inches='tight')
			plt.show()
			
# single stress reps and cross stress (control - merged)
def ss_grab(stress, condition):
	"""
	Slice the relevant condition for 
	Parameters
	----------
	stress : string
		Stress of interest.
	condition : string
		Condition of interest.

	Returns
	-------
	sssub : dataframe
		df of single stress.

	"""
	sssub = measurements[(measurements.Stress == stress) & (measurements.Condition == condition)]
	return sssub

srep_keys = [['Cold', '3'],
			 ['Heat', '33'],
			 ['Salt', '40'],
			 ['Mannitol', '100'],
			 ['Light', '435'],
			 ['Dark', '3'],
			 ['Nitrogen', '0']]
srepdf = measurements[(measurements.Stress == 'Cold') & (measurements.Condition == '3')]

for s, c in srep_keys[1:]:
	srepdf = pd.concat([srepdf, ss_grab(s, c)])
	
s_cs_meas = pd.concat([srepdf, cs_meas])
	
s_cs_control_batches = list(cs_meas.Batch.unique()) + list(srepdf.Batch.unique())
s_cs_control_mean = control_meas[control_meas.Batch.isin(s_cs_control_batches)].groupby('Stress', sort = False).mean()
s_cs_control_std = control_meas[control_meas.Batch.isin(s_cs_control_batches)].groupby('Stress', sort = False).std()
s_cs_stress_mean = s_cs_meas.groupby('Stress', sort=False).mean()
s_cs_stress_std = s_cs_meas.groupby('Stress', sort=False).std()

s_cs_meas_label = [x + ' (' + x[0] + ')' if len(x) > 2 else x for x in s_cs_stress_mean.index]
s_cs_meas_label[s_cs_meas_label.index('Light (L)')] = 'High light (L)'
s_cs_meas_label[s_cs_meas_label.index('Dark (D)')] = 'Darkness (D)'
s_cs_labels = s_cs_control_mean.index.to_list() + s_cs_meas_label

# Fig1
#t-test
control_df = control_meas[control_meas.Batch.isin(s_cs_control_batches)][['Parea', 'Earea']]
## cross-stress
for k, a in enumerate(areatype):
	for ss in list(cs_meas.Stress.unique()):
		stress_df = cs_meas[(cs_meas.Stress == ss)]
		for singleS in ss:
			singlecontrol = srepdf[srepdf.Stress == singled[singleS]][a]
			tstat, pval = st.ttest_ind(stress_df[a], singlecontrol)
			tout.write(('\t').join(['Day '+ titletype[k], ss + '_' + 'mixed', 'control_' + singled[singleS], str(tstat), str(pval)]) + "\n")
## single stress
for k, a in enumerate(areatype):
	for ss in list(srepdf.Stress.unique()):
		cond = srepdf[(srepdf.Stress == ss)].Condition.unique()[0]
		stress_df = srepdf[(srepdf.Stress == ss)][a]
		tstat, pval = st.ttest_ind(stress_df, control_df[a])
		tout.write(('\t').join(['Day '+ titletype[k], ss + '_' + cond, 'control', str(tstat), str(pval)]) + "\n")
		
tout.close()	
# plotting
colour_seq = ['tomato'] +  ['mediumseagreen']*7 + ['cornflowerblue']*20
for j, a in enumerate(areatype): # Day 15 or 21 area
		coll_mean = list(s_cs_control_mean[a]) + list(s_cs_stress_mean[a])
		coll_std = list(s_cs_control_std[a]) + list(s_cs_stress_std[a])
		plt.bar(s_cs_labels, coll_mean, yerr = coll_std, capsize=4, color = colour_seq)
		plt.title('Area (Day ' + titletype[j] + ')')
		plt.xticks(rotation=90)
		plt.xlabel('Experiment') 
		plt.ylabel('Area (mm\u00b2)')
		plt.savefig(odir + 'fig1_Day' + titletype[j] + '.png', dpi = 600, bbox_inches='tight')
		plt.show()

### Supp. Fig 2: QC of RNA-seq data

In [None]:
# adapted from QC_scaled_updated.py
from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr

o_dir = dir_path + 'figures/'
sumdir = dir_path + 'summary_files/'

expdesc = ['all_stress', 'diurnal_exp', 'single_stress']
targetexp = expdesc[0]
targetp = sumdir + targetexp + '.txt'
expmatp = dir_path + 'prep_files/' + targetexp + '.tsv'
exps = [x.split("\t")[0] for x in open(targetp, "r").readlines()]
labels = [x.strip().split("\t")[1] + '_' + x.split("\t")[0].split('_')[1] for x in open(targetp, "r").readlines()]

df = pd.read_csv(expmatp, index_col = 0, sep = "\t", header = 0)
df.columns = labels

# Standard Scaling
scaled_features = StandardScaler().fit_transform(df.values)
df_scaled = pd.DataFrame(scaled_features, index = df.index, columns = df.columns)

# plot cluster map
sns.set(font_scale=1.6)

methods = "average"

g1 = sns.clustermap(df_scaled.corr(),
				 method = methods,
				 figsize=(20,20),
				 xticklabels=True,
				 yticklabels=True)
plt.title("All stress (scaled): " + methods)
plt.savefig(o_dir + "SuppFig2" + '.png')


# PCC of experiments
pcc_out = open(dir_path + "prep_files/mpo/all_stress_PCC.txt", "w+")
pcc_out.write("exp1\texp2\tpcc_val\tp_value\n")
exps = list(df_scaled.columns)
for exp1 in range(len(exps)):
	for exp2 in range(exp1):
		if exps[exp1].split("_")[0] == exps[exp2].split("_")[0]:
			pcc_val, p_value = pearsonr(df_scaled[exps[exp1]], df_scaled[exps[exp2]])
			pcc_out.write(exps[exp1] + "\t" + exps[exp2] + "\t" + str(pcc_val) + "\t" + str(p_value) + "\n")
pcc_out.close()

### Supp. Fig 3: Volcano plots (DESeq2)

In [None]:
# adapted from deseq_volcano.py

wdir = dir_path + 'prep_files/mpo/deseq/'
odir = wdir + 'volcano/'
deseqouts = [x for x in os.listdir(wdir) if "res.tsv" in x] # controlD2controlH2_res.tsv
control = 'controlD2controlH2_res.tsv'
deseqouts.pop(deseqouts.index(control))
all_stress = list(set([x.split('control')[0] for i, x in enumerate(deseqouts)]))
all_stress.sort()
s_stress = [x for x in all_stress if len(x) == 1]
c_stress = [x for x in all_stress if len(x) == 2]
all_stress = s_stress + c_stress

# plot control
controls = pd.read_csv(wdir + control,
			  sep = "\t", header = 0, index_col = 0)
sns.scatterplot(x = controls['log2FoldChange'],
				 y = -np.log10(controls["padj"]),
				 #ax = axs[axcord[0][0], axcord[0][1]],
				 alpha = 0.2,
				 marker = '.',
				 legend = False,
				 edgecolor = "none",
				 hue = np.logical_and(abs(controls['log2FoldChange']) > 1,
						  -np.log10(controls["padj"]) > -np.log10(0.05)))
plt.title("control " + control.split("control")[1] + " vs control H2")

# plot everything else
xlen = 5
ylen = math.ceil(len(deseqouts)/5)
fig, axs = plt.subplots(ylen, xlen, figsize=(30, 37.5), sharex='col', sharey='row')
#sns.set(font_scale=1.6)
axcord = []
for a in range(ylen):
	for b in range(xlen):
		axcord.append([a, b])

for i, z in enumerate(all_stress):
	files = [x for x in deseqouts if x.startswith(z+'control')]
	files.sort()
	fileD2 = pd.read_csv(wdir + files[0],
					  sep = "\t", header = 0, index_col = 0)
	fileH2 = pd.read_csv(wdir + files[1],
					  sep = "\t", header = 0, index_col = 0)
	D2ax = int(((i*2)//10)*10 + ((i*2)%10)/2)
	H2ax = int(D2ax + 5)
	# Volcano plots
	# against control D2
	sns.scatterplot(x = fileD2['log2FoldChange'],
				 y = -np.log10(fileD2["padj"]),
				 ax = axs[axcord[D2ax][0], axcord[D2ax][1]],
				 alpha = 0.2,
				 marker = '.',
				 legend = False,
				 edgecolor = "none",
				 hue = np.logical_and(abs(fileD2['log2FoldChange']) > 1,
						  -np.log10(fileD2["padj"]) > -np.log10(0.05)))
	axs[axcord[D2ax][0], axcord[D2ax][1]].set_title(z + " vs control D2")
	
	# against control H2
	sns.scatterplot(x= fileH2['log2FoldChange'],
				 y = -np.log10(fileH2["padj"]),
				 ax = axs[axcord[H2ax][0], axcord[H2ax][1]],
				 alpha = 0.2,
				 marker = '.',
				 legend = False,
				 edgecolor = "none",
				 hue = np.logical_and(abs(fileH2['log2FoldChange']) > 1,
						  -np.log10(fileH2["padj"]) > -np.log10(0.05)))
	axs[axcord[H2ax][0], axcord[H2ax][1]].set_title(z + " vs control H2")
	
for ax in axs.flat:
    ax.set(xlabel='log2FoldChange', ylabel='-log10 padj')

# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axs.flat:
    ax.label_outer()

plt.savefig(dir_path + "figures/SuppFig3.png", dpi = 600)

### Supp. Fig 4: Comparison of DEGs between two controls

In [None]:
wdir = mpo_path
deseqouts = [x for x in os.listdir(wdir) if "resSig.tsv" in x]
deseqouts.remove('controlD2controlH2_resSig.tsv')

# create subplots
xlen = 4
ylen = math.ceil(len(deseqouts)/8)
figw = xlen * 4
figh = ylen * 2.5

stress_list = ['H', 'C', 'HM', 'CM', 'M', 'CL', 'ML', 'L', 'HS', 'CS', 'SM', 'LS', 'S',
			   'HN', 'CN', 'MN', 'NL', 'SN', 'N', 'HD', 'CD', 'MD', 'SD', 'ND', 'D']
f_axes = string.ascii_uppercase[:len(stress_list)]
axd = plt.figure(constrained_layout=True,
				 figsize=(figw, figh)).subplot_mosaic(
					 """
					 A......
					 .B.....
					 CDE....
					 .FGH...
					 IJKLM..
					 NOPQRS.
					 TUV.WXY
					 """,
					 gridspec_kw = {'hspace' : 0.3}
					 )

counter = 0
for i in stress_list:
	files = [x for x in deseqouts if x.startswith(i+'control')]
	files.sort()
	fileD2 = pd.read_csv(wdir + files[0],
					  sep = "\t", header = 0, index_col = 0)
	fileH2 = pd.read_csv(wdir + files[1],
					  sep = "\t", header = 0, index_col = 0)
	D2index, H2index = set(fileD2.index.tolist()), set(fileH2.index.tolist())
	status = []
	
	# Create sets
	D2only = D2index - H2index
	H2only = H2index - D2index
	D2H2 = D2index & H2index
	
	# Subsets of df for D2only, H2only and D2H2
	D2onlydf = fileD2[fileD2.index.isin(list(D2only))]
	H2onlydf = fileH2[fileH2.index.isin(list(H2only))]
	D2H2df = fileD2.append(fileH2)
	D2H2df = D2H2df[D2H2df.index.isin(list(D2H2))]
	D2H2df.sort_index(inplace=True)
	
	# Create list of lists of all differentially expressed genes with corresponding status
	stress = "Mpo_" + files[0].split("controlD2")[0]
	for j in D2only:
		status.append([j, stress, 'D2', str(fileD2.loc[j, "log2FoldChange"]), str(fileD2.loc[j, "padj"]), "N/A", "N/A"])
	for k in H2only:
		status.append([k, stress, 'H2', "N/A", "N/A", str(fileH2.loc[k, "log2FoldChange"]), str(fileH2.loc[k, "padj"])])
	for m in D2H2:
		status.append([m, stress, 'D2H2', str(fileD2.loc[m, "log2FoldChange"]), str(fileD2.loc[m, "padj"]), str(fileH2.loc[m, "log2FoldChange"]), str(fileH2.loc[m, "padj"])])
	status.sort()
	
	# Output file with genes and status: D2, H2 or D2H2
	with open(wdir + "sets/results/" + stress + ".txt", "w+") as filo:
		filo.write("gene\tstress\tstatus\tL2FC_D2\tpadj_D2\tL2FC_H2\tpadj_H2\n")
		for n in status:
			filo.write("\t".join(n) + "\n")
			
	# Venn diagram
	sp_ax = axd[f_axes[stress_list.index(stress.split('_')[1])]]
	venn2(subsets=(len(D2only), len(H2only), len(D2H2)),
	   set_labels = ('D2', 'H2'),
	   ax = sp_ax)
	sp_ax.set_title(stress.split('_')[1], size=14)
	plt.savefig(dir_path + "figures/" + "suppfig4.png", dpi = 600)
	
	# increase counter
	counter += 1

In [None]:
# Sort genes according to whether they are the same in both controls
# adapted from compile_sigGenes_phase1n2.py
deseqdir = mpo_path
setdir = deseqdir + 'sets/results/'
setdir_safe = setdir.replace(' ', '\ ')
if not os.path.exists(setdir):
	!mkdir -p $setdir_safe

merfile = open(dir_path + 'mercator/MpoProt.results.txt', 'r')
ofile = open(deseqdir + 'resSig_compiled.txt', 'w+')
efile = open(deseqdir + 'resSig_failed.txt', 'w+')
setfiles = [x for x in os.listdir(setdir) if '.txt' in x]

ofile.write("\t".join(['gene', 'stress', 'L2FC_D2', 'L2FC_H2', 'annotation']) + "\n")
efile.write("\t".join(['gene', 'stress', 'L2FC_D2', 'L2FC_H2', 'annotation']) + "\n")

def get_anno(gene):
	gene = gene.lower()
	return meranno[gene]

def up_down(val):
	if val < 0:
		stat = "DOWN"
	elif val > 0:
		stat = "UP"
	elif math.isnan(val):
		stat = "NaN"
	return stat

meranno = {}

for line in merfile:
	if len(line.rstrip().split("\t")) == 5:
		bincode, name, identifier, desc, ptype = line.rstrip().replace("'", "").split("\t")
		if identifier not in meranno:
			meranno[identifier] = [[bincode, desc]]
		else:
			meranno[identifier].append([bincode, desc])

for i in setfiles:
	content = pd.read_csv(setdir + i, sep = "\t", header = 0)
	sigGenes = content[content['status'] == "D2H2"]
	sigGenes['annotation'] = sigGenes['gene'].apply(get_anno)
	sigGenes['L2FC_D2'] = sigGenes['L2FC_D2'].apply(up_down)
	sigGenes['L2FC_H2'] = sigGenes['L2FC_H2'].apply(up_down)
	sigGenes = sigGenes.drop(columns = ['status', 'padj_D2', 'padj_H2'])
	for index, row in sigGenes.iterrows():
		if row['L2FC_D2'] != row['L2FC_H2']:
			efile.write("\t".join([str(z) for z in row]) + "\n")
		else:
			ofile.write("\t".join([str(z) for z in row]) + "\n")

ofile.close()
efile.close()

### Supp. Figure 5: Upset plot (up- and down-regulated)

In [None]:
# Fig S5 A & B (adapted from upset.py)
import upsetplot
from collections import defaultdict

wdir = dir_path + 'prep_files/mpo/deseq/'
odir = wdir + 'upset/'
odir_safe = dir_path_safe + 'prep_files/mpo/deseq/' + 'upset/'

if not os.path.exists(odir):
	!mkdir $odir_safe

data = pd.read_csv(wdir + 'resSig_compiled.txt', sep = '\t')

### FUNCTIONS ###

def upset_matrix(set_dict, stress_types):
	upset_data_sub = upsetplot.from_contents({k: v for k, v in set_dict.items() if k in stress_types})
	return upset_data_sub # , fig=None


def plot_selected(cond_dict, cond_list, filename, title, orient = "horizontal", cutoff=50):
	df_set = upset_matrix(cond_dict, cond_list)
	df_set = df_set.sort_index()
	
	# preparation to output all data
	index_names = list(df_set.index.names)
	index_list = df_set.index.to_list()
	set_count = Counter(index_list)
	counter_list = [[k, v] for k, v in set_count.items()]
	counter_list.sort(key = lambda x: x[1], reverse=True)
	
	# writing output to file
	with open(filename + "_matrix.txt", "w+") as ofile:
		ofile.write("\t".join(index_names + ['count', 'genes']) + "\n")
		for i in counter_list:
			glist = df_set.loc[i[0],:].id.to_list()
			ofile.write("\t".join([str(int(x)) for x in i[0]] + [str(i[1]), str(glist)]) + "\n")
	
	# writing top 50 to file
	set_cutoff = set_count.most_common(cutoff)
	selection = [x[0] for x in set_cutoff]
	with open(filename + "_top50.txt", "w+") as cfile:
		cfile.write("\t".join([str(index_names), 'count', 'genes']) + "\n")
		for i in range(len(set_cutoff)):
			glist = df_set.loc[set_cutoff[i][0],:].id.to_list()
			cfile.write("\t".join([str(set_cutoff[i][0]), str(set_cutoff[i][1]), str(glist)]) + "\n")
	# selection for plotting
	sel_matrix = df_set.loc[selection[0], :]	
	for i in range(1, len(selection)):
		sel_matrix = sel_matrix + df_set.loc[selection[i],:]
	upsetplot.plot(sel_matrix, orientation = orient, sort_by = 'cardinality')
	plt.title(title, size=20)
	if "upreg" in filename:
		figname = 'c'
	else:
		figname = 'd'
	plt.savefig(dir_path + 'figures/FigS5' + figname + '.png', dpi=600)

### END ###

# Reshape data to have for every category,
cond_dict_U = defaultdict(list) # genres_movies
cond_dict_D = defaultdict(list)
for index, row in data.iterrows():
	if row['L2FC_D2'] == 'UP':
		cond_dict_U[row['stress'].split("_")[1]].append(row['gene'])
	elif row['L2FC_D2'] == 'DOWN':
		cond_dict_D[row['stress'].split("_")[1]].append(row['gene'])
		
all_stress_list = [x.split('_')[1] for x in data.stress.unique()]

# initialise dictionaries of up and downregulated genes for each condition
cond_dict_U_set = dict()
cond_dict_D_set = dict()
for k, v in cond_dict_U.items():
    cond_dict_U_set[k] = set(v)
for k, v in cond_dict_D.items():
    cond_dict_D_set[k] = set(v)

# Plot horizontal (default)
plot_selected(cond_dict = cond_dict_D_set,
			  cond_list = all_stress_list,
			  filename = odir + "all_downreg",
			  title = "Downregulated genes")

plot_selected(cond_dict = cond_dict_U_set,
			  cond_list = all_stress_list,
			  filename = odir + "all_upreg",
			  title = "Upregulated genes")

### Supp Figs 6 & 7: Intersection of DEGs across single and cross stresses (up- and down-regulated)

In [None]:
# Supp. figs 6 & 7 (adapted from indivenn_hm.py)
from collections import defaultdict
from matplotlib_venn import venn3

wdir = dir_path + 'prep_files/mpo/deseq/'
odir = dir_path + 'figures/'
data = pd.read_csv(wdir + 'resSig_compiled.txt', sep = '\t')

# Mercator bin conversion
dicto = literal_eval(open(dir_path + 'prep_files/merdict.txt', 'r').read())

all_s = [x.split("_")[1] for x in data.stress.unique()]
single = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 1]
cross = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 2]

data.annotation = data.annotation.apply(literal_eval)
data["mername"] = data.annotation.apply(lambda x: dicto[int(x[0][0].split('.')[0])])

dict_A = defaultdict(list)
dict_U = defaultdict(list)
dict_D = defaultdict(list)

def sum_to_dict(dicto, stress, reg):
	if reg == "ALL":
		subset = data[(data.stress == "Mpo_" + stress)]
	else:
		subset = data[(data.stress == "Mpo_" + stress) & (data.L2FC_D2 == reg)]
	dicto[stress].append(set(subset.gene.to_list()))
	dicto[stress].append(subset.mername.to_list())

def dict_to_df(dicto):
	df = pd.DataFrame.from_dict(dicto, orient='index', columns=["gene", "mername"])
	return df

for s in all_s:
	#sum_to_dict(dict_A, s, "ALL")
	sum_to_dict(dict_U, s, "UP")
	sum_to_dict(dict_D, s, "DOWN")

#df_A = dict_to_df(dict_A)
df_U = dict_to_df(dict_U)
df_D = dict_to_df(dict_D)

def plot_venn(df, s1, s2, c1, title, axis):
	venn3([df.loc[s1].gene, df.loc[s2].gene, df.loc[c1].gene],
	   (s1, s2, c1),
	   ax = axis)
	axis.set_title(title, size=20)

# create subplots
xlen = 4
ylen = math.ceil(len(all_s)/4)
figw = xlen * 4
figh = ylen * 3.5

a_axes = string.ascii_uppercase[:len(all_s)]
def plot_subplot(df, title_ext):
	axa = plt.figure(constrained_layout=True,
				 figsize=(figw, figh)).subplot_mosaic(
					 """
					 ABCD
					 EFGH
					 IJKL
					 MNOP
					 QR..
					 """
					 )

	for c in range(len(cross)):
		st = cross[c]
		plot_venn(df, st[0], st[1], st, st + title_ext, axa[a_axes[c]])
	plt.savefig(odir+'supp_fig6or7' + title_ext +'.png', dpi=600)


#df_col = [[df_A, ''], [df_U, "_upregulated"], [df_D, "_downregulated"]]
df_col = [[df_U, "_upregulated"], [df_D, "_downregulated"]]

for x in df_col:
	df_type, ext = x[0], x[1]
	plot_subplot(df_type, ext)

### Figure 2: Summary of DEGs in Marchantia and inter stress comparisons



In [None]:
# Fig 2A and B (adapted from DGE_count_sizecorr.py)
cross = pd.read_csv(dir_path + 'prep_files/mpo/deseq/resSig_compiled.txt', sep='\t')

cross.stress = [x.split('_')[1] for x in list(cross.stress)]
stress_l = list(cross.stress.unique())
stress_l.sort(key=lambda x: len(x))
c_dgecount = cross.groupby(['stress', 'L2FC_D2']).count().gene.to_frame(name='count')

unstacked = c_dgecount.unstack().reindex(stress_l)
ax = unstacked.plot.bar(figsize=(7,3), stacked=True, ylabel='Number of DEGs', color=['navy', 'firebrick'])

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles[::-1], labels=[x.split(', ')[1].split(')')[0] for x in labels][::-1])
plt.savefig(dir_path + 'figures/fig4a.png', dpi=600)

wdir = dir_path + 'prep_files/'
infile = 'phase1n2_measurements_nooutliers.txt'

measurements = pd.read_csv(wdir + infile, sep='\t')

def ss_grab(stress, condition):
	"""
	Slice the relevant condition for 
	Parameters
	----------
	stress : string
		Stress of interest.
	condition : string
		Condition of interest.

	Returns
	-------
	sssub : dataframe
		df of single stress.

	"""
	sssub = measurements[(measurements.Stress == stress) & (measurements.Condition == condition)]
	return sssub

srep_keys = [['Cold', '3'],
			 ['Heat', '33'],
			 ['Salt', '40'],
			 ['Mannitol', '100'],
			 ['Light', '435'],
			 ['Dark', '3'],
			 ['Nitrogen', '0']]
srepdf = measurements[(measurements.Stress == 'Cold') & (measurements.Condition == '3')]

for s, c in srep_keys[1:]:
	srepdf = pd.concat([srepdf, ss_grab(s, c)])

crepdf = measurements[measurements.Condition == 'mixed']
m_nocon = pd.concat([srepdf, crepdf])
m_nocon.Stress = [x[0] if len(x) > 2 else x for x in m_nocon.Stress]
noHL = m_nocon[m_nocon.Stress != 'HL']

avg_meas = noHL.groupby('Stress').mean()[['Parea', 'Earea']]
avg_meas.reindex(stress_l)

totaldeg = cross.groupby(['stress']).count()['L2FC_D2']
totaldeg.reindex(stress_l)
avg_meas['totaldeg'] = totaldeg
avg_meas = avg_meas[['totaldeg', 'Parea', 'Earea']]

# size plots by df
ax = avg_meas.plot.scatter(x='totaldeg', y='Parea', color='orange', label='Area (Day 15)')

for ind, dat in avg_meas.iterrows():
    ax.annotate(ind, (dat['totaldeg'], dat['Parea']),
				xytext=(-4,-12), textcoords='offset points')

avg_meas.plot.scatter(x='totaldeg', y='Earea', color='navy', label='Area (Day 21)', ax=ax)
for ind, dat in avg_meas.iterrows():
    ax.annotate(ind, (dat['totaldeg'], dat['Earea']),
				xytext=(-4,-12), textcoords='offset points')
	
# size plot with regression
from scipy import stats

q_colnames = avg_meas.columns.to_list()
dicto = {'Parea' : 'Day 15', 'Earea' : 'Day 21'}

def plot_reg(df, title):
	labels = []
	col = q_colnames[0]
	for col2 in q_colnames[1:]:
		plt.scatter(col, col2, data=df)
		#m, c = np.polyfit(df[col], df[col2], 1)
		m, c, r_value, p_value, std_err = stats.linregress(df[col], df[col2])
		for ind, dat in avg_meas.iterrows():
			plt.annotate(ind, (dat[col], dat[col2]),
			  xytext=(-4,-12), textcoords='offset points')
		plt.plot(df[col], m*df[col] + c)
		labels.append(dicto[col2] + ' (R\u00b2: ' + str(round(r_value**2,2)) + ', p: ' + str(round(p_value, 2)) + ')')
	plt.legend(labels)
	plt.xlabel('DEG count')
	plt.ylabel('Size (mm\u00b2)')
	plt.savefig(dir_path + 'figures/fig4b.png', dpi=600)
	
plot_reg(avg_meas, 'Number of DEGs vs Size')

In [None]:
# Figure 2 C-F (adapted from plot_venn_sum.py)

from collections import defaultdict, Counter
from matplotlib_venn import venn3, venn3_circles
import random
from scipy import stats

wdir = dir_path + 'prep_files/mpo/deseq/'
odir = wdir + 'indivenn_hm/'
data = pd.read_csv(wdir + 'resSig_compiled.txt', sep = '\t')

# Mercator
### DICTIONARY OF MERCATOR BINS ###
mfile = dir_path + 'mercator/MpoProt.results.txt'

meranno = defaultdict(list)
merbin = defaultdict(list)
map2anno = {}

merfile = open(mfile, 'r')
merfile.readline()
for line in merfile:
	linecon = line.rstrip().replace("'", "").split("\t")
	if len(linecon) == 5:
		bincode, name, identifier, desc, ptype = linecon
		meranno[identifier].append(dicto[int(bincode.split('.')[0])])
		merbin[identifier].append('.'.join(bincode.split('.')[:2]))
	if len(linecon[0].split('.')) == 2:
		map2anno[linecon[0]] = linecon[1]
		

all_s = [x.split("_")[1] for x in data.stress.unique()]
single = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 1]
cross = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 2]

data.annotation = data.annotation.apply(literal_eval)
data["mername"] = data.annotation.apply(lambda x: [dicto[int(y[0].split('.')[0])] for y in x]) # different from cell above, hence the repetitive code

dict_A = defaultdict(list)
dict_U = defaultdict(list)
dict_D = defaultdict(list)

def sum_to_dict(dicto, stress, reg):
	if reg == "ALL":
		subset = data[(data.stress == "Mpo_" + stress)]
	else:
		subset = data[(data.stress == "Mpo_" + stress) & (data.L2FC_D2 == reg)]
	dicto[stress].append(set(subset.gene.to_list()))
	dicto[stress].append([y for x in subset.mername.to_list() for y in x])
	dicto[stress].append(['.'.join(y[0].split('.')[:2]) for x in subset.annotation.to_list() for y in x])

def dict_to_df(dicto):
	df = pd.DataFrame.from_dict(dicto, orient='index', columns=["gene", "mername", "mapbin2"])
	return df

for s in all_s:
	sum_to_dict(dict_A, s, "ALL")
	sum_to_dict(dict_U, s, "UP")
	sum_to_dict(dict_D, s, "DOWN")

df_A = dict_to_df(dict_A)
df_U = dict_to_df(dict_U)
df_D = dict_to_df(dict_D)

# =============================================================================
# 
# # Summary of stress response
# 
# =============================================================================

# Q1 : ji_cal(a, b) [%]
def ji_cal(a, b):
	# jaccard index calculation
	return len(a&b) / len(a|b)
# Q2: |(A − AB)/A − (B − AB)/B| [% difference]
def suppInX(a, b, ab):
	return len((a-ab))/len(a) - len((b-ab))/len(b)

# Q3: (AB - A - B) / AB [%]
def novel(a, b, ab):
	return len(ab - a - b) / len(ab)

def q_col(df, colnames):
	"""
	Collates the params for each cross stress and output in df
	
	Parameters
	----------
	df : dataframe
		dataframe to use (all genes, upreg/downreg only).

	Returns
	-------
	q_df : dataframe
		datafram containing JI of all cross stress.

	"""
	q_dict = {}
	for c in range(len(cross)):
		st = cross[c]
		a = df.loc[st[0]].gene
		b = df.loc[st[1]].gene
		ab = df.loc[st].gene
		q_dict[st] = [
			ji_cal(a,b),
			#perXInAB(a,ab),
			#perXInAB(b,ab),
			suppInX(a,b,ab),
			#suppInAB(a,b,ab),
			novel(a,b,ab)
				]
	q_df = pd.DataFrame.from_dict(q_dict, orient="index", columns = colnames)
	return q_df

q_colnames = ["similarity", "suppression", "novel interaction"]

q_A, q_U, q_D = [q_col(df_A, q_colnames), q_col(df_U, q_colnames), q_col(df_D, q_colnames)]
qdf_col = [[q_U, "Upregulated DEGs"], [q_D, "Downregulated DEGs"]]

def plot_q_subplots(df, outerax):
	ax = [axe[x] for x in outerax]
	#plt.suptitle(title, fontsize=14)
	for i, axis in enumerate(ax):
		if q_colnames[i] == "suppression":
			sns.heatmap(df[q_colnames[i]].to_frame().transpose(), cmap='coolwarm', ax=axis)
		else:
			sns.heatmap(df[q_colnames[i]].to_frame().transpose(), cmap='Blues', ax=axis)
		axis.set_yticklabels([q_colnames[i]], rotation=0)
		cbar = axis.collections[0].colorbar
		minval = round(df[q_colnames[i]].min(),2)
		maxval = round(df[q_colnames[i]].max(),2)
	
		while round(minval*100,2) % 5 != 0:
			minval += 0.01
		while round(maxval*100,2) % 5 != 0:
			maxval -= 0.01
		cbar.set_ticks([minval, maxval])

def plot_reg(df, outerax):
	labels = []
	statscol = []
	for i, col in enumerate(q_colnames[:-1]):
		for j, col2 in enumerate(q_colnames[i+1:]):
			outerax.scatter(col, col2, data=df)
			m, c, r_value, p_value, std_err = stats.linregress(df[col], df[col2])
			statscol.append([m, c, r_value, p_value, std_err])
			# m, c = np.polyfit(df[col], df[col2], 1)
			outerax.plot(df[col], m*df[col] + c)
			labels.append(col[:3] + ' v ' + col2[:3] + ' ($\mathregular{R^{2}}$: '+str(round(r_value**2,1))+', p: ' + str('{:.2f}'.format(round(p_value,2))+')'))
	outerax.legend(labels, fontsize="x-small")
	return statscol
	
def dum_venn(a_b, c_a, c_b, c_ab, ax, col, title, ac=20, bc=20, cc=20):
	'''

	Parameters
	----------
	a_b : int
		Size of A&B.
	c_a : int
		Size of A&C-B.
	c_b : int
		Size of B&C-A.
	c_ab : int
		Szie of C&(A&B).
	ax : axis handle
		Axis handle of subplot to plot into.
	col : list
		List containing lists of patch id and corresponding colour.
	title:
		
	ac : int, optional
		Size of set a. The default is 20.
	bc : int, optional
		Size of set b. The default is 20.
	cc : int, optional
		Size of set c. The default is 20.

	Returns
	-------
	None.

	'''
	
# =============================================================================
# 	a_b = 8 # A&B
# 	c_a_b = 4 # C&A-B/ C&B-A
# 	c_ab = 3   # C&(A&B)
# =============================================================================
	dum = list(string.ascii_uppercase + string.ascii_lowercase)
	random.shuffle(dum)
	a = set(dum[:20])
	b = set(list(a)[:a_b] + [x for x in dum if x not in a][:bc-a_b])
	c = set(list(a-b)[:c_a] +
		 list(b-a)[:c_b] + list(a&b)[:c_ab] +
		 [x for x in dum if x not in a and x not in b][:cc-c_a-c_b-c_ab])

	v = venn3([a, b, c],
	   ('A', 'B', 'AB'),
	   ax = ax) # ax = axis
	venn3_circles([a, b, c], linewidth=1, color='k', ax=ax)
	for i in col:
		v.get_patch_by_id(i[0]).set_color(i[1])
	for idx, subset in enumerate(v.subset_labels):
		v.subset_labels[idx].set_visible(False)
	ax.set_title(title, fontsize=16)

# =============================================================================
# #
# # Initialising subplot
# #
# =============================================================================
#figsize=(figw, figh)

top_mosaic = [["v1", "v2", "v3"]]
eq_mosaic = [	["e1", "e2", "e3"]	]
middle_mosaic = [
	["u1", "d1"],
	["u2", "d2"],
	["u3", "d3"]
]
bottom_mosaic = [["r1", "r2"]]

figw, figh = 11, 9
fig = plt.figure(figsize=(figw, figh))
axc = fig.subplot_mosaic(
	top_mosaic,
	gridspec_kw={
		"bottom": 0.75,
		"top": 1,
		#"wspace": 0.5,
		#"hspace": 0.5,
		}
	)
axd = fig.subplot_mosaic(
	eq_mosaic,
	gridspec_kw={
		"bottom": 0.55,
		"top": 0.8,
		#"wspace": 0.5,
		#"hspace": 0.5,
		}
	)

axe = fig.subplot_mosaic(
	middle_mosaic,
	gridspec_kw={
		"bottom": 0.38,
		"top": 0.6,
		#"wspace": 0.5,
		"hspace": 0.2,
		}
	)
axf = fig.subplot_mosaic(
	bottom_mosaic,
	gridspec_kw={
		"bottom": 0,
		"top": 0.3,
		#"wspace": 0.5,
		#"hspace": 0.5,
		}
	)

for axy in ['e1', 'e2', 'e3']:
	axd[axy].axis('off')
for axy in ["v1", "v2", "v3"]:
	axc[axy].set_anchor('N')

axd['e1'].text(0.39, 0.45, r"$\frac{A \cap B}{A \cup B}$", fontsize=20)
axd['e2'].text(0.04, 0.45, r"$\frac{A-B-AB}{A}-\frac{B-A-AB}{B}$", fontsize=20)
axd['e3'].text(0.29, 0.45, r"$\frac{AB-A-B}{AB}$", fontsize=20)

for seq, x in enumerate(qdf_col):
	df_type, title = x[0], x[1]
	plot_q_subplots(df_type, [x[seq] for x in middle_mosaic])
for axy in ['u1', 'u2', 'd1', 'd2']:
	axe[axy].set_xticklabels([])
	axe[axy].xaxis.set_visible(False)
for axy in ['d1', 'd2', 'd3']:
	axe[axy].set_yticklabels([])
	axe[axy].yaxis.set_visible(False)

v1col = [['100', 'white'], ['110', 'limegreen'],
		 ['101', 'white'], ['111', 'limegreen'],
		 ['010', 'white'], ['011', 'white'],
		 ['001', 'white']]
v2col = [['100', 'red'], ['110', 'white'],
		 ['101', 'white'], ['111', 'white'],
		 ['010', 'cornflowerblue'], ['011', 'white'],
		 ['001', 'white']]
v3col = [['100', 'white'], ['110', 'white'],
		 ['101', 'white'], ['111', 'white'],
		 ['010', 'white'], ['011', 'white'],
		 ['001', 'darkorchid']]
dum_venn(a_b=8,c_a=4, c_b = 4,c_ab=3,
		 ax=axc['v1'], col=v1col, title='Similarity')
dum_venn(a_b=8, c_a=10, c_b=6, c_ab=3,
		 ax=axc['v2'], col=v2col, title='Suppression')
dum_venn(a_b=8, c_a=4, c_b=4, c_ab=3,
		 ax=axc['v3'], col=v3col, title='Novel interaction')

reg_stats = []
for i, x in enumerate(qdf_col):
	df_type, title = x[0], x[1]
	df_abs = df_type[:]
	df_abs.suppression = abs(df_abs.suppression)
	reg_stats.append(plot_reg(df_abs, axf[bottom_mosaic[0][i]]))

plt.savefig(dir_path+'figures/Fig5A_E.png', dpi=600)

In [None]:
# Fig 2H refer to l2_en_jaccard_hm.png
# =============================================================================
# 
#  Summary: What is the dominant effect of each stress?
# 
# =============================================================================
def sum_df(df):
	dicto_sum = {}
	for ss in single:
		ori = [x for x in cross if ss in x]
		relcross = [y for x in cross if ss in x for y in x if ss not in y]
		sim = [ji_cal(df.loc[x[0]].gene, df.loc[x[1]].gene) for x in ori]
		nov = [novel(df.loc[x[0]].gene, df.loc[x[1]].gene, df.loc[x].gene) for x in ori]
		sup = [suppInX(df.loc[ss].gene, df.loc[x].gene, df.loc[ori[i]].gene) for i, x in enumerate(relcross)]
		for i, x in enumerate(relcross):
			dicto_sum[ss+x] = [ss, sim[i], sup[i], nov[i]]
			
	df_sum= pd.DataFrame.from_dict(dicto_sum, orient='index', columns=['stress', 'similarity', 'suppression', 'novel'])
	return df_sum

cond=['similarity', 'suppression', 'novel']
df_list = [sum_df(df_U), sum_df(df_D)]

# figs, axs = plt.subplots(3,2,
# 						 sharex=True,
# 						 sharey='row',
# 						 constrained_layout=True,
# 						 figsize=(7,6))

# for i, x in enumerate(df_list):
# 	for j, y in enumerate(cond):
# 		sns.violinplot(x='stress', y= y, data=x, ax= axs[j][i])
# for k in range(3):
# 	axs[k][1].set_ylabel('')
# for l in range(2):
# 	for m in range(2):
# 		axs[l][m].set_xlabel('')
# axs[0][0].set_title('Upregulated', fontsize=14)
# axs[0][1].set_title('Downregulated', fontsize=14)
# plt.savefig(odir + 'venn_sum.png', dpi=600)

# =============================================================================
# 
# Enrichment
# 
# =============================================================================

from statsmodels.stats.multitest import multipletests
import math
import numpy as np

ori_count = Counter([y for x in list(meranno.values()) for y in x])
mapbins = list(dicto.values())

def sig_df(df, sigcol, merdict, mapbins):
	"""
	Calculates and correct mapman bin enrichment p-value for all stresses
	Returns dataframe

	Parameters
	----------
	df : dataframe
		df containing genes and corresponding mapman bins of DEGs.
	sigcol : str
		column name to use for enrichment
	merdict : dict
		corresponding dictionary of mapman annotation/ 2nd level bins to use
	mapbins : list
		list of mapman annotation/bins to use

	Returns
	-------
	df_sig : dataframe
		df summarising enrichment (corrected p-value) for each mapman bin (row)
		and each stress (column).

	"""
	sig_sum = {}
	for s in all_s:
		s_count = Counter(df.loc[s][sigcol])
		valid_bins = list(s_count.keys()) # bins found in stress
		# initialise count dicitonary
		sig_count = {}
		for key in valid_bins:
			sig_count[key] = 1
		# random simulations
		for i in range(1000):
			shuffle = list(merdict.values())
			random.shuffle(shuffle)
			sub = shuffle[:len(df.loc[s].gene)]
			sub_count = Counter([y for x in sub for y in x])
			for mapman in valid_bins:
				if sub_count[mapman] >= s_count[mapman]:
					sig_count[mapman] += 1
		# p-value calculation
		pval_coll = []
		for mapman in valid_bins:
			pval = sig_count[mapman]/1000
			# correction for pval > 1
			if pval <= 1:
				pval_coll.append(pval)
			else:
				pval_coll.append(float(round(pval)))
		
		# BH correction for multiple testing
		y = multipletests(pvals=pval_coll, alpha=0.05, method="fdr_bh")[1]
		all_bins_corr_pval = []
		for mapman in mapbins:
			if mapman in valid_bins:
				all_bins_corr_pval.append(y[valid_bins.index(mapman)])
			else:
				all_bins_corr_pval.append(None)
		sig_sum[s] = all_bins_corr_pval

	df_sig = pd.DataFrame.from_dict(sig_sum, orient='index', columns=mapbins)
	return df_sig
	
def chunk(uval, dval):
	if math.isnan(uval) and math.isnan(dval):
		# not differentially regulated
		cat = 0
	elif  uval >= 0.05 and (dval >= 0.05 or math.isnan(dval)):
		# not enriched
		cat = 0
	elif  dval >= 0.05 and (uval >= 0.05 or math.isnan(uval)):
		# not enriched
		cat = 0
	elif  uval < 0.05 and dval < 0.05:
		# differentially up and downregulated in bin
		cat = 2
	elif dval < 0.05:
		# differentially downregualted
		cat = 1
	elif uval < 0.05:
		# differentially upregulated
		cat = 3
	return cat

from matplotlib.colors import ListedColormap
cmap = ListedColormap(["lightgray", "royalblue", "violet", "firebrick"])
catno = 4
cbarticks = [(x/(catno*2))*(catno-1) for x in range(1,catno*2,2)]

# =============================================================================
# 
# Enrichment (Part 2: 2nd level Mapman)
# 
# =============================================================================
mapbins2 = list(set([y for x in list(merbin.values()) for y in x]))
mapbins2.sort(key=lambda x: (int(x.split('.')[0]), int(x.split('.')[1])))

df_sig_U2 = sig_df(df_U, 'mapbin2', merbin, mapbins2)
df_sig_D2 = sig_df(df_D, 'mapbin2', merbin, mapbins2)
df_sig_U2 = df_sig_U2.fillna(value=np.nan)
df_sig_D2 = df_sig_D2.fillna(value=np.nan)

cat_dict2 = {}
for mapman in list(df_sig_U2.columns):
	cat_col = []
	for stress in list(df_sig_U2.index):
		uval, dval = df_sig_U2.loc[stress, mapman], df_sig_D2.loc[stress, mapman]
		cat_col.append(chunk(uval,dval))
	cat_dict2[mapman] = cat_col

df_combined_sig2 = pd.DataFrame.from_dict(cat_dict2, orient='index', columns=list(df_sig_U2.index))
df_combined_sig2 = df_combined_sig2.loc[df_combined_sig2.max(axis=1) > 0,:]
df_combined_sig2 = df_combined_sig2.loc[(df_combined_sig2 > 0).sum(axis=1) >2,:]
df_combined_sig2.reset_index(inplace=True)
df_combined_sig2['index'] = df_combined_sig2['index'].apply(lambda x: map2anno[x])
df_combined_sig2.set_index('index', inplace=True)

# =============================================================================
# 
# Plotting 2nd level mapman enrichment (df_combined_sig2)
# 
# =============================================================================
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def jdistprep(df, axis):
	'''
	Convert df to sets (for calculation of JD of X axis)

	Parameters
	----------
	df : dataframe
		dataframe of categorical variables to be converted to sets.
	axis : int 
		axis to do sets on, 0 by column (default), 1 by row

	Returns
	-------
	dicto : dict
		dictionary containing list of column values.

	'''
	if axis == 1:
		df = df.T
	dxkeys = df.columns.to_list()
	dykeys = df.index.to_list()
	dicto = {}
	for col in dxkeys:
		dicto[col] = [dykeys[i] + '_' + str(x) for i, x in enumerate(df[col].to_list())]
	return [dicto, dxkeys]

def jdist(df, axis=0):
	'''
	Construct jaccard distance square matrix

	Parameters
	----------
	df : df
		dataframe to be used for jiprep/ jdist calculation.
	axis : int
		axis to do sets on, 0 by column (default), 1 by row

	Returns
	-------
	linkage_matrix : list
		condensed jaccard distance matrix.
	jlist : list
		list of list (jaccard distance square matrix)
	dicto : dict
		dictionary of list

	'''
	dicto, dxkeys = jdistprep(df, axis)
	jlist = []
	for key in dxkeys:
		col = []
		for key2 in dxkeys:
			set1, set2 = dicto[key], dicto[key2]
			set1x = set([x for x in set1 if x.split('_')[1] != '0'])
			set2x = set([x for x in set2 if x.split('_')[1] != '0'])
			col.append(1 - ji_cal(set1x, set2x))
		jlist.append(col)
	dists = squareform(jlist)
	linkage_matrix = linkage(dists, "single")
	return linkage_matrix, jlist, dicto

def plot_dendro(linkage_matrix, ax, orient):
	'''
	Plots dendrogram into subplot

	Parameters
	----------
	mat : list of lists
		Contains the square matrix of jaccard distances.
	ax : axes
		axis of subplot to plot to.
	orient : str
		orientation of dendrogram to be plotted.

	Returns
	-------
	None.

	'''
	
	dendrogram(linkage_matrix, no_labels=True, ax=ax, orientation=orient, color_threshold=0, above_threshold_color='#000000')

xmat, xlist, xdict = jdist(df_combined_sig2)
ymat, ylist , ydict = jdist(df_combined_sig2, axis=1)

yden = dendrogram(ymat, labels=df_combined_sig2.index.to_list(), orientation='left') #, color_threshold=0, above_threshold_color='#000000'
plt.show()
xden = dendrogram(xmat, labels=df_combined_sig2.columns.to_list(), orientation='top') #, color_threshold=0, above_threshold_color='#000000'
plt.show()

yorder = yden['ivl']
xorder = xden['ivl']
df_sig2_reordered = df_combined_sig2[xorder]
df_sig2_reordered = df_sig2_reordered.reindex(yorder[::-1])

fig, ax = plt.subplots(2,2,
					   figsize=(7.5,8.5), # (width, height)
					   constrained_layout=True,
					   gridspec_kw={'width_ratios': [1.5, 5],'height_ratios': [1, 5]}) # constrained_layout=True,
ax0, ax1, ax2, ax3 = ax.flatten()
for i in [ax0, ax1, ax2]:
	i.axis('off')

plot_dendro(xmat, ax1, 'top')
plot_dendro(ymat, ax2, 'left')
# heatmap, tick and tick labels
hplot = ax3.imshow(df_sig2_reordered, cmap=cmap)
ax3.yaxis.tick_right()
ax3.set_ylabel("")
ax3.set_xticks(np.arange(0, len(df_sig2_reordered.columns), 1))
ax3.set_yticks(np.arange(0, len(df_sig2_reordered), 1))


xcolour = ['k'] + ['firebrick']*4 + ['gray']*6 + ['mediumseagreen']*4 + ['k']*2 + ['darkorange']*3 + ['k']*2 + ['royalblue']*3
ax3.set_xticklabels(df_sig2_reordered.columns.to_list(), rotation=90)

for i, tick_label in enumerate(ax3.get_xticklabels()):
	tick_text = tick_label.get_text()
	tick_label.set_color(xcolour[i])
	
anno_long = ['annotated', 'cellulose', 'biosynthesis', 'hemicellulose', 'pectin', 'channels', 'degradation']
ax3.set_yticklabels([(lambda x: x.split('.')[1].lower() if x.split('.')[1] not in anno_long else x.lower())(x) for x in df_sig2_reordered.index.to_list()])
# cbar plotting and control
axins = inset_axes(ax0,
					width="40%",  # width = 50% of parent_bbox width
					height="90%",  # height : 5%
					loc = 'center')
cbar = fig.colorbar(hplot, cax=axins, ticks = cbarticks)
cbar.ax.set_yticklabels(['N', 'D', 'UD', 'U'])

plt.savefig(dir_path+'figures/Fig5F.png', dpi=600, bbox_inches='tight') # no N

### Supp. Fig. 10: Expression of GRN TFs across experiments (clustered)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

nw_anno = pd.read_csv(dir_path + 'prep_files/condensed_deg_anno_nwonly.txt', header=0, sep="\t")
tf_enrich = pd.read_csv(dir_path + 'prep_files/TF_group_enrich_anno.txt', header=0, sep="\t")
bin_enrich = pd.read_csv(dir_path + 'prep_files/group_TF_enrich_anno.txt', header=0, sep="\t")

all_stresses = ['C', 'CD', 'CL', 'CM', 'CN', 'CS', 'D', 'H', 'HD',
				   'HM', 'HN', 'HS', 'L', 'LS', 'M', 'MD', 'ML','MN', 'N',
				   'ND', 'NL', 'S', 'SD', 'SM', 'SN']

stresses = ["D", "H", "C", "L", "M", "S", "N"]
clustered_stresses = [y for x in stresses for y in all_stresses if x in y]

# plot tf
tfdf = nw_anno[-nw_anno.TF_anno.isna()]
#tfdf.sort_values(by=["TF_anno"], inplace=True)
tfdf["longname"] = tfdf.gene + " (" + tfdf.TF_anno + ")"

# clustering
tfplot = tfdf[["longname"] + all_stresses]
tfplot.set_index("longname", inplace=True)

from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import numpy as np

def ji_cal(a, b):
	# jaccard index calculation
	return len(a&b) / len(a|b)

def jdistprep(df, axis):
	'''
	Convert df to sets (for calculation of JD of X axis)

	Parameters
	----------
	df : dataframe
		dataframe of categorical variables to be converted to sets.
	axis : int 
		axis to do sets on, 0 by column (default), 1 by row

	Returns
	-------
	dicto : dict
		dictionary containing list of column values.

	'''
	if axis == 1:
		df = df.T
	dxkeys = df.columns.to_list()
	dykeys = df.index.to_list()
	dicto = {}
	for col in dxkeys:
		dicto[col] = [dykeys[i] + '_' + str(x) for i, x in enumerate(df[col].to_list())]
	return [dicto, dxkeys]

def jdist(df, axis=0):
	'''
	Construct jaccard distance square matrix

	Parameters
	----------
	df : df
		dataframe to be used for jiprep/ jdist calculation.
	axis : int
		axis to do sets on, 0 by column (default), 1 by row

	Returns
	-------
	linkage_matrix : list
		condensed jaccard distance matrix.
	jlist : list
		list of list (jaccard distance square matrix)
	dicto : dict
		dictionary of list

	'''
	dicto, dxkeys = jdistprep(df, axis)
	jlist = []
	for key in dxkeys:
		col = []
		for key2 in dxkeys:
			set1, set2 = dicto[key], dicto[key2]
			set1x = set([x for x in set1 if x.split('_')[1] != 'nan'])
			set2x = set([x for x in set2 if x.split('_')[1] != 'nan'])
			col.append(1 - ji_cal(set1x, set2x))
		jlist.append(col)
	dists = squareform(jlist)
	linkage_matrix = linkage(dists, "single")
	return linkage_matrix, jlist, dicto

def plot_dendro(linkage_matrix, ax, orient):
	'''
	Plots dendrogram into subplot

	Parameters
	----------
	mat : list of lists
		Contains the square matrix of jaccard distances.
	ax : axes
		axis of subplot to plot to.
	orient : str
		orientation of dendrogram to be plotted.

	Returns
	-------
	None.

	'''
	
	dendrogram(linkage_matrix, no_labels=True, ax=ax, orientation=orient, color_threshold=0, above_threshold_color='#000000')

#%% TFs with stress specific regulation
spec_path = elnet_dir + 'TF_Scond_spec.txt'
spec_df = pd.read_csv(spec_path, header=0, sep="\t")
spec_df["longname"] = spec_df.gene + " (" + spec_df.TF_anno + ")"
spec_TF = spec_df[spec_df.specificity > 0.7].longname.unique().tolist()
tfplot_subset = tfplot.loc[spec_TF]

ymat, ylist, ydict = jdist(tfplot_subset, axis=1)
yden = dendrogram(ymat, labels=tfplot_subset.index.to_list(), orientation='left')
plt.show()

yorder = yden['ivl']
tfplot_reordered = tfplot_subset.reindex(yorder[::-1])

# duplicate stress columns
tfplot_reordered.replace({"DOWN": 1, "UP": 2}, inplace=True)
tfplot_reordered.fillna(0, inplace=True)
tfplot_all = tfplot_reordered[clustered_stresses]


#%%
# plotting
fig, ax = plt.subplots(2,2,
					   figsize=(10,13), # (width, height)
					   constrained_layout=True,
					   gridspec_kw={'width_ratios': [1, 4],'height_ratios': [1, 7]}) # constrained_layout=True,
ax0, ax1, ax2, ax3 = ax.flatten()
for i in [ax0, ax1, ax2]:
	i.axis('off')

plot_dendro(ymat, ax2, 'left')
# heatmap, tick and tick labels

hplot = ax3.imshow(tfplot_all, cmap=cmap)
# https://stackoverflow.com/questions/10354397/python-matplotlib-y-axis-ticks-on-right-side-of-plot
ax3.yaxis.tick_right()
ax3.set_ylabel("")
ax3.set_xticks(np.arange(0, len(tfplot_all.columns), 1))
ax3.set_yticks(np.arange(0, len(tfplot_all), 1))

ax3.set_xticklabels(tfplot_all.columns.to_list(), rotation=90)
ax3.set_yticklabels(tfplot_all.index.to_list())

# xlabel colour
xcolour = ['royalblue']*6 + ["darkkhaki"]*7 + ["mediumorchid"]*7 + ["firebrick"]*5 + ["k"]*6 + ["mediumseagreen"]*5 + ["darkorange"]*7
for i, tick_label in enumerate(ax3.get_xticklabels()):
	tick_text = tick_label.get_text()
	tick_label.set_color(xcolour[i])
# ylabel colour
TF_fam_count = Counter(tfdf.TF_anno.tolist())
TF_fam_big = [k for k, v in TF_fam_count.items() if v > 3]
dict_col = {k : palette[i] for i, k in enumerate(TF_fam_big) if TF_fam_count[k] > 3}
ycolour = [dict_col[x.split("(")[1].split(")")[0]] if x.split("(")[1].split(")")[0] in dict_col else palette[-1] for  x in tfplot_all.index.to_list()]

for i, tick_label in enumerate(ax3.get_yticklabels()):
	tick_text = tick_label.get_text()
	tick_label.set_color(ycolour[i])

# cbar plotting and control
# https://matplotlib.org/stable/gallery/axes_grid1/demo_colorbar_with_inset_locator.html
axins = inset_axes(ax0,
					width="40%",  # width = 50% of parent_bbox width
					height="90%",  # height : 5%
					loc = 'center')
cbar = fig.colorbar(hplot, cax=axins, ticks = cbarticks)
cbar.ax.set_yticklabels(["N", 'D', 'U'])
plt.savefig(dir_path + 'figures/' + 'FigS10.png', dpi=600)

### Figure 3B: Expression of GRN TFs across experiments

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 27 12:15:31 2022

@author: Qiao Wen
Essentially, calculate if TF is significantly enriched for a bin
"""

import pandas as pd
from ast import literal_eval
from collections import Counter
import random

anno_path = dir_path + 'prep_files/condensed_deg_anno.txt'
anno_df = pd.read_csv(anno_path, header=0, sep="\t")
anno_df.merbin = anno_df.merbin.apply(literal_eval)

nw_path = elnet_dir + 'union_0.8_ignoreAGRIS_topTF.txt'
nw_df = pd.read_csv(nw_path, header=0, sep="\t")
nw_genes = list(set(nw_df.predicted.to_list() + nw_df["Gene.ID"].to_list()))

nw_anno = anno_df[anno_df.gene.isin(nw_genes)]
bin_counts = Counter([x for x in nw_anno[nw_anno.TF_anno.isna()].merbinname.to_list()])
nw_anno["group"] = [" and ".join(literal_eval(x)) if bin_counts[x] > 2 else "Other bin combinations" for x in nw_anno.merbinname.to_list()]
nw_anno.to_csv(dir_path + 'prep_files/condensed_deg_anno_nwonly.txt', index=False, sep="\t")

TF_list = nw_anno[-nw_anno.TF_anno.isna()].gene.to_list()
gene_groups = nw_anno[nw_anno.TF_anno.isna()].group.unique().tolist()
all_groups = nw_anno[nw_anno.TF_anno.isna()].group.to_list()

# enrichment TF-bin
tf_enrich = pd.DataFrame(columns=["TF", "group", "pval"])
for tf in TF_list:
	subset = nw_df[nw_df["Gene.ID"] == tf]
	annot = pd.merge(subset, nw_anno[["gene", "group"]], left_on = ['predicted'], right_on = ['gene'])
	actual_group = Counter(annot.group.to_list())
	dict_count = {k:0 for k in actual_group.keys()}
	for i in range(1000):
		random_group = Counter(random.sample(all_groups, len(annot)))
		for g in actual_group.keys():
			if random_group[g] >= actual_group[g]:
				dict_count[g] += 1
	for k, v in dict_count.items():
		tf_enrich.loc[len(tf_enrich)] = [tf, k, v/1000]

merged_df = pd.merge(nw_df, nw_anno[["gene", "group"]], left_on = ['predicted'], right_on = ['gene'])
count_df = merged_df.groupby(["Gene.ID", "group", "TF_stat"]).predicted.count()

def get_details(tf, group):
	stats = Counter(merged_df[(merged_df["Gene.ID"] == tf) & (merged_df["group"] == group)].TF_stat.to_list())
	int_size = sum(stats.values())
	ratio = (stats["Activator"] - stats["Repressor"])/int_size
	if ratio > 0.6:
		ratio_stat = "VP"
	elif ratio > 0.2:
		ratio_stat = "P"
	elif ratio > -0.2:
		ratio_stat = "A"
	elif ratio > -0.6:
		ratio_stat = "N"
	else:
		ratio_stat = "VN"
	return([int_size, ratio, ratio_stat, tf + " (meta) " + group])

tf_enrich["details"] = tf_enrich.apply(lambda x: get_details(x.TF, x.group), axis=1)
tf_enrich["int_size"] = tf_enrich.apply(lambda x: x.details[0], axis=1)
tf_enrich["ratio"] = tf_enrich.apply(lambda x: x.details[1], axis=1)
tf_enrich["ratio_stat"] = tf_enrich.apply(lambda x: x.details[2], axis=1)
tf_enrich["edge_name"] = tf_enrich.apply(lambda x: x.details[3], axis=1)

tf_enrich = tf_enrich[['edge_name', 'TF', 'group', 'pval', 'int_size', 'ratio', 'ratio_stat']]
tf_enrich.to_csv(dir_path + 'prep_files/TF_group_enrich_anno.txt', index=False, sep="\t")

# enrichment bin-TF family
bin_enrich = pd.DataFrame(columns=["TF_fam", "group", "pval"])
merged_df["TF_fam"] = [nw_anno[nw_anno.gene == x].TF_anno.to_list()[0] for x in merged_df["Gene.ID"].to_list()]
noTFs_as_predicted = merged_df[-merged_df.predicted.isin(TF_list)]
all_fams = noTFs_as_predicted.TF_fam.to_list()
for group in gene_groups:
	subset = noTFs_as_predicted[noTFs_as_predicted.group == group]
	actual_fam = Counter(Counter(subset.TF_fam.to_list()))
	dict_fam = {k:0 for k in actual_fam.keys()}
	for i in range(1000):
		random_fam = Counter(random.sample(all_fams, len(subset)))
		for g in actual_fam.keys():
			if random_fam[g] >= actual_fam[g]:
				dict_fam[g] += 1
	for k, v in dict_fam.items():
		bin_enrich.loc[len(bin_enrich)] = [k, group, v/1000]

group_size_dict = {k:len(noTFs_as_predicted[noTFs_as_predicted["group"] == k]) for k in gene_groups}
def get_fam_details(fam, group):
	stats = Counter(noTFs_as_predicted[(noTFs_as_predicted["TF_fam"] == fam) & (noTFs_as_predicted["group"] == group)].TF_stat.to_list())
	int_size = sum(stats.values())
	group_size = group_size_dict[group]
	rel_int_size = int_size/group_size
	ratio = (stats["Activator"] - stats["Repressor"])/int_size
	if ratio > 0.6:
		ratio_stat = "VP"
	elif ratio > 0.2:
		ratio_stat = "P"
	elif ratio > -0.2:
		ratio_stat = "A"
	elif ratio > -0.6:
		ratio_stat = "N"
	else:
		ratio_stat = "VN"
	return([int_size, rel_int_size, ratio, ratio_stat])

bin_enrich["details"] = bin_enrich.apply(lambda x: get_fam_details(x.TF_fam, x.group), axis=1)
bin_enrich["int_size"] = bin_enrich.apply(lambda x: x.details[0], axis=1)
bin_enrich["rel_int_size"] = bin_enrich.apply(lambda x: x.details[1], axis=1)
bin_enrich["ratio"] = bin_enrich.apply(lambda x: x.details[2], axis=1)
bin_enrich["ratio_stat"] = bin_enrich.apply(lambda x: x.details[3], axis=1)

bin_enrich = bin_enrich[['TF_fam', 'group', 'pval', 'int_size', 'rel_int_size', 'ratio', 'ratio_stat']]
bin_enrich.to_csv(dir_path + 'prep_files/group_TF_enrich_anno.txt', index=False, sep="\t")

In [None]:
# Selected TF representatives from 3C
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

nw_anno = pd.read_csv(dir_path + 'prep_files/condensed_deg_anno_nwonly.txt', header=0, sep="\t")
tf_enrich = pd.read_csv(dir_path + 'prep_files/TF_group_enrich_anno.txt', header=0, sep="\t")
bin_enrich = pd.read_csv(dir_path + 'prep_files/group_TF_enrich_anno.txt', header=0, sep="\t")

all_stresses = ['C', 'CD', 'CL', 'CM', 'CN', 'CS', 'D', 'H', 'HD',
				   'HM', 'HN', 'HS', 'L', 'LS', 'M', 'MD', 'ML','MN', 'N',
				   'ND', 'NL', 'S', 'SD', 'SM', 'SN']
stresses = ["D", "H", "C", "L", "M", "S", "N"]
clustered_stresses = [y for x in stresses for y in all_stresses if x in y]

# plot tf
tfdf = nw_anno[-nw_anno.TF_anno.isna()]
tfdf["longname"] = tfdf.gene + " (" + tfdf.TF_anno + ")"

tf_list_ordered = ["Mp3g21490.1", "Mp4g21220.1", "Mp6g02620.1", "Mp8g18310.1",
				   "Mp4g00180.1", "Mp8g01770.1", "Mp7g00860.1", "Mp6g04650.1",
				   "Mp5g18910.1", "Mp5g12480.1", "Mp3g06860.1", "Mp1g13740.1",
				   "Mp8g14220.1", "Mp1g02860.1", "Mp8g04130.1", "Mp2g20960.1",
				   "Mp4g12490.1", "Mp1g27060.1", "Mp5g14820.1", "Mp6g08290.1",
				   "Mp3g10500.1", "Mp6g08290.1", "Mp3g10500.1", "Mp1g25720.1",
				   "Mp1g13010.1", "Mp1g19730.1", "Mp3g19670.1", "Mp5g10280.1",
				   "Mp1g04550.1", "Mp4g22280.1", "Mp1g25140.1", "Mp2g00890.1"]

tfdf_short = tfdf[tfdf.gene.isin(tf_list_ordered)]
tfdf_short.set_index("gene", inplace=True)
tfdf_short_filtered = tfdf_short.reindex(tf_list_ordered)
tfdf_short_filtered = tfdf_short_filtered.iloc[:, 3:-2]

# duplicate stress columns
tfdf_short_filtered.replace({"DOWN": 1, "UP": 2}, inplace=True)
tfdf_short_filtered.fillna(0, inplace=True)
tfplot_all = tfdf_short_filtered[clustered_stresses]

# Plot heatmap
fig, ax = plt.subplots(figsize=(7,6))

# heatmap, tick and tick labels
from matplotlib.colors import ListedColormap
cmap = ListedColormap(["lightgray", "royalblue", "firebrick"])
catno = 3
cbarticks = [(x/(catno*2))*(catno-1) for x in range(1,catno*2,2)]

hplot = ax.imshow(tfplot_all, cmap=cmap)
ax.yaxis.tick_right()
ax.set_ylabel("")
ax.set_xticks(np.arange(0, len(tfplot_all.columns), 1))
ax.set_yticks(np.arange(0, len(tfplot_all), 1))

ax.set_xticklabels(tfplot_all.columns.to_list(), rotation=90)
ax.set_yticklabels(tfplot_all.index.to_list())

xcolour = ["k"]*6 +  ["firebrick"]*5 + ['royalblue']*6 + ["mediumseagreen"]*5 + ["mediumorchid"]*7 + ["darkkhaki"]*7 + ["darkorange"]*7
for i, tick_label in enumerate(ax.get_xticklabels()):
	tick_text = tick_label.get_text()
	tick_label.set_color(xcolour[i])
	
plt.savefig(dir_path + 'figures/' + 'Fig3B.png', dpi=600, bbox_inches="tight")

### Figure 3C: Specific expression in GRN TFs

In [None]:
import pandas as pd

nw_anno = pd.read_csv(dir_path + 'prep_files/condensed_deg_anno_nwonly.txt', header=0, sep="\t")

all_stresses = ['C', 'CD', 'CL', 'CM', 'CN', 'CS', 'D', 'H', 'HD',
				   'HM', 'HN', 'HS', 'L', 'LS', 'M', 'MD', 'ML','MN', 'N',
				   'ND', 'NL', 'S', 'SD', 'SM', 'SN']
stresses = ["C", "S", "M", "H", "D", "L", "N"]
full_stresses = ["Cold", "Salt", "Mannitol", "Heat", "Dark", "Light", "Nitrogen deficiency"]
clustered_stresses = [[y for y in all_stresses if x in y] for x in stresses]

# get TFs
tfdf = nw_anno[-nw_anno.TF_anno.isna()]

def get_cluster_df(idx, cluster):
	subset = tfdf[["gene", "TF_anno"] + cluster]
	subset.dropna(subset=cluster, how="all", inplace=True)
	cluster_len = len(cluster)
	subset["specificity"] = subset.apply(lambda x: x.iloc[2:].value_counts().max()/cluster_len, axis=1)
	subset["top_direction"] = subset.apply(lambda x: x.iloc[2:-1].value_counts().idxmax(), axis=1)
	subset["cluster"] = full_stresses[idx]
	
	return(subset[["cluster", "gene", "TF_anno", "specificity", "top_direction"]])

df_list = [get_cluster_df(i, cluster) for i, cluster in enumerate(clustered_stresses)]
cluster_df = pd.concat(df_list)

cluster_df.to_csv(elnet_dir + 'TF_Scond_spec.txt', index=False, sep="\t")

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

stresses = ["C", "S", "M", "H", "D", "L", "N"]
spec_path = elnet_dir + 'TF_Scond_spec.txt'
spec_df = pd.read_csv(spec_path, header=0, sep="\t")
spec_df["longname"] = spec_df.gene + " (" + spec_df.TF_anno + ")"
spec_df_subset = spec_df[spec_df.specificity > 0.7]
spec_TF = spec_df_subset.longname.unique().tolist()

def update_df(idx, col, val, df):
	df.loc[idx, col] = val
	
# initialise dataframe
plot_spec_df = pd.DataFrame(np.nan, index = spec_TF, columns = stresses)
plot_dir_df = pd.DataFrame(np.nan, index = spec_TF, columns = stresses)
plot_spec_dir_df = pd.DataFrame(np.nan, index = spec_TF, columns = stresses)
# update dataframe
spec_df_subset.apply(lambda x: update_df(df=plot_spec_df, idx=x.longname, col=x.cluster[0], val=x.specificity), axis=1)
spec_df_subset.apply(lambda x: update_df(df=plot_dir_df, idx=x.longname, col=x.cluster[0], val=x.top_direction), axis=1)
spec_df_subset.apply(lambda x: update_df(df=plot_spec_dir_df, idx=x.longname, col=x.cluster[0], val=x.top_direction), axis=1)
# change values to suit plotting
plot_dir_df.replace({"DOWN": 1, "UP": 2}, inplace=True)
plot_dir_df.fillna(0, inplace=True)

plot_spec_dir_df.replace({"DOWN": -1, "UP": 1}, inplace=True)
plot_dendro = plot_spec_dir_df * plot_spec_df
plot_dendro.fillna(0, inplace=True)
#%% colour palette
from collections import Counter
palette = sns.color_palette()
palette.append((166/255,163/255,162/255))
TF_list =  plot_spec_df.index.to_list()
TF_fam_count = Counter([x.split("(")[1].split(")")[0] for x in TF_list])
TF_fam_big = [k for k, v in TF_fam_count.items() if v > 3]
dict_col = {k : palette[i] for i, k in enumerate(TF_fam_big) if TF_fam_count[k] > 3}

# to get status of TF (ycolour)
elnet_path = elnet_dir + union_0.8_ignoreAGRIS_topTF.txt'
elnet_df = pd.read_csv(elnet_path, header=0, sep="\t")

def get_TFstat(tf):
	stat_dict = Counter(elnet_df[elnet_df["Gene.ID"] == tf].TF_stat)
	TF_stat_max = max(stat_dict, key=stat_dict.get)
	if TF_stat_max == 'Activator':
		return("limegreen")
	elif TF_stat_max == 'Ambiguous':
		return("yellow")
	else:
		return("firebrick")

# Count number of Activators, Ambiguous and Repressors in all 95 Mpo TFs.
tf_list = elnet_df["Gene.ID"].unique().tolist()

def get_TFstat_str(tf):
	stat_dict = Counter(elnet_df[elnet_df["Gene.ID"] == tf].TF_stat)
	TF_stat_max = max(stat_dict, key=stat_dict.get)
	return TF_stat_max

tf_stat_str_list = [get_TFstat_str(x) for x in tf_list]
stat_count = Counter(tf_stat_str_list) # Counter({'Activator': 75, 'Repressor': 19, 'Ambiguous': 1})
	
ycolour = [get_TFstat(tf) for tf in [x.split(" (")[0] for x in TF_list]]

#%% Plotting
g = sns.clustermap(plot_dendro, cmap="coolwarm",
				   yticklabels=True, row_colors = ycolour,
				   figsize=(8, 14), cbar_pos=(0.02, 0.9, 0.05, 0.08),
				   dendrogram_ratio=(0.2,0.1)) 

ax = g.ax_heatmap

for tick_label in ax.get_yticklabels():
	tick_text = tick_label.get_text().split("(")[1].split(")")[0]
	if tick_text in dict_col:
		tick_label.set_color(dict_col[tick_text])
	else:
		tick_label.set_color(palette[-1])
plt.savefig(dir_path + 'figures/' + 'Fig3C.png', dpi=600)

#%% Fill df with more info
plot_dendro["spec_count"] = plot_dendro.apply(lambda x: sum(x != 0), axis=1)
plot_dendro.to_csv(elnet_dir + 'TF_Scond_specificity_07.txt', sep="\t")

### Supp. Fig 12: Influence of TFs

In [None]:
"""
Extent of influence for each transcription factor
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

tf_nw_p = dir_path + 'prep_files/network_cf22_spec.txt'
top_nw_p = elnet_dir + 'union_0.8_ignoreAGRIS_topTF.txt'

tf_nw = pd.read_csv(tf_nw_p, header=0, sep="\t")
top_nw = pd.read_csv(top_nw_p, header=0, sep="\t") 

tf_list = top_nw["Gene.ID"].unique().tolist()

# get number of direct genes
tf_direct_genes = [top_nw[top_nw["Gene.ID"] == tf].predicted.to_list() for tf in tf_list]
tf_direct_gcount = [len(x) for x in tf_direct_genes]
tf_direct_tf = [[gene for gene in top_nw[top_nw["Gene.ID"] == tf].predicted.to_list() if gene in tf_list] for tf in tf_list]
tf_direct_tfcount = [len(x) for x in tf_direct_tf]
direct_genedf = pd.DataFrame({"gene_list" : tf_direct_genes,
						 "gene_count" : tf_direct_gcount,
						 "tf_list" : tf_direct_tf,
						 "tf_count" : tf_direct_tfcount},
						 index=tf_list)
# get number of direct TFs
tf_direct_tf = [[gene for gene in tf_nw[tf_nw["Gene.ID"] == tf].predicted.to_list() if gene in tf_list] for tf in tf_list]
tf_direct_tfcount = [len(x) for x in tf_direct_tf]
tf_incoming_tf = [[gene for gene in tf_nw[tf_nw.predicted == tf]["Gene.ID"].to_list() if gene in tf_list] for tf in tf_list]
tf_incoming_tfcount = [len(x) for x in tf_incoming_tf]
direct_tfdf = pd.DataFrame({"tf_list" : tf_direct_tf,
						 "tf_count" : tf_direct_tfcount,
						 "incoming_tf_list": tf_incoming_tf,
						 "incoming_tf_count": tf_incoming_tfcount},
						 index=tf_list)

# plot number of outgoing edges (TFs)
tf_count = direct_tfdf.tf_count.value_counts()
tf_count_index = list(tf_count.keys())
tf_count_index.sort()
tf_count_values = [tf_count[x] for x in tf_count_index]
plt.bar(tf_count_index, tf_count_values)
plt.title("Number of TFs in first neighbourhood")
plt.xlabel("Count")
plt.ylabel("Frequency")
#plt.savefig('G:/My Drive/Projects/Marchantia_2019/grn/TF_hierarchy/nbh_tf.png', dpi=600)
plt.show()

# get number of outgoing edges (genes and TFs)

all_direct_genes = [list(set(direct_genedf.loc[tf].gene_list + direct_tfdf.loc[tf].tf_list + [z for y in [direct_genedf.loc[x].gene_list for x in direct_tfdf.loc[tf].tf_list] for z in y])) for tf in tf_list]
len_all_direct_genes = [len(x) for x in all_direct_genes]
	
direct_tfdf["direct_genes"] = all_direct_genes
direct_tfdf["gene_count"] = len_all_direct_genes

direct_tfdf.to_csv(dir_path + 'prep_files/influence_nw_anno.txt', sep="\t")

#%% get stress specificity source (for labelling)
desc_path = elnet_dir + 'TF_Scond_specificity_07.txt'
desc_file = pd.read_csv(desc_path, header=0, sep="\t")
desc_file.columns = ["desc"] + desc_file.columns.to_list()[1:]
desc_file["gene"] = desc_file.desc.apply(lambda x: x.split(" ")[0])

stress_index = desc_file.iloc[:,1:8].columns.to_list()
def stress_source(df_row):
	status = [stress_index[i] for i, x in enumerate(list(df_row)) if x != 0]
	desc = ", ".join(status)
	return desc
desc_file["stress_source"] = desc_file.iloc[:,1:8].apply(lambda x: stress_source(x), axis=1)

#%% get stress score: estimated impact (and similarity?)
valid_genes = desc_file.gene.to_list()
direct_tfdf["stress_source"] = [desc_file[desc_file.gene == tf].stress_source.values[0] if tf in valid_genes else "X" for tf in direct_tfdf.index.to_list()]
stresses = ["C", "S", "M", "H", "D", "L", "N"]
stress_tf_list = [direct_tfdf[[s in x for x in direct_tfdf.stress_source.to_list()]] for s in stresses]
stress_tf_count = [len(x) for x in stress_tf_list]
stress_tfnbh_list = [list(set([z for y in direct_tfdf[[s in x for x in direct_tfdf.stress_source.to_list()]].tf_list.to_list() for z in y + direct_tfdf[[s in x for x in direct_tfdf.stress_source.to_list()]].index.to_list()])) for s in stresses]
stress_tfnbh_count = [len(x) for x in stress_tfnbh_list]
# [40, 9, 16, 19, 34, 59, 56]
stress_nbh_list = [list(set([y for x in direct_genedf.loc[stress_tfnbh_list[i]].gene_list.to_list() + stress_tfnbh_list[i] for y in x])) for i, s in enumerate(stresses)]
stress_nbh_count = [len(x) for x in stress_nbh_list]

stress_stats = pd.DataFrame({"stress" : stresses,
							 "tf" : stress_tf_list,
							 "tf_count": stress_tf_count,
							 "tf_nbh": stress_tfnbh_list,
							 "tf_nbh_count": stress_tfnbh_count,
							 "gene_nbh": stress_nbh_list,
							 "gene_nbh_count": stress_nbh_count
							 })

stress_stats.plot.bar("stress", logy=True)
plt.legend(["TFs", "TFs in neighbourhood", "Genes in neighbourhood"], bbox_to_anchor=(1.0, 1.02))
plt.ylabel("log(Count)")
plt.xticks([i for i in range(len(stresses))], ["Cold", "Salt", "Mannitol", "Heat", "Dark", "Light", "Nitrogen\ndeficiency"], rotation="horizontal")
plt.savefig(dir_path + 'figures/' + 'FigS12.png', dpi=600, bbox_inches="tight")

### Supp. Fig 13: Robustly responding second-level MapMan bins across the 7 abiotic stresses

In [None]:
from ast import literal_eval
import random
from collections import defaultdict, Counter
from statsmodels.stats.multitest import multipletests
import math
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

dicto = literal_eval(open(dir_path + 'mercator/merdict.txt', "r").read())

meranno = defaultdict(list)
merbin = defaultdict(list)
map2anno = {}

mfile = dir_path + 'mercator/MpoProt.results.txt'
merfile = open(mfile, 'r')
merfile.readline()
for line in merfile:
	linecon = line.rstrip().replace("'", "").split("\t")
	if len(linecon) == 5:
		bincode, name, identifier, desc, ptype = linecon
		meranno[identifier].append(dicto[int(bincode.split('.')[0])])
		merbin[identifier].append('.'.join(bincode.split('.')[:2]))
	if len(linecon[0].split('.')) == 2:
		map2anno[linecon[0]] = linecon[1]

#%% Initialise
wdir = mpo_path
data = pd.read_csv(wdir + 'resSig_compiled.txt', sep = '\t')

all_s = [x.split("_")[1] for x in data.stress.unique()]
single = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 1]
cross = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 2]

data.annotation = data.annotation.apply(literal_eval)
data["mername"] = data.annotation.apply(lambda x: [dicto[int(y[0].split('.')[0])] for y in x])

dict_A = defaultdict(list)
dict_U = defaultdict(list)
dict_D = defaultdict(list)

def sum_to_dict(dicto, stress, reg):
	if reg == "ALL":
		subset = data[(data.stress == "Mpo_" + stress)]
	else:
		subset = data[(data.stress == "Mpo_" + stress) & (data.L2FC_D2 == reg)]
	dicto[stress].append(set(subset.gene.to_list()))
	dicto[stress].append([y for x in subset.mername.to_list() for y in x])
	dicto[stress].append(['.'.join(y[0].split('.')[:2]) for x in subset.annotation.to_list() for y in x])

def dict_to_df(dicto):
	df = pd.DataFrame.from_dict(dicto, orient='index', columns=["gene", "mername", "mapbin2"])
	return df

for s in all_s:
	sum_to_dict(dict_A, s, "ALL")
	sum_to_dict(dict_U, s, "UP")
	sum_to_dict(dict_D, s, "DOWN")

df_U = dict_to_df(dict_U)
df_D = dict_to_df(dict_D)

#%% Definitions for enrichment
def sig_df(df, sigcol, merdict, mapbins):
	"""
	Calculates and correct mapman bin enrichment p-value for all stresses
	Returns dataframe

	Parameters
	----------
	df : dataframe
		df containing genes and corresponding mapman bins of DEGs.
	sigcol : str
		column name to use for enrichment
	merdict : dict
		corresponding dictionary of mapman annotation/ 2nd level bins to use
	mapbins : list
		list of mapman annotation/bins to use

	Returns
	-------
	df_sig : dataframe
		df summarising enrichment (corrected p-value) for each mapman bin (row)
		and each stress (column).

	"""
	sig_sum = {}
	for s in all_s:
		s_count = Counter(df.loc[s][sigcol])
		valid_bins = list(s_count.keys()) # bins found in stress
		# initialise count dicitonary
		sig_count = {}
		for key in valid_bins:
			sig_count[key] = 1
		# random simulations
		for i in range(1000):
			shuffle = list(merdict.values())
			random.shuffle(shuffle)
			sub = shuffle[:len(df.loc[s].gene)]
			sub_count = Counter([y for x in sub for y in x])
			for mapman in valid_bins:
				if sub_count[mapman] >= s_count[mapman]:
					sig_count[mapman] += 1
		# p-value calculation
		pval_coll = []
		for mapman in valid_bins:
			pval = sig_count[mapman]/1000
			# correction for pval > 1
			if pval <= 1:
				pval_coll.append(pval)
			else:
				pval_coll.append(float(round(pval)))
		
		# BH correction for multiple testing
		y = multipletests(pvals=pval_coll, alpha=0.05, method="fdr_bh")[1]
		all_bins_corr_pval = []
		for mapman in mapbins:
			if mapman in valid_bins:
				all_bins_corr_pval.append(y[valid_bins.index(mapman)])
			else:
				all_bins_corr_pval.append(None)
		sig_sum[s] = all_bins_corr_pval

	df_sig = pd.DataFrame.from_dict(sig_sum, orient='index', columns=mapbins)
	return df_sig
	
def chunk(uval, dval):
	if math.isnan(uval) and math.isnan(dval):
		# not differentially regulated
		cat = 0
	elif  uval >= 0.05 and (dval >= 0.05 or math.isnan(dval)):
		# not enriched
		cat = 0
	elif  dval >= 0.05 and (uval >= 0.05 or math.isnan(uval)):
		# not enriched
		cat = 0
	elif  uval < 0.05 and dval < 0.05:
		# differentially up and downregulated in bin
		cat = 2
	elif dval < 0.05:
		# differentially downregualted
		cat = 1
	elif uval < 0.05:
		# differentially upregulated
		cat = 3
	return cat


# =============================================================================
# 
# Enrichment (Part 2: 2nd level Mapman)
# 
# =============================================================================
mapbins2 = list(set([y for x in list(merbin.values()) for y in x]))
mapbins2.sort(key=lambda x: (int(x.split('.')[0]), int(x.split('.')[1])))

df_sig_U2 = sig_df(df_U, 'mapbin2', merbin, mapbins2)
df_sig_D2 = sig_df(df_D, 'mapbin2', merbin, mapbins2)
df_sig_U2 = df_sig_U2.fillna(value=np.nan)
df_sig_D2 = df_sig_D2.fillna(value=np.nan)

cat_dict2 = {}
for mapman in list(df_sig_U2.columns):
	cat_col = []
	for stress in list(df_sig_U2.index):
		uval, dval = df_sig_U2.loc[stress, mapman], df_sig_D2.loc[stress, mapman]
		cat_col.append(chunk(uval,dval))
	cat_dict2[mapman] = cat_col

df_combined_sig2 = pd.DataFrame.from_dict(cat_dict2, orient='index', columns=list(df_sig_U2.index))
df_combined_sig2 = df_combined_sig2.loc[df_combined_sig2.max(axis=1) > 0,:]
df_combined_sig2 = df_combined_sig2.loc[(df_combined_sig2 > 0).sum(axis=1) >2,:]

df_combined_sig2.reset_index(inplace=True)
df_combined_sig2['index'] = df_combined_sig2['index'].apply(lambda x: map2anno[x])
df_combined_sig2.set_index('index', inplace=True)
df_combined_sig2.replace({0: np.nan, 1: "DOWN", 2: "UP_DOWN", 3: "UP"}, inplace=True)

df_combined_sig2.to_csv(mpo_path + 'map2_collapsed_df.txt', sep="\t")

#%% collapse to cluster
all_stresses = ['C', 'CD', 'CL', 'CM', 'CN', 'CS', 'D', 'H', 'HD',
				   'HM', 'HN', 'HS', 'L', 'LS', 'M', 'MD', 'ML','MN', 'N',
				   'ND', 'NL', 'S', 'SD', 'SM', 'SN']
stresses = ["C", "S", "M", "H", "D", "L", "N"]
full_stresses = ["Cold", "Salt", "Mannitol", "Heat", "Dark", "Light", "Nitrogen deficiency"]
clustered_stresses = [[y for y in all_stresses if x in y] for x in stresses]

def get_cluster_df(idx, cluster, df):
	subset = df[cluster]
	subset.dropna(subset=cluster, how="all", inplace=True)
	cluster_len = len(cluster)
	# https://www.codegrepper.com/code-examples/python/find+max+value+index+in+value+count+pandas
	subset["specificity"] = subset.apply(lambda x: x.value_counts().max()/cluster_len, axis=1)
	subset["top_direction"] = subset.apply(lambda x: x.iloc[:-1].value_counts().idxmax(), axis=1)
	subset["cluster"] = full_stresses[idx]
	subset.reset_index(inplace=True)
	return(subset[["index", "cluster", "specificity", "top_direction"]])

df_list = [get_cluster_df(i, cluster, df_combined_sig2) for i, cluster in enumerate(clustered_stresses)]
cluster_df = pd.concat(df_list)
cluster_df.to_csv(mpo_path + "mapman_spec_df.txt", sep="\t", index=None)
#%% preparation for plotting
def update_df(idx, col, val, df):
	df.loc[idx, col] = val

mapbin2_list = [x.split(".")[1] for x in cluster_df["index"].unique().tolist()]

#%% For filtered heatmap; specificity > 0.7

cluster_df = cluster_df[cluster_df.specificity > 0.7]
cluster_df.to_csv(mpo_path + "mapman_spec_df_07.txt", sep="\t", index=None)
mapbin2_list = [x.split(".")[1] for x in cluster_df["index"].unique().tolist()]

# initialise dataframe
plot_spec_df = pd.DataFrame(np.nan, index = mapbin2_list, columns = stresses)
plot_spec_dir_df = pd.DataFrame(np.nan, index = mapbin2_list, columns = stresses)
# update dataframe
cluster_df.apply(lambda x: update_df(df=plot_spec_df, idx=x["index"].split(".")[1], col=x.cluster[0], val=x.specificity), axis=1)
cluster_df.apply(lambda x: update_df(df=plot_spec_dir_df, idx=x["index"].split(".")[1], col=x.cluster[0], val=x.top_direction), axis=1)
# change values to suit plotting
plot_spec_dir_df.replace({"DOWN": -1, "UP_DOWN": 0, "UP": 1}, inplace=True)
plot_dendro = plot_spec_dir_df * plot_spec_df
plot_dendro.fillna(0.00000000001, inplace=True)
plot_dendro.dropna(how="all", inplace=True)

# Plot
sns.clustermap(plot_dendro, cmap="coolwarm",
			   yticklabels=True, mask = plot_dendro == 0.00000000001, figsize=(6, 8),
			   cbar_pos=(0.02, 0.9, 0.05, 0.08), dendrogram_ratio=(0.2,0.1)
			   )
plt.savefig(dir_path + 'figures/' + 'FigS13.png', dpi=600)

### Figure 4 B,C,D; Supp. Fig. 14, 15 Bipartite networks for robustly expressed TFs and biological processes
(Fig 4C, D and Supp. Fig. 15 is processed here but visualised in cytoscape)

In [None]:
# Preparation for Fig 4C
import pandas as pd
from ast import literal_eval
from collections import defaultdict, Counter
import random
from statsmodels.stats.multitest import multipletests
import numpy as np

mapman_spec07_path = mpo_path + "mapman_spec_df_07.txt"
mapman_spec07_df = pd.read_csv(mapman_spec07_path, header=0, sep="\t")

spec_lvl1bin = [x.split(".")[0] for x in mapman_spec07_df["index"].to_list()]

#%% mercator stuff
merdict = literal_eval(open(dir_path + 'mercator/merdict.txt', "r").read())
merbin = defaultdict(list)
map2anno = {}

mfile = dir_path + 'mercator/results/MpoProt.results.txt'
merfile = open(mfile, 'r')
merfile.readline()
for line in merfile:
	linecon = line.rstrip().replace("'", "").split("\t")
	if len(linecon) == 5:
		bincode, name, identifier, desc, ptype = linecon
		merbin[identifier].append('.'.join(name.split('.')[:2]))
	if len(linecon[0].split('.')) == 2:
		map2anno[linecon[0]] = linecon[1]

#%% TF-network enrichment (2nd level mapman)
nw_path = elnet_dir + "union_0.8_ignoreAGRIS_topTF.txt"
nw_df = pd.read_csv(nw_path, header=0, sep="\t")
nw_df["predicted_map2"] = nw_df.apply(lambda x: merbin[x.predicted.lower()], axis=1)

TF_list = nw_df["Gene.ID"].unique().tolist()
anno_list = list(merbin.values())
sig_df = pd.DataFrame(0, index=map2anno.values(), columns=TF_list)

for tf in TF_list:
	nbh = nw_df[nw_df["Gene.ID"] == tf]
	s_count = Counter([y for x in nbh.predicted_map2.to_list() for y in x])
	valid_bins = list(s_count.keys())
	sig_df[tf] = [1 if x in valid_bins else np.nan for x in sig_df.index.to_list()]
	
	for i in range(1000):
		sub = random.sample(anno_list, len(nbh))
		sub_count = Counter([y for x in sub for y in x])
		for mapman in valid_bins:
			if sub_count[mapman] >= s_count[mapman]:
				sig_df.loc[mapman, tf] += 1

sig_df.replace(1001, 1000, inplace=True)
sig_df /= 1000

#%% BH correction
sig_df_corr = sig_df.copy()
df_idx = sig_df_corr.index.to_list()
for tf in TF_list:
	pval_series = sig_df[tf].dropna()
	pval_index = pval_series.index.tolist()
	pval_list = pval_series.to_list()

	y = multipletests(pvals=pval_list, alpha=0.05, method="fdr_bh")[1]
	sig_df_corr[tf] = [y[pval_index.index(x)] if x in pval_index else np.nan for x in df_idx]
	
#%% Flatten
sig_bin_flat = pd.DataFrame(columns=["TF", "map2bin", "corr_pval"])
all_bin_flat = pd.DataFrame(columns=["TF", "map2bin", "corr_pval"])
for tf in TF_list:
	for mapbin in df_idx:
		pval = sig_df_corr.loc[mapbin, tf]
		if not np.isnan(pval):
			all_bin_flat.loc[len(all_bin_flat)] = [tf, mapbin, pval]
		if pval < 0.05:
			sig_bin_flat.loc[len(sig_bin_flat)] = [tf, mapbin, pval]

sig_bin_flat.to_csv(elnet_dir + 'TF_map2bin_enrich_sig.txt', sep="\t", index=False)
all_bin_flat.to_csv(elnet_dir + 'TF_map2bin_enrich_all.txt', sep="\t", index=False)

In [None]:
# Fig 4B
import pandas as pd
from ast import literal_eval
from collections import defaultdict, Counter
import matplotlib.pyplot as plt

TF_spec_path = elnet_dir + 'TF_Scond_specificity_07.txt'
TF_spec = pd.read_csv(TF_spec_path, header=0, sep="\t")

mapman_spec_path = mpo_path + 'mapman_spec_df_07.txt'
mapman_spec = pd.read_csv(mapman_spec_path, header=0, sep="\t")
mapman_spec["top_dir_val"] = mapman_spec.top_direction.replace({"DOWN": -1, "UP_DOWN": 0, "UP": 1})
mapman_spec["corr_specificity"] = mapman_spec.specificity * mapman_spec.top_dir_val

enrich_path = elnet_dir + 'TF_map2bin_enrich_sig.txt'
enriched_df = pd.read_csv(enrich_path, header=0, sep="\t")

stresses = ["C", "S", "M", "H", "D", "L", "N"]

#%% TF-enriched map2 (For Supp. Fig. 14 and 15)
for s in stresses:
	TF_subset = TF_spec[TF_spec[s] != 0]
	mapman_subset = mapman_spec[[x.startswith(s) for x in mapman_spec.cluster.to_list()]]
	
	spec_TF_long = [x for x in TF_subset.iloc[:,0].to_list()]
	spec_TF = [x.split(" (")[0] for x in spec_TF_long]
	spec_mapman = mapman_subset["index"].to_list()
	
	enriched_subset = enriched_df[(enriched_df.TF.isin(spec_TF)) & (enriched_df.map2bin.isin(spec_mapman))]
	if len(enriched_subset) != 0:
		enriched_subset["TF_spec"] = enriched_subset.apply(lambda x: TF_subset[TF_subset.iloc[:,0] == spec_TF_long[spec_TF.index(x.TF)]][s].values.tolist()[0], axis=1)
		enriched_subset["mapman_spec"] = enriched_subset.apply(lambda x: mapman_subset[(mapman_subset['index'] == x.map2bin) & ([x.startswith(s) for x in mapman_subset.cluster.to_list()])].corr_specificity.values.tolist()[0], axis=1)
		enriched_subset["TF_dir"] = enriched_subset.apply(lambda x: 'UP' if x.TF_spec > 0 else "DOWN", axis=1)
		enriched_subset["mapman_dir"] = enriched_subset.apply(lambda x: 'UP' if x.mapman_spec > 0 else "DOWN", axis=1)
		enriched_subset.to_csv(elnet_dir + 'TF_map2_spec_bipartite_' + s + '.txt', sep="\t", index=False)
	
enriched_all = enriched_df[(enriched_df.TF.isin([x.split(" (")[0] for x in TF_spec.iloc[:,0].to_list()])) & (enriched_df.map2bin.isin(mapman_spec["index"].to_list()))]

#%% TF-map2 (ratio of map2/all genes in map2)

# mapman stuff
merdict = literal_eval(open(dir_path + 'mercator/merdict.txt', "r").read())
merbin = defaultdict(list)
map2anno = {}

mfile = dir_path + 'mercator/MpoProt.results.txt'
merfile = open(mfile, 'r')
merfile.readline()
for line in merfile:
	linecon = line.rstrip().replace("'", "").split("\t")
	if len(linecon) == 5:
		bincode, name, identifier, desc, ptype = linecon
		merbin[identifier].append('.'.join(name.split('.')[:2]))
	if len(linecon[0].split('.')) == 2:
		map2anno[linecon[0]] = linecon[1]

bin_size = Counter([y for x in list(merbin.values()) for y in x])

# nw stuff
nw_path = elnet_dir + "union_0.8_ignoreAGRIS_topTF.txt"
nw_df = pd.read_csv(nw_path, header=0, sep="\t")
nw_df["predicted_map2"] = nw_df.apply(lambda x: merbin[x.predicted.lower()], axis=1)

#%% tf_stat anno

spec_path = elnet_dir + 'TF_Scond_spec.txt'
spec_df = pd.read_csv(spec_path, header=0, sep="\t")
tf_unique = spec_df.drop_duplicates(subset=['gene'])[["gene", "TF_anno"]]

elnet_path = elnet_dir + "union_0.8_ignoreAGRIS_topTF.txt'
elnet_df = pd.read_csv(elnet_path, header=0, sep="\t")

def get_TFstat(tf):
	stat_dict = Counter(elnet_df[elnet_df["Gene.ID"] == tf].TF_stat)
	TF_stat_max = max(stat_dict, key=stat_dict.get)
	if TF_stat_max == 'Activator':
		return("limegreen")
	elif TF_stat_max == 'Ambiguous':
		return("yellow")
	else:
		return("firebrick")
tf_unique["stat_colour"] = tf_unique.apply(lambda x: get_TFstat(x.gene), axis=1)
tf_unique["top_stat"] = tf_unique.apply(lambda x: "Activator" if x.stat_colour == "limegreen" else "Ambiguous" if x.stat_colour == "yellow" else "Repressor", axis=1)

tf_unique.to_csv(elnet_dir + 'TF_top_stat_anno.txt', sep="\t", index=False)

#%% calculate
df_collate = []
for s in stresses:
	counter = pd.DataFrame(columns=["TF", "Mapman", "TF_specificity", "Mapman_specificity", "Mapman_ratio", "TF_type"])
	TF_subset = TF_spec[TF_spec[s] != 0][["Unnamed: 0", s]]
	mapman_subset = mapman_spec[[x.startswith(s) for x in mapman_spec.cluster.to_list()]]
	for i, tf in enumerate([x.split(" (")[0] for x in TF_subset.iloc[:,0].to_list()]):
		tfsub = nw_df[nw_df["Gene.ID"] == tf]
		tf_spec = TF_subset.iloc[i, 1]
		tf_type = tf_unique[tf_unique.gene == tf].top_stat.to_list()[0]
		bin_count = Counter([y for x in tfsub.predicted_map2.to_list() for y in x])
		bin_norm = {k: v/bin_size[k] for k, v in bin_count.items()}
		for j, m in enumerate(mapman_subset["index"].to_list()):
			if m in bin_norm:
				map2_spec = mapman_subset.iloc[j, 5]
				map2_ratio = bin_norm[m]
				counter.loc[len(counter)] = [tf, m, tf_spec, map2_spec, map2_ratio, tf_type]
	# bipartite graphs
    counter.to_csv(elnet_dir + 'TF_map2_spec_bipartite_ratio_' + s + '.txt', sep="\t", index=False)
	counter['stress'] = s
	df_collate.append(counter)
	TF_unique = counter.drop_duplicates(subset=['TF'])
	mapman_unique = counter.drop_duplicates(subset=['Mapman'])
	node = TF_unique.TF.to_list() + mapman_unique.Mapman.to_list()
	specificity = TF_unique.TF_specificity.to_list() + mapman_unique.Mapman_specificity.to_list()
	combined = ["\t".join([node[i], str(specificity[i])]) + "\n" for i in range(len(node))]
	# annotation for graphs
    with open(elnet_dir + 'TF_map2_spec_bipartite_ratio_anno' + s + '.txt', "w+") as annof:
		for line in combined:
			annof.write(line)

df_all = pd.concat(df_collate)
df_cut05 = df_all[df_all.Mapman_ratio >= 0.05]

df_all.Mapman_ratio.plot.hist()
plt.xlabel("Number of genes regulated by TFx in binY/ Total number of genes in binY")
plt.savefig(dir_path + 'figures/' + 'FigS14.png', dpi=600)

def expected_outcome(tf_spec, mapman_spec, tf_type):
	if tf_spec > 0:
		if mapman_spec > 0:
			if tf_type == "Activator":
				return "Y"
			else:
				return "N"
		elif mapman_spec < 0:
			if tf_type == "Repressor":
				return "Y"
			else:
				return "N"
		else:
			return("U")
	elif tf_spec < 0:
		if mapman_spec < 0:
			if tf_type == "Activator":
				return "Y"
			else:
				return "N"
		elif mapman_spec > 0:
			if tf_type == "Repressor":
				return "Y"
			else:
				return "N"
		else:
			return("U")

df_cut05["expected"] = df_cut05.apply(lambda x: expected_outcome(x.TF_specificity, x.Mapman_specificity, x.TF_type) , axis=1)
count = Counter(df_cut05.expected)
plt.pie(count.values(), labels=["Expected" if x == "Y" else "Not expected" if x ==  "N" else "Ambiguous" for x in count.keys()])
count_vals = list(count.values())[::-1]
count_keys = ["Expected" if x == "Y" else "Not expected" if x ==  "N" else "Ambiguous" for x in count.keys()][::-1]
bottom_list = [0, count_vals[0], count_vals[1] + count_vals[0]]
# Fig 4B
fig, ax = plt.subplots(figsize=(1,6))
for i, x in enumerate(count_vals):
	ax.bar("relationship", x, width = 0.3, label=count_keys[i], bottom=bottom_list[i])
plt.ylabel("Count")
ax.legend(bbox_to_anchor=(1.3, -0.06))

plt.savefig(dir_path + 'figures/' + 'Fig4B.png', bbox_inches='tight', dpi=600)

### Figure 5: Annotation of Ath orthologs with evidence from literature

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
from matplotlib import colors
		
hm_path = dir_path + 'prep_files/anno_hm_compat.txt'
hm_df = pd.read_csv(hm_path, header=[0,1], index_col=[0,1], sep="\t")
hm_df.replace(np.nan, 0, inplace=True)
hm_df.replace("L", 2, inplace=True)
hm_df.replace("GO", 2, inplace=True)
hm_df.replace("T", 1, inplace=True)

xlabels = [x[1].replace("\r","") for x in hm_df.columns.to_list()[1:]]
xlabels = [x.replace("Nitrogen", "Nitrogen\ndefeciency") for x in xlabels]

ylabels = [x[0] if type(x[0]) != float else "No specific response" for x in hm_df.index.to_list()]
ylabel_unique = []
ylabel_pos = []

for lab in list(set(ylabels)):
	ylabel_unique.append(lab)
	pos = [i for i, x in enumerate(ylabels) if x == lab]
	if pos == 1:
		ylabel_pos.append(pos)
	else:
		ylabel_pos.append((pos[0] + pos[-1])/2)
ylabel_unique_mod = [x.replace("Light/Nitrogen", "           Light/Nitrogen") for x in ylabel_unique]
ylabel_unique_mod = [x.replace("Cold/Nitrogen", "           Cold/Nitrogen") for x in ylabel_unique_mod]

yunqiue_ordered = []
for l in ylabels:
	if l not in yunqiue_ordered:
		yunqiue_ordered.append(l)

#%% plotting

row_col_list = ["orange", "lemonchiffon", "rosybrown",
				"peachpuff", "orange", "lemonchiffon",
				"rosybrown", "peachpuff", "orange",
				"lemonchiffon", "rosybrown", "peachpuff",
				"orange", "lemonchiffon", "gray"]

row_col_list = [colors.CSS4_COLORS[x] for x in row_col_list]
yrowcol = [row_col_list[yunqiue_ordered.index(y)] for y in ylabels]

colour_list = ListedColormap(["white", "royalblue", "limegreen"])
g = sns.clustermap(hm_df.iloc[:,1:].values, cmap=colour_list,
				   col_cluster=False, row_cluster=False,
				   cbar_kws={"ticks": [1/3, 3/3, 5/3], },
				   figsize=(6,10), row_colors=yrowcol)

g.ax_heatmap.axes.set_yticks(ylabel_pos, ylabel_unique_mod)
g.ax_heatmap.axes.set_xticks([i+0.5 for i in range(len(xlabels))], xlabels, rotation=90)
g.cax.set_yticklabels(['NA', 'Observed', 'Literature'])
plt.savefig(dir_path + 'figures/' + 'Fig5A.png', dpi=600, bbox_inches="tight")

#%% stacked bar by row
stresses_lit = hm_df.columns.to_list()[1:8]
stresses_obs = hm_df.columns.to_list()[8:]
stress_lit = ["Heat", "Cold", "Salt", "Mannitol", "Light", "Dark", "Nitrogen", "nan"]
stress_obs = ["Heat", "Cold", "Salt", "Mannitol", "nan"]

def get_counts(stresses, stress):
	coll_df = pd.concat([hm_df[[s in str(x[0]) for x in hm_df.index.to_list()]][stresses].sum() for s in stress], axis=1)
	coll_df.columns = stress[:-1] + ["No specific response"]
	coll_df.index = [x[1].replace("\r\n", "") for x in coll_df.index.to_list()]
	coll_df_per = pd.DataFrame(index=coll_df.index.to_list())
	lit_sum = coll_df.sum()
	for col in coll_df.columns.to_list():
		coll_df_per[col] = coll_df[col].apply(lambda x: x/lit_sum[col] if lit_sum[col] > 0 else 0)
	return(coll_df_per)

# Plotting
lit_col = colors.ListedColormap(["steelblue", "darkorange", "mediumseagreen", "crimson", "mediumpurple","saddlebrown", "hotpink"])
obs_col = colors.ListedColormap(["steelblue", "darkorange", "mediumpurple","saddlebrown"])

col_lit = get_counts(stresses_lit, stress_lit)
col_obs = get_counts(stresses_obs, stress_obs)
col_lit.T.plot.bar(stacked=True, cmap=lit_col)
plt.legend(bbox_to_anchor = (1., 1.))
plt.ylabel("Ratio of Arabidopsis responses")
plt.xlabel("Response in Marchantia")
plt.savefig(dir_path + 'figures/' + 'Fig5B.png', dpi=600, bbox_inches="tight")

col_obs.T.plot.bar(stacked=True, cmap=obs_col, figsize=(4.9,5.2), fontsize=13)
plt.legend(bbox_to_anchor = (1., 1.))
plt.ylabel("Ratio of Arabidopsis responses", fontsize=13)
plt.xlabel("Response in Marchantia", fontsize=13)
plt.savefig(dir_path + 'figures/' + 'Fig5C.png', dpi=600, bbox_inches="tight")

#%% for supp table S8
for stress in [stress_lit, stress_obs]:
	coll_df = pd.concat([hm_df[[s in str(x[0]) for x in hm_df.index.to_list()]][stresses].sum() for s in stress], axis=1)
	coll_df.columns = stress[:-1] + ["No specific response"]
	coll_df.index = [x[1].replace("\r\n", "") for x in coll_df.index.to_list()]
	formatted = coll_df.T
	formatted["Number of corresponding orthologs"] = [len(hm_df[[s in str(x[0]) for x in hm_df.index.to_list()]]) for s in stress]
	if len(stress) == 8:
		fname = "lit_count.txt"
	else:
		fname = "obs_count.txt"
	formatted.to_csv(dir_path + '/prep_files/' + fname, sep="\t")


### Figure 6 A, B: Effects of combined stress in terms of significant L2FC

In [None]:
# Figure 6A
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import cm
from matplotlib.colors import ListedColormap

#%% intialisation
stresses = ["C", "S", "M", "H", "D", "L", "N"]
stresses_long = ["Cold", "Salt", "Mannitol", "Heat", "Dark", "Light", "Nitrogen deficiency"]
all_stresses = ['C', 'CD', 'CL', 'CM', 'CN', 'CS', 'D', 'H', 'HD',
				'HM', 'HN', 'HS', 'L', 'LS', 'M', 'MD', 'ML','MN',
				'N', 'ND', 'NL', 'S', 'SD', 'SM', 'SN']
combined_stress = [x for x in all_stresses if len(x) == 2]

#%% get averaged L2FC files
deseq_dir = mpo_path
def deg_stat(H2, D2):
	if H2 > 0 and D2 > 0:
		return("UP")
	elif H2 < 0 and D2 < 0:
		return("DOWN")
	else:
		return("MIXED")
	
def get_resSig(stress):
	H2df = pd.read_csv(deseq_dir + stress + "controlH2_resSig.tsv", header=0, sep="\t")
	D2df = pd.read_csv(deseq_dir + stress + "controlD2_resSig.tsv", header=0, sep="\t")
	H2df.columns = ['gene', 'baseMean', 'L2FC_H2', 'lfcSE', 'stat', 'pvalue', 'padj']
	D2df.columns = ['gene', 'baseMean', 'L2FC_D2', 'lfcSE', 'stat', 'pvalue', 'padj']
	merged = pd.merge(H2df[['gene', 'L2FC_H2']], D2df[['gene', 'L2FC_D2']], on='gene', how='outer')
	merged.dropna(inplace=True)
	merged['deg_stat'] = merged.apply(lambda x: deg_stat(x.L2FC_H2, x.L2FC_D2) , axis=1)
	merged_filtered = merged[merged.deg_stat != "MIXED"]
	merged_filtered["L2FC_avg"] = merged_filtered.apply(lambda x: sum([x.L2FC_H2, x.L2FC_D2])/2, axis=1)
	merged_filtered["stress"] = stress
	return merged_filtered

resSig_list = [get_resSig(x) for x in all_stresses]
resSig_df = pd.concat(resSig_list)
resSig_df.to_csv(deseq_dir + "resSig_L2FC_compiled.txt", sep="\t", index = False)

#%% get stats for each gene
cds_path = dir_path + 'prep_files/Mpo.cds.fasta'
all_genes = [x[1:-1] for x in open(cds_path, "r").readlines() if ">" in x]

deg_path = mpo_path + 'resSig_L2FC_compiled.txt'
deg_df = pd.read_csv(deg_path, header=0, sep="\t")

int_class_df = pd.DataFrame(columns = ["gene", "Xname", "Yname", "XYname", "Xl2fc", "Yl2fc", "XYl2fc"])

for sxy in combined_stress:
	print("Calculating: " + sxy)
	sx = sxy[0]
	sy = sxy[1]
	for g in all_genes:
		stat_series = [deg_df[(deg_df.stress == x) & (deg_df.gene == g)].L2FC_avg for x in [sx, sy, sxy]]
		stat_list = ["NC" if len(x) == 0 else x.to_list()[0] for x in stat_series]
		int_class_df.loc[len(int_class_df)] = [g, sx, sy, sxy] + stat_list

int_class_df.to_csv(mpo_path + 'stress_int_class_l2fc.txt', sep="\t", index=None)

# get reverse of XY i.e. YX for the sake of counting
int_class_rev_df = int_class_df[["gene", "Yname", "Xname", "XYname", "Yl2fc", "Xl2fc", "XYl2fc"]]
int_class_rev_df.columns = ["gene", "Xname", "Yname", "XYname", "Xl2fc", "Yl2fc", "XYl2fc"]

int_class_all_df = pd.concat([int_class_df, int_class_rev_df])
int_class_all_df["Xstat"] = [x if x == "NC" else "UP" if float(x) > 0 else "DOWN" for x in int_class_all_df.Xl2fc.to_list()]
int_class_all_df["Ystat"] = [x if x == "NC" else "UP" if float(x) > 0 else "DOWN" for x in int_class_all_df.Yl2fc.to_list()]
int_class_all_df["XYstat"] = [x if x == "NC" else "UP" if float(x) > 0 else "DOWN" for x in int_class_all_df.XYl2fc.to_list()]
int_class_all_df["Xstat_Ystat"] = int_class_all_df.Xstat + "_" + int_class_all_df.Ystat
int_class_all_df["all_stat"] = int_class_all_df["Xstat_Ystat"] + "_" + int_class_all_df["XYstat"]


for col in ["Xl2fc", "Yl2fc", "XYl2fc"]:
	int_class_all_df[col] = int_class_all_df[col].replace("NC", 0)
	int_class_all_df[col] = int_class_all_df[col].apply(float)

#%% all_stat
all_reshaped = int_class_all_df.groupby(["Xname", "all_stat"])[["Xl2fc", "Yl2fc", "XYl2fc"]].mean()
s_reshaped = [all_reshaped.xs(x, level="Xname") for x in stresses]
merged_flat = pd.concat(s_reshaped, axis=1)

col1 = [x for x in stresses for i in range(3)]
col2 = merged_flat.columns.to_list()
index = pd.MultiIndex.from_tuples(zip(col1, col2), names=["Xname", "l2fc"])

merged_flat.columns = index

plt.figure(figsize=(6, 7))
sns.heatmap(merged_flat, cmap="coolwarm", yticklabels=True)

#%% Xstat_Ystat

two_reshaped = int_class_all_df.groupby(["Xname", "Xstat_Ystat"])[["Xl2fc", "Yl2fc", "XYl2fc"]].mean()
s_reshaped = [two_reshaped.xs(x, level="Xname") for x in stresses]
merged_flat = pd.concat(s_reshaped, axis=1)

col1 = [x for x in stresses for i in range(3)]
col2 = merged_flat.columns.to_list()
index = pd.MultiIndex.from_tuples(zip(col1, col2), names=["Xname", "l2fc"])

merged_flat.columns = index

merged_flat.to_csv(mpo_path + 'stress_int_class_l2fc_avg.txt', sep="\t")

plt.figure(figsize=(11, 2.3))
sns.heatmap(merged_flat, cmap="coolwarm", yticklabels=True, annot=merged_flat, fmt=",.1f")
plt.savefig(dir_path + 'figures/' + 'Fig6A.png', dpi=600)

In [None]:
#Figure 6B: Scatter plot

int_class_all_df = pd.read_csv(mpo_path + 'stress_int_class_l2fc_extended.txt', sep="\t", header=0)
# calculate difference between Sxy and Sx or Sy (depending if it is higher than both or lower than both)
int_class_all_df["XY_val"] = int_class_all_df.apply(lambda x: x.XYl2fc - max(x.Xl2fc, x.Yl2fc) if x.XYl2fc > x.Xl2fc and x.XYl2fc > x.Yl2fc else x.XYl2fc - min(x.Xl2fc, x.Yl2fc) if x.XYl2fc < x.Xl2fc and x.XYl2fc < x.Yl2fc else 0, axis=1)
int_class_all_df.to_csv(mpo_path + 'stress_int_class_l2fc_extended.txt', sep="\t", index=None)

# Set limits for scatter plot for general overview
int_class_all_df_sub = int_class_all_df[(abs(int_class_all_df.Xl2fc) < 20) & (abs(int_class_all_df.Yl2fc) < 20)]
int_class_all_df_sub["XY_relstat"] = int_class_all_df_sub.apply(lambda x: "firebrick" if x.XY_val > 0 else "royalblue" if x.XY_val < 0 else "lightgray", axis=1)

#%% plot
int_class_all_df_sub.plot.scatter(x="Xl2fc", y="Yl2fc", marker='.',
								  c="XY_relstat", linewidths=0, alpha=0.2,
								  xlim=(-12,12), ylim=(-12,12))
plt.savefig(dir_path + 'figures/' + 'Fig6B.png', dpi=600)

### Figure 6C: Classification of stress interactions

In [None]:
"""
Classification of interactions
"""

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

stresses = ["C", "S", "M", "H", "D", "L", "N"]
stresses_long = ["Cold", "Salt", "Mannitol", "Heat", "Dark", "Light", "Nitrogen deficiency"]
all_stresses = ['C', 'CD', 'CL', 'CM', 'CN', 'CS', 'D', 'H', 'HD',
				'HM', 'HN', 'HS', 'L', 'LS', 'M', 'MD', 'ML','MN',
				'N', 'ND', 'NL', 'S', 'SD', 'SM', 'SN']
combined_stress = [x for x in all_stresses if len(x) == 2]

cds_path = dir_path + 'prep_files/Mpo.cds.fasta'
all_genes = [x[1:-1] for x in open(cds_path, "r").readlines() if ">" in x]
deg_path = mpo_path + 'resSig_compiled.txt'
deg_df = pd.read_csv(deg_path, header=0, sep="\t")
int_class_df = pd.DataFrame(columns = ["gene", "Xname", "Yname", "XYname", "Xstat", "Ystat", "XYstat"])

# get status (up-, down-regulated and no change) for each gene of stress X, Y and XY 
for sxy in combined_stress:
	print("Calculating: " + sxy)
	sx = sxy[0]
	sy = sxy[1]
	for g in all_genes:
		stat_series = [deg_df[(deg_df.stress == "Mpo_" + x) & (deg_df.gene == g)].L2FC_D2 for x in [sx, sy, sxy]]
		stat_list = ["NC" if len(x) == 0 else x.to_list()[0] for x in stat_series]
		int_class_df.loc[len(int_class_df)] = [g, sx, sy, sxy] + stat_list

int_class_df.to_csv(mpo_path + 'stress_int_classification.txt', sep="\t", index=None)

# get reverse of XY i.e. YX for the sake of counting
int_class_rev_df = int_class_df[["gene", "Yname", "Xname", "XYname", "Ystat", "Xstat", "XYstat"]]
int_class_rev_df.columns = ["gene", "Xname", "Yname", "XYname", "Xstat", "Ystat", "XYstat"]

int_class_all_df = pd.concat([int_class_df, int_class_rev_df])
int_class_all_df["Xstat_Ystat"] = int_class_all_df.Xstat + "_" + int_class_all_df.Ystat

int_class_all_df.to_csv(mpo_path + 'stress_int_classification_extended.txt', sep="\t", index=None)

int_class_count = int_class_all_df.groupby(["Xname", "Xstat_Ystat", "XYstat"]).count()["Yname"]
int_class_count.to_csv(mpo_path + 'stress_int_classification_count.txt', sep="\t")

#%% For collated plot, separate norm
int_class_all_df["all_stat"] = int_class_all_df.Xstat + "_" + int_class_all_df.Ystat + "_" + int_class_all_df.XYstat
collapsed_count = int_class_all_df.groupby(["Xname", "all_stat"]).count()["Yname"]
all_unstack = collapsed_count.unstack()

#%% Plotting
fig, ax = plt.subplots(9, 1, sharex=True, constrained_layout=True, figsize=(8,10))
axes = ax.flatten()

for i in range(0,27,3):
	subset = all_unstack.iloc[:,i:i+3]
	total_int_stress = subset.sum(axis=1)
	all_unstack_norm = subset.apply(lambda x: x/total_int_stress)
	subset = subset.T
	all_unstack_norm = all_unstack_norm.T
	sns.heatmap(all_unstack_norm, cmap="Blues", annot=subset, fmt=",.0f", ax=axes[int(i/3)])
for axis in range(8):
	axes[axis].set_xlabel('')
for axis in range(9):
	axes[axis].set_ylabel('')
plt.savefig(dir_path + 'figures/' + 'Fig6C.png', dpi=600)

### Figure 6D: Linear regression of all experiments
Note: package not compatible to colab, please run locally

In [None]:
# Figure 6D
import matplotlib.pyplot as plt
import numpy as np
import sklearn.linear_model
import pandas as pd
from sklearn.utils import shuffle
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import seaborn as sns

# directory paths
deseq_dir = mpo_path
linreg_dir = deseq_dir + "linreg/"
linreg_path_safe = linreg_dir.replace(' ', '\ ')
!mkdir $linreg_path_safe

# data paths
avg_path = deseq_dir + "stress_int_class_l2fc_avg.txt"
norm_path = deseq_dir + "stress_int_class_l2fc_allres.txt"
extended_path = deseq_dir + "stress_int_class_l2fc_extended_allres.txt"

#%% for averaged linreg
avg_df = pd.read_csv(avg_path, sep="\t", header=[0,1], index_col = 0)
stress_col = avg_df.columns.get_level_values(0).unique().tolist()
avg_flat = pd.concat([avg_df[x] for x in stress_col])
avg_flat_shuffled = shuffle(avg_flat)

X_train = avg_flat_shuffled[["Xl2fc", "Yl2fc"]].values
Y_train = avg_flat_shuffled["XYl2fc"].values

# Training
model = sklearn.linear_model.LinearRegression()
model.fit(X_train, Y_train)

#print('Score R2:', model.score(xtrain, ytrain))
coefs = model.coef_
intercept = model.intercept_

Y_pred = model.predict(X_train)
MAE = round(mean_absolute_error(Y_train, Y_pred), 2)
RMSE = round(mean_squared_error(Y_train, Y_pred, squared=False), 2)
r2 = round(r2_score(Y_train, Y_pred), 2)

minx, maxx = avg_flat_shuffled.Xl2fc.min()-1, avg_flat_shuffled.Xl2fc.max()+1
miny, maxy = avg_flat_shuffled.Yl2fc.min()-1, avg_flat_shuffled.Yl2fc.max()+1

# for mesh
xs = np.tile(np.arange(minx, maxx), (len(np.arange(miny, maxy)),1))
ys = np.tile(np.arange(miny, maxy), (len(np.arange(minx, maxx)),1)).T
zs = xs*coefs[0]+ys*coefs[1]+intercept

# Plot
fig = plt.figure(constrained_layout=True)
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel("Sx")
ax.set_ylabel("Sy")
ax.set_zlabel("Sxy")

ax.plot_surface(xs,ys,zs, alpha=0.5)
ax.stem(X_train[:,0], X_train[:,1], Y_pred, bottom=-6,
		basefmt=" ", markerfmt="C1o", linefmt="C1-")
eqn_str = "Sxy = {:.2f} + {:.2f} Sx + {:.2f} Sy".format(intercept, coefs[0], coefs[1])
err_str = "MAE: {}    RMSE: {}   R".format(MAE, RMSE) + u"\u00b2" + ": {}".format(r2)

z_bound = ax.get_zbound()
x_lower_bound = ax.get_xbound()[0]
y_bound = ax.get_ybound()
z_bound_dist = z_bound[1] - z_bound[0]
ax.text(x=x_lower_bound - x_lower_bound/4, y=(y_bound[1] - y_bound[0])/2, z = z_bound[1] + z_bound_dist/6, s=eqn_str)
ax.text(x=x_lower_bound - x_lower_bound/4, y=(y_bound[1] - y_bound[0])/2, z = z_bound[1] + 0.7, s=err_str)
plt.savefig(dir_path + 'figures/' + 'Fig6D.png', dpi=600, bbox_inches="tight")
plt.show()

### Figure 7 A-H: Linear regression by stress

In [None]:
#%% Plotting
def plot_3d(X_test, Y_test, title, mae, rmse, r2, xs, ys, zs, coefs, intercept):
	fig = plt.figure(constrained_layout=True)
	ax = fig.add_subplot(111, projection='3d')
	ax.set_xlabel("Sx")
	ax.set_ylabel("Sy")
	ax.set_zlabel("Sxy")
	#plt.ylim((-12,12))
	#plt.xlim((-12,12))
	
	ax.plot_surface(xs,ys,zs, alpha=0.5)
	plane_zmin = min([y for x in zs.tolist() for y in x])
	if plane_zmin < min(Y_test):
		stem_bottom = plane_zmin
	else:
		stem_bottom = min(Y_test)
	
	xvals = X_test[:,0]
	yvals = X_test[:,1]
	zvals = Y_test
	for i in range(len(xvals)):
		ax.plot([xvals[i], xvals[i]], [yvals[i], yvals[i]], [stem_bottom, zvals[i]], 
		  '-', linewidth=1, color='darkgray', alpha=.1)

	# plotting a circle on the top of each stem
	ax.plot(xvals, yvals, zvals, 'o', markersize=4, color='orange',label='ib', alpha=.1)
	eqn_str = "Sxy = {:.2f} + {:.2f} Sx + {:.2f} Sy".format(intercept, coefs[0], coefs[1])
	err_str = "MAE: {}    RMSE: {}   R".format(mae, rmse) + u"\u00b2" + ": {}".format(r2)
	
	z_bound = ax.get_zbound()
	x_lower_bound = ax.get_xbound()[0]
	y_bound = ax.get_ybound()
	z_bound_dist = z_bound[1] - z_bound[0]
	# https://glowingpython.blogspot.com/2012/12/3d-stem-plot.html
	ax.text(x=x_lower_bound - x_lower_bound/4, y=(y_bound[1] - y_bound[0])/3, z = z_bound[1] + z_bound_dist/8, s=eqn_str)
	ax.text(x=x_lower_bound - x_lower_bound/4, y=(y_bound[1] - y_bound[0])/3, z = z_bound[1], s=err_str)
	plt.savefig(dir_path + 'figures/' + "Fig7_" + title + ".png", dpi=600, bbox_inches="tight")
	plt.show()

#%% Process and plot
stresses = ["C", "S", "M", "H", "D", "L", "N"]
stresses_long = ["Cold", "Salt", "Mannitol", "Heat", "Dark", "Light", "Nitrogen\ndeficiency"]

data_df = pd.read_csv(extended_path, header=0, sep="\t")
sub_len = [len(data_df[data_df.Xname == x]) for x in stresses]
def process_data(data_type, idx, permu, j=None):
	title = stresses_long[idx]
	if permu == False:
		shuffled_df = shuffle(data_df[data_df.Xname == data_type])
	else:
		if  j % 1000 == 0:
			print("Doing " + data_type)
		elif j+1 % 100 == 0:
			print("\t{}/1000".format(j+1))
		shuffled_df = shuffle(data_df).iloc[:sub_len[idx]]
	
	# Data preparation
	xtrain = shuffled_df[["Xl2fc", "Yl2fc"]].values
	ytrain = shuffled_df["XYl2fc"].values
	
	# Training
	model = sklearn.linear_model.LinearRegression()
	model.fit(xtrain, ytrain)

	coefs = model.coef_
	intercept = model.intercept_
	
	Y_pred = model.predict(xtrain)
	MAE = round(mean_absolute_error(ytrain, Y_pred), 2)
	RMSE = round(mean_squared_error(ytrain, Y_pred, squared=False), 2)
	r2 = round(r2_score(ytrain, Y_pred), 2)
	
	if permu == False:
		minx, maxx = -12, 12
		miny, maxy = -12, 12
		
        # for mesh
		xs = np.tile(np.arange(minx, maxx), (len(np.arange(miny, maxy)),1))
		ys = np.tile(np.arange(miny, maxy), (len(np.arange(minx, maxx)),1)).T
		zs = xs*coefs[0]+ys*coefs[1]+intercept
		
		plot_sub = shuffled_df[((shuffled_df.Xl2fc < 12) & (shuffled_df.Xl2fc > -12)) & ((shuffled_df.Yl2fc < 12) & (shuffled_df.Yl2fc > -12))]
		plot_x = plot_sub[["Xl2fc", "Yl2fc"]].values
		plot_y = plot_sub["XYl2fc"].values
		plot_3d(plot_x, plot_y, title, MAE, RMSE, r2, xs, ys, zs, coefs, intercept)
	
	return [coefs[0], coefs[1], intercept, MAE, RMSE, r2]

#%% Actual plotting
ori_out = []
s_start = 0
for i, item in enumerate(stresses[s_start:], s_start):
	print("Doing " + item)
	ori_out.append(process_data(item, i, False))

ori_df = pd.DataFrame.from_dict({"yval" : [x[0] for x in ori_out] + [x[1] for x in ori_out] + [x[-1] for x in ori_out], "Legend": ["Sx"]*len(ori_out) + ["Sy"]*len(ori_out) + ["R" +  u"\u00b2"]*len(ori_out), "stress": stresses_long + stresses_long + stresses_long})
#%%
colors = ["royalblue", "crimson", "lightsteelblue"]
col_pal = sns.color_palette(colors)
sns.set(font_scale=1.25)
sns.catplot(data=ori_df, kind="bar",  x="stress", y="yval", hue="Legend", aspect=1.7, palette=col_pal)
plt.ylabel("Magnitude")
plt.xlabel("Sx")
plt.savefig(dir_path + 'figures/' + "Fig7H.png", dpi=600)