# Install and import all required packages

In [None]:
!pip install pandas h5py joblib matplotlib mpmath sklearn pyensembl
# !pip install combat

In [None]:
!pyensembl install --release 104 --species homo_sapiens

In [None]:
import pandas as pd
import numpy as np
import os, h5py, argparse, sys
from collections import Counter
from joblib import Parallel, delayed
# from combat.pycombat import pycombat
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.impute import KNNImputer, SimpleImputer
import pyensembl

# Define all functions

In [None]:
def get_filenames(meth_raw_dir : str) -> (list, list):
    """Extract all valid methylation files in dir.
    
    :param meth_raw_dir: Path to directory where Illumina folders are downloaded
    :return: Lists of filenames for 450k files and for EPIC/850K files
    """
    filenames_450 = []
    filenames_850 = []
    skipped_files = 0
    for dname in os.listdir(meth_raw_dir):
        sub_dirname = os.path.join(meth_raw_dir, dname)
        if os.path.isdir(sub_dirname):
            for fname in os.listdir(sub_dirname):
                if fname.endswith('gdc_hg38.txt'): # These correspond to Illumina 450K files
                        filenames_450.append(os.path.join(sub_dirname, fname))
                elif fname.endswith('level3betas.txt'): # These correspond to Illumina 850K files
                        filenames_850.append(os.path.join(sub_dirname, fname))
                else: # These may be Illumina 27K files or log/annotation files
                        skipped_files += 1
    print ("Skipped {} arrays".format(skipped_files))
    print ("Found {} 450k arrays and {} epic arrays".format(len(filenames_450),len(filenames_850)))
    return filenames_450, filenames_850

def get_metadata(filenames_450 : list, filenames_850 : list) -> (pd.DataFrame,pd.DataFrame):
    """Get tumor/normal sample, sample ID, case ID metadata information and map it to 450 and 850 filenames.
    
    -> Do a shell query using file ids from your manifest file to get desired fields for EPIC arrays.
    Follow this- https://docs.gdc.cancer.gov/API/Users_Guide/Search_and_Retrieval/#example-http-post-request.
    -> For 450 arrays, download the biospecimen files and use the sample file from them for mapping.
    -> Give code 0 for all types of tumor files and 1 for all normal files. There are three types we noticed in our files.
    
    :param filenames: Lists of filenames for 450k files and for EPIC/850K files
    :return: Metadata dataframe for 450 and 850 files
    """
    cols = ['cases.0.case_id', 'cases.0.samples.0.sample_id', 'cases.0.samples.0.sample_type', 'file_name']
    metaepic = pd.read_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\File_metadata.txt', 
                           header=0, sep = "\t", usecols = cols )
    metaepic.columns = ['case_id', 'sample_id', 'sample_type', 'file_name_short'] # Rename column names
    metaepic = metaepic.replace({'Primary Tumor': 0, 'Solid Tissue Normal': 1, 'Recurrent Tumor': 1})  
    paired = []
    
    for f in filenames_850:
        patient = f.split('\\')[-1]
        paired.append([patient, f])
    df_paired = pd.DataFrame(paired)
    df_paired.columns = ['file_name_short', 'file_name']
    # Right merge keeps only cases that have a file in our directory
    meta850df = metaepic.merge(df_paired, how = 'right', on = 'file_name_short') 
    print ("Got metadata for epic data")
     
    cols = ['case_id', 'sample_submitter_id', 'sample_id', 'sample_type']
    meta450 = pd.read_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\sample450.tsv', 
                          header=0, sep = "\t", usecols = cols)
    meta450 = meta450.replace({'Primary Tumor': 0, 'Solid Tissue Normal': 1,'Recurrent Tumor': 1})
    paired = []
    
    for f in filenames_450:
        patient = f.split('\\')[-1].split('.')[5].split('-')
        samplesubmitter = '-'.join(patient[:4]) # Refer TCGA barcode for this information
        paired.append([samplesubmitter, f])
    df_paired = pd.DataFrame(paired)
    df_paired.columns = ['sample_submitter_id', 'file_name']
    # Right merge keeps only cases that have a file in our directory
    meta450df = meta450.merge(df_paired, how = 'right', on = 'sample_submitter_id')
    print ("Got metadata for 450 data")
    return meta450df, meta850df

def get_methylation_map(annotation_path : str, methylation_map_path : str) -> pd.DataFrame:
    """
    Use the 'Infinium MethylationEPIC v1.0 B5 Manifest File'.
    Download from https://sapac.support.illumina.com/array/array_kits/infinium-methylationepic-beadchip-kit/downloads.html. 
    For column heading explanations, see the 'Infinium MethylationEPIC Manifest Column Headings' file. 
    """
    
    # keep only specified columns
    cols = ['IlmnID','CHR','UCSC_RefGene_Name','UCSC_RefGene_Group','GencodeCompV12_NAME',
            'GencodeCompV12_Accession', 'GencodeCompV12_Group','Methyl450_Loci']
    manifest = pd.read_csv(annotation_path, usecols = cols, skiprows = 7) 
    # remove control samples, keep only cpg sites # samples starting with 'rs' are blood samples   
    manifest = manifest[manifest['IlmnID'].astype(str).str.startswith('cg')]
    # remove sites in X chromosome to prevent gender bias and those not mapped to genes
    manifest = manifest[manifest.CHR != 'X'] 
    manifest = manifest[manifest.UCSC_RefGene_Name.isnull() & manifest.GencodeCompV12_NAME.isnull() == False]
    
    # merge gene mappings from hg19 and ch38, then delete ucsc mappings   
    manifest.GencodeCompV12_NAME.fillna(manifest.UCSC_RefGene_Name, inplace = True)
    manifest.GencodeCompV12_Group.fillna(manifest.UCSC_RefGene_Group, inplace = True)
    manifest = manifest.drop(columns = ['CHR', 'UCSC_RefGene_Name','UCSC_RefGene_Group'])
    # keep only cpg sites that are 1500 bp or 200 bp around TSS. Removes about HALF the probes    
    manifest_tss = manifest[manifest['GencodeCompV12_Group'].str.contains('TSS200|TSS1500') == True]
    manifest_tss.index = range(len(manifest_tss)) # reindex dataframe
    manifest_tss.to_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\manifest_tss.csv')
    
    # keep only genes for each site that are associated with TSS
    gene = []
    for row in range(len(manifest_tss)):
        strings = manifest_tss.GencodeCompV12_NAME[row].split(';')
        groups = manifest_tss.GencodeCompV12_Group[row].split(';')
        if ((len(set(strings)) == 1) == False): # if all genes are same, keep the first
            gene.append(strings[0])
        else: # if not, keep the first gene corresponding to TSS group
            for elem in range(len(groups)):
                if groups[elem].startswith('TSS'):
                    gene.append(strings[elem])
                    break
    manifest_tss['Gene'] = gene
    cpg_gene_map = manifest_tss[['IlmnID','Gene']].copy() # final cpg map with cpg sites and target gene
    cpg_gene_map.to_csv(methylation_map_path)
    print ("Computed mapping of cpg sites to target genes")
    return cpg_gene_map

def filter_pairedsamples(meta450 : pd.DataFrame, metaepic : pd.DataFrame) -> (pd.DataFrame,pd.DataFrame):
    metaepic_paired = pd.DataFrame(metaepic.groupby("case_id")['file_name'].apply(list)).reset_index(drop=True)['file_name']
    metaepic_paired = metaepic_paired[metaepic_paired.map(len) >= 2] # keep only paired samples
    meta450_paired = pd.DataFrame(meta450.groupby("case_id")['file_name'].apply(list)).reset_index(drop=True)['file_name']
    meta450_paired = meta450_paired[meta450_paired.map(len) >= 2]
    print ("Filtered to get {} 450 and {} 850 matched arrays".format(len(meta450_paired),len(metaepic_paired)))
    return meta450_paired, metaepic_paired
        
def concat_samples(filenames_450 : list, filenames_850 : list, paired=True) -> pd.DataFrame:
    """Create a df of all Illumina samples passed.
    Pass paired=True if you only want to keep samples that have both tumor and normal vals.
    """
    dfs1 = []
    dfs2 = []
    names = [] # TCGA barcode + # tumor=0 for tumor and =1 for normal + # batch=0 for 450k and =1 for 850k arrays
    
    print ("Getting metadata..")
    meta450, metaepic = get_metadata(filenames_450, filenames_850)
    
    print ("Obtaining list of paired samples..")
    if paired==True:
        meta450_paired, metaepic_paired = filter_pairedsamples(meta450, metaepic)
        filenames_450 = [j for i in meta450_paired for j in i]
        filenames_850 = [j for i in metaepic_paired for j in i]
        
    print ("Concatenating {} 450 and {} 850 Samples".format(len(filenames_450),len(filenames_850)))
    
    batch = '0'
    for f in filenames_450:
        df = pd.read_csv(f, sep = "\t", index_col = 0)
        df = pd.Series(df['Beta_value'])
        tumor = meta450.loc[meta450['file_name'] == f, 'sample_type'].iloc[0]
        colname = f + "|" + str(tumor) + "|" + batch
        dfs1.append(df)
        names.append(colname)
    totaldf1 = pd.concat(dfs1, axis = 1)
    print ("Concatenated 450 arrays")
    
    batch = '1'
    for f in filenames_850:
        df = pd.read_csv(f, sep = "\t", names = ['Composite Element REF','Beta_value'], index_col = 0)
        df = pd.Series(df['Beta_value'])
        tumor = metaepic.loc[metaepic['file_name'] == f, 'sample_type'].iloc[0]
        colname = f + "|" + str(tumor) + "|" + batch
        dfs2.append(df)
        names.append(colname)
    totaldf2 = pd.concat(dfs2, axis = 1)
    print ("Concatenated 850 arrays")
    
    # keep cpg sites in EPIC array - 450 probes not in epic are not considered 
    totaldf = totaldf2.merge(totaldf1, how = 'left', on = 'Composite Element REF')
    totaldf.columns = names
    print ("Concatenated all arrays")
    return totaldf

def plot_pca(dataframe : pd.DataFrame, tumor : list):

    dataframe['sample_type'] = tumor
    dataframe.set_index('sample_type', inplace = True)
    scaled_data = StandardScaler().fit_transform(dataframe)
    pca = PCA(2)
    pca = pca.fit(scaled_data) 
    pca = pca.transform(scaled_data)
    per_var = np.round(pca.explained_variance_ratio_ * 100, decimals = 1)
    pca_df = pd.DataFrame(pca, columns = labels)
    plt.scatter(pca_df.PC1, pca_df.PC2, c = dataframe.index)
    plt.xlabel('PC1 - {0}%'.format(per_var[0]))
    plt.ylabel('PC2 - {0}%'.format(per_var[1]))
    plt.show()
    
def impute(totaldf : pd.DataFrame, impute='mean') -> (pd.DataFrame,list):  
    """Denoise and impute missing/NA/0 value data
    """
    impute_type = ['mean', 'knn', 'dca', 'pca']
    if impute not in impute_type:
        raise ValueError("Invalid impute type. Expected one of: %s" % impute_type)
    
    # drop rows that have only NAs
    totaldf = totaldf.dropna(how='all')
    dftum = totaldf.loc[:, totaldf.columns.str.contains('|0|',regex=False)]
    dfnor = totaldf.loc[:, totaldf.columns.str.contains('|1|',regex=False)]
    batch = []
    
    # impute missing probes
    if impute=='knn':
        print ("Imputing using KNN")
        imputer = KNNImputer()
        dftum [:] = imputer.fit_transform(dftum)
        imputer = KNNImputer()
        dfnor[:] = imputer.fit_transform(dfnor)
        dfn = pd.concat([dftum,dfnor],axis=1)
        dfn.to_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\knn_imputed_df.tsv.gz', sep='\t', compression='gzip')
        
    if impute=='mean':
        print ("Imputing using mean")
        meantumor = dftum.mean(axis=1)
        meannormal = dfnor.mean(axis=1)
        dfna=[]
        for (colname,colval) in dftum.iteritems():
            dfna.append(colval.fillna(meantumor))
            batch.append(int(colname.split('|')[-1]))
        for (colname,colval) in dfnor.iteritems():
            dfna.append(colval.fillna(meannormal))
            batch.append(int(colname.split('|')[-1]))
        dfn = pd.concat(dfna,axis=1)
        dfn.to_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\mean_imputed_df.tsv.gz', sep='\t', compression='gzip')
    return dfn , batch

def batch_correct(df_imputed : pd.DataFrame, batch : list) -> pd.DataFrame:  
    """Do batch correction using combat and visualize results using PCA.
    """
    print ("Performing batch correction..")      
    dfn = df_imputed.replace(0, np.NaN)
    df = dfn.dropna(how='all')
    print ("Removing {} rows with only NAs and 0s in dataframe after imputation".format(len(dfn.index)-len(df.index)))
#     print ("Removing {} NAs and 0s in dataframe".format(df.isna().sum()))
    df = df.dropna()
    df_corrected = pycombat(df, batch)

    # visualise data clustering before and after combat
#     plot_pca(totaldfna.transpose(), meta['tumor'])
#     plt.title('PCA_BEFORE')
#     plot_pca(df_corrected.transpose(), meta['tumor'])
#     plt.title('PCA_AFTER')
    return df_corrected

def get_dm_cpg(df_corrected : pd.DataFrame, de='subtract', paired=True) -> pd.DataFrame:
    
    de_type = ['subtract', 'log2fc']
    if de not in de_type:
        raise ValueError("Invalid differential expression compute type. Expected one of: %s" % de_type)
    
    meta450, metaepic = get_metadata(filenames_450, filenames_850)
    meta450_paired, metaepic_paired = filter_pairedsamples(meta450, metaepic)
    meta_combined = pd.concat([meta450_paired, metaepic_paired],ignore_index=True)
    
    dftum = df_corrected.loc[:, df_corrected.columns.str.contains('|0|',regex=False)]
    dfnor = df_corrected.loc[:, df_corrected.columns.str.contains('|1|',regex=False)] 
    tumfiles = [i.split('|')[0] for i in dftum.columns]
    norfiles = [i.split('|')[0] for i in dfnor.columns]
    dftum.columns = tumfiles
    dfnor.columns = norfiles
    
    # compute differential expression of each cpg site taking mean of paired tumor/normal columns
    if paired==True:
        dm_case = []
        for case in meta_combined:
            tumorcasefiles=[]
            normalcasefiles=[]
            for file in case:
                if file in tumfiles:
                    tumorcasefiles.append(file)
                elif file in norfiles:
                    normalcasefiles.append(file)
            log2fc_case = np.log2(dftum[tumorcasefiles].mean(axis=1)) - np.log2(dfnor[normalcasefiles].mean(axis=1)) 
            dm_case.append(log2fc_case)
        cpg_de = pd.concat(dm_case,axis=1)
            
    # compute differential expression of each cpg site taking mean of all tumor/normal columns
    else:
        tm = dftum.mean(axis=1)
        nm = dfnor.mean(axis=1)
        de='subtract'
        cpg_de = pd.DataFrame()
        if de=='subtract':
            cpg_de = tm - nm
        if de=='log2fc':
            cpg_de = np.log2(tm) - np.log2(nm)
    return cpg_de
    
def get_gene_sample_matrix(dmcpg_df : pd.DataFrame, methylation_map_path : str) -> pd.DataFrame:
    """Create a gene vs sample matrix using batch corrected data.
    """
    cpg_gene_map = pd.read_csv(methylation_map_path, usecols = ['IlmnID','Gene'])
    cpg_gene_map.columns = ['Composite Element REF','Gene']
    gene_sample_matrix = cpg_gene_map.merge(corrected_df, how = 'left', on = 'Composite Element REF')
    gene_sample_matrix = gene_sample_matrix.groupby('Gene', as_index=False).mean()
    
    # get list of protein coding genes from ensembl
    ensembl = pyensembl.EnsemblRelease()
    gene_ids = ensembl.gene_ids()
    genes = [ensembl.gene_by_id(gene_id) for gene_id in gene_ids]
    coding_genes = [gene.gene_name for gene in genes if gene.biotype == 'protein_coding']
    coding_genes = list(filter(None, coding_genes)) # remove empty strings
    # keep only protein coding genes in gene sample matrix dataframe
    gene_sample_matrix = gene_sample_matrix[gene_sample_matrix['Gene'].isin(coding_genes)]
    gene_sample_matrix = gene_sample_matrix.set_index('Gene', drop=True)
    return gene_sample_matrix

# Main Code 

In [None]:
annotation = r'C:\Users\HP\Downloads\infinium-methylationepic-v-1-0-b5-manifest-file.csv'
methylation_map = r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\cpg_gene_mapEPIC.csv'
meth_dir = r'C:\Users\HP\Downloads\methluad'
filenames_450, filenames_850 = get_filenames(meth_dir)

In [None]:
get_methylation_map(annotation,methylation_map)
totaldf = concat_samples(filenames_450,filenames_850, True) # takes time
totaldf.to_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\totaldf.tsv.gz', sep='\t', compression='gzip')

In [None]:
totaldf = pd.read_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\totaldf.tsv.gz', sep='\t', compression='gzip')
totaldf = totaldf.set_index('Composite Element REF', drop=True)
totaldf

In [None]:
df_imputed_mean, batch = impute(totaldf, 'mean')
# df_corrected_knn = batch_correct(totaldf, 'knn')
df_corrected_mean = batch_correct(df_imputed_mean, batch)
df_corrected_mean.to_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\df_corrected_mean.tsv.gz', sep='\t', compression='gzip')

In [None]:
df_corrected = pd.read_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\df_corrected_mean.tsv.gz', sep='\t', compression='gzip')
df_corrected = df_corrected.set_index('Composite Element REF', drop=True)
df_corrected

In [None]:
genesampledf = get_gene_sample_matrix(cpg_de, methylation_map)
# genesampledf.to_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\genesampledf.tsv.gz', sep='\t', compression='gzip')
genesampledf = genesampledf.mean(axis=1).dropna()
genesampledf

# Identify differentially methylated genes

In [None]:
# get list of top 1000 LUAD genes from disgenet
target = pd.read_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\C0152013_disease_gda_summary.tsv', 
                     usecols = ['Gene'], sep = '\t', nrows = 1000).Gene.tolist()
mirdown = pd.read_csv(r'C:\Users\HP\Downloads\top_1000.csv', usecols = ['genes']).genes.tolist()
mirup = pd.read_csv(r'C:\Users\HP\Downloads\Positive_aggregate_genes - Sheet1.csv', usecols = ['Gene']).Gene.tolist()
# target =  pd.read_csv(r'C:\Users\HP\Downloads\DisGeNET_luad.csv', usecols = ['Gene Names'])['Gene Names'].tolist()
print (len(target))

In [None]:
tumormean = genesampledf[0].mean(axis=1)
hypermeth = tumormean[tumormean.abs()> 0.80]
hypometh = tumormean[tumormean.abs()< 0.20]
meth = set(hypermeth.index).union(set(hypometh.index))
mir = set(mirdown).union(set(mirup))
print (len(meth), len(mir))
epigenes = meth.union(mir)
commonmeth = meth.intersection(set(target))
print (len(commonmeth))
commonmir = mir.intersection(set(target))
print (len(commonmir))
common = epigenes.intersection(set(target))
print (len(common))

1. Using beta value cutoffs of 0.8 and 0.2 in tumor samples only

In [None]:
genesampledf.columns = meta['tumor']
tumormean = genesampledf[0].mean(axis=1)
hypermeth = tumormean[tumormean.abs()> 0.75]
hypometh = tumormean[tumormean.abs()< 0.25]
print (len(hypermeth), len(hypometh))

common_hyper1k = list(set(hypermeth.index).intersection(set(target)))
common_hypo1k = list(set(hypometh.index).intersection(set(target)))
print ("beta threshold 0.7 and 0.25 for top 1k genes--", len(common_hyper1k),",", len(common_hypo1k))
print (common_hyper1k)

common_hyper500 = list(set(hypermeth.index).intersection(set(target[:500])))
common_hypo500 = list(set(hypometh.index).intersection(set(target[:500])))
print ("beta threshold 0.7 and 0.25 for top 500 genes--", len(common_hyper500),",", len(common_hypo500))

In [None]:
# test the difference before batch correction
genesample.columns = meta['tumor']
tumormean = genesample[0].mean(axis=1)
hypermeth = tumormean[tumormean.abs()> 0.7]
hypometh = tumormean[tumormean.abs()< 0.25]
print (len(hypermeth), len(hypometh))

common_hyper1k = list(set(hypermeth.index).intersection(set(target)))
common_hypo1k = list(set(hypometh.index).intersection(set(target)))
print ("beta threshold 0.8 and 0.2 for top 1k genes--", len(common_hyper1k),",", len(common_hypo1k))

common_hyper500 = list(set(hypermeth.index).intersection(set(target[:500])))
common_hypo500 = list(set(hypometh.index).intersection(set(target[:500])))
print ("beta threshold 0.8 and 0.2 for top 500 genes--", len(common_hyper500),",", len(common_hypo500))
print (common_hyper500)

2. Using log2fc cutoff of 0.5 between tumor and normal samples collectively

In [None]:
normalmean = genesampledf[1].mean(axis=1)

# compute log2 fold changes (for each sample, divide by mean normal values. Then compute mean across samples)
fc = np.log2(genesampledf[0].divide(normalmean, axis=0))
fc_nan = fc.replace([np.inf, -np.inf], np.nan)
print ("Got {} invalid values after computing log2 fold changes".format(fc_nan.isnull().sum().sum()))
ratio = fc_nan.dropna(axis=0) # remove NaN and inf (from division by 0 or 0+eta)
# calculate mean of all log2fc samples and save to directory
final_mean_meth=ratio.mean(axis=1)
final_mean_meth.to_csv(r'C:\Users\HP\OneDrive\Documents\Bayes\methluad\finalmeth.tsv', sep='\t')
print (final_mean_meth)

In [None]:
hypermeth = final_mean_meth[final_mean_meth> 0.5]
hypometh = final_mean_meth[final_mean_meth< -0.5]
print (len(hypermeth), len(hypometh))

common_hyper1k = list(set(hypermeth.index).intersection(set(target)))
common_hypo1k = list(set(hypometh.index).intersection(set(target)))
print ("beta threshold 0.8 and 0.2 for top 1k genes--", len(common_hyper1k),",", len(common_hypo1k))
print (common_hyper1k)

common_hyper500 = list(set(hypermeth.index).intersection(set(target[:500])))
common_hypo500 = list(set(hypometh.index).intersection(set(target[:500])))
print ("beta threshold 0.8 and 0.2 for top 500 genes--", len(common_hyper500),",", len(common_hypo500))

2. Using log2fc cutoff of 0.5 between tumor and normal samples case-wise

# PyCombat

In [None]:
# this is the code from pycombat. run in case of too many dependency errors
import numpy as np
from math import exp
from multiprocessing import Pool, cpu_count
from functools import partial
import mpmath as mp
import pandas as pd

def model_matrix(info, intercept=True, drop_first=True):
    """Creates the model_matrix from batch list
    Arguments:
        info {list} -- list info with batch or covariates data
        intercept {bool} -- boolean for intercept in model matrix
    Returns:
        matrix -- model matrix generate from batch list
    """
    if not isinstance(info[0], list):
        info = [info]
    else:
        info = info
    info_dict = {}
    for i in range(len(info)):
        info_dict[f"col{str(i)}"] = list(map(str,info[i]))
    df = pd.get_dummies(pd.DataFrame(info_dict), drop_first=drop_first, dtype=float)
    if intercept:
        df["intercept"] = 1.0
    return df.to_numpy()


def all_1(list_of_elements):
    """checks if all elements in a list are 1s
    Arguments:
        list_of_elements {list} -- list of elements
    Returns:
        bool -- True iff all elements of the list are 1s
    """
    return((list_of_elements == 1).all())


# aprior and bprior are useful to compute "hyper-prior values"
# -> prior parameters used to estimate the prior gamma distribution for multiplicative batch effect
# aprior - calculates empirical hyper-prior values

def compute_prior(prior, gamma_hat, mean_only):
    """[summary]
    Arguments:
        prior {char} -- 'a' or 'b' depending of the prior to be calculated
        gamma_hat {matrix} -- matrix of additive batch effect
        mean_only {bool} -- True iff mean_only selected
    Returns:
        float -- [the prior calculated (aprior or bprior)
    """
    if mean_only:
        return 1
    m = np.mean(gamma_hat)
    s2 = np.var(gamma_hat)
    if prior == 'a':
        return (2*s2+m*m)/s2
    elif prior == 'b':
        return (m*s2+m*m*m)/s2


def postmean(g_bar, d_star, t2_n, t2_n_g_hat):
    """estimates additive batch effect
    Arguments:
        g_bar {matrix} -- additive batch effect
        d_star {matrix} -- multiplicative batch effect
        t2_n {matrix} --
        t2_n_g_hat {matrix} --
    Returns:
        matrix -- estimated additive batch effect
    """
    return np.divide(t2_n_g_hat+d_star*g_bar, np.asarray(t2_n+d_star))


def postvar(sum2, n, a, b):
    """estimates multiplicative batch effect
    Arguments:
        sum2 {vector} --
        n {[type]} --
        a {float} -- aprior
        b {float} -- bprior
    Returns:
        matrix -- estimated multiplicative batch effect
    """
    return(np.divide((np.multiply(0.5, sum2)+b), (np.multiply(0.5, n)+a-1)))


def it_sol(sdat, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001, exit_iteration=10e5):
    """iterative solution for Empirical Bayesian method
    Arguments:
        sdat {matrix} --
        g_hat {matrix} -- average additive batch effect
        d_hat {matrix} -- average multiplicative batch effect
        g_bar {matrix} -- additive batch effect
        t2 {matrix} --
        a {float} -- aprior
        b {float} -- bprior
    Keyword Arguments:
        conv {float} -- convergence criterion (default: {0.0001})
        exit_iteration {float} -- maximum number of iterations before exit (default: {10e5})
    Returns:
        array list -- estimated additive and multiplicative batch effect
    """

    n = [len(i) for i in np.asarray(sdat)]
    t2_n = np.multiply(t2, n)
    t2_n_g_hat = np.multiply(t2_n, g_hat)
    g_old = np.ndarray.copy(g_hat)
    d_old = np.ndarray.copy(d_hat)
    change = 1
    count = 0  # number of steps needed (for diagnostic only)
    # convergence criteria, if new-old < conv, then stop
    while (change > conv) and (count < exit_iteration):
        g_new = postmean(g_bar, d_old, t2_n, t2_n_g_hat)  # updated additive batch effect
        sum2 = np.sum(np.asarray(np.square(
            sdat-np.outer(g_new[0][0], np.ones(np.ma.size(sdat, axis=1))))), axis=1)
        d_new = postvar(sum2, n, a, b)  # updated multiplicative batch effect
        change = max(np.amax(np.absolute(g_new-np.asarray(g_old))/np.asarray(g_old)), np.amax(
            np.absolute(d_new-d_old)/d_old))  # maximum difference between new and old estimate
        g_old = np.ndarray.copy(g_new)  # save value for g
        d_old = np.ndarray.copy(d_new)  # save value for d
        count += 1
    adjust = np.asarray([g_new, d_new])
    return(adjust)  # remove parenthesis in returns

# int_eprior - Monte Carlo integration function to find nonparametric adjustments
# Johnson et al (Biostatistics 2007, supp.mat.) show that we can estimate the multiplicative and additive batch effects with an integral
# This integral is numerically computed through Monte Carlo inegration (iterative method)


def int_eprior(sdat, g_hat, d_hat, precision):
    """ int_eprior - Monte Carlo integration function to find nonparametric adjustments
        Johnson et al (Biostatistics 2007, supp.mat.) show that we can estimate the multiplicative and additive batch effects with an integral
        This integral is numerically computed through Monte Carlo inegration (iterative method)
    Arguments:
        sdat {matrix} -- data matrix
        g_hat {matrix} -- average additive batch effect
        d_hat {matrix} -- average multiplicative batch effect
        precision {float} -- level of precision for precision computing
    Returns:
        array list -- estimated additive and multiplicative batch effect
    """
    g_star = []
    d_star = []
    # use this variable to only print error message once if approximation used
    test_approximation = 0
    for i in range(len(sdat)):
        # additive batch effect
        g = np.asarray(np.delete(np.transpose(g_hat), i))
        # multiplicative batch effect
        d = np.asarray(np.delete(np.transpose(d_hat), i))
        x = np.asarray(np.transpose(sdat[i]))
        n = len(x)
        j = [1]*n
        dat = np.repeat(x, len(np.transpose(g)), axis=1)
        resid2 = np.square(dat-g)
        sum2 = np.dot(np.transpose(resid2), j)
        # /begin{handling high precision computing}
        temp_2d = 2*d
        if (precision == None):
            LH = np.power(1/(np.pi*temp_2d), n/2)*np.exp(np.negative(sum2)/(temp_2d))

        else:  # only if precision parameter informed
            # increase the precision of the computing (if negative exponential too close to 0)
            mp.dps = precision
            buf_exp = np.array(list(map(mp.exp, np.negative(sum2)/(temp_2d))))
            buf_pow = np.array(list(map(partial(mp.power, y=n/2), 1/(np.pi*temp_2d))))
            #print(buf_exp.dtype, buf_pow.dtype)
            LH = buf_pow*buf_exp  # likelihood
        # /end{handling high precision computing}
        LH = np.nan_to_num(LH)  # corrects NaNs in likelihood
        if np.sum(LH) == 0 and test_approximation == 0:
            test_approximation = 1  # this message won't appear again
            print("###\nValues too small, approximation applied to avoid division by 0.\nPrecision mode can correct this problem, but increases computation time.\n###")

        if np.sum(LH) == 0: # correction for LH full of 0.0
            LH[LH == 0] = np.exp(-745)
            g_star.append(np.sum(g*LH)/np.sum(LH))
            d_star.append(np.sum(d*LH)/np.sum(LH))
        else:
            g_star.append(np.sum(g*LH)/np.sum(LH))
            d_star.append(np.sum(d*LH)/np.sum(LH))
    adjust = np.asarray([np.asarray(g_star), np.asarray(d_star)])
    return(adjust)


def param_fun(i, s_data, batches, mean_only, gamma_hat, gamma_bar, delta_hat, t2, a_prior, b_prior):
    """parametric estimation of batch effects
    Arguments:
        i {int} -- column index
        s_data {matrix} --
        batches {list list} -- list of list of batches' elements
        mean_only {bool} -- True iff mean_only selected
        gamma_hat {matrix} -- average additive batch effect
        gamma_bar {matrix} -- estimated additive batch effect
        delta_hat {matrix} -- average multiplicative batch effect
        t2 {matrix} --
        a_prior {float} -- aprior
        b_prior {float} -- bprior
    Returns:
        array list -- estimated adjusted additive and multiplicative batch effect
    """
    if mean_only:  # if mean_only, no need for complex method: batch effect is immediately calculated
        t2_n = np.multiply(t2[i], 1)
        t2_n_g_hat = np.multiply(t2_n, gamma_hat[i])
        gamma_star = postmean(gamma_bar[i], 1, t2_n, t2_n_g_hat)  # additive batch effect
        delta_star = [1]*len(s_data)  # multiplicative batch effect
    else:  # if not(mean_only) then use it_solve
        temp = it_sol(np.transpose(np.transpose(s_data)[
                      batches[i]]), gamma_hat[i], delta_hat[i], gamma_bar[i], t2[i], a_prior[i], b_prior[i])
        gamma_star = temp[0]  # additive batch effect
        delta_star = temp[1]  # multiplicative batch effect
    return [gamma_star, delta_star]

def nonparam_fun(i, mean_only, delta_hat, s_data, batches, gamma_hat, precision):
    """non-parametric estimation
    Arguments:
        i {int} -- column index
        mean_only {bool} -- True iff mean_only selected
        delta_hat {matrix} -- estimated multiplicative batch effect
        s_data {matrix} --
        batches {list list} -- list of list of batches' elements
        gamma_hat {matrix} -- estimated additive batch effect
        precision {float} -- level of precision for precision computing
    Returns:
        array list -- estimated adjusted additive and multiplicative batch effect
    """
    if mean_only:  # if mean only, change delta_hat to vector of 1s
        delta_hat[i] = [1]*len(delta_hat[i])
    # use int_eprior for non-parametric estimation
    temp = int_eprior(np.transpose(np.transpose(s_data)[
                      batches[i]]), gamma_hat[i], delta_hat[i], precision)
    return [temp[0], temp[1]]

############
# pyComBat #
############


def check_mean_only(mean_only):
    """checks mean_only option
    Arguments:
        mean_only {boolean} -- user's choice about mean_only
    Returns:
        ()
    """
    if mean_only == True:
        print("Using mean only version")


def define_batchmod(batch):
    """generates model matrix
    Arguments:
        batch {list} -- list of batch id
    Returns:
        batchmod {matrix} -- model matrix for batches
    """
    batchmod = model_matrix(list(batch), intercept=False, drop_first=False)
    return(batchmod)


def check_ref_batch(ref_batch, batch, batchmod):
    """check ref_batch option and treat it if needed
    Arguments:
        ref_batch {int} -- the reference batch
        batch {list} -- list of batch id
        batchmod {matrix} -- model matrix related to batches
    Returns:
        ref {int list} -- the corresponding positions of the reference batch in the batch list
        batchmod {matrix} -- updated model matrix related to batches, with reference
    """
    if ref_batch is not None:
        if ref_batch not in batch:
            print("Reference level ref.batch must be one of the levels of batch.")
            exit(0)
        print("Using batch "+str(ref_batch) +
              " as a reference batch.")
        # ref keeps in memory the columns concerned by the reference batch
        ref = np.where(np.unique(batch) == ref_batch)[0][0]
        # updates batchmod with reference
        batchmod[:,ref] = 1
    else:
        ref = None  # default settings
    return(ref, batchmod)


def treat_batches(batch):
    """treat batches
    Arguments:
        batch {list} -- batch list
    Returns:
        n_batch {int} -- number of batches
        batches {int list} -- list of unique batches
        n_batches {int list} -- list of batches lengths
        n_array {int} -- total size of dataset
    """
    batch = pd.Series(batch)
    n_batch = len(np.unique(batch))  # number of batches
    print("Found "+str(n_batch)+" batches.")
    batches = []  # list of lists, contains the list of position for each batch
    for i in range(n_batch):
        batches.append(np.where(batch == np.unique(batch)[i])[0].astype(np.int32))
    n_batches = list(map(len, batches))
    if 1 in n_batches:
        #mean_only = True  # no variance if only one sample in a batch - mean_only has to be used
        print("\nOne batch has only one sample, try setting mean_only=True.\n")
    n_array = sum(n_batches)
    return(n_batch, batches, n_batches, n_array)


def treat_covariates(batchmod, mod, ref, n_batch):
    """treat covariates
    Arguments:
        batchmod {matrix} -- model matrix for batch
        mod {matrix} -- model matrix for other covariates
        ref {int} -- reference batch
        n_batch {int} -- number of batches
    Returns:
        check {bool list} -- a list characterising all covariates
        design {matrix} -- model matrix for all covariates, including batch
    """
    # design matrix for sample conditions
    if mod == []:
        design = batchmod
    else:
        mod_matrix = model_matrix(mod, intercept=True)
        design = np.concatenate((batchmod, mod_matrix), axis=1)
    check = list(map(all_1, np.transpose(design)))
    if ref is not None:  # if ref
        check[ref] = False  # the reference in not considered as a covariate
    design = design[:, ~np.array(check)]
    design = np.transpose(design)

    print("Adjusting for "+str(len(design)-len(np.transpose(batchmod))) +
          " covariate(s) or covariate level(s).")

    # if matrix cannot be invertible, different cases
    if np.linalg.matrix_rank(design) < len(design):
        if len(design) == n_batch + 1:  # case 1: covariate confunded with a batch
            print(
                "The covariate is confunded with batch. Remove the covariate and rerun pyComBat.")
            exit(0)
        if len(design) > n_batch + 1:  # case 2: multiple covariates confunded with a batch
            if np.linalg.matrix_rank(np.transpose(design)[:n_batch]) < len(design):
                print(
                    "The covariates are confounded! Please remove one or more of the covariates so the design is not confounded.")
                exit(0)
            else:  # case 3: at least a covariate confunded with a batch
                print(
                    "At least one covariate is confounded with batch. Please remove confounded covariates and rerun pyComBat")
                exit(0)
    return(design)


def check_NAs(dat):
    """check if NaNs - in theory, we construct the data without NAs
    Arguments:
        dat {matrix} -- the data matrix
    Returns:
        NAs {bool} -- boolean characterising the presence of NaNs in the data matrix
    """
    # NAs = True in (np.isnan(dat))
    NAs = np.isnan(np.sum(dat))  # Check if NaN exists
    if NAs:
        print("Found missing data values. Please remove all missing values before proceeding with pyComBat.")
    return(NAs)


def calculate_mean_var(design, batches, ref, dat, NAs, ref_batch, n_batches, n_batch, n_array):
    """ calculates the Normalisation factors
    Arguments:
        design {matrix} -- model matrix for all covariates
        batches {int list} -- list of unique batches
        dat {matrix} -- data matrix
        NAs {bool} -- presence of NaNs in the data matrix
        ref_batch {int} -- reference batch
        n_batches {int list} -- list of batches lengths
        n_array {int} -- total size of dataset
    Returns:
        B_hat {matrix} -- regression coefficients corresponding to the design matrix
        grand_mean {matrix} -- Mean for each gene and each batch
        var_pooled {matrix} -- Variance for each gene and each batch
    """
    print("Standardizing Data across genes.")
    if not(NAs):  # NAs not supported
        # B_hat is the vector of regression coefficients corresponding to the design matrix
        B_hat = np.linalg.solve(np.dot(design, np.transpose(
            design)), np.dot(design, np.transpose(dat)))

    # Calculates the general mean
    if ref_batch is not None:
        grand_mean = np.transpose(B_hat[ref])
    else:
        grand_mean = np.dot(np.transpose(
            [i / n_array for i in n_batches]), B_hat[0:n_batch])
    # Calculates the general variance
    if not NAs:  # NAs not supported
        if ref_batch is not None:  # depending on ref batch
            ref_dat = np.transpose(np.transpose(dat)[batches[ref]])
            var_pooled = np.dot(np.square(ref_dat - np.transpose(np.dot(np.transpose(
                design)[batches[ref]], B_hat))), [1/n_batches[ref]]*n_batches[ref])
        else:
            var_pooled = np.dot(np.square(
                dat - np.transpose(np.dot(np.transpose(design), B_hat))), [1/n_array]*n_array)

    return(B_hat, grand_mean, var_pooled)


def calculate_stand_mean(grand_mean, n_array, design, n_batch, B_hat):
    """ transform the format of the mean for substraction
    Arguments:
        grand_mean {matrix} -- Mean for each gene and each batch
        n_array {int} -- total size of dataset
        design {[type]} -- design matrix for all covariates including batch
        n_batch {int} -- number of batches
        B_hat {matrix} -- regression coefficients corresponding to the design matrix
    Returns:
        stand_mean {matrix} -- standardised mean
    """
    stand_mean = np.dot(np.transpose(np.mat(grand_mean)), np.mat([1]*n_array))
    # corrects the mean with design matrix information
    if design is not None:
        tmp = np.ndarray.copy(design)
        tmp[0:n_batch] = 0
        stand_mean = stand_mean + \
            np.transpose(np.dot(np.transpose(tmp), B_hat))
    return(stand_mean)


def standardise_data(dat, stand_mean, var_pooled, n_array):
    """standardise the data: substract mean and divide by variance
    Arguments:
        dat {matrix} -- data matrix
        stand_mean {matrix} -- standardised mean
        var_pooled {matrix} -- Variance for each gene and each batch
        n_array {int} -- total size of dataset
    Returns:
        s_data {matrix} -- standardised data matrix
    """
    s_data = (dat - stand_mean) / \
        np.dot(np.transpose(np.mat(np.sqrt(var_pooled))), np.mat([1]*n_array))
    return(s_data)


def fit_model(design, n_batch, s_data, batches, mean_only, par_prior, precision, ref_batch, ref, NAs):
    print("Fitting L/S model and finding priors.")

    # fraction of design matrix related to batches
    batch_design = design[0:n_batch]

    if not NAs:  # CF SUPRA FOR NAs
        # gamma_hat is the vector of additive batch effect
        gamma_hat = np.linalg.solve(np.dot(batch_design, np.transpose(batch_design)),
                                    np.dot(batch_design, np.transpose(s_data)))

    delta_hat = []  # delta_hat is the vector of estimated multiplicative batch effect

    if (mean_only):
        # no variance if mean_only == True
        delta_hat = [np.asarray([1]*len(s_data))] * len(batches)
    else:
        for i in batches:  # feed incrementally delta_hat
            list_map = np.transpose(np.transpose(s_data)[i]).var(
                axis=1)  # variance for each row
            delta_hat.append(np.squeeze(np.asarray(list_map)))

    gamma_bar = list(map(np.mean, gamma_hat))  # vector of means for gamma_hat
    t2 = list(map(np.var, gamma_hat))  # vector of variances for gamma_hat

    # calculates hyper priors for gamma (additive batch effect)
    a_prior = list(
        map(partial(compute_prior, 'a', mean_only=mean_only), delta_hat))
    b_prior = list(
        map(partial(compute_prior, 'b', mean_only=mean_only), delta_hat))

    # initialise gamma and delta for parameters estimation
    gamma_star = np.empty((n_batch, len(s_data)))
    delta_star = np.empty((n_batch, len(s_data)))

    if par_prior:
        # use param_fun function for parametric adjustments (cf. function definition)
        print("Finding parametric adjustments.")
        results = list(map(partial(param_fun,
                                   s_data=s_data,
                                   batches=batches,
                                   mean_only=mean_only,
                                   gamma_hat=gamma_hat,
                                   gamma_bar=gamma_bar,
                                   delta_hat=delta_hat,
                                   t2=t2,
                                   a_prior=a_prior,
                                   b_prior=b_prior), range(n_batch)))
    else:
        # use nonparam_fun for non-parametric adjustments (cf. function definition)
        print("Finding nonparametric adjustments")
        results = list(map(partial(nonparam_fun, mean_only=mean_only, delta_hat=delta_hat,
                                   s_data=s_data, batches=batches, gamma_hat=gamma_hat, precision=precision), range(n_batch)))

    for i in range(n_batch):  # store the results in gamma/delta_star
        results_i = results[i]
        gamma_star[i], delta_star[i] = results_i[0], results_i[1]

    # update if reference batch (the reference batch is not supposed to be modified)
    if ref_batch:
        len_gamma_star_ref = len(gamma_star[ref])
        gamma_star[ref] = [0] * len_gamma_star_ref
        delta_star[ref] = [1] * len_gamma_star_ref

    return(gamma_star, delta_star, batch_design)


def adjust_data(s_data, gamma_star, delta_star, batch_design, n_batches, var_pooled, stand_mean, n_array, ref_batch, ref, batches, dat):
    """Adjust the data -- corrects for estimated batch effects
    Arguments:
        s_data {matrix} -- standardised data matrix
        gamma_star {matrix} -- estimated additive batch effect
        delta_star {matrix} -- estimated multiplicative batch effect
        batch_design {matrix} -- information about batches in design matrix
        n_batches {int list} -- list of batches lengths
        stand_mean {matrix} -- standardised mean
        var_pooled {matrix} -- Variance for each gene and each batch
        n_array {int} -- total size of dataset
        ref_batch {int} -- reference batch
        ref {int list} -- the corresponding positions of the reference batch in the batch list
        batches {int list} -- list of unique batches
        dat
    Returns:
        bayes_data [matrix] -- data adjusted for correction of batch effects
    """
    # Now we adjust the data:
    # 1. substract additive batch effect (gamma_star)
    # 2. divide by multiplicative batch effect (delta_star)
    print("Adjusting the Data")
    bayes_data = np.transpose(s_data)
    j = 0
    for i in batches:  # for each batch, specific correction
        bayes_data[i] = (bayes_data[i] - np.dot(np.transpose(batch_design)[i], gamma_star)) / \
            np.transpose(
                np.outer(np.sqrt(delta_star[j]), np.asarray([1]*n_batches[j])))
        j += 1

    # renormalise the data after correction:
    # 1. multiply by variance
    # 2. add mean
    bayes_data = np.multiply(np.transpose(bayes_data), np.outer(
        np.sqrt(var_pooled), np.asarray([1]*n_array))) + stand_mean

    # correction for reference batch
    if ref_batch:
        bayes_data[batches[ref]] = dat[batches[ref]]

    # returns the data corrected for batch effects
    return bayes_data


def pycombat(data, batch, mod=[], par_prior=True, prior_plots=False, mean_only=False, ref_batch=None, precision=None, **kwargs):
    """Corrects batch effect in microarray expression data. Takes an gene expression file and a list of known batches corresponding to each sample.
    Arguments:
        data {matrix} -- The expression matrix (dataframe). It contains the information about the gene expression (rows) for each sample (columns).
        batch {list} -- List of batch indexes. The batch list describes the batch for each sample. The batches list has as many elements as the number of columns in the expression matrix.
    Keyword Arguments:
        mod {list} -- List (or list of lists) of covariate(s) indexes. The mod list describes the covariate(s) for each sample. Each mod list has as many elements as the number of columns in the expression matrix (default: {[]}).
        par_prior {bool} -- False for non-parametric estimation of batch effects (default: {True}).
        prior_plots {bool} -- True if requires to plot the priors (default: {False} -- Not implemented yet!).
        mean_only {bool} -- True iff just adjusting the means and not individual batch effects (default: {False}).
        ref_batch {int} -- reference batch selected (default: {None}).
        precision {float} -- level of precision for precision computing (default: {None}).
    Returns:
        bayes_data_df -- The expression dataframe adjusted for batch effects.
    """

    list_samples = data.columns
    list_genes = data.index
    dat = data.values

    check_mean_only(mean_only)

    batchmod = define_batchmod(batch)
    ref, batchmod = check_ref_batch(ref_batch, batch, batchmod)
    n_batch, batches, n_batches, n_array = treat_batches(batch)
    design = treat_covariates(batchmod, mod, ref, n_batch)
    NAs = check_NAs(dat)
    if not(NAs):
        B_hat, grand_mean, var_pooled = calculate_mean_var(
            design, batches, ref, dat, NAs, ref_batch, n_batches, n_batch, n_array)
        stand_mean = calculate_stand_mean(
            grand_mean, n_array, design, n_batch, B_hat)
        s_data = standardise_data(dat, stand_mean, var_pooled, n_array)
        gamma_star, delta_star, batch_design = fit_model(
            design, n_batch, s_data, batches, mean_only, par_prior, precision, ref_batch, ref, NAs)
        bayes_data = adjust_data(s_data, gamma_star, delta_star, batch_design,
                                n_batches, var_pooled, stand_mean, n_array, ref_batch, ref, batches, dat)

        bayes_data_df = pd.DataFrame(bayes_data,
                    columns = list_samples,
                    index = list_genes)

        return(bayes_data_df)
    else:
        raise ValueError("NaN value is not accepted")