<a href="https://colab.research.google.com/github/tqiaowen/marchantia-stress/blob/main/marchantia_stress.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Code to replicate analysis

## Chapters
### 1. Miscellaneous preparation steps
1. Install non-default modules and upgrade modules
1. Mount Google Drive
1. Set paths (point to google drive folder to work in)
1. Download files (initial set up, skip if continuing)
1. Import modules, initialise paths

### 2. Generating data for analysis
1. Differential Gene Expression (DESeq2)
1. Supp. Fig 4: Comparison of DEGs between two controls
1. Diurnal Gene Expression (JTK_cycle)

### 3. Analysis and plotting
1. Figure 1 & Supp. Fig 1: Measurements and Student's t-test
1. Figure 2: Interspecies comparison (Biological processes)
1. Supp. Fig 5: Interspecies comparison (Gene families)
1. Figure 3: Stress responsiveness
1. Figure 4: Upset plot and summary of DEGs in Marchantia
1. Figure 5, Supp. Fig. 6 & 7: Inter-stress (Marchantia only) comparison
1. Figure 6: Diurnal gene expression
1. Supp. Fig 2: QC of RNA-seq data
1. Supp. Fig 3: Volcano plots (DESeq2)

1. Supp. Fig 8: Overview of diurnal data

### 4. Experimental
1. Download RNA-seq data
1. Mapping and generating expression matrix

# 1. Miscellaneous preparation steps

### 1.1 Install non-default modules and upgrade modules

In [None]:
# install non-default colab modules
# Restart runtime after installation and skip to next step
!pip install upsetplot
!pip install matplotlib --upgrade

### 1.2 Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
!rm -rf /content/sample_data

In [None]:
#@title 1.3 Set path {display-mode: "form"}

#@markdown Enter the path of the directory you want to work in.

drive_path = '/content/gdrive/My Drive/' #@param {type: 'string'}

### 1.4 Download files (first time only)

In [None]:
# Downloads necessary files to perform analyses [only need to be done once]
# https://gist.github.com/iamtekeste/3cdfd0366ebfd2c0d805 download raw files directtly from Google Drive
!wget --no-check-certificate -r "https://drive.google.com/uc?id=1cbKgWbEWtstl_2_rb06tI_D-vseprPnT&export=download" -O marchantia_stress.zip

dir_path = drive_path + 'marchantia_stress/'
dir_path_safe = dir_path.replace(' ', '\ ')
!mkdir $dir_path_safe
!unzip marchantia_stress.zip -d $dir_path_safe

###1.5 Import modules, set paths

In [None]:
# import modules
import os
import string
%load_ext rpy2.ipython
import pandas as pd
import math
from matplotlib_venn import venn2
from matplotlib import pyplot as plt
from ast import literal_eval
import seaborn as sns
from collections import Counter
import random
from statsmodels.stats.multitest import multipletests
import numpy as np
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

dir_path = drive_path + 'marchantia_stress/'
dir_path_safe = dir_path.replace(' ', '\ ')

# 2. Generating data for analysis

### 2.1 Differential Gene Expression

In [None]:
# Making necessary directories for outputs
mpo_path_safe = dir_path_safe + "prep_files/mpo/deseq/"
osa_path_safe = dir_path_safe + "prep_files/osa/deseq/"
if not os.path.exists(mpo_path_safe):
    !mkdir -p $mpo_path
    print("Directories made: " + mpo_path.replace('\\', ''))
if not os.path.exists(osa_path_safe):
    !mkdir -p $osa_path
    print("Directories made: " + osa_path.replace('\\', ''))

In [None]:
# Installing DESeq2
%%R
if (!requireNamespace("BiocManager", quietly = TRUE))
    install.packages("BiocManager")

BiocManager::install("DESeq2", ask = FALSE)

In [None]:
# To pull python variables
mpo_path = dir_path + "prep_files/mpo/deseq/"
osa_path = dir_path + "prep_files/osa/deseq/"

%R -i dir_path
%Rget dir_path

%R -i dir_path_safe
%Rget dir_path_safe

%R -i mpo_path
%Rget mpo_path

%R -i osa_path
%Rget osa_path

In [None]:
# adapted from DESeq2_stressonly_phase1n2.R
%%R
# DESeq2 (Marchantia)
library('DESeq2')
library('RColorBrewer')

sink(paste0(mpo_path, "phase1n2_sum.txt"), type="output")

raw_counts <- read.table(file = paste0(dir_path, 'prep_files/all_stress_raw.tsv'), sep = '\t', header = TRUE)
raw_counts <- data.frame(raw_counts, row.names = 1)

stresses <- c("controlH2", "controlD2", "H", "C", "M", "S", "L", "D", "N",
              "HS", "HM", "HN", "CS", "CM", "CN", "SM", "ML", "NL", "MN",
              "SD", "MD", "ND", "HD", "CD", "CL", "LS", "SN")
colData = read.csv(paste0(dir_path, 'summary_files/all_stress.txt'), sep = '\t', row.names=1, header = FALSE)
names(colData) <- c('condition')

dds = DESeqDataSetFromMatrix(countData=raw_counts,
                             colData=colData,
                             design=~condition)
dds = DESeq(dds)

y = 2
for (x in 1:2){
  for (i in y:length(stresses)){
    if (stresses[i] != stresses[x]){
      res = results(dds, contrast=c("condition", stresses[i], stresses[x]))
      res = res[order(res$pvalue),]
      resSig = subset(res, res$padj < 0.05 & abs(res$log2FoldChange) > 1)
      resSig = resSig[ order(resSig$padj), ]
      print(paste(stresses[i], 'vs', stresses[x]))
      summary(res)
      summary(resSig)
      write.table(as.data.frame(res), file=paste(mpo_path, stresses[i], stresses[x], '_res.tsv', sep = ''),
                  quote=FALSE, sep='\t', col.names = NA)
      write.table(as.data.frame(resSig), file=paste(mpo_path, stresses[i], stresses[x], '_resSig.tsv', sep = ''),
                  quote=FALSE, sep='\t', col.names = NA)
    }
  }
  y = y + 1
}

In [None]:
%%R
# DESeq2 (Rice) adapted from DESeq2_Osa.R

library('DESeq2')
library('RColorBrewer')

raw_counts <- read.table(file = paste0(dir_path, 'prep_files/expmat_Osa_raw.tsv'), sep = '\t', header = TRUE)
raw_counts <- data.frame(raw_counts, row.names = 1)

annotations <- c("1913_salt", "1913_control",
              "5941_cold", "5941_control",
              "ERP003982_salt", "ERP003982_control",
              "GSE57950_drought", "GSE57950_control")

md1913 = c("1913_salt",
           "1913_control",
           "1913_control",
           "1913_salt",
           "1913_salt",
           "1913_control")
md5941 = c(rep(c("5941_control"), 3),
           rep(c("5941_cold"), 3),
           rep(c("5941_control"), 3),
           rep(c("5941_cold"), 3))
md3982 = c("ERP003982_salt",
           "ERP003982_control",
           rep(c("ERP003982_salt"), 2),
           rep(c("ERP003982_control"), 2))
md57950 = c(rep(c("GSE57950_control"), 6),
             rep(c("GSE57950_drought"), 6))

mdlist <- list(md1913, md5941, md3982, md57950)

for (i in seq(1,length(annotations), by = 2)){
  df = raw_counts[, grep(strsplit(annotations[i], "_")[[1]][1], names(raw_counts))]
  sampleMetaData <- data.frame(condition = mdlist[[i - (i-1)/2]])
  dds = DESeqDataSetFromMatrix(countData=df,
                               colData=sampleMetaData,
                               design=~condition)
  dds = DESeq(dds)
  res = results(dds, contrast=c("condition", annotations[i], annotations[i+1]))
  res = res[order(res$pvalue),]
  resSig = subset(res, res$padj < 0.05 & abs(res$log2FoldChange) > 1)
  resSig = resSig[ order(resSig$padj), ]
  print(paste(annotations[i], 'vs', annotations[i+1]))
  summary(res)
  summary(resSig)
  write.table(as.data.frame(res), file=paste(osa_path, annotations[i], strsplit(annotations[i+1], "_")[[1]][2], '_res.tsv', sep = ''),
              quote=FALSE, sep='\t', col.names = NA)
  write.table(as.data.frame(resSig), file=paste(osa_path, annotations[i], strsplit(annotations[i+1], "_")[[1]][2], '_resSig.tsv', sep = ''),
              quote=FALSE, sep='\t', col.names = NA)
}

### 2.2 Supp. Fig 4: Comparison of DEGs between two controls

In [None]:
wdir = dir_path + 'prep_files/mpo/deseq/'
deseqouts = [x for x in os.listdir(wdir) if "resSig.tsv" in x]
deseqouts.remove('controlD2controlH2_resSig.tsv')

# create subplots
xlen = 4
ylen = math.ceil(len(deseqouts)/8)
figw = xlen * 4
figh = ylen * 2.5

stress_list = ['H', 'C', 'HM', 'CM', 'M', 'CL', 'ML', 'L', 'HS', 'CS', 'SM', 'LS', 'S',
			   'HN', 'CN', 'MN', 'NL', 'SN', 'N', 'HD', 'CD', 'MD', 'SD', 'ND', 'D']
f_axes = string.ascii_uppercase[:len(stress_list)]
axd = plt.figure(constrained_layout=True,
				 figsize=(figw, figh)).subplot_mosaic(
					 """
					 A......
					 .B.....
					 CDE....
					 .FGH...
					 IJKLM..
					 NOPQRS.
					 TUV.WXY
					 """,
					 gridspec_kw = {'hspace' : 0.3}
					 )

counter = 0
for i in stress_list:
	files = [x for x in deseqouts if x.startswith(i+'control')]
	files.sort()
	fileD2 = pd.read_csv(wdir + files[0],
					  sep = "\t", header = 0, index_col = 0)
	fileH2 = pd.read_csv(wdir + files[1],
					  sep = "\t", header = 0, index_col = 0)
	D2index, H2index = set(fileD2.index.tolist()), set(fileH2.index.tolist())
	status = []
	
	# Create sets
	D2only = D2index - H2index
	H2only = H2index - D2index
	D2H2 = D2index & H2index
	
	# Subsets of df for D2only, H2only and D2H2
	D2onlydf = fileD2[fileD2.index.isin(list(D2only))]
	H2onlydf = fileH2[fileH2.index.isin(list(H2only))]
	D2H2df = fileD2.append(fileH2)
	D2H2df = D2H2df[D2H2df.index.isin(list(D2H2))]
	D2H2df.sort_index(inplace=True)
	
	# Create list of lists of all differentially expressed genes with corresponding status
	stress = "Mpo_" + files[0].split("controlD2")[0]
	for j in D2only:
		status.append([j, stress, 'D2', str(fileD2.loc[j, "log2FoldChange"]), str(fileD2.loc[j, "padj"]), "N/A", "N/A"])
	for k in H2only:
		status.append([k, stress, 'H2', "N/A", "N/A", str(fileH2.loc[k, "log2FoldChange"]), str(fileH2.loc[k, "padj"])])
	for m in D2H2:
		status.append([m, stress, 'D2H2', str(fileD2.loc[m, "log2FoldChange"]), str(fileD2.loc[m, "padj"]), str(fileH2.loc[m, "log2FoldChange"]), str(fileH2.loc[m, "padj"])])
	status.sort()
	
	# Output file with genes and status: D2, H2 or D2H2
	with open(wdir + "sets/results/" + stress + ".txt", "w+") as filo:
		filo.write("gene\tstress\tstatus\tL2FC_D2\tpadj_D2\tL2FC_H2\tpadj_H2\n")
		for n in status:
			filo.write("\t".join(n) + "\n")
			
	# Venn diagram
	sp_ax = axd[f_axes[stress_list.index(stress.split('_')[1])]]
	venn2(subsets=(len(D2only), len(H2only), len(D2H2)),
	   set_labels = ('D2', 'H2'),
	   ax = sp_ax)
	sp_ax.set_title(stress.split('_')[1], size=14)
	plt.savefig(dir_path + "figures/" + "suppfig4.png", dpi = 600)
	
	# increase counter
	counter += 1

In [None]:
# Sort genes according to whether they are the same in both controls
# adapted from compile_sigGenes_phase1n2.py
deseqdir = dir_path + 'prep_files/mpo/deseq/'
setdir = deseqdir + 'sets/results/'
setdir_safe = setdir.replace(' ', '\ ')
if not os.path.exists(setdir):
	!mkdir -p $setdir_safe

merfile = open(dir_path + 'mercator/MpoProt.results.txt', 'r')
ofile = open(deseqdir + 'resSig_compiled.txt', 'w+')
efile = open(deseqdir + 'resSig_failed.txt', 'w+')
setfiles = [x for x in os.listdir(setdir) if '.txt' in x]

ofile.write("\t".join(['gene', 'stress', 'L2FC_D2', 'L2FC_H2', 'annotation']) + "\n")
efile.write("\t".join(['gene', 'stress', 'L2FC_D2', 'L2FC_H2', 'annotation']) + "\n")

def get_anno(gene):
	gene = gene.lower()
	return meranno[gene]

def up_down(val):
	if val < 0:
		stat = "DOWN"
	elif val > 0:
		stat = "UP"
	elif math.isnan(val):
		stat = "NaN"
	return stat

meranno = {}

for line in merfile:
	if len(line.rstrip().split("\t")) == 5:
		bincode, name, identifier, desc, ptype = line.rstrip().replace("'", "").split("\t")
		if identifier not in meranno:
			meranno[identifier] = [[bincode, desc]]
		else:
			meranno[identifier].append([bincode, desc])

for i in setfiles:
	content = pd.read_csv(setdir + i, sep = "\t", header = 0)
	sigGenes = content[content['status'] == "D2H2"]
	sigGenes['annotation'] = sigGenes['gene'].apply(get_anno)
	sigGenes['L2FC_D2'] = sigGenes['L2FC_D2'].apply(up_down)
	sigGenes['L2FC_H2'] = sigGenes['L2FC_H2'].apply(up_down)
	sigGenes = sigGenes.drop(columns = ['status', 'padj_D2', 'padj_H2'])
	for index, row in sigGenes.iterrows():
		if row['L2FC_D2'] != row['L2FC_H2']:
			efile.write("\t".join([str(z) for z in row]) + "\n")
		else:
			ofile.write("\t".join([str(z) for z in row]) + "\n")

ofile.close()
efile.close()

### 2.3 Diurnal gene expression

In [None]:
jtk_dir = dir_path_safe + "JTK/"
%cd $jtk_dir

In [None]:
%%R
# https://towardsdatascience.com/how-to-install-packages-in-r-google-colab-423e8928cd2e
#system(paste("cd", paste0(dir_path_safe, "JTK/")))
source("JTK_CYCLEv3.1.R")

project <- "Mpo_JTK"

options(stringsAsFactors=FALSE)
annot <- read.delim("annot_diur.txt")
data <- read.delim("expmat_diur.txt")

rownames(data) <- data[,1]
data <- data[,-1]
jtkdist(6, 3)       # 6 total time points, 3 replicates per time point

periods <- 6:6       # looking for rhythms between 0-23 hours (i.e. between 1 and 6 time points per cycle).
jtk.init(periods,4)  # 4 is the number of hours between time points

cat("JTK analysis started on",date(),"\n")
flush.console()

st <- system.time({
  res <- apply(data,1,function(z) {
    jtkx(z)
    c(JTK.ADJP,JTK.PERIOD,JTK.LAG,JTK.AMP)
  })
  res <- as.data.frame(t(res))
  bhq <- p.adjust(unlist(res[,1]),"BH")
  res <- cbind(bhq,res)
  colnames(res) <- c("BH.Q","ADJ.P","PER","LAG","AMP")
  results <- cbind(annot,res,data)
  results <- results[order(res$ADJ.P,-res$AMP),]
})
print(st)

save(results,file=paste("JTK",project,"rda",sep="."))
write.table(results,file=paste("JTK",project,"txt",sep="."),row.names=F,col.names=T,quote=F,sep="\t")

In [None]:
%cd $dir_path_safe

In [None]:
# adapted from clean_Mpo.py
# To prepare and format Mpo JTK results to Camilla supp standard.
# NR -- Not rhythmic genes ADJ.P <0.05
# NE -- row[1:].max() > 1; no expression of TPM > 1 across all timepoints and replicates

expanno = dir_path + "summary_files/diurnal_exp.txt"

# label conversion to experiment annotation
annodict = {}
with open(expanno, "r") as expannof:
	content = expannof.readlines()
	for line in content:
		label, actual = line.strip().split("\t")
		annodict[label] = actual + '_' + label.split('_')[1]

diurlabels = ["gene"]
diurlabels.extend(list(annodict.keys()))

# select only diurnal experiments
diurexpmat = dir_path + 'prep_files/diurnal_exp.tsv'
diuronly = pd.read_csv(diurexpmat, sep='\t')
mpogenes = diuronly.gene.to_list()
diuronly.set_index("gene", inplace=True)
diuronly.columns = [annodict[x] for x in diuronly.columns.to_list()]

# prepping annotation file for JTK_Cycle/supp.
meranno = {}
merp = dir_path + 'mercator/MpoProt.results.txt'
merfile = open(merp, 'r')
merfile.readline()
for line in merfile:
	if len(line.rstrip().split("\t")) == 5:
		bincode, name, identifier, desc, ptype = line.rstrip().replace("'", "").split("\t")
		if identifier not in meranno:
			meranno[identifier] = name

# JTK output with expmat_diur
JTKout = pd.read_csv(dir_path + "JTK/JTK.Mpo_JTK.txt", sep = "\t")
JTKgenes = JTKout.Probe.to_list()
NEgenes = [x for x in mpogenes if x not in JTKgenes]

# for supp
colused = list(JTKout.columns)
colunwanted = ['BH.Q', 'PER','AMP']
for i in colunwanted:
	colused.remove(i)

# to format it to similar format of Camilla's supp material
forsupp = JTKout[colused]
# defnitions to change not significantly rthymic genes to NR instead of default output values
def NRcheck(num):
	if num >= 0.05:
		return "NR"
	else:
		return "{:.2E}".format(num)
forsupp['ADJ.P'] = forsupp['ADJ.P'].apply(lambda x: NRcheck(x))

def phaseCheck(adjval, lagval):
	if adjval == "NR":
		newval = "NR"
	else:
		newval = lagval + 2
		if newval >= 24:
			newval = newval - 24
	return newval

forsupp["LAG"] = forsupp.apply(lambda row: phaseCheck(row["ADJ.P"], row["LAG"]), axis = 1)

# to format genes that are not expressed (NE) and excluded in JTK analysis to supp file
NEcollect = {}
for j in NEgenes:
	NEcollect[j] = [meranno[j.lower()], "NE", "NE"] + diuronly.loc[j, colused[4:]].to_list()

NEdf = pd.DataFrame(NEcollect, index = colused[1:])
NEdf = NEdf.transpose()
NEdf.reset_index(inplace=True)
NEdf.columns = colused
# combine formatted JTK output and NE genes
combined = forsupp.append(NEdf, ignore_index = True)
combined.sort_values("Probe", inplace=True, ignore_index = True)
# write to directory and ready for use (for analysis)
cleaned = dir_path + "diurnal/"
combined.to_csv(cleaned + "Mpo_supp.txt", index = False, sep = "\t")

# 3. Analysis and plotting

### Figure 1 & Supp. Fig 1: Measurements and Student's t-test

In [None]:
# adated from measurements_forsupp.py
wdir = dir_path + 'prep_files/'
odir = dir_path + 'figures/'
odir_safe = dir_path_safe + 'figures/'
if not os.path.exists(odir):
    !mkdir $odir_safe
infile = 'phase1n2_measurements_nooutliers.txt'

measurements = pd.read_csv(wdir + infile, sep='\t')
# Single letter to full single stress description
singled = {'C':'Cold',
		   'H':'Heat',
		   'S':'Salt',
		   'M':'Mannitol',
		   'L':'Light',
		   'D':'Dark',
		   'N':'Nitrogen'
		   }

# For plotting all controls
areatype = ['Parea', 'Earea']
titletype = ['15', '21']
for i in range(0,2):
	area = areatype[i]
	title = titletype[i]
	control_m = measurements[measurements.Stress == 'Control'][["Batch", area]].groupby('Batch', sort = False).mean()
	control_s = measurements[measurements.Stress == 'Control'][["Batch", area]].groupby('Batch', sort = False).std()
	control_m.plot.bar(yerr=[list(control_s[area]), list(control_s[area])[::-1]], legend=False, title='Control (Day '+ title + ')', capsize=4)
	plt.savefig(odir + 'Control_Day' + title + '.png', dpi = 600, bbox_inches='tight')
	plt.show()

# df with only single stress measurements
ss_meas = measurements[(measurements.Condition != 'None') & (measurements.Condition != 'mixed')]
# df with only crossed stress measurements
cs_meas = measurements[measurements.Condition == 'mixed']
# df with only controls
control_meas = measurements[measurements.Stress == 'Control']

controltype = ['Stress']
controltitle = ['_merged']

xaxislabel = {'Heat': 'Temperature (\u00B0C)',
			  'Cold': 'Temperature (\u00B0C)',
			  'Mannitol': 'Mannitol (mM)',
			  'Salt': 'NaCl (mM)',
			  'Light': 'Light intensity (\u03bcEm\u207b\u00b2s\u207b\u00b9)',
			  'Dark': 'Days',
			  'Nitrogen': 'KNO\u2083 (%)'}

# Supp. Fig. 1
# t-test (control as b, following test, a)
tout = open(wdir + 'ttest.txt', 'w+') 
from scipy import stats as st

for ss in list(ss_meas.Stress.unique()):
	for i, c in enumerate(controltype):
		control_batches = ss_meas[ss_meas.Stress == ss].Batch.unique()
		control_mean = control_meas[control_meas.Batch.isin(control_batches)].groupby(c, sort = False).mean()
		control_std = control_meas[control_meas.Batch.isin(control_batches)].groupby(c, sort = False).std()
		stress_mean = ss_meas[ss_meas.Stress == ss].groupby('Condition', sort=False).mean()
		stress_std = ss_meas[ss_meas.Stress == ss].groupby('Condition', sort=False).std()
		
		if c == 'Stress': # t-test
			control_df = control_meas[control_meas.Batch.isin(control_batches)][['Parea', 'Earea']]
			stress_conds = ss_meas[ss_meas.Stress == ss].Condition.unique()
			for k, a in enumerate(areatype):
				for scond in stress_conds:
					stress_df = ss_meas[(ss_meas.Stress == ss) & (ss_meas.Condition == scond)]
					tstat, pval = st.ttest_ind(stress_df[a], control_df[a])
					tout.write(('\t').join(['Day '+ titletype[k], ss + '_' + scond, 'control_merged', str(tstat), str(pval)]) + "\n")
		
		labels = control_mean.index.to_list() + stress_mean.index.to_list()
		for j, a in enumerate(areatype): # Day 15 or 21 area
			coll_mean = list(control_mean[a]) + list(stress_mean[a])
			coll_std = list(control_std[a]) + list(stress_std[a])
			plt.bar(labels, coll_mean, yerr = coll_std, capsize=4)
			plt.title(ss + ' (Day ' + titletype[j] + ')')
			plt.xlabel(xaxislabel[ss]) 
			plt.ylabel('Area (mm\u00b2)')
			plt.savefig(odir + ss + '_Day' + titletype[j] + controltitle[i] + '.png', dpi = 600, bbox_inches='tight')
			plt.show()

# cross_stress plot
for i, c in enumerate(controltype):
	cs_control_batches = cs_meas.Batch.unique()
	cs_control_mean = control_meas[control_meas.Batch.isin(cs_control_batches)].groupby(c, sort = False).mean()
	cs_control_std = control_meas[control_meas.Batch.isin(cs_control_batches)].groupby(c, sort = False).std()
	cs_stress_mean = cs_meas.groupby('Stress', sort=False).mean()
	cs_stress_std = cs_meas.groupby('Stress', sort=False).std()
	cs_labels = cs_control_mean.index.to_list() + cs_stress_mean.index.to_list()
	
	if c == 'Stress': #t-test
		control_df = control_meas[control_meas.Batch.isin(cs_control_batches)][['Parea', 'Earea']]
		for k, a in enumerate(areatype):
			for ss in list(cs_meas.Stress.unique()):
				stress_df = cs_meas[(cs_meas.Stress == ss)]
				tstat, pval = st.ttest_ind(stress_df[a], control_df[a])
				tout.write(('\t').join(['Day '+ titletype[k], ss + '_' + 'mixed', 'control_merged', str(tstat), str(pval)]) + "\n")
					
	for j, a in enumerate(areatype): # Day 15 or 21 area
			coll_mean = list(cs_control_mean[a]) + list(cs_stress_mean[a])
			coll_std = list(cs_control_std[a]) + list(cs_stress_std[a])
			plt.bar(cs_labels, coll_mean, yerr = coll_std, capsize=4)
			plt.title('Cross stress (Day ' + titletype[j] + ')')
			plt.xticks(rotation=90)
			plt.xlabel('Experiment') 
			plt.ylabel('Area (mm\u00b2)')
			plt.savefig(odir + 'Cross_stress_Day' + titletype[j] + controltitle[i] + '.png', dpi = 600, bbox_inches='tight')
			plt.show()
			
# single stress reps and cross stress (control - merged)
def ss_grab(stress, condition):
	"""
	Slice the relevant condition for 
	Parameters
	----------
	stress : string
		Stress of interest.
	condition : string
		Condition of interest.

	Returns
	-------
	sssub : dataframe
		df of single stress.

	"""
	sssub = measurements[(measurements.Stress == stress) & (measurements.Condition == condition)]
	return sssub

srep_keys = [['Cold', '3'],
			 ['Heat', '33'],
			 ['Salt', '40'],
			 ['Mannitol', '100'],
			 ['Light', '435'],
			 ['Dark', '3'],
			 ['Nitrogen', '0']]
srepdf = measurements[(measurements.Stress == 'Cold') & (measurements.Condition == '3')]

for s, c in srep_keys[1:]:
	srepdf = pd.concat([srepdf, ss_grab(s, c)])
	
s_cs_meas = pd.concat([srepdf, cs_meas])
	
s_cs_control_batches = list(cs_meas.Batch.unique()) + list(srepdf.Batch.unique())
s_cs_control_mean = control_meas[control_meas.Batch.isin(s_cs_control_batches)].groupby('Stress', sort = False).mean()
s_cs_control_std = control_meas[control_meas.Batch.isin(s_cs_control_batches)].groupby('Stress', sort = False).std()
s_cs_stress_mean = s_cs_meas.groupby('Stress', sort=False).mean()
s_cs_stress_std = s_cs_meas.groupby('Stress', sort=False).std()

s_cs_meas_label = [x + ' (' + x[0] + ')' if len(x) > 2 else x for x in s_cs_stress_mean.index]
s_cs_meas_label[s_cs_meas_label.index('Light (L)')] = 'High light (L)'
s_cs_meas_label[s_cs_meas_label.index('Dark (D)')] = 'Darkness (D)'
s_cs_labels = s_cs_control_mean.index.to_list() + s_cs_meas_label

# Fig1
#t-test
control_df = control_meas[control_meas.Batch.isin(s_cs_control_batches)][['Parea', 'Earea']]
## cross-stress
for k, a in enumerate(areatype):
	for ss in list(cs_meas.Stress.unique()):
		stress_df = cs_meas[(cs_meas.Stress == ss)]
		for singleS in ss:
			singlecontrol = srepdf[srepdf.Stress == singled[singleS]][a]
			tstat, pval = st.ttest_ind(stress_df[a], singlecontrol)
			tout.write(('\t').join(['Day '+ titletype[k], ss + '_' + 'mixed', 'control_' + singled[singleS], str(tstat), str(pval)]) + "\n")
## single stress
for k, a in enumerate(areatype):
	for ss in list(srepdf.Stress.unique()):
		cond = srepdf[(srepdf.Stress == ss)].Condition.unique()[0]
		stress_df = srepdf[(srepdf.Stress == ss)][a]
		tstat, pval = st.ttest_ind(stress_df, control_df[a])
		tout.write(('\t').join(['Day '+ titletype[k], ss + '_' + cond, 'control', str(tstat), str(pval)]) + "\n")
		
tout.close()	
# plotting
colour_seq = ['tomato'] +  ['mediumseagreen']*7 + ['cornflowerblue']*20
for j, a in enumerate(areatype): # Day 15 or 21 area
		coll_mean = list(s_cs_control_mean[a]) + list(s_cs_stress_mean[a])
		coll_std = list(s_cs_control_std[a]) + list(s_cs_stress_std[a])
		plt.bar(s_cs_labels, coll_mean, yerr = coll_std, capsize=4, color = colour_seq)
		plt.title('Area (Day ' + titletype[j] + ')')
		plt.xticks(rotation=90)
		plt.xlabel('Experiment') 
		plt.ylabel('Area (mm\u00b2)')
		plt.savefig(odir + 'fig1_Day' + titletype[j] + '.png', dpi = 600, bbox_inches='tight')
		plt.show()

### Figure 2: Interspecies comparison (Biological processes)


In [None]:
# adapted from cross_spe_mapman.py
### FUNCTION ###
def anno_split(row):
	return int(row['annotation'][0][0].split(".")[0])
def bin_count(row):
	count = row['rel_count']
	if count >= 0.5:
		return 0.5
	elif count >= 0.25:
		return 0.35
	elif count > 0.0:
		return 0.2
	else:
		return 0
def label_color(xlabel):
	if "heat" in xlabel:
		return "firebrick"
	elif "cold" in xlabel:
		return "steelblue"
	elif "light" in xlabel:
		return "darkorange"
	elif "dark" in xlabel:
		return "black"
	elif "salt" in xlabel:
		return "rebeccapurple"
	elif "mannitol" in xlabel:
		return "mediumvioletred"
	elif "nitrogen" in xlabel:
		return "forestgreen"
	else:
		return "slategrey"

def species_color(xlabel):
	if "Ath" in xlabel:
		return "firebrick"
	elif "Cpa" in xlabel:
		return "steelblue"
	elif "Cre" in xlabel:
		return "darkorange"
	elif "Osa" in xlabel:
		return "rebeccapurple"
	elif "Mpo" in xlabel:
		return "forestgreen"

### PATHS ###
wdir = dir_path + 'prep_files/'
setres = wdir + 'Figure2_alldata_compiled_updated.txt'
jdir = wdir + 'proteomes/'
mdict = wdir + 'merdict.txt'

### DICTIONARY OF MERCATOR BINS ###
dicto = literal_eval(open(mdict, 'r').read())

### LOAD GENE PER SPECIES AND COUNT
# initialise species
spedicto = {'ARATH' : 'Ath',
			'CHLRE' : 'Cre',
			'CYAPA' : 'Cpa',
			'MARPO' : 'Mpo',
			'ORYSA' : 'Osa'}
species_list = list(spedicto.values())
		
# initialise gene count in species [for % of DGEs]
Gdicto = {}
# {"species" : ["gene1", "gene2"...]}
for pepfile in [x for x in os.listdir(jdir) if '.ini' not in x]:
	with open(jdir + pepfile, "r") as peppy:
		species, genes = pepfile.split('.fa')[0], []
		for lini in peppy:
			if '>' in lini:
				genes.append(lini.strip().split('>')[1])
		Gdicto[spedicto[species]] = len(genes)
		
# Get only Ath genes (for name conversion, mercator output)
athdict = {}
with open(jdir + "ARATH.fa", "r") as athgenes:
	for lini in athgenes:
		if '>' in lini:
			genename = lini.strip().split('>')[1]
			athdict[genename.lower()] = genename
			
athdict2 = {} # for name conversion (DGE table)
with open(jdir + "ARATH.fa", "r") as athgenes:
	for lini in athgenes:
		if '>' in lini:
			genename = lini.strip().split('>')[1]
			athdict2[genename.lower().capitalize()] = genename

### LOAD SIGNIFICANTLY DIFFERENTIAL GENE TABLE ###
sigtable = pd.read_csv(setres, sep='\t', header=0, index_col=0)
sigtable = sigtable.reset_index()
sigtable["gene"].replace(athdict2, inplace=True)
sigtable = sigtable.set_index("gene")

### DICTIONARY OF MERCATOR ANNOTATION ###
merdir = wdir + 'mercator_results/'
merlist = [x for x in os.listdir(merdir) if '.results.txt' in x]
# Read mercator annotations (list of lists) as lists instead of string
sigtable['annotation'] = sigtable['annotation'].apply(literal_eval)

meranno = {}
map2anno = {}
for i in merlist:
	sp = i.split("Prot")[0]
	merfile = open(merdir + i, 'r')
	merfile.readline()
	for line in merfile:
		linecon = line.rstrip().replace("'", "").split("\t")
		if len(linecon) == 5:
			bincode, name, identifier, desc, ptype = linecon
			if identifier not in meranno:
				meranno[identifier] = [sp, [bincode.split('.')[0]]]
			else:
				meranno[identifier][1].append(bincode.split('.')[0])
		if len(linecon[0].split('.')) == 2:
			map2anno[linecon[0]] = linecon[1]

merdf = pd.DataFrame.from_dict(meranno, orient = 'index', columns = ['species', 'code'])
merdf = merdf.reset_index()
merdf["index"].replace(athdict, inplace=True)
merdf = merdf.set_index("index")

all_s = list(set(sigtable.stress.to_list()))

def bin_collate(updown):
	'''
	To generate df of collated bins per stress

	Parameters
	----------
	updown : str
		choice of whether to construct for upregulated or downregulated genes.

	Returns
	-------
	None.

	'''
	dicto = {}
	for stress in all_s:
		genes = [x for x in sigtable[(sigtable.stress == stress) & (sigtable.L2FC_D2 == updown)].index.to_list()]
		beans = [y[0].split('.')[0] for x in sigtable[(sigtable.stress == stress) & (sigtable.L2FC_D2 == updown)].annotation.to_list() for y in x]
		dicto[stress] = [genes, beans]
	df_beans = pd.DataFrame.from_dict(dicto, orient='index', columns=['gene', 'bins'])
	return df_beans
	
def sig_df(df, sigcol, mapbins):
	"""
	Calculates and correct mapman bin enrichment p-value for all stresses
	Returns dataframe

	Parameters
	----------
	df : dataframe
		df containing genes and corresponding mapman bins of DEGs.
	sigcol : str
		column name to use for enrichment
	mapbins : list
		list of mapman annotation/bins to use

	Returns
	-------
	df_sig : dataframe
		df summarising enrichment (corrected p-value) for each mapman bin (row)
		and each stress (column).

	"""
	sig_sum = {}
	for s in all_s:
		s_count = Counter(df.loc[s][sigcol])
		valid_bins = list(s_count.keys()) # bins found in stress
		# initialise count dicitonary
		sig_count = {}
		for key in valid_bins:
			sig_count[key] = 1
		# initilaise values for mercator by species
		sp = s.split('_')[0]
		# random simulations
		simno = 1000
		for i in range(simno):
			shuffle = merdf[merdf.species == sp].code.to_list()
			random.shuffle(shuffle)
			sub = shuffle[:len(df.loc[s].gene)]
			sub_count = Counter([y for x in sub for y in x])
			for mapman in valid_bins:
				if sub_count[mapman] >= s_count[mapman]:
					sig_count[mapman] += 1
		# p-value calculation
		pval_coll = []
		for mapman in valid_bins:
			pval = sig_count[mapman]/simno
			# correction for pval > 1
			if pval <= 1:
				pval_coll.append(pval)
			else:
				pval_coll.append(float(round(pval)))
		
		# BH correction for multiple testing
		y = multipletests(pvals=pval_coll, alpha=0.05, method="fdr_bh")[1]
		all_bins_corr_pval = []
		for mapman in mapbins:
			if mapman in valid_bins:
				all_bins_corr_pval.append(y[valid_bins.index(mapman)])
			else:
				all_bins_corr_pval.append(np.nan)
		sig_sum[s] = all_bins_corr_pval

	df_sig = pd.DataFrame.from_dict(sig_sum, orient='index', columns=mapbins)
	return df_sig

def chunk(uval, dval):
	if math.isnan(uval) and math.isnan(dval):
		# not differentially regulated
		cat = 0
	elif  uval >= 0.05 and (dval >= 0.05 or math.isnan(dval)):
		# not enriched
		cat = 0
	elif  dval >= 0.05 and (uval >= 0.05 or math.isnan(uval)):
		# not enriched
		cat = 0
	elif  uval < 0.05 and dval < 0.05:
		# differentially up and downregulated in bin
		cat = 2
	elif dval < 0.05:
		# differentially downregualted
		cat = 1
	elif uval < 0.05:
		# differentially upregulated
		cat = 3
	return cat

# initialisation for enrichment
#mapbins = list(map2anno.keys()) #level 2
mapbins = [str(x) for x in list(dicto.keys())] # level 1

# segregate stress and associated genes and mapman bins into up and downregulated df respectively
df_U = bin_collate('UP')
df_D = bin_collate('DOWN')

# df of significance values
df_sig_U = sig_df(df_U, 'bins', mapbins)
df_sig_D = sig_df(df_D, 'bins', mapbins)

cat_dict = {}
for mapman in list(df_sig_U.columns):
	cat_col = []
	for stress in list(df_sig_U.index):
		uval, dval = df_sig_U.loc[stress, mapman], df_sig_D.loc[stress, mapman]
		cat_col.append(chunk(uval,dval))
	cat_dict[mapman] = cat_col

df_combined_sig = pd.DataFrame.from_dict(cat_dict, orient='index', columns=list(df_sig_U.index))
df_combined_sig = df_combined_sig.loc[df_combined_sig.max(axis=1) > 0,:] # remove bins w/o enrichment
df_combined_sig = df_combined_sig.loc[(df_combined_sig > 0).sum(axis=1) >2,:] # select for at least 2 enrichment
df_combined_sig = df_combined_sig.loc[:,df_combined_sig.max() > 0] # remove stresses w/o enrichment
	
def ji_cal(a, b):
	# jaccard index calculation
	return len(a&b) / len(a|b)

def jdistprep(df, axis):
	'''
	Convert df to sets (for calculation of JD of X axis)

	Parameters
	----------
	df : dataframe
		dataframe of categorical variables to be converted to sets.
	axis : int 
		axis to do sets on, 0 by column (default), 1 by row

	Returns
	-------
	dicto : dict
		dictionary containing list of column values.

	'''
	if axis == 1:
		df = df.T
	dxkeys = df.columns.to_list()
	dykeys = df.index.to_list()
	dicto = {}
	for col in dxkeys:
		dicto[col] = [dykeys[i] + '_' + str(x) for i, x in enumerate(df[col].to_list())]
	return [dicto, dxkeys]

def jdist(df, axis=0):
	'''
	Construct jaccard distance square matrix

	Parameters
	----------
	df : df
		dataframe to be used for jiprep/ jdist calculation.
	axis : int
		axis to do sets on, 0 by column (default), 1 by row

	Returns
	-------
	linkage_matrix : list
		condensed jaccard distance matrix.
	jlist : list
		list of list (jaccard distance square matrix)
	dicto : dict
		dictionary of list

	'''
	dicto, dxkeys = jdistprep(df, axis)
	jlist = []
	for key in dxkeys:
		col = []
		for key2 in dxkeys:
			set1, set2 = dicto[key], dicto[key2]
			
			set1x = set([x for x in set1 if x.split('_')[1] != '0'])
			set2x = set([x for x in set2 if x.split('_')[1] != '0'])
			col.append(1 - ji_cal(set1x, set2x))
		jlist.append(col)
	dists = squareform(jlist)
	linkage_matrix = linkage(dists, "single")
	return linkage_matrix, jlist, dicto

def plot_dendro(linkage_matrix, ax, orient):
	'''
	Plots dendrogram into subplot

	Parameters
	----------
	mat : list of lists
		Contains the square matrix of jaccard distances.
	ax : axes
		axis of subplot to plot to.
	orient : str
		orientation of dendrogram to be plotted.

	Returns
	-------
	None.

	'''
	
	dendrogram(linkage_matrix, no_labels=True, ax=ax, orientation=orient, color_threshold=0, above_threshold_color='#000000')

xmat, xlist, xdict = jdist(df_combined_sig)
ymat, ylist , ydict = jdist(df_combined_sig, axis=1)

yden = dendrogram(ymat, labels=df_combined_sig.index.to_list(), orientation='left')
plt.show()
xden = dendrogram(xmat, labels=df_combined_sig.columns.to_list(), orientation='top')
plt.show()

yorder = yden['ivl']
xorder = xden['ivl']
df_sig_reordered = df_combined_sig[xorder]
df_sig_reordered = df_sig_reordered.reindex(yorder[::-1])

"""
Custom plot
"""

# Create plot with subplot
fig, ax = plt.subplots(6,2, constrained_layout=True,
					   figsize=(16.3, 18), # (width, height)
					   gridspec_kw={'width_ratios': [1, 8.3],
					 'height_ratios': [0.5,0.5,0.5,0.5,1,5.3]})
plt.rcParams['font.size'] = '16'
ax_1, ax_2, ax_3, ax_4, ax_5, ax_6, ax_7, ax_8, ax1, ax2, ax3, ax4 = ax.flatten()

for ax in [ax_1, ax_2, ax_3, ax_4, ax_5, ax_6, ax_7, ax_8, ax1, ax2, ax3, ax4]:
	ax.tick_params(axis='both', which='major', labelsize=16)

ax_1.axis('off') # empty
#ax_2.axis('off') # DGE % of genes
ax_3.axis('off') # empty
#ax_4.axis('off') # DGE % of TFs
ax_5.axis('off') # empty
#ax_6.axis('off') #  DGE % of kinases
ax_7.axis('off') # empty
#ax_8.axis('off') # %DGE up/down reg
ax1.axis('off') # cbar
ax2.axis('off') # dendrogram row
ax3.axis('off') # dendrogram column

# ax_2 DGE % of genes
### STATISTICS OF DGEs ###
DGEperdict = {}
for sx in sigtable.stress.unique():
	DGEperdict[sx] = (len(sigtable[sigtable.stress == sx])/Gdicto[sx.split("_")[0]])*100
DGEper = pd.DataFrame(DGEperdict, index = ["DGEs"])
DGEper = DGEper.transpose()
DGEper["NotDGE"] = DGEper.apply(lambda row: 100 - row, axis=0)
DGEper = DGEper.reindex(xorder)

DGEper.plot.bar(stacked = True,
				color = {"DGEs":"firebrick", "NotDGE":"darkgrey"},
				edgecolor = "black",
				ylim = [0, 50],
				ax = ax_2)
handles, labels = ax_2.get_legend_handles_labels()
ax_2.legend(handles=handles[:-1], labels=labels[:-1],
          loc='center left', bbox_to_anchor=(1, 0.5))
ax_2.set_ylabel("% genes", rotation = 90, fontsize=16)
ax_2.yaxis.set_label_coords(-0.06,0.36)
ax_2.axes.get_xaxis().set_visible(False)

# ax_4 DGE % of DGEs that are TFs
tfdir = dir_path + 'tf_kinases/'
tfpaths = [x for x in os.listdir(tfdir) if ".ini" not in x and ".TF." in x]
tfdict = {}

for file in tfpaths:
	tempspe = spedicto[file.split(".")[0]]
	content = open(tfdir + file, "r")
	for line in content:
		gene, anno = line.strip().split("\t")
		if anno!= "NoFunction":
			tfdict[gene] = anno
	
sigmod = sigtable.reset_index()
tfdf = pd.DataFrame(columns = list(sigmod.columns))
for sx in sigtable.stress.unique():
	tfsubset = sigmod[(sigmod.stress == sx) & (sigmod.apply(lambda row: row["gene"] in tfdict, axis = 1))]
	tfdf = tfdf.append(tfsubset, ignore_index = True)
	
TFperdict = {}
for sx in sigtable.stress.unique():
	TFperdict[sx] = (len(tfdf[tfdf.stress == sx])/len(sigmod[sigmod.stress == sx]))*100
TFper = pd.DataFrame(TFperdict, index = ["TFs"])
TFper = TFper.transpose()
TFper["NotTFs"] = TFper.apply(lambda row: 100 - row, axis=0)
TFper = TFper.reindex(xorder)

TFper.plot.bar(stacked = True,
				color = {"TFs":"forestgreen", "NotTFs":"darkgrey"},
				edgecolor = "black",
				ax = ax_4,
				ylim = [0,25])
TFhandles, TFlabels = ax_4.get_legend_handles_labels()
ax_4.legend(handles=TFhandles[:-1], labels=TFlabels[:-1],
          loc='center left', bbox_to_anchor=(1, 0.5))
ax_4.set_ylabel("% DEGs", rotation = 90, fontsize=16)
ax_4.yaxis.set_label_coords(-0.06,0.36)
ax_4.axes.get_xaxis().set_visible(False)

# ax_6 DGE % of DGEs that are kinases
kindir = dir_path + 'tf_kinases/'
kinpaths = [x for x in os.listdir(kindir) if ".ini" not in x and ".kinases." in x]
kindict = {}

for file in kinpaths:
	tempspe = spedicto[file.split(".")[0]]
	content = open(kindir + file, "r")
	for line in content:
		gene, anno, anno1, anno2 = line.strip().split("\t")
		if anno!= "NoFunction":
			kindict[gene] = [anno, anno1, anno2]
	
kindf = pd.DataFrame(columns = list(sigmod.columns))
for sx in sigtable.stress.unique():
	if sx.split("_")[0] == "Ath":
		kinsubset = sigmod[(sigmod.stress == sx) & (sigmod.apply(lambda row: row["gene"].upper() in kindict, axis = 1))]
	else:
		kinsubset = sigmod[(sigmod.stress == sx) & (sigmod.apply(lambda row: row["gene"] in kindict, axis = 1))]
	kindf = kindf.append(kinsubset, ignore_index = True)
	
kinperdict = {}
for sx in sigtable.stress.unique():
	kinperdict[sx] = (len(kindf[kindf.stress == sx])/len(sigmod[sigmod.stress == sx]))*100
kinper = pd.DataFrame(kinperdict, index = ["kinases"])
kinper = kinper.transpose()
kinper["NotKinases"] = kinper.apply(lambda row: 100 - row, axis=0)
kinper = kinper.reindex(xorder)

kinper.plot.bar(stacked = True,
				color = {"kinases":"darkgoldenrod", "NotKinases":"darkgrey"},
				edgecolor = "black",
				ax = ax_6,
				ylim = [0,25])
kinhandles, kinlabels = ax_6.get_legend_handles_labels()
ax_6.legend(handles=kinhandles[:-1], labels=kinlabels[:-1],
          loc='center left', bbox_to_anchor=(1, 0.5))
ax_6.set_ylabel("% DEGs", rotation = 90, fontsize=16)
ax_6.yaxis.set_label_coords(-0.06,0.36)
ax_6.axes.get_xaxis().set_visible(False)

# ax_8 DGE, up/down ratio
uddict = {}
for sx in sigtable.stress.unique():
	upcount = len(sigtable[(sigtable.stress == sx) & (sigtable.L2FC_D2 == "UP")])
	downcount = len(sigtable[(sigtable.stress == sx) & (sigtable.L2FC_D2 == "DOWN")])
	total = upcount + downcount
	uddict[sx] = [downcount/(total)*100, upcount/(total)*100]
udper = pd.DataFrame(uddict, index = ["downregulated", "upregulated"])
udper = udper.transpose()
udper = udper.reindex(xorder)
udper.plot.bar(stacked = True,
				color = {"upregulated":"firebrick", "downregulated":"navy"},
				edgecolor = "white",
				yticks = [0, 50, 100],
				ax = ax_8)
dgehandles, dgelabels = ax_8.get_legend_handles_labels()
ax_8.legend(handles=dgehandles[::-1], labels=dgelabels[::-1],
			loc='center left', bbox_to_anchor=(1, 0.5))
ax_8.set_ylabel("% DEGs", rotation = 90, fontsize=16)
ax_8.yaxis.set_label_coords(-0.06,0.36)
ax_8.axes.get_xaxis().set_visible(False)

# ax2/3 Dendrogram

plot_dendro(xmat, ax2, 'top')
plot_dendro(ymat, ax3, 'left')

# ax4 Heatmap
from matplotlib.colors import ListedColormap
cmap = ListedColormap(["lightgray", "royalblue", "violet", "firebrick"])
catno = 4

hplot = ax4.imshow(df_sig_reordered, cmap=cmap)
ax4.yaxis.tick_right()
ax4.set_ylabel("")
ax4.set_xticks(np.arange(0, len(df_sig_reordered.columns), 1))
ax4.set_yticks(np.arange(0, len(df_sig_reordered), 1))

xcolour = [species_color(x) for x in xorder]
newlabel = df_sig_reordered.columns.to_list()
longlabel = ['heat', 'cold', 'light', 'dark', 'salt', 'mannitol', 'nitrogen']
shortlabel = ['H', 'C', 'L', 'D', 'S', 'M', 'N']
for i, y in enumerate(longlabel):
	newlabel = [x.replace(y, shortlabel[i]) for x in newlabel]
newlabel = [x.replace('_', ' ') for x in newlabel]
ax4.set_xticklabels(newlabel, rotation=90, fontsize=18)

for i, tick_label in enumerate(ax4.get_xticklabels()):
	tick_text = tick_label.get_text()
	tick_label.set_color(xcolour[i])

anno_long = ['annotated', 'cellulose', 'biosynthesis', 'hemicellulose', 'pectin', 'channels', 'degradation']
ax4.set_yticklabels([dicto[int(x)] for x in df_sig_reordered.index.to_list()], fontsize=18)

# colourbar
cbarticks = [(x/(catno*2))*(catno-1) for x in range(1,catno*2,2)]
axins = inset_axes(ax1,
					width="40%",
					height="90%", 
					loc = 'center')
cbar = fig.colorbar(hplot, cax=axins, ticks = cbarticks)
cbar.ax.set_yticklabels(['N', 'D', 'UD', 'U'])

#plt.tight_layout()
plt.savefig(dir_path + 'figures/fig2.png',
				dpi=600,
				bbox_inches='tight')

# reset rcparams
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

### Supp. Fig 5: Interspecies comparison (Gene families)

In [None]:
wdir = dir_path + 'prep_files/'
jdir = wdir + 'proteomes/'
OFfile = 'Orthogroups.txt'
DGEfile = wdir + 'Figure2_alldata_compiled_updated.txt'
jhdir = wdir + 'interspeGF/'
jhdir_safe = dir_path_safe + 'prep_files/interspeGF/'
if not os.path.exists(jhdir):
	!mkdir $jhdir_safe

# initialise species
spedicto = {'ARATH' : 'Ath',
			'CHLRE' : 'Cre',
			'CYAPA' : 'Cpa',
			'MARPO' : 'Mpo',
			'ORYSA' : 'Osa'}
species_list = list(spedicto.values())
		
# initialise genes in species
Gdicto = {}
# {"species" : ["gene1", "gene2"...]}
for pepfile in [x for x in os.listdir(jdir) if '.ini' not in x]:
	with open(jdir + pepfile, "r") as peppy:
		species, genes = pepfile.split('.fa')[0], []
		for lini in peppy:
			if '>' in lini:
				genes.append(lini.strip().split('>')[1])
		Gdicto[spedicto[species]] = genes

# initialise orthofinder groups and corresponding genes by species
OFdicto = {}
OG_list = []
#{"OGx_spe":["gene1", "gene2"]}
with open(wdir+OFfile, "r") as content:
	for line in content:
		og, val = line.split(': ')[0], line.rstrip().split(': ')[1]
		OG_list.append(og)
		for spe in spedicto.values():
			OFdicto[og + "_" + spe] = [x for x in val.split(' ') if x in Gdicto[spe]]

        
# initialise DGEs
DGEdicto = {}
# {"gene1_stress" : "UP/DOWN"}
spe_stress = []
with open(DGEfile, "r") as dgecon:
	dgecon.readline()
	for lino in dgecon:
		gene, stress, L2FC_D2, L2FC_H2, annotation = lino.strip().split("\t")
		# to account for difference in gene names in this file and in Orthogroups.txt and ARATH.fa
		if stress.split("_")[0] == "Ath":
			gene = gene.upper()
		elif stress.split("_")[0] == "Cre":
			gene = gene.split(".t")[0]
		DGEdicto[gene +  '_' + stress] = L2FC_D2
		if stress not in spe_stress:
			spe_stress.append(stress)

# Functions
def get_ortho_status(og, spe, stress):
	"""
	Parameters
	----------
	og : str
		orthogroup name.
	spe : str
		species code (3 letters).
	stress : str
		type of stress.

	Returns
	-------
	status : str/None
		og + "_UP": consistently UP
		og + "_DOWN": consistently DOWN
		og + "_AMB": ambiguous, UP and DOWN detected
		og + "_NC" : all genes present have no signicant DGEs
		None: No gene present in orthogroup

	"""
	status = []
	if len(OFdicto[og + "_" + spe]) > 0:
		for gene in OFdicto[og + "_" + spe]:
			if gene + "_" + spe + "_"+ stress in DGEdicto:
				status.append(DGEdicto[gene + "_" + spe + "_"+ stress])
			else:
				status.append("NC") # need to account for unchanged genes
		status = list(set(status))
		if len(status) > 1:			# more than 1 type, AMB: ambiguous
			if "UP" in status and "DOWN" in status:
				status = "AMB"
			elif "UP" in status:
				status = "UP"
			elif "DOWN" in status:
				status = "DOWN"
		elif  status[0] == "NC":	# only 1 type, NC: no change
			status = "NC"
		elif status[0] == "DOWN":	# only 1 type, DOWN: downregulated
			status = "DOWN"
		elif status[0] == "UP":		# only 1 type, UP: upregulated
			status = "UP"
	else:
		status = "None" # no gene in the orthogroup
	return status

all_statuses = {}
for item in spe_stress:
	species, stress_type = item.split("_", 1)
	# container for collecting all the orthogroup status for item in spe_stress
	spe_stress_stat = []
	for orthogroup in OG_list:
		spe_stress_stat.append(get_ortho_status(orthogroup, species, stress_type))
	all_statuses[item] = spe_stress_stat

# Initialize dataframe for OG (rows)/spe_stress (cols)
OGstats = pd.DataFrame(all_statuses, index = OG_list, columns = spe_stress)
OGstats.to_csv(jhdir + 'OGstats.txt', sep="\t")

# Calculate Jaccard distance

# Functions
def cal_jd(cond1, cond2):
	"""
	Parameters
	----------
	stress1 : pandas Series
		OG stat of first stress
	stress2 : pandas Series
		OG stat of second stress

	Returns
	-------
	status : float
		modified jaccard distance
	"""
	# Score container
	score = ""
	if "None" in cond1 or "None" in cond2: 	# 1 or more OG absent
		score = None
	elif cond1 != cond2:					# status do not match 
		score = 0
	elif cond1 == "NC" and cond2 == "NC":	# match but NC
		score = None
	elif cond1 == cond2:
		score = 1							# match
	return score

jd_dict = {}
jd_counts = {}
for i in range(len(spe_stress)):
	# container for jaccard distances
	jd_con = []
	# container for counts of OG per comparison
	count_con = []
	for j in range(len(spe_stress)):
		if spe_stress[i] == spe_stress[j]:
			jd_con.append(0)
			# reflects the number of OG != "None"
			count_con.append(sum(OGstats[spe_stress[i]] != "None"))
		else:
			interdf = OGstats[[spe_stress[i], spe_stress[j]]]
			scoreseries = interdf.apply(lambda row: cal_jd(row[spe_stress[i]], row[spe_stress[j]]), axis=1)
			jd_con.append(1 - (scoreseries.sum()/scoreseries.count()))
			count_con.append(scoreseries.count())
	jd_dict[spe_stress[i]] = jd_con
	jd_counts[spe_stress[i]] = count_con
	
# Initialize dataframe for spe_stress (rows)/spe_stress (cols) [jaccard distance]
JDstats = pd.DataFrame(jd_dict, index = spe_stress, columns = spe_stress)
JDstats.to_csv(jhdir + 'JDstats.txt', sep="\t")
# Initialize dataframe for spe_stress (rows)/ spe_stress (cols)
# [OG count used to calculate jaccard distance]
JDcounts = pd.DataFrame(jd_counts, index = spe_stress, columns = spe_stress)
JDcounts.to_csv(jhdir + 'JDcounts.txt', sep="\t")

# clustermap -- distance matrix

def label_color(xlabel):
	if "heat" in xlabel or 'Mpo_H' in xlabel:
		return "firebrick"
	elif "cold" in xlabel or 'Mpo_C' in xlabel:
		return "steelblue"
	elif "light" in xlabel or 'Mpo_L' in xlabel:
		return "darkorange"
	elif "dark" in xlabel or 'Mpo_D' in xlabel:
		return "black"
	elif "salt" in xlabel or 'Mpo_S' in xlabel:
		return "rebeccapurple"
	elif "mannitol" in xlabel or 'Mpo_M' in xlabel:
		return "mediumvioletred"
	elif "nitrogen" in xlabel or 'Mpo_N' in xlabel:
		return "forestgreen"
	else:
		return "slategrey"

import scipy.spatial as sp, scipy.cluster.hierarchy as hc

colnames = JDstats.index.to_series().apply(lambda row: row.split("_")[0])
specoldict = dict(zip(colnames.unique(), "rgbcy"))
specol = colnames.map(specoldict)

linkage = hc.linkage(sp.distance.squareform(JDstats), method='single')
g = sns.clustermap(JDstats,
				   row_linkage = linkage,
				   col_linkage = linkage,
				   row_colors = specol,
				   xticklabels = True,
				   yticklabels=True)
newlabel = [x.get_text() for x in g.ax_heatmap.axes.get_xticklabels()]
longlabel = ['heat', 'cold', 'light', 'dark', 'salt', 'mannitol', 'nitrogen']
shortlabel = ['H', 'C', 'L', 'D', 'S', 'M', 'N']
for i, y in enumerate(longlabel):
	newlabel = [x.replace(y, shortlabel[i]) for x in newlabel]
newlabel = [x.replace('_', ' ') for x in newlabel]
g.ax_heatmap.axes.set_xticklabels(newlabel, rotation=90, fontsize=16)

for tick_label in g.ax_heatmap.axes.get_yticklabels():
	tick_text = tick_label.get_text()
	tick_label.set_color(label_color(tick_text))
g.ax_heatmap.axes.set_yticklabels(newlabel, fontsize=16)

plt.savefig(dir_path + 'figures/suppfig5.png')

### Figure 3: Stress responsiveness

In [None]:
# adapted from stress_res_og.py
OFpath = dir_path + 'prep_files/Orthogroups.txt'
jhdir = dir_path + 'prep_files/interspeGF/'
prefix = {'Cpa|' : 'Cpa',
		  'Cre' : 'Cre',
		  'Mp' : 'Mpo',
		  'ChrUn' : 'Osa',
		  'LOC_Os' : 'Osa',
		  'AT' : 'Ath'}
spelist = list(prefix.values())

def spe_finder(gene):
	'''
	Finds the species the gene belongs to

	Parameters
	----------
	gene : str
		Gene ID.

	Returns
	-------
	spestat : str
		Corresponding species of gene.

	'''
	spestat = 'Other'
	for k in list(prefix.keys()):
		if gene.startswith(k):
			spestat = prefix[k]
	return spestat

# order of species
spe_order = ['Cpa', 'Cre', 'Mpo', 'Osa', 'Ath']
spe_class = [['Angiosperm', ['Osa', 'Ath']],
			 ['Embryophyte', ['Mpo', 'Osa', 'Ath']],
			 ['Viridiplantae', spe_order[1:]],
			 ['Archaeplastida', spe_order]]

def speclass(spelist):
	spec = []
	for spe in spelist:
		for i, c in enumerate(spe_class):
			if spe in c[1]:
				spec.append(i)
				break
	return spe_class[max(spec)][0]

# initialise OG stats df
OGstats = pd.read_csv(jhdir + 'OGstats.txt', sep="\t", index_col=0)
newlabel = OGstats.columns.to_list()
longlabel = ['heat', 'cold', 'light', 'dark', 'salt', 'mannitol', 'nitrogen', 'drought']
shortlabel = ['H', 'C', 'L', 'D', 'S', 'M', 'N', 'M']
for i, y in enumerate(longlabel):
	newlabel = [x.replace(y, shortlabel[i]) for x in newlabel]
OGstats.columns = newlabel


# initialise DGEs
DGEfile = dir_path + 'prep_files/Figure2_alldata_compiled_updated.txt'
DGElist = []
DGEbins = {}
# {"gene1_stress" : "UP/DOWN"}
spe_stress = []
with open(DGEfile, "r") as dgecon:
	dgecon.readline()
	for lino in dgecon:
		gene, stress, L2FC_D2, L2FC_H2, annotation = lino.strip().split("\t")
		# to account for difference in gene names in this file and in Orthogroups.txt and ARATH.fa
		if stress.split("_")[0] == "Ath":
			gene = gene.upper()
		elif stress.split("_")[0] == "Cre":
			gene = gene.split(".t")[0]
		DGElist.append(gene)
		binlist = [int(x[0].replace("'", "").split('.')[0]) for x in literal_eval(annotation)]
		DGEbins[gene] = list(set(binlist))
DGElist = set(DGElist)

og_genes = {}
spespec_og = {} # list of species specific OGs, excludes OGs of species not included in analysis
other_og = {} # # list of non-species specific OGs, excludes OGs of species not included in analysis
with open(OFpath, 'r') as OFfile:
	for line in OFfile:
		og, val = line.split(': ')[0], line.rstrip().split(': ')[1]
		og_species = list(set([spe_finder(x) for x in val.split(' ')]))
		if len(og_species) == 1 and og_species[0] in spelist:
			spespec_og[og] = og_species[0]
			og_genes[og] = list(set(val.split(' ')) & DGElist)
		elif len(og_species) > 1:
			ogclass = speclass(og_species)
			other_og[og] = ogclass
			og_genes[og] = list(set(val.split(' ')) & DGElist)


# =============================================================================
# 
# Stress-responsive OGs
# 
# =============================================================================
from collections import Counter, defaultdict

df_coln = ['Archaeplastida', 'Viridiplantae', 'Embryophyte', 'Angiosperm']

# omit OGs that contain only 'None' across all stresses
oglist = OGstats.index.to_list()
suboglist = [x for x in oglist if set(OGstats.loc[x].to_list()) != {'None'}] 
val_spespec_og = list(set(spespec_og) & set(suboglist))
val_spespec_og.sort()
val_other_og = list(set(suboglist) - set(val_spespec_og))
val_other_og.sort()

# df that contains only OGs that are not made up of 'None'
subdf = OGstats.loc[suboglist]

# dictionary to contain all statuses
og_stress_stat = defaultdict(list)

def update_stress_stat(newstat, phyla, dfcount, og, stresstype):
	dfcount.loc[newstat, phyla] += 1
	og_stress_stat[og].append(stresstype + '_' + newstat)
	
# to get counts per stress
stresslist = shortlabel[:-1]
def counts_per_stress(slabel):
	stresstype = '_' + slabel
	valid_exp = [x for x in OGstats.columns.to_list() if stresstype in x]
	valid_spe = [x.split('_')[0] for x in valid_exp]
	ordered_exp = [y for x in spe_order for y in valid_exp if x in y]
	unique_spe = [x for x in spe_order if x in valid_spe]
	stresssub = subdf[ordered_exp] # df containing only exps of required stress type
	
	dumdict = {'UP': [0 for x in range(4 + len(unique_spe))],
			   'DOWN': [0 for x in range(4 + len(unique_spe))],
			   'AMB': [0 for x in range(4 + len(unique_spe))],
			   'MIXED': [0 for x in range(4 + len(unique_spe))],
			   'NR': [0 for x in range(4 + len(unique_spe))]}
	
	dfcount = pd.DataFrame.from_dict(dumdict, orient='index', columns = df_coln + unique_spe)
	nogroup = []
	for og, ogclass in spespec_og.items(): # species specific OGs
		kcount = Counter(stresssub.loc[og].to_list())
		# species specific OGs
		if kcount['None'] != len(valid_spe): # ignore if OG not valid for stress
			if slabel == 'S' and ogclass == 'Osa':
				if kcount['NC'] == 2: # for species specific OG that is 'NC' [Osa]
					update_stress_stat('NR', ogclass, dfcount, og, slabel)
				else: # kcount['NC']!= 2
					stat = [x for x in kcount if x != 'None' and 'NC']
					if len(stat) == 1: # only one type of UP/DOWN/AMB
						update_stress_stat(stat[0], ogclass, dfcount, og, slabel)
					else: # mixture of UP/DOWN/AMB
						update_stress_stat('MIXED', ogclass, dfcount, og, slabel)
			else:
				if 'NC' not in kcount: # for species specific OG that is not 'NC'
					stat = [x for x in kcount if kcount[x] == 1]
					update_stress_stat(stat[0], ogclass, dfcount, og, slabel)
				else: # for species specific OG that is 'NC'
					update_stress_stat('NR', ogclass, dfcount, og, slabel)
	
	for og, ogclass in other_og.items(): # non-species specific OGs
		kcount = Counter(stresssub.loc[og].to_list())	
		if kcount['None'] != len(valid_spe): # ignore if OG not valid for stress
			notnil = [x for x in list(kcount.keys()) if x != 'None' and x != 'NC']
			if len(notnil) > 0:
				if len(notnil) > 1: # contains combination of UP/DOWN/AMB
					update_stress_stat('MIXED', ogclass, dfcount, og, slabel)
				elif len(notnil) == 1: # contains only one type of status apart from 'NR'
					if kcount[notnil[0]] > 1: # if UP/DOWN/AMB appear more than once
						update_stress_stat(notnil[0], ogclass, dfcount, og, slabel)
					elif kcount[notnil[0]] == 1: # if UP/DOWN/AMB only appear once
						update_stress_stat('NR', ogclass, dfcount, og, slabel)
					else:
						nogroup.append(og)
			else: # OGs that only have 'None' and 'NC'	
				update_stress_stat('NR', ogclass, dfcount, og, slabel)
		else: # all 'None', meaning that OG is valid in other species not present in this analysis
			update_stress_stat('NR', ogclass, dfcount, og, slabel)
	return dfcount.T

dH, dC, dL, dD, dS, dM, dN = [counts_per_stress(x) for x in stresslist]
dflist = dH, dC, dL, dD, dS, dM, dN
wdir = dir_path + 'phylostrata/'
wdir_safe = dir_path_safe + 'phylostrata/'
if not os.path.exists(wdir):
    !mkdir $wdir_safe
for i, df in enumerate(dflist):
	df.to_csv(wdir + stresslist[i] + '_df.txt', sep="\t")

# =============================================================================
# 
# Quantifying stress responsiveness of Orthogroups
# 
# =============================================================================

# intermediate container for new df
resog_dict = {}
for dicto in [other_og, spespec_og]:
	for og, phyla in dicto.items():
		reslist = [x for x in og_stress_stat[og] if 'NR' not in x]
		resog_dict[og] = [phyla, reslist, len(reslist), og_genes[og]]

resog_df = pd.DataFrame.from_dict(resog_dict, orient='index', columns=['Phylostrata', 'Responsive in', 'Count', 'Genes'])
resog_df.to_csv(wdir + 'resog_df.txt', sep="\t")

sorder = df_coln + spe_order
grouped_count = resog_df.groupby(['Phylostrata','Count']).count()['Genes'].unstack().reindex(sorder)
grouped_per = grouped_count.copy()
for row in sorder:
	grouped_per.loc[row] = grouped_per.loc[row].apply(lambda x: (x/grouped_count.loc[row].sum())*100)

sorder.remove('Angiosperm')

g = grouped_count.loc[sorder].plot.bar(stacked=True, ylabel = 'Count')
g.legend(bbox_to_anchor=(1, 0.75))

g2 = grouped_per.loc[sorder,[i for i in range(1,8)]].plot.bar(stacked=True, ylabel='Percentage (%)')
g2.legend(bbox_to_anchor=(1, 0.75))

# Percentage of OGs from various phylostrata that are responsive in respective number of stresses (x-axis)
grouped_per_bycount = grouped_count.copy()
for col in grouped_per_bycount.columns.to_list():
	grouped_per_bycount[col] = grouped_per_bycount[col].apply(lambda x: (x/grouped_count[col].sum())*100)
g3 = grouped_per_bycount.loc[sorder,[i for i in range(1,8)]].T.plot.bar(stacked=True, ylabel='Percentage (%)')
g3.legend(bbox_to_anchor=(1, 1))

# log y of number og OGs responsive in respective number of stresses (x axis)
g = grouped_count.loc[sorder].T.plot.bar(logy=True, ylabel = 'Number of OGs')
g.legend(bbox_to_anchor=(1, 0.75))

countbysres = grouped_count.loc[sorder].T
countbysres.to_csv(wdir + 'countbystressres.txt', sep='\t')

# Mapman bins
import seaborn as sns
import math
merdict = literal_eval(open(dir_path + 'prep_files/merdict.txt', 'r').read())
# By Phylo
catdict = {} # Mapman bin count for Phylostrata that are stress responsive (Count > 0)
for cat in sorder:
	catbins = [z for x in resog_df[(resog_df.Phylostrata == cat) & (resog_df.Count > 0)].Genes.to_list() for y in x for z in DGEbins[y]]
	catdict[cat] = Counter(catbins)

catdf = pd.DataFrame.from_dict(catdict)
catdf.sort_index(inplace=True)
catdf.reset_index(inplace=True)
catdf.columns = ['Mapman bins'] + sorder
catdf['Mapman bins'] = catdf['Mapman bins'].apply(lambda x: merdict[x])
catdf.set_index('Mapman bins', inplace=True)

catperdf = catdf.copy()
catperlogdf = catdf.copy()
for x in catperdf.columns.to_list():
	total = catdf[x].sum()
	catperdf[x] = catdf[x].apply(lambda x: (x/total)*100)
	catperlogdf[x] = catdf[x].apply(lambda x: math.log((x/total)*100,2))

catperdf.fillna(float(0), inplace=True)

f = sns.clustermap(catperdf, yticklabels=True, col_cluster=False) # to get linkage for logged values (percentages can be filled 0 but cannot fill NaN with 0 for logged values)
row_linkage = f.dendrogram_row.linkage

sns.clustermap(catperlogdf, yticklabels=True, col_cluster=False, row_linkage = row_linkage, cmap='coolwarm')

# By Stress responsiveness
countdict = {} # Mapman bin count for Phylostrata that are stress responsive (Count > 0)
for count in [i for i in range(1,8)]:
	countbins = [z for x in resog_df[(resog_df.Count == count)].Genes.to_list() for y in x for z in DGEbins[y]]
	countdict[str(count)] = Counter(countbins)
countdf = pd.DataFrame.from_dict(countdict)
countdf.sort_index(inplace=True)
countdf.reset_index(inplace=True)
countdf.columns = ['Mapman bins'] + [i for i in range(1,8)]
countdf['Mapman bins'] = countdf['Mapman bins'].apply(lambda x: merdict[x])
countdf.set_index('Mapman bins', inplace=True)

# column normalised
countperdf = countdf.copy()
countperlogdf = countdf.copy()
for x in countperdf.columns.to_list():
	total = countdf[x].sum()
	countperdf[x] = countdf[x].apply(lambda x: (x/total)*100)
	countperlogdf[x] = countdf[x].apply(lambda x: math.log((x/total)*100,2))

countperdf.fillna(float(0), inplace=True)

f2 = sns.clustermap(countperdf, yticklabels=True, col_cluster=False, cmap='coolwarm') # to get linkage for logged values (percentages can be filled 0 but cannot fill NaN with 0 for logged values)
row_linkage2 = f2.dendrogram_row.linkage

sns.clustermap(countperlogdf, yticklabels=True, col_cluster=False, row_linkage = row_linkage2, cmap='coolwarm')

# row normalised
countper_rownorm_df = countdf.copy()
countper_rownorm_logdf = countdf.copy()
for x in countper_rownorm_df.index.to_list():
	total = countdf.loc[x].sum()
	countper_rownorm_df.loc[x] = countdf.loc[x].apply(lambda x: (x/total)*100)
	countper_rownorm_logdf.loc[x] = countdf.loc[x].apply(lambda x: math.log((x/total)*100, 2))

countper_rownorm_df.fillna(float(0), inplace=True)

f3 = sns.clustermap(countper_rownorm_df, yticklabels=True, col_cluster=False,cmap='coolwarm') # to get linkage for logged values (percentages can be filled 0 but cannot fill NaN with 0 for logged values)
row_linkage3 = f3.dendrogram_row.linkage

sns.clustermap(countper_rownorm_logdf, yticklabels=True, col_cluster=False,
			   row_linkage = row_linkage3, cmap='coolwarm', figsize=(5,6))
plt.savefig(dir_path + 'figures/fig3a', dpi=600)

### Figure 4: Upset plot and summary of DEGs in Marchantia

In [None]:
# Fig 4A and B (adapted from DGE_count_sizecorr.py)
cross = pd.read_csv(dir_path + 'prep_files/mpo/deseq/resSig_compiled.txt', sep='\t')

cross.stress = [x.split('_')[1] for x in list(cross.stress)]
stress_l = list(cross.stress.unique())
stress_l.sort(key=lambda x: len(x))
c_dgecount = cross.groupby(['stress', 'L2FC_D2']).count().gene.to_frame(name='count')

unstacked = c_dgecount.unstack().reindex(stress_l)
ax = unstacked.plot.bar(figsize=(7,3), stacked=True, ylabel='Number of DEGs', color=['navy', 'firebrick'])

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles[::-1], labels=[x.split(', ')[1].split(')')[0] for x in labels][::-1])
plt.savefig(dir_path + 'figures/fig4a.png', dpi=600)

wdir = dir_path + 'prep_files/'
infile = 'phase1n2_measurements_nooutliers.txt'

measurements = pd.read_csv(wdir + infile, sep='\t')

def ss_grab(stress, condition):
	"""
	Slice the relevant condition for 
	Parameters
	----------
	stress : string
		Stress of interest.
	condition : string
		Condition of interest.

	Returns
	-------
	sssub : dataframe
		df of single stress.

	"""
	sssub = measurements[(measurements.Stress == stress) & (measurements.Condition == condition)]
	return sssub

srep_keys = [['Cold', '3'],
			 ['Heat', '33'],
			 ['Salt', '40'],
			 ['Mannitol', '100'],
			 ['Light', '435'],
			 ['Dark', '3'],
			 ['Nitrogen', '0']]
srepdf = measurements[(measurements.Stress == 'Cold') & (measurements.Condition == '3')]

for s, c in srep_keys[1:]:
	srepdf = pd.concat([srepdf, ss_grab(s, c)])

crepdf = measurements[measurements.Condition == 'mixed']
m_nocon = pd.concat([srepdf, crepdf])
m_nocon.Stress = [x[0] if len(x) > 2 else x for x in m_nocon.Stress]
noHL = m_nocon[m_nocon.Stress != 'HL']

avg_meas = noHL.groupby('Stress').mean()[['Parea', 'Earea']]
avg_meas.reindex(stress_l)

totaldeg = cross.groupby(['stress']).count()['L2FC_D2']
totaldeg.reindex(stress_l)
avg_meas['totaldeg'] = totaldeg
avg_meas = avg_meas[['totaldeg', 'Parea', 'Earea']]

# size plots by df
ax = avg_meas.plot.scatter(x='totaldeg', y='Parea', color='orange', label='Area (Day 15)')

for ind, dat in avg_meas.iterrows():
    ax.annotate(ind, (dat['totaldeg'], dat['Parea']),
				xytext=(-4,-12), textcoords='offset points')

avg_meas.plot.scatter(x='totaldeg', y='Earea', color='navy', label='Area (Day 21)', ax=ax)
for ind, dat in avg_meas.iterrows():
    ax.annotate(ind, (dat['totaldeg'], dat['Earea']),
				xytext=(-4,-12), textcoords='offset points')
	
# size plot with regression
from scipy import stats

q_colnames = avg_meas.columns.to_list()
dicto = {'Parea' : 'Day 15', 'Earea' : 'Day 21'}

def plot_reg(df, title):
	labels = []
	col = q_colnames[0]
	for col2 in q_colnames[1:]:
		plt.scatter(col, col2, data=df)
		#m, c = np.polyfit(df[col], df[col2], 1)
		m, c, r_value, p_value, std_err = stats.linregress(df[col], df[col2])
		for ind, dat in avg_meas.iterrows():
			plt.annotate(ind, (dat[col], dat[col2]),
			  xytext=(-4,-12), textcoords='offset points')
		plt.plot(df[col], m*df[col] + c)
		labels.append(dicto[col2] + ' (R\u00b2: ' + str(round(r_value**2,2)) + ', p: ' + str(round(p_value, 2)) + ')')
	plt.legend(labels)
	plt.xlabel('DEG count')
	plt.ylabel('Size (mm\u00b2)')
	plt.savefig(dir_path + 'figures/fig4b.png', dpi=600)
	
plot_reg(avg_meas, 'Number of DEGs vs Size')

In [None]:
# Fig 4C and D (adapted from upset.py)
import upsetplot
from collections import defaultdict

wdir = dir_path + 'prep_files/mpo/deseq/'
odir = wdir + 'upset/'
odir_safe = dir_path_safe + 'prep_files/mpo/deseq/' + 'upset/'

if not os.path.exists(odir):
	!mkdir $odir_safe

data = pd.read_csv(wdir + 'resSig_compiled.txt', sep = '\t')

### FUNCTIONS ###

def upset_matrix(set_dict, stress_types):
	upset_data_sub = upsetplot.from_contents({k: v for k, v in set_dict.items() if k in stress_types})
	return upset_data_sub # , fig=None


def plot_selected(cond_dict, cond_list, filename, title, orient = "horizontal", cutoff=50):
	df_set = upset_matrix(cond_dict, cond_list)
	df_set = df_set.sort_index()
	
	# preparation to output all data
	index_names = list(df_set.index.names)
	index_list = df_set.index.to_list()
	set_count = Counter(index_list)
	counter_list = [[k, v] for k, v in set_count.items()]
	counter_list.sort(key = lambda x: x[1], reverse=True)
	
	# writing output to file
	with open(filename + "_matrix.txt", "w+") as ofile:
		ofile.write("\t".join(index_names + ['count', 'genes']) + "\n")
		for i in counter_list:
			glist = df_set.loc[i[0],:].id.to_list()
			ofile.write("\t".join([str(int(x)) for x in i[0]] + [str(i[1]), str(glist)]) + "\n")
	
	# writing top 50 to file
	set_cutoff = set_count.most_common(cutoff)
	selection = [x[0] for x in set_cutoff]
	with open(filename + "_top50.txt", "w+") as cfile:
		cfile.write("\t".join([str(index_names), 'count', 'genes']) + "\n")
		for i in range(len(set_cutoff)):
			glist = df_set.loc[set_cutoff[i][0],:].id.to_list()
			cfile.write("\t".join([str(set_cutoff[i][0]), str(set_cutoff[i][1]), str(glist)]) + "\n")
	# selection for plotting
	sel_matrix = df_set.loc[selection[0], :]	
	for i in range(1, len(selection)):
		sel_matrix = sel_matrix + df_set.loc[selection[i],:]
	upsetplot.plot(sel_matrix, orientation = orient, sort_by = 'cardinality')
	plt.title(title, size=20)
	if "upreg" in filename:
		figname = 'c'
	else:
		figname = 'd'
	plt.savefig(dir_path + 'figures/Fig4' + figname + '.png', dpi=600)

### END ###

# Reshape data to have for every category,
cond_dict_U = defaultdict(list) # genres_movies
cond_dict_D = defaultdict(list)
for index, row in data.iterrows():
	if row['L2FC_D2'] == 'UP':
		cond_dict_U[row['stress'].split("_")[1]].append(row['gene'])
	elif row['L2FC_D2'] == 'DOWN':
		cond_dict_D[row['stress'].split("_")[1]].append(row['gene'])
		
all_stress_list = [x.split('_')[1] for x in data.stress.unique()]

# initialise dictionaries of up and downregulated genes for each condition
cond_dict_U_set = dict()
cond_dict_D_set = dict()
for k, v in cond_dict_U.items():
    cond_dict_U_set[k] = set(v)
for k, v in cond_dict_D.items():
    cond_dict_D_set[k] = set(v)

# Plot horizontal (default)
plot_selected(cond_dict = cond_dict_D_set,
			  cond_list = all_stress_list,
			  filename = odir + "all_downreg",
			  title = "Downregulated genes")

plot_selected(cond_dict = cond_dict_U_set,
			  cond_list = all_stress_list,
			  filename = odir + "all_upreg",
			  title = "Upregulated genes")

### Supp Figs 6 & 7, Figure 5: Inter-stress (Marchantia only) comparison

In [None]:
# Supp. figs 6 & 7 (adapted from indivenn_hm.py)
from collections import defaultdict
from matplotlib_venn import venn3

wdir = dir_path + 'prep_files/mpo/deseq/'
odir = dir_path + 'figures/'
data = pd.read_csv(wdir + 'resSig_compiled.txt', sep = '\t')

# Mercator bin conversion
dicto = literal_eval(open(dir_path + 'prep_files/merdict.txt', 'r').read())

all_s = [x.split("_")[1] for x in data.stress.unique()]
single = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 1]
cross = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 2]

data.annotation = data.annotation.apply(literal_eval)
data["mername"] = data.annotation.apply(lambda x: dicto[int(x[0][0].split('.')[0])])

dict_A = defaultdict(list)
dict_U = defaultdict(list)
dict_D = defaultdict(list)

def sum_to_dict(dicto, stress, reg):
	if reg == "ALL":
		subset = data[(data.stress == "Mpo_" + stress)]
	else:
		subset = data[(data.stress == "Mpo_" + stress) & (data.L2FC_D2 == reg)]
	dicto[stress].append(set(subset.gene.to_list()))
	dicto[stress].append(subset.mername.to_list())

def dict_to_df(dicto):
	df = pd.DataFrame.from_dict(dicto, orient='index', columns=["gene", "mername"])
	return df

for s in all_s:
	#sum_to_dict(dict_A, s, "ALL")
	sum_to_dict(dict_U, s, "UP")
	sum_to_dict(dict_D, s, "DOWN")

#df_A = dict_to_df(dict_A)
df_U = dict_to_df(dict_U)
df_D = dict_to_df(dict_D)

def plot_venn(df, s1, s2, c1, title, axis):
	venn3([df.loc[s1].gene, df.loc[s2].gene, df.loc[c1].gene],
	   (s1, s2, c1),
	   ax = axis)
	axis.set_title(title, size=20)

# create subplots
xlen = 4
ylen = math.ceil(len(all_s)/4)
figw = xlen * 4
figh = ylen * 3.5

a_axes = string.ascii_uppercase[:len(all_s)]
def plot_subplot(df, title_ext):
	axa = plt.figure(constrained_layout=True,
				 figsize=(figw, figh)).subplot_mosaic(
					 """
					 ABCD
					 EFGH
					 IJKL
					 MNOP
					 QR..
					 """
					 )

	for c in range(len(cross)):
		st = cross[c]
		plot_venn(df, st[0], st[1], st, st + title_ext, axa[a_axes[c]])
	plt.savefig(odir+'supp_fig6or7' + title_ext +'.png', dpi=600)


#df_col = [[df_A, ''], [df_U, "_upregulated"], [df_D, "_downregulated"]]
df_col = [[df_U, "_upregulated"], [df_D, "_downregulated"]]

for x in df_col:
	df_type, ext = x[0], x[1]
	plot_subplot(df_type, ext)

In [None]:
# Figure 5 (adapted from plot_venn_sum.py)

from collections import defaultdict, Counter
from matplotlib_venn import venn3, venn3_circles
import random
from scipy import stats

wdir = dir_path + 'prep_files/mpo/deseq/'
odir = wdir + 'indivenn_hm/'
data = pd.read_csv(wdir + 'resSig_compiled.txt', sep = '\t')

# Mercator
### DICTIONARY OF MERCATOR BINS ###
mfile = dir_path + 'mercator/MpoProt.results.txt'

meranno = defaultdict(list)
merbin = defaultdict(list)
map2anno = {}

merfile = open(mfile, 'r')
merfile.readline()
for line in merfile:
	linecon = line.rstrip().replace("'", "").split("\t")
	if len(linecon) == 5:
		bincode, name, identifier, desc, ptype = linecon
		meranno[identifier].append(dicto[int(bincode.split('.')[0])])
		merbin[identifier].append('.'.join(bincode.split('.')[:2]))
	if len(linecon[0].split('.')) == 2:
		map2anno[linecon[0]] = linecon[1]
		

all_s = [x.split("_")[1] for x in data.stress.unique()]
single = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 1]
cross = [x.split("_")[1] for x in data.stress.unique() if len(x.split("_")[1]) == 2]

data.annotation = data.annotation.apply(literal_eval)
data["mername"] = data.annotation.apply(lambda x: [dicto[int(y[0].split('.')[0])] for y in x]) # different from cell above, hence the repetitive code

dict_A = defaultdict(list)
dict_U = defaultdict(list)
dict_D = defaultdict(list)

def sum_to_dict(dicto, stress, reg):
	if reg == "ALL":
		subset = data[(data.stress == "Mpo_" + stress)]
	else:
		subset = data[(data.stress == "Mpo_" + stress) & (data.L2FC_D2 == reg)]
	dicto[stress].append(set(subset.gene.to_list()))
	dicto[stress].append([y for x in subset.mername.to_list() for y in x])
	dicto[stress].append(['.'.join(y[0].split('.')[:2]) for x in subset.annotation.to_list() for y in x])

def dict_to_df(dicto):
	df = pd.DataFrame.from_dict(dicto, orient='index', columns=["gene", "mername", "mapbin2"])
	return df

for s in all_s:
	sum_to_dict(dict_A, s, "ALL")
	sum_to_dict(dict_U, s, "UP")
	sum_to_dict(dict_D, s, "DOWN")

df_A = dict_to_df(dict_A)
df_U = dict_to_df(dict_U)
df_D = dict_to_df(dict_D)

# =============================================================================
# 
# # Summary of stress response
# 
# =============================================================================

# Q1 : ji_cal(a, b) [%]
def ji_cal(a, b):
	# jaccard index calculation
	return len(a&b) / len(a|b)
# Q2: |(A − AB)/A − (B − AB)/B| [% difference]
def suppInX(a, b, ab):
	return len((a-ab))/len(a) - len((b-ab))/len(b)

# Q3: (AB - A - B) / AB [%]
def novel(a, b, ab):
	return len(ab - a - b) / len(ab)

def q_col(df, colnames):
	"""
	Collates the params for each cross stress and output in df
	
	Parameters
	----------
	df : dataframe
		dataframe to use (all genes, upreg/downreg only).

	Returns
	-------
	q_df : dataframe
		datafram containing JI of all cross stress.

	"""
	q_dict = {}
	for c in range(len(cross)):
		st = cross[c]
		a = df.loc[st[0]].gene
		b = df.loc[st[1]].gene
		ab = df.loc[st].gene
		q_dict[st] = [
			ji_cal(a,b),
			#perXInAB(a,ab),
			#perXInAB(b,ab),
			suppInX(a,b,ab),
			#suppInAB(a,b,ab),
			novel(a,b,ab)
				]
	q_df = pd.DataFrame.from_dict(q_dict, orient="index", columns = colnames)
	return q_df

q_colnames = ["similarity", "suppression", "novel interaction"]

q_A, q_U, q_D = [q_col(df_A, q_colnames), q_col(df_U, q_colnames), q_col(df_D, q_colnames)]
qdf_col = [[q_U, "Upregulated DEGs"], [q_D, "Downregulated DEGs"]]

def plot_q_subplots(df, outerax):
	ax = [axe[x] for x in outerax]
	#plt.suptitle(title, fontsize=14)
	for i, axis in enumerate(ax):
		if q_colnames[i] == "suppression":
			sns.heatmap(df[q_colnames[i]].to_frame().transpose(), cmap='coolwarm', ax=axis)
		else:
			sns.heatmap(df[q_colnames[i]].to_frame().transpose(), cmap='Blues', ax=axis)
		axis.set_yticklabels([q_colnames[i]], rotation=0)
		cbar = axis.collections[0].colorbar
		minval = round(df[q_colnames[i]].min(),2)
		maxval = round(df[q_colnames[i]].max(),2)
	
		while round(minval*100,2) % 5 != 0:
			minval += 0.01
		while round(maxval*100,2) % 5 != 0:
			maxval -= 0.01
		cbar.set_ticks([minval, maxval])

def plot_reg(df, outerax):
	labels = []
	statscol = []
	for i, col in enumerate(q_colnames[:-1]):
		for j, col2 in enumerate(q_colnames[i+1:]):
			outerax.scatter(col, col2, data=df)
			m, c, r_value, p_value, std_err = stats.linregress(df[col], df[col2])
			statscol.append([m, c, r_value, p_value, std_err])
			# m, c = np.polyfit(df[col], df[col2], 1)
			outerax.plot(df[col], m*df[col] + c)
			labels.append(col[:3] + ' v ' + col2[:3] + ' ($\mathregular{R^{2}}$: '+str(round(r_value**2,1))+', p: ' + str('{:.2f}'.format(round(p_value,2))+')'))
	outerax.legend(labels, fontsize="x-small")
	return statscol
	
def dum_venn(a_b, c_a, c_b, c_ab, ax, col, title, ac=20, bc=20, cc=20):
	'''

	Parameters
	----------
	a_b : int
		Size of A&B.
	c_a : int
		Size of A&C-B.
	c_b : int
		Size of B&C-A.
	c_ab : int
		Szie of C&(A&B).
	ax : axis handle
		Axis handle of subplot to plot into.
	col : list
		List containing lists of patch id and corresponding colour.
	title:
		
	ac : int, optional
		Size of set a. The default is 20.
	bc : int, optional
		Size of set b. The default is 20.
	cc : int, optional
		Size of set c. The default is 20.

	Returns
	-------
	None.

	'''
	
# =============================================================================
# 	a_b = 8 # A&B
# 	c_a_b = 4 # C&A-B/ C&B-A
# 	c_ab = 3   # C&(A&B)
# =============================================================================
	dum = list(string.ascii_uppercase + string.ascii_lowercase)
	random.shuffle(dum)
	a = set(dum[:20])
	b = set(list(a)[:a_b] + [x for x in dum if x not in a][:bc-a_b])
	c = set(list(a-b)[:c_a] +
		 list(b-a)[:c_b] + list(a&b)[:c_ab] +
		 [x for x in dum if x not in a and x not in b][:cc-c_a-c_b-c_ab])

	v = venn3([a, b, c],
	   ('A', 'B', 'AB'),
	   ax = ax) # ax = axis
	venn3_circles([a, b, c], linewidth=1, color='k', ax=ax)
	for i in col:
		v.get_patch_by_id(i[0]).set_color(i[1])
	for idx, subset in enumerate(v.subset_labels):
		v.subset_labels[idx].set_visible(False)
	ax.set_title(title, fontsize=16)

# =============================================================================
# #
# # Initialising subplot
# #
# =============================================================================
#figsize=(figw, figh)

top_mosaic = [["v1", "v2", "v3"]]
eq_mosaic = [	["e1", "e2", "e3"]	]
middle_mosaic = [
	["u1", "d1"],
	["u2", "d2"],
	["u3", "d3"]
]
bottom_mosaic = [["r1", "r2"]]

figw, figh = 11, 9
fig = plt.figure(figsize=(figw, figh))
axc = fig.subplot_mosaic(
	top_mosaic,
	gridspec_kw={
		"bottom": 0.75,
		"top": 1,
		#"wspace": 0.5,
		#"hspace": 0.5,
		}
	)
axd = fig.subplot_mosaic(
	eq_mosaic,
	gridspec_kw={
		"bottom": 0.55,
		"top": 0.8,
		#"wspace": 0.5,
		#"hspace": 0.5,
		}
	)

axe = fig.subplot_mosaic(
	middle_mosaic,
	gridspec_kw={
		"bottom": 0.38,
		"top": 0.6,
		#"wspace": 0.5,
		"hspace": 0.2,
		}
	)
axf = fig.subplot_mosaic(
	bottom_mosaic,
	gridspec_kw={
		"bottom": 0,
		"top": 0.3,
		#"wspace": 0.5,
		#"hspace": 0.5,
		}
	)

for axy in ['e1', 'e2', 'e3']:
	axd[axy].axis('off')
for axy in ["v1", "v2", "v3"]:
	axc[axy].set_anchor('N')

axd['e1'].text(0.39, 0.45, r"$\frac{A \cap B}{A \cup B}$", fontsize=20)
axd['e2'].text(0.04, 0.45, r"$\frac{A-B-AB}{A}-\frac{B-A-AB}{B}$", fontsize=20)
axd['e3'].text(0.29, 0.45, r"$\frac{AB-A-B}{AB}$", fontsize=20)

for seq, x in enumerate(qdf_col):
	df_type, title = x[0], x[1]
	plot_q_subplots(df_type, [x[seq] for x in middle_mosaic])
for axy in ['u1', 'u2', 'd1', 'd2']:
	axe[axy].set_xticklabels([])
	axe[axy].xaxis.set_visible(False)
for axy in ['d1', 'd2', 'd3']:
	axe[axy].set_yticklabels([])
	axe[axy].yaxis.set_visible(False)

v1col = [['100', 'white'], ['110', 'limegreen'],
		 ['101', 'white'], ['111', 'limegreen'],
		 ['010', 'white'], ['011', 'white'],
		 ['001', 'white']]
v2col = [['100', 'red'], ['110', 'white'],
		 ['101', 'white'], ['111', 'white'],
		 ['010', 'cornflowerblue'], ['011', 'white'],
		 ['001', 'white']]
v3col = [['100', 'white'], ['110', 'white'],
		 ['101', 'white'], ['111', 'white'],
		 ['010', 'white'], ['011', 'white'],
		 ['001', 'darkorchid']]
dum_venn(a_b=8,c_a=4, c_b = 4,c_ab=3,
		 ax=axc['v1'], col=v1col, title='Similarity')
dum_venn(a_b=8, c_a=10, c_b=6, c_ab=3,
		 ax=axc['v2'], col=v2col, title='Suppression')
dum_venn(a_b=8, c_a=4, c_b=4, c_ab=3,
		 ax=axc['v3'], col=v3col, title='Novel interaction')

reg_stats = []
for i, x in enumerate(qdf_col):
	df_type, title = x[0], x[1]
	df_abs = df_type[:]
	df_abs.suppression = abs(df_abs.suppression)
	reg_stats.append(plot_reg(df_abs, axf[bottom_mosaic[0][i]]))

plt.savefig(dir_path+'figures/Fig5A_E.png', dpi=600)

In [None]:
# Fig 5F refer to l2_en_jaccard_hm.png
# =============================================================================
# 
#  Summary: What is the dominant effect of each stress?
# 
# =============================================================================
def sum_df(df):
	dicto_sum = {}
	for ss in single:
		ori = [x for x in cross if ss in x]
		relcross = [y for x in cross if ss in x for y in x if ss not in y]
		sim = [ji_cal(df.loc[x[0]].gene, df.loc[x[1]].gene) for x in ori]
		nov = [novel(df.loc[x[0]].gene, df.loc[x[1]].gene, df.loc[x].gene) for x in ori]
		sup = [suppInX(df.loc[ss].gene, df.loc[x].gene, df.loc[ori[i]].gene) for i, x in enumerate(relcross)]
		for i, x in enumerate(relcross):
			dicto_sum[ss+x] = [ss, sim[i], sup[i], nov[i]]
			
	df_sum= pd.DataFrame.from_dict(dicto_sum, orient='index', columns=['stress', 'similarity', 'suppression', 'novel'])
	return df_sum

cond=['similarity', 'suppression', 'novel']
df_list = [sum_df(df_U), sum_df(df_D)]

# figs, axs = plt.subplots(3,2,
# 						 sharex=True,
# 						 sharey='row',
# 						 constrained_layout=True,
# 						 figsize=(7,6))

# for i, x in enumerate(df_list):
# 	for j, y in enumerate(cond):
# 		sns.violinplot(x='stress', y= y, data=x, ax= axs[j][i])
# for k in range(3):
# 	axs[k][1].set_ylabel('')
# for l in range(2):
# 	for m in range(2):
# 		axs[l][m].set_xlabel('')
# axs[0][0].set_title('Upregulated', fontsize=14)
# axs[0][1].set_title('Downregulated', fontsize=14)
# plt.savefig(odir + 'venn_sum.png', dpi=600)

# =============================================================================
# 
# Enrichment
# 
# =============================================================================

from statsmodels.stats.multitest import multipletests
import math
import numpy as np

ori_count = Counter([y for x in list(meranno.values()) for y in x])
mapbins = list(dicto.values())

def sig_df(df, sigcol, merdict, mapbins):
	"""
	Calculates and correct mapman bin enrichment p-value for all stresses
	Returns dataframe

	Parameters
	----------
	df : dataframe
		df containing genes and corresponding mapman bins of DEGs.
	sigcol : str
		column name to use for enrichment
	merdict : dict
		corresponding dictionary of mapman annotation/ 2nd level bins to use
	mapbins : list
		list of mapman annotation/bins to use

	Returns
	-------
	df_sig : dataframe
		df summarising enrichment (corrected p-value) for each mapman bin (row)
		and each stress (column).

	"""
	sig_sum = {}
	for s in all_s:
		s_count = Counter(df.loc[s][sigcol])
		valid_bins = list(s_count.keys()) # bins found in stress
		# initialise count dicitonary
		sig_count = {}
		for key in valid_bins:
			sig_count[key] = 1
		# random simulations
		for i in range(1000):
			shuffle = list(merdict.values())
			random.shuffle(shuffle)
			sub = shuffle[:len(df.loc[s].gene)]
			sub_count = Counter([y for x in sub for y in x])
			for mapman in valid_bins:
				if sub_count[mapman] >= s_count[mapman]:
					sig_count[mapman] += 1
		# p-value calculation
		pval_coll = []
		for mapman in valid_bins:
			pval = sig_count[mapman]/1000
			# correction for pval > 1
			if pval <= 1:
				pval_coll.append(pval)
			else:
				pval_coll.append(float(round(pval)))
		
		# BH correction for multiple testing
		y = multipletests(pvals=pval_coll, alpha=0.05, method="fdr_bh")[1]
		all_bins_corr_pval = []
		for mapman in mapbins:
			if mapman in valid_bins:
				all_bins_corr_pval.append(y[valid_bins.index(mapman)])
			else:
				all_bins_corr_pval.append(None)
		sig_sum[s] = all_bins_corr_pval

	df_sig = pd.DataFrame.from_dict(sig_sum, orient='index', columns=mapbins)
	return df_sig
	
def chunk(uval, dval):
	if math.isnan(uval) and math.isnan(dval):
		# not differentially regulated
		cat = 0
	elif  uval >= 0.05 and (dval >= 0.05 or math.isnan(dval)):
		# not enriched
		cat = 0
	elif  dval >= 0.05 and (uval >= 0.05 or math.isnan(uval)):
		# not enriched
		cat = 0
	elif  uval < 0.05 and dval < 0.05:
		# differentially up and downregulated in bin
		cat = 2
	elif dval < 0.05:
		# differentially downregualted
		cat = 1
	elif uval < 0.05:
		# differentially upregulated
		cat = 3
	return cat

from matplotlib.colors import ListedColormap
cmap = ListedColormap(["lightgray", "royalblue", "violet", "firebrick"])
catno = 4
cbarticks = [(x/(catno*2))*(catno-1) for x in range(1,catno*2,2)]

# =============================================================================
# 
# Enrichment (Part 2: 2nd level Mapman)
# 
# =============================================================================
mapbins2 = list(set([y for x in list(merbin.values()) for y in x]))
mapbins2.sort(key=lambda x: (int(x.split('.')[0]), int(x.split('.')[1])))

df_sig_U2 = sig_df(df_U, 'mapbin2', merbin, mapbins2)
df_sig_D2 = sig_df(df_D, 'mapbin2', merbin, mapbins2)
df_sig_U2 = df_sig_U2.fillna(value=np.nan)
df_sig_D2 = df_sig_D2.fillna(value=np.nan)

cat_dict2 = {}
for mapman in list(df_sig_U2.columns):
	cat_col = []
	for stress in list(df_sig_U2.index):
		uval, dval = df_sig_U2.loc[stress, mapman], df_sig_D2.loc[stress, mapman]
		cat_col.append(chunk(uval,dval))
	cat_dict2[mapman] = cat_col

df_combined_sig2 = pd.DataFrame.from_dict(cat_dict2, orient='index', columns=list(df_sig_U2.index))
df_combined_sig2 = df_combined_sig2.loc[df_combined_sig2.max(axis=1) > 0,:]
df_combined_sig2 = df_combined_sig2.loc[(df_combined_sig2 > 0).sum(axis=1) >2,:]
df_combined_sig2.reset_index(inplace=True)
df_combined_sig2['index'] = df_combined_sig2['index'].apply(lambda x: map2anno[x])
df_combined_sig2.set_index('index', inplace=True)

# =============================================================================
# 
# Plotting 2nd level mapman enrichment (df_combined_sig2)
# 
# =============================================================================
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def jdistprep(df, axis):
	'''
	Convert df to sets (for calculation of JD of X axis)

	Parameters
	----------
	df : dataframe
		dataframe of categorical variables to be converted to sets.
	axis : int 
		axis to do sets on, 0 by column (default), 1 by row

	Returns
	-------
	dicto : dict
		dictionary containing list of column values.

	'''
	if axis == 1:
		df = df.T
	dxkeys = df.columns.to_list()
	dykeys = df.index.to_list()
	dicto = {}
	for col in dxkeys:
		dicto[col] = [dykeys[i] + '_' + str(x) for i, x in enumerate(df[col].to_list())]
	return [dicto, dxkeys]

def jdist(df, axis=0):
	'''
	Construct jaccard distance square matrix

	Parameters
	----------
	df : df
		dataframe to be used for jiprep/ jdist calculation.
	axis : int
		axis to do sets on, 0 by column (default), 1 by row

	Returns
	-------
	linkage_matrix : list
		condensed jaccard distance matrix.
	jlist : list
		list of list (jaccard distance square matrix)
	dicto : dict
		dictionary of list

	'''
	dicto, dxkeys = jdistprep(df, axis)
	jlist = []
	for key in dxkeys:
		col = []
		for key2 in dxkeys:
			set1, set2 = dicto[key], dicto[key2]
			set1x = set([x for x in set1 if x.split('_')[1] != '0'])
			set2x = set([x for x in set2 if x.split('_')[1] != '0'])
			col.append(1 - ji_cal(set1x, set2x))
		jlist.append(col)
	dists = squareform(jlist)
	linkage_matrix = linkage(dists, "single")
	return linkage_matrix, jlist, dicto

def plot_dendro(linkage_matrix, ax, orient):
	'''
	Plots dendrogram into subplot

	Parameters
	----------
	mat : list of lists
		Contains the square matrix of jaccard distances.
	ax : axes
		axis of subplot to plot to.
	orient : str
		orientation of dendrogram to be plotted.

	Returns
	-------
	None.

	'''
	
	dendrogram(linkage_matrix, no_labels=True, ax=ax, orientation=orient, color_threshold=0, above_threshold_color='#000000')

xmat, xlist, xdict = jdist(df_combined_sig2)
ymat, ylist , ydict = jdist(df_combined_sig2, axis=1)

yden = dendrogram(ymat, labels=df_combined_sig2.index.to_list(), orientation='left') #, color_threshold=0, above_threshold_color='#000000'
plt.show()
xden = dendrogram(xmat, labels=df_combined_sig2.columns.to_list(), orientation='top') #, color_threshold=0, above_threshold_color='#000000'
plt.show()

yorder = yden['ivl']
xorder = xden['ivl']
df_sig2_reordered = df_combined_sig2[xorder]
df_sig2_reordered = df_sig2_reordered.reindex(yorder[::-1])

fig, ax = plt.subplots(2,2,
					   figsize=(7.5,8.5), # (width, height)
					   constrained_layout=True,
					   gridspec_kw={'width_ratios': [1.5, 5],'height_ratios': [1, 5]}) # constrained_layout=True,
ax0, ax1, ax2, ax3 = ax.flatten()
for i in [ax0, ax1, ax2]:
	i.axis('off')

plot_dendro(xmat, ax1, 'top')
plot_dendro(ymat, ax2, 'left')
# heatmap, tick and tick labels
hplot = ax3.imshow(df_sig2_reordered, cmap=cmap)
ax3.yaxis.tick_right()
ax3.set_ylabel("")
ax3.set_xticks(np.arange(0, len(df_sig2_reordered.columns), 1))
ax3.set_yticks(np.arange(0, len(df_sig2_reordered), 1))


xcolour = ['k'] + ['firebrick']*4 + ['gray']*6 + ['mediumseagreen']*4 + ['k']*2 + ['darkorange']*3 + ['k']*2 + ['royalblue']*3
ax3.set_xticklabels(df_sig2_reordered.columns.to_list(), rotation=90)

for i, tick_label in enumerate(ax3.get_xticklabels()):
	tick_text = tick_label.get_text()
	tick_label.set_color(xcolour[i])
	
anno_long = ['annotated', 'cellulose', 'biosynthesis', 'hemicellulose', 'pectin', 'channels', 'degradation']
ax3.set_yticklabels([(lambda x: x.split('.')[1].lower() if x.split('.')[1] not in anno_long else x.lower())(x) for x in df_sig2_reordered.index.to_list()])
# cbar plotting and control
axins = inset_axes(ax0,
					width="40%",  # width = 50% of parent_bbox width
					height="90%",  # height : 5%
					loc = 'center')
cbar = fig.colorbar(hplot, cax=axins, ticks = cbarticks)
cbar.ax.set_yticklabels(['N', 'D', 'UD', 'U'])

plt.savefig(dir_path+'figures/Fig5F.png', dpi=600, bbox_inches='tight') # no N

### Figure 6: Diurnal gene expression

In [None]:
#Fig 6 A to D, adapted from Mpo_panel1.py
from scipy.stats import zscore

wdir = dir_path + "diurnal/"
Mpodf = pd.read_csv(wdir + "Mpo_supp.txt", sep = "\t", index_col = 0)

Mpo_exp_only = Mpodf[Mpodf.LAG != "NE"]
Mpo_rhy_only = Mpo_exp_only[Mpo_exp_only.LAG != "NR"]

perall = (len(Mpo_rhy_only) / len(Mpodf))*100
perexp = (len(Mpo_rhy_only) / len(Mpo_exp_only))*100

# Subset rhythmic genes only
Mpo_rhy_only.LAG = Mpo_rhy_only.LAG.astype(int)
Mpo_rhy_only["ADJ.P"] = Mpo_rhy_only["ADJ.P"].astype(np.float)
Mpo_rhy_only.sort_values(["LAG", "ADJ.P"], inplace=True)

# normalisation of rhythmic gene expression
rhy_zscore = Mpo_rhy_only[Mpo_rhy_only.columns.to_list()[Mpo_rhy_only.columns.to_list().index("ZT2_1"):]]
rhy_zscore = rhy_zscore.transpose()
rhy_zscore = rhy_zscore.apply(zscore)
rhy_zscore = rhy_zscore.transpose()
plt.figure(figsize=(10,20))


"""
Panel 1a) plot
"""
colnames = [x.split("_")[0] for x in rhy_zscore.columns.to_list()]
condcoldict = {}
for x in list(set(colnames)):
	if int(x.split("ZT")[1]) < 12:
		condcoldict[x] = (255,255,153) #"khaki"
	else:
		condcoldict[x] = (160,160,160) #"lightslategrey"
		
condcol = np.array([[condcoldict[x] for x in colnames]])

fig, ax = plt.subplots(2,1,
					   figsize=(4,8), # (width, height)
					   gridspec_kw={'height_ratios': [0.091, 3.9]})
fig.subplots_adjust(hspace=0.01)
ax1, ax2= ax.flatten()

ax1.imshow(condcol)
# Set gridlines
ax1.set_xticks(np.arange(-.5, 18, 3))
ax1.set_yticks(np.arange(-.5, 1, 1))
ax1.grid(color='k', linestyle='-', linewidth=1)
ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax1.xaxis.set_ticks_position('none')
ax1.yaxis.set_ticks_position('none')
ax1.set_anchor('W')

sns.heatmap(rhy_zscore, cmap = "coolwarm", ax = ax2, yticklabels=False, xticklabels=True, center=0)
ax2.set_ylabel("Genes")
plt.savefig(dir_path + 'figures/Fig6A.png', dpi = 600)
plt.show()

"""
Panel 1B) plot
"""
rhycount = len(Mpo_rhy_only)
LAGcount = Mpo_rhy_only.groupby("LAG").count().annotation.to_frame()
LAGcount.columns = ["count"]
LAGcount["percent"] = LAGcount.apply(lambda x: (x/rhycount)*100)
g = LAGcount.percent.plot(xticks = LAGcount.index.to_list(),
						  yticks = [0, 18],
						  ylim = [0,18],
						  ylabel = "% rhythmic genes",
						  xlabel = "Phase",
						  color = "k")
g.axvline(12, color = "k")
plt.savefig(dir_path + 'figures/Fig6B.png', dpi = 600)
plt.show()

"""
Panel 1C) plot
"""
fp = dir_path + "diurnal/Ferrari_2019/SD14_compat.txt"
qp = dir_path + "diurnal/Ferrari_2019/OF_20210623_compat.tsv"
camortho = pd.read_csv(fp, sep="\t", index_col=0)
camortho.Ath = camortho.Ath.str.upper()
qortho = pd.read_csv(qp, sep="\t", index_col=0)
qortho.Osa = qortho.Osa.str.replace("\.[0-9]*", "", regex=True)

camgrps = ["OG0000156", "OG0000215", "OG0000679", "OG0004739", "OG0004944", "OG0005370"]
qgrps = ["OG0000167", "OG0000301", "OG0000399", "OG0003516", "OG0004855", "OG0004502"]

cgrpdict = {"OG0000156":"OG0000167 (Cyclin A, B)",
			"OG0000215":"OG0000301 (Cyclin D)",
			"OG0000679":"OG0000399 (CDK)",
			"OG0004739":"OG0003516 (Timeless)",
			"OG0004944":"OG0004855 (DNA primase)",
			"OG0005370":"OG0004502 (DNA polymerase)"}

mpogrpgenes = {}
for i in range(len(camgrps)):
	testc = camortho.loc[camgrps[i],:].to_list()
	testc = [x.split(", ") for x in testc if type(x) != float]
	testcs = [x for a in testc for x in a]
	testq = qortho.loc[qgrps[i],:].to_list()
	testq = [x.split(", ") for x in testq if type(x) != float]
	testqs = [x for a in testq for x in a]
	c_s = set(testcs) - set(testqs)
	print(camgrps[i])
	print(str(list(c_s)))
	q_s = list(set(testqs) - set(testcs))
	mpogrpgenes[cgrpdict[camgrps[i]]] = [x for x in q_s if "Mp" in x]
	
mpocyclegenes = [x for a in list(mpogrpgenes.values()) for x in a if x in rhy_zscore.index]
timepoints = ["ZT2", "ZT6", "ZT10", "ZT14", "ZT18", "ZT22"]
# dataframe normalised timepoints for mpocyclegenes only
cycle_zscore = rhy_zscore.loc[mpocyclegenes].transpose()
# dictionary of mpocyclegenes and their corresponding OG information
cycle_grp = {}
for k, v in mpogrpgenes.items():
	for item in v:
		cycle_grp[item] = k
qwOGgrps = list(mpogrpgenes.keys())
# to transpose and create new df that contains the average zscore of replicates
cycle_dict = {}
for t in timepoints:
	cycle_dict[t] = cycle_zscore.loc[[t+"_1", t+"_2", t+"_3"],:].mean().to_list()
cycle_df = pd.DataFrame(cycle_dict, columns = timepoints, index = [x + ": " + cycle_grp[x] for x in cycle_zscore.columns.to_list()]).transpose()

colours = ["maroon", "orangered", "forestgreen", "midnightblue", "mediumorchid", "steelblue", "darkseagreen"]

cycle_cols = {}
for keys in cycle_grp.keys():
	cycle_cols[keys + ": " + cycle_grp[keys]] = colours[qwOGgrps.index(cycle_grp[keys])]

cycle_df.plot().legend(bbox_to_anchor=(0.81, -0.1))
cycle_df.plot(color = cycle_cols).legend(bbox_to_anchor=(0.81, -0.1))
plt.savefig(dir_path + 'figures/Fig6C.png',
			dpi = 600,
			bbox_inches='tight')

"""
Panel 1D) Mercator by phase
"""
import seaborn as sns

Mpo_rhy_only["MapMan bins"] = Mpo_rhy_only.apply(lambda x: x.annotation.split(".")[0].capitalize(), axis=1)
mer_grouped = Mpo_rhy_only.groupby(["MapMan bins", "LAG"]).count().annotation.unstack(fill_value=0)
phases = mer_grouped.columns.to_list()
binsum = mer_grouped.sum(axis=1)
for phase in phases:
	mer_grouped[phase] = mer_grouped[phase]/binsum
g = sns.clustermap(mer_grouped)
plt.show()
g_ytick = [str(x).split("'")[1] for x in g.ax_heatmap.get_yticklabels()]
drow = g.dendrogram_row.linkage
mer_grouped_reordered = mer_grouped.reindex(g_ytick)
sns.heatmap(mer_grouped_reordered, yticklabels=True)
mer_grouped.transpose().plot().legend(bbox_to_anchor=(0.72, -0.1))

# subplots with linkage
mer_grouped_dendro = mer_grouped.reindex(g_ytick[::-1])
from scipy.cluster.hierarchy import dendrogram
figii, axii = plt.subplots(1,2,
					   figsize=(10,6), # (width, height
					   constrained_layout=True,
					   gridspec_kw={'width_ratios': [1.9, 8.1]})
ax1ii, ax2ii= axii.flatten()
ax1ii.axis("off")
dendrogram(drow, no_labels=True, ax=ax1ii, orientation='left', color_threshold=0, above_threshold_color='#000000')
sns.heatmap(mer_grouped_dendro, yticklabels=True, ax = ax2ii)
ax2ii.set_ylabel("")
plt.show()

# to plot heatmap by chunks
def chunk(num):
	if  num == 0:
		cat = 0
	elif num < 0.1:
		cat = 1
	elif num < 0.2:
		cat = 2
	elif num < 0.3:
		cat = 3
	elif num < 0.4:
		cat = 4
	else:
		cat = 5
	return cat

mer_grouped_chunk = mer_grouped.reindex(g_ytick)
for col in mer_grouped_chunk:
	mer_grouped_chunk[col] = mer_grouped_chunk[col].apply(lambda x: chunk(x))
sns.heatmap(mer_grouped_chunk, yticklabels=True)

"""
Mercator count binned by percentage with custom colormap
"""
from matplotlib.colors import ListedColormap
figm, (axm1, axm2) = plt.subplots(1,2,
								  figsize=(6.3,6), # (width, height
								  constrained_layout=True,
								  gridspec_kw={'width_ratios': [1.6, 8.4]})

axm1.axis("off")
dendrogram(drow,
		   no_labels=True,
		   ax=axm1, orientation='left',
		   color_threshold=0,
		   above_threshold_color='#000000')
cmap = ListedColormap(["gray", "lightsteelblue", "lightgreen", "palegoldenrod", "coral", "indianred"])
mplot = axm2.imshow(mer_grouped_chunk, cmap=cmap, interpolation="none")

axm2.set_xticks(np.arange(0, len(mer_grouped_chunk.columns), 1))
axm2.set_yticks(np.arange(0, len(mer_grouped_chunk), 1))
axm2.set_xticklabels(mer_grouped_chunk.columns.to_list())
axm2.set_yticklabels(mer_grouped_chunk.index.to_list())
cbar = figm.colorbar(mplot,
					 ax=axm2,
					 ticks = [x/12 for x in np.arange(5,6*10,10)],
					 label="% rhythmic genes in bin")
#cbar = fig.colorbar(cax, ticks=[-1, 0, 1])
cbar.ax.set_yticklabels(['None', '0-9%', '10-19%', '20-29%', '30-39%', '>40%'])  # vertically oriented colorbar
plt.savefig(dir_path + 'figures/Fig6D.png', dpi = 600)

In [None]:
# Fig 6E and F, adapted from 1to1ortho.py
Mpo_exp_only = Mpodf[Mpodf.LAG != "NE"]
Mpo_rhy_only = Mpo_exp_only[Mpo_exp_only.LAG != "NR"]
Mpo_rhy_genes = Mpo_rhy_only.index.to_list()

species = ["Cpa", "Ppu", "Cre", "Kni", "Ppa", "Smo", "Pab", "Osa", "Ath"]
night = [8, 12, 12, 12, 8, 12, 8, 6, 6]
daylength = [16, 12, 12, 12, 16, 12, 16, 6, 6]
freq = [1, 1, 1, 1, 1, 1, 1, 2, 2]
odir = dir_path + "diurnal/Orthologues_Mpo/"
ofiles = ["Mpo__v__" + x + ".tsv" for x in species]

camdir = dir_path + "diurnal/"
camfiles = [x + "_supp.txt" for x in species]

### FUNCTION ###
def lag_diff(a, b):
	"""

	Parameters
	----------
	a : int
		LAG value of species X.
	b : int
		LAG value of Mpo.

	Returns
	-------
	diff : int
		smallest LAG diff.

	"""
	diff = a - b
	if abs(diff) > 12:
		if diff < 0:
			diff = diff + 24
		else:
			diff = diff - 24
	return diff
### END ###

f_axes = string.ascii_uppercase[:len(species)]
d_axes = string.ascii_lowercase[:len(species)]
axd = plt.figure(constrained_layout=True,
				 figsize=(27,6)).subplot_mosaic(
    """
	abcdefghi
	ABCDEFGHI
    """
)

#for spe in species:
# get 1 to 1 orthologue
for z in range(len(species)):
	ol_df = pd.read_csv(odir + ofiles[z], sep="\t", index_col=0)
	ol_df = ol_df[ol_df.Mpo.apply(lambda row: len(row.split(", ")) ==1) & ol_df[species[z]].apply(lambda row: len(row.split(", ")) ==1)]
	if species[z] == "Osa":
		osa_dict = {}
		osa_genes = ol_df.Osa.to_list()
		for gene in osa_genes:
			osa_dict[gene] = gene.split(".")[0]
		ol_df.Osa.replace(osa_dict, inplace=True)	
	# get LAGs
	cam_f = pd.read_csv(camdir + camfiles[z], sep="\t", index_col=0)
	# Mpo LAG
	ol_df["Mpo_LAG"] = ol_df.apply(lambda row: Mpodf.LAG.loc[row.Mpo], axis=1)
	ol_df[species[z] + "_LAG"] = ol_df.apply(lambda row: cam_f.phase.get(row[species[z]], None), axis=1)
	# exclude NE and NR in either Mpo or species[z] LAG   
	for i in ["Mpo", species[z]]:
		for j in ["NE", "NR"]:
			ol_df = ol_df[ol_df[i + "_LAG"] != j]
		ol_df = ol_df[ol_df[i + "_LAG"].notna()]
	# correcting LAG value 24 to 0 and converting to numeric
	ol_df[species[z] + "_LAG"] = ol_df[species[z] + "_LAG"].replace({"24":"0"})
	ol_df.Mpo_LAG = pd.to_numeric(ol_df.Mpo_LAG)
	ol_df[species[z] + "_LAG"] = pd.to_numeric(ol_df[species[z] + "_LAG"])
	
	# calculating smallest lag diff
	ol_df["LAG_diff"] = ol_df.apply(lambda row: lag_diff(row[species[z] + "_LAG"], row.Mpo_LAG), axis=1)
	diff_ser = ol_df.groupby("LAG_diff").count().Mpo_LAG
	max_diff = diff_ser[diff_ser == diff_ser.max()].index.to_list()
	
	# Plot LAG diff
	axd[f_axes[z]].set_xticks(np.arange(-.5,len(ol_df.LAG_diff.unique())-1))
	diff_ser_index = diff_ser.index.to_list()
	diff_xticks = []
	if len(diff_ser_index) > 13:
		for i in range(len(diff_ser_index)):
			if i%2 == 0:
				diff_xticks.append(str(diff_ser_index[i]))
			else:
				diff_xticks.append("")
	else:
		diff_xticks = diff_ser_index

	sns.histplot(ol_df.LAG_diff,
			  #x=diff_xticks,
			  bins=len(ol_df.LAG_diff.unique()),
			  kde=True,
			  ax=axd[d_axes[z]],)
	if z != 0:
		axd[d_axes[z]].set_ylabel("")
	elif z == 0:
		axd[d_axes[z]].set_ylabel("Count", fontsize=14)
	axd[d_axes[z]].set_xlabel("")

	# Plot for 1 to 1 ortho
	mpo_tp = list(range(0,24,2)) # x-axis
	full_other_tp = list(ol_df[species[z] + "_LAG"].unique())
	other_tp = list(range(0,24,freq[z])) # y-axis
	sum_dict = {}
	for o in other_tp:
		o_col = []
		for m in mpo_tp:
			o_col.append(sum((ol_df.Mpo_LAG == m) & (ol_df[species[z] + "_LAG"] == o)))
		sum_dict["ZT" + str(o)] = o_col
	sum_df = pd.DataFrame(sum_dict, columns = ["ZT" + str(y) for y in other_tp], index = ["ZT" + str(x) for x in mpo_tp])
	
	axd[f_axes[z]].imshow(sum_df, cmap="Blues", aspect="auto")
	axd[f_axes[z]].set_xticks(np.arange(-.5,len(other_tp)-1))
	axd[f_axes[z]].set_yticks(np.arange(-.5,len(mpo_tp)-1))
	xticklist = [other_tp[0]] + ["" for x in range(0,daylength[z]-1)] + [other_tp[daylength[z]]] + ["" for x in range(0,night[z]-2)] + [other_tp[-1]]
	yticklist = [mpo_tp[0]] + ["" for x in range(0,6-1)] + [mpo_tp[6]] + ["" for x in range(0,6-2)] + [mpo_tp[-1]]
	axd[f_axes[z]].set_xticklabels(xticklist)
	axd[f_axes[z]].set_yticklabels(yticklist)
	axd[f_axes[z]].axvline(daylength[z]-0.5, color="k")
	axd[f_axes[z]].axhline(6-0.5, color="k")
	axd[f_axes[z]].set_xlabel(species[z], fontsize=14)
	if z != 0:
		axd[f_axes[z]].set_yticks([])
		axd[f_axes[z]].set_yticklabels([])
	elif z == 0:
		axd[f_axes[z]].set_ylabel("Mpo", fontsize=14)
plt.savefig(dir_path + 'figures/Fig6E_F.png', dpi = 600)

### Supp. Fig 2: QC of RNA-seq data

In [None]:
# adapted from QC_scaled_updated.py
from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr

o_dir = dir_path + 'figures/'
sumdir = dir_path + 'summary_files/'

expdesc = ['all_stress', 'diurnal_exp', 'single_stress']
targetexp = expdesc[0]
targetp = sumdir + targetexp + '.txt'
expmatp = dir_path + 'prep_files/' + targetexp + '.tsv'
exps = [x.split("\t")[0] for x in open(targetp, "r").readlines()]
labels = [x.strip().split("\t")[1] + '_' + x.split("\t")[0].split('_')[1] for x in open(targetp, "r").readlines()]

df = pd.read_csv(expmatp, index_col = 0, sep = "\t", header = 0)
df.columns = labels

# Standard Scaling
scaled_features = StandardScaler().fit_transform(df.values)
df_scaled = pd.DataFrame(scaled_features, index = df.index, columns = df.columns)

# plot cluster map
sns.set(font_scale=1.6)

methods = "average"

g1 = sns.clustermap(df_scaled.corr(),
				 method = methods,
				 figsize=(20,20),
				 xticklabels=True,
				 yticklabels=True)
plt.title("All stress (scaled): " + methods)
plt.savefig(o_dir + "SuppFig2" + '.png')


# PCC of experiments
pcc_out = open(dir_path + "prep_files/mpo/all_stress_PCC.txt", "w+")
pcc_out.write("exp1\texp2\tpcc_val\tp_value\n")
exps = list(df_scaled.columns)
for exp1 in range(len(exps)):
	for exp2 in range(exp1):
		if exps[exp1].split("_")[0] == exps[exp2].split("_")[0]:
			pcc_val, p_value = pearsonr(df_scaled[exps[exp1]], df_scaled[exps[exp2]])
			pcc_out.write(exps[exp1] + "\t" + exps[exp2] + "\t" + str(pcc_val) + "\t" + str(p_value) + "\n")
pcc_out.close()

### Supp. Fig 3: Volcano plots (DESeq2)

In [None]:
# adapted from deseq_volcano.py

wdir = dir_path + 'prep_files/mpo/deseq/'
odir = wdir + 'volcano/'
deseqouts = [x for x in os.listdir(wdir) if "res.tsv" in x] # controlD2controlH2_res.tsv
control = 'controlD2controlH2_res.tsv'
deseqouts.pop(deseqouts.index(control))
all_stress = list(set([x.split('control')[0] for i, x in enumerate(deseqouts)]))
all_stress.sort()
s_stress = [x for x in all_stress if len(x) == 1]
c_stress = [x for x in all_stress if len(x) == 2]
all_stress = s_stress + c_stress

# plot control
controls = pd.read_csv(wdir + control,
			  sep = "\t", header = 0, index_col = 0)
sns.scatterplot(x = controls['log2FoldChange'],
				 y = -np.log10(controls["padj"]),
				 #ax = axs[axcord[0][0], axcord[0][1]],
				 alpha = 0.2,
				 marker = '.',
				 legend = False,
				 edgecolor = "none",
				 hue = np.logical_and(abs(controls['log2FoldChange']) > 1,
						  -np.log10(controls["padj"]) > -np.log10(0.05)))
plt.title("control " + control.split("control")[1] + " vs control H2")

# plot everything else
xlen = 5
ylen = math.ceil(len(deseqouts)/5)
fig, axs = plt.subplots(ylen, xlen, figsize=(30, 37.5), sharex='col', sharey='row')
#sns.set(font_scale=1.6)
axcord = []
for a in range(ylen):
	for b in range(xlen):
		axcord.append([a, b])

for i, z in enumerate(all_stress):
	files = [x for x in deseqouts if x.startswith(z+'control')]
	files.sort()
	fileD2 = pd.read_csv(wdir + files[0],
					  sep = "\t", header = 0, index_col = 0)
	fileH2 = pd.read_csv(wdir + files[1],
					  sep = "\t", header = 0, index_col = 0)
	D2ax = int(((i*2)//10)*10 + ((i*2)%10)/2)
	H2ax = int(D2ax + 5)
	# Volcano plots
	# against control D2
	sns.scatterplot(x = fileD2['log2FoldChange'],
				 y = -np.log10(fileD2["padj"]),
				 ax = axs[axcord[D2ax][0], axcord[D2ax][1]],
				 alpha = 0.2,
				 marker = '.',
				 legend = False,
				 edgecolor = "none",
				 hue = np.logical_and(abs(fileD2['log2FoldChange']) > 1,
						  -np.log10(fileD2["padj"]) > -np.log10(0.05)))
	axs[axcord[D2ax][0], axcord[D2ax][1]].set_title(z + " vs control D2")
	
	# against control H2
	sns.scatterplot(x= fileH2['log2FoldChange'],
				 y = -np.log10(fileH2["padj"]),
				 ax = axs[axcord[H2ax][0], axcord[H2ax][1]],
				 alpha = 0.2,
				 marker = '.',
				 legend = False,
				 edgecolor = "none",
				 hue = np.logical_and(abs(fileH2['log2FoldChange']) > 1,
						  -np.log10(fileH2["padj"]) > -np.log10(0.05)))
	axs[axcord[H2ax][0], axcord[H2ax][1]].set_title(z + " vs control H2")
	
for ax in axs.flat:
    ax.set(xlabel='log2FoldChange', ylabel='-log10 padj')

# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axs.flat:
    ax.label_outer()

plt.savefig(dir_path + "figures/SuppFig3.png", dpi = 600)

### Supp. Fig 8: Overview of diurnal data

In [None]:
# adapted from QC_scaled_updated.py

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

o_dir = dir_path + 'figures/'
sumdir = dir_path + 'summary_files/'

expdesc = ['all_stress', 'diurnal_exp', 'single_stress']
targetexp = expdesc[1]
targetp = sumdir + targetexp + '.txt'
expmatp = dir_path + 'prep_files/' + targetexp + '.tsv'
exps = [x.split("\t")[0] for x in open(targetp, "r").readlines()]
labels = [x.strip().split("\t")[1] + '_' + x.split("\t")[0].split('_')[1] for x in open(targetp, "r").readlines()]

df = pd.read_csv(expmatp, index_col = 0, sep = "\t", header = 0)
df.columns = labels

# =============================================================================
#
# PCA for diurnal by genes (Panel A)
#
# =============================================================================
pca = PCA(n_components=2)
diurnal_transformed = StandardScaler().fit_transform(df.values)
pcomp = pca.fit_transform(diurnal_transformed.T)
p_df = pd.DataFrame(data = pcomp, columns = ['PC1', 'PC2'])
pc1, pc2 = pca.explained_variance_ratio_

finalDf = p_df.copy()
finalDf['target'] = [x.split('_')[0] for x in df.columns.to_list()]
#plot PCA figure
PCA_plot = o_dir + 'SuppFig8A.png'

fig = plt.figure(figsize = (8,5))
ax = fig.add_subplot(1,1,1) 
ax.set_xlabel('Principal Component 1: ' + str(round(pc1, 2)), fontsize = 15)
ax.set_ylabel('Principal Component 2: ' + str(round(pc2, 2)), fontsize = 15)

targets = [x.split('_')[0] for i, x in enumerate(df.columns.to_list()) if i%3 == 0]
colors = ['r', 'y', 'g', 'b', 'c', 'm']
for target, color in zip(targets,colors):
    indicesToKeep = finalDf['target'] == target
    ax.scatter(finalDf.loc[indicesToKeep, 'PC1']
               , finalDf.loc[indicesToKeep, 'PC2']
               , c = color
               , s = 50
			   , alpha = 0.5
			   , edgecolors = 'k')

ax.legend(targets, loc = 'upper right', fontsize='xx-small')
ax.grid()

plt.savefig(PCA_plot, dpi=600)

# =============================================================================
#
# PCA for diurnal by genes (Panel B)
#
# =============================================================================

# JTK_output info
wdir = dir_path + "diurnal/"
Mpodf = pd.read_csv(wdir + "Mpo_supp.txt", sep = "\t", index_col = 0)

#PCA part
df['target'] = [Mpodf.loc[x].LAG for x in df.index.to_list()]
diurnal_filt = df[df.target != 'NE']
diurnal_scaled = StandardScaler().fit_transform(diurnal_filt.iloc[:,:-1].T.values)

pca = PCA(n_components=2)
pcomp = pca.fit_transform(diurnal_scaled.T)
p_df = pd.DataFrame(data = pcomp, columns = ['PC1', 'PC2'])
pc1, pc2 = pca.explained_variance_ratio_

finalDf = p_df.copy()
finalDf['target'] = diurnal_filt.target.to_list()

#plot PCA figure
PCA_plot = o_dir + 'SuppFig8B.png'

fig = plt.figure(figsize = (8,5))
ax = fig.add_subplot(1,1,1) 
ax.set_xlabel('Principal Component 1: ' + str(round(pc1, 2)), fontsize = 15)
ax.set_ylabel('Principal Component 2: ' + str(round(pc2, 2)), fontsize = 15)

targets = list(finalDf.target.unique())
targets.sort()
num_only = [int(x) for x in targets[:-1]]
num_only.sort()
new_targets = [str(x) for x in num_only] + [targets[-1]]
colors = [(1,1,0), (1,0.75,0), (1,0.5,0),
		  (1,0.25,0), (1,0,0.25), (1,0,0.5),
		  (1,0,0.75), (1,0,1), (0.75,0,1),
		  (0.5,0,1), (0.25,0,1), (0,0,1),
		  (0.75,0.75,0.75)]
for target, color in zip(new_targets,colors):
    indicesToKeep = finalDf['target'] == target
    ax.scatter(finalDf.loc[indicesToKeep, 'PC1']
               , finalDf.loc[indicesToKeep, 'PC2']
               , color = color
               , s = 50
			   , alpha = 0.5
			   , edgecolors = None)

ax.legend(new_targets, bbox_to_anchor=(1, 1), fontsize='x-small')
ax.grid()

plt.savefig(PCA_plot, dpi=600)


# 4. Experimental

### 2.1 Download RNA-seq experiments !experimental

In [None]:
kal_dir = dir_path + 'kal_out/'
def kal_index():

def get_ftp_links(RunID):
	'''(str)->(lst,str)
	Return ftp link in the paired and unpaired format for the RunID specified
	'''
	dir2 = ""
	if 9 < len(RunID) <= 12:
		dir2 = "0"*(12 - len(RunID)) + RunID[-(len(RunID)-9):] + "/"
		dirs = RunID[:6] + "/" + dir2 + RunID
		ftp_link_paired = [dirs + "/" + RunID + "_1.fastq.gz",
					 dirs + "/" + RunID + "_2.fastq.gz"]
		ftp_link_unpaired = dirs + "/" + RunID + ".fastq.gz"
	elif len(RunID) == 9:
		dirs = RunID[:6] + "/" + RunID
		ftp_link_paired = [dirs + "/" + RunID + "_1.fastq.gz",
					 dirs + "/" + RunID + "_2.fastq.gz"]
		ftp_link_unpaired = dirs + "/" + RunID + ".fastq.gz"
	return ftp_link_paired, ftp_link_unpaired

def kal_single(outname, index, SpotLen, flink):
	!kallisto quant -i $index -o $outname --single -l $SpotLen -s 20 -t 2 <(curl $flink)

def kal_paired(outname, index, flink1, flink2):
	!kallisto quant -i $index -o $outname -t 2 <(curl $flink1 $flink2)

# Download Rice experiments
kal_osa = kal_dir + 'osa/'
if not os.path.exists(kal_osa):
	!mkdir $kal_osa
RunTable = pd.read_csv(sum_dir + "selected_Osa.txt",
			  sep = "\t", header = 0)

for i in range(len(RunTable)):
	runid = RunTable["Run"][i]
	study = RunTable["Study"][i]
	liblay = RunTable["Layout"][i]
	spotlen = RunTable["Spot_length"][i]
	if study + "_" + runid not in completed:
		path_paired, path_single = get_ftp_links(runid)
		print(str(i) + "\t" + path_single.split("/")[-1].split(".fastq.gz")[0] + "\t" + liblay + "\n")
		if liblay == "SINGLE":
			kal_single(kal_osa + study+'_'+runid, osa_idx, spotlen, pathsingle)
		elif liblay == "PAIRED":
			kal_paired(kal_osa + study+'_'+runid, osa_idx, path_paired[0], path_paired[1])

### 2.2 Generate expression matrix !experimental

In [None]:
# Generation of gene expression matrix and kallisto statistics
def kal_extract(kout, exps):
	'''(str,list)->(dict,dict)
	Return dictionary containing tpm and raw expression value
	'''
	dicto = {}
	dicto_raw = {}

	output_header = 'gene\t'
	output_content = ''
	for folder in exps:
		filep = kout + folder + '/abundance.tsv'
		if os.path.exists(filep):
			print('In directory ' + folder)
			output_header += folder + '\t'
			content = open(filep, 'r')
			content.readline()
			for item in content:
				item, tpm = item.rstrip().split('\t')
				raw = str(round(float(values[-2])))
				if item in dicto:
					dicto[item].append(tpm)
				else:
					dicto[item] = [tpm]
				if item in dicto_raw:
					dicto_raw[item].append(raw)
				else:
					dicto_raw[item] = [raw]
	if '' in dicto:
		dicto.pop('')
	if '' in dicto_raw:
		dicto.pop('')
	return dicto, dicto_raw
 
def write_expmat(filepath, dicttouse):
	'''(str, dict)->(None)
	Writes expression matrix to file from dictionary
	'''
	with open(filepath, "w+") as output_file:
		output_file.write(output_header[:-1] + "\n")
		for key, value in dicttouse.items():
			line = ''
			line += key + '\t'
			for item in value:
				line += item + '\t'
			output_file.write(line[:-1] + "\n")

def kal_stats(kout):
	'''(str)->(None)
	Writes summary of mapping statistics of kallisto runs to file
	'''
	kal_dirs = [x for x in os.listdir(kout)]

	with open(kout + "kallisto_stats.txt", "w+") as output_file:
		output_file.write("experiment\tn_processed\tn_pseudoaligned\tn_unique\tp_pseudoaligned\tp_unique\n")
		for folder in kal_dirs:
			kallisto_json = ast.literal_eval(open(kout + folder + '/run_info.json', 'r').read())
			processed = kallisto_json["n_processed"]
			pseudoaligned = kallisto_json["n_pseudoaligned"]
			unique = kallisto_json["n_unique"]
			ppseudoaligned = kallisto_json["p_pseudoaligned"]
			punique = kallisto_json["p_unique"]
			output_file.write(folder + "\t" +
						str(processed) + "\t" +
						str(pseudoaligned) + "\t" +
						str(unique) + "\t" +
						str(ppseudoaligned) + "\t" +
						str(punique) + "\n")

# Marchantia
sum_dir = dir_path + 'summary_files/'

expdesc = ['all_stress', 'diurnal_exp', 'single_stress', 'cross_stress']
for targetexp in expdesc:
	targetp = sum_dir + targetexp + '.txt'
	expmatp = dir_path + 'prep_files/' + targetexp + '.tsv'
	expmatrawp = dir_path + 'prep_files/' + targetexp + '_raw.tsv'

	mpo_exps = [x.split("\t")[0] for x in open(targetp, "r").readlines()]
	mpo_tpm, mpo_raw = kal_extract(kal_dir + 'mpo/', mpo_exps)
	for i in [expmatp, expmatrawp]:
		write_expmat(i, mpo_tpm)
		write_expmat(i, mpo_raw)
kal_stats(kal_dir + 'mpo/', "kallisto_stats.txt")

# Rice
RunTable = pd.read_csv(sum_dir + "selected_Osa.txt",
			  sep = "\t", header = 0)
osa_runs = RunTable.Run.to_list()
osa_study = RunTable.Study.to_list()
osa_exps = [osa_study[i] + '_' + x for i, x in enumerate(osa_runs)]
expmatp = dir_path + 'prep_files/' + 'expmat_Osa.tsv'
expmatrawp = dir_path + 'prep_files/' + 'expmat_Osa_raw.tsv'
osa_tpm, osa_raw = kal_extract(kal_dir + 'osa/', osa_exps)
for i in [expmatp, expmatrawp]:
	write_expmat(i, osa_tpm)
	write_expmat(i, osa_raw)
kal_stats(kal_dir + 'osa/', "kallisto_stats.txt")