20240306

SarahFong

Take output from DiffBind deseq2 formatting

          $HOME/EMF/US/bin/ATAC/1_prep_deseq2_ATAC.ipynb 
          
Prepare data: 
1. Clean diffbind deseq2 conc data (remove bad coordinates)
4. Format dataframe for LegNet, multitask, deepstarr multitask
5. Split into training and test set - randomly hold out 2 chromosomes. 
6. Add fold num to training set

In [1]:
from Bio.SeqIO.FastaIO import SimpleFastaParser
import config_readwrite as crw
import matplotlib.pyplot as plt
import numpy as np
import os, sys
import pandas as pd

from scipy import stats

import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

import plot_params as pp
pp.fonts()

('sans-serif', 'Arial', 18)

In [2]:
config_name =os.path.join(os.getcwd(), "config.multi.ini")
config, cfn = crw.read(config_name)

# functions

In [3]:
def setPrefixRules(prefix):
    """ rules for processing datasets"""
    if "class.all" in prefix:
        QUANTILE_FILTER=False
        MIN_CONC_FILTER=False
        JOINT_ACCESSIBLE=True
        CLASS_LABEL = True

    elif "class.nojoint" in prefix:
        QUANTILE_FILTER=False
        MIN_CONC_FILTER=True
        JOINT_ACCESSIBLE=False
        CLASS_LABEL = True

    elif "reg.all" in PREFIX:
        QUANTILE_FILTER=False
        MIN_CONC_FILTER=True
        JOINT_ACCESSIBLE=True
        CLASS_LABEL = False

    elif "reg.nojoint" in prefix:
        QUANTILE_FILTER=False
        MIN_CONC_FILTER=True
        JOINT_ACCESSIBLE=False
        CLASS_LABEL = False
    else:
        print('need to add rules for', prefix)
        
    return QUANTILE_FILTER, MIN_CONC_FILTER, JOINT_ACCESSIBLE, CLASS_LABEL


## make chromosome list

In [4]:
def chrList():
    """return  list of chromosomes"""
    
    chrs = []
    
    for n in np.arange(1,23):
        chrs.append(f"chr{n}")
    
    # add sex chromosomes
    chrs.append("chrX")
    chrs.append("chrY")
    
    return chrs

## write fa

In [5]:
def writeFa(heldout_df, heldout_fa):
    with open(heldout_fa, "w") as writer:
        for row in heldout_df.iterrows():
            seqid, seq=row[1][:2]
            writer.write(f">{seqid}\n{seq}\n")

## train test split on chromosome

In [6]:
def splitTrainTestVal(df, val_chr_list=None, test_chr_list=None):
    """randomly sample and hold out 2 chromosomes for testing, validation"""
    
    cols =['coor.type',"seq"]
    
    # randomly sample test chromosomes (n=2)
    if test_chr_list is None:
        print('randomly sampling chromosomes for test')


        # randomly sample
        test_chr_list = list(np.random.choice(chrs, 2))
   
    # randomly sample validation chromosome (n=1)
    if val_chr_list is None:
        print('randomly sampling chromosomes for val')

        # remove test chromosomes
        for chr_ in test_chr_list:
            chrs.remove(chr_)

        # randomly sample
        val_chr_list = list(np.random.choice(chrs, 1))

    # separate held out chromosomes from  training chromosomes
    test = df.loc[df["#chr"].isin(test_chr_list)].copy()
    val = df.loc[df["#chr"].isin(val_chr_list)].copy()
    train = df.loc[(~df["#chr"].isin(test_chr_list))&
                  (~df["#chr"].isin(val_chr_list))].copy()

    return train, val, test


## filters

In [7]:
def minReadDepthFilter(table, MIN_CONC, col1, col2):

    print("before read depth filter:", table.shape)
    table = table.loc[(table[col1] > MIN_CONC) |
                      (table[col2] > MIN_CONC)].copy()
    print("after:", table.shape)

    return table


def quantileFilter(table, quantile=0.99):
    """upper quantile filter for read count values"""

    print("before quantile filter:", table.shape)

    # quantiles
    ctrl_thresh, us_thresh = table[["ctrl", "US"]].quantile(quantile)

    # filter table
    table = table.loc[(table["ctrl"] < ctrl_thresh) &
                      (table["US"] < us_thresh)]

    print("after:", table.shape)

    return table


def jointAccessibleFilter(table):
    """remove joint accessible regions, scramble new dataframe"""

    print("before joint_accessible filter:", table.shape)

    # filter table
    ctrl_only = table.loc[(table["ctrl"] > 0) &
                          (table["US"] == 0)].copy()

    US_only = table.loc[(table["ctrl"] == 0) &
                        (table["US"] > 0)].copy()

    # combine us and ctrl
    # shuffle dataframe
    table = pd.concat([ctrl_only, US_only]).sample(
        frac=1).reset_index(drop=True)

    print("after:", table.shape)

    return table

# binarize


def classLabel(table):
    """binarize read count column"""
    cols_to_label = ['ctrl', 'US']
    for col in cols_to_label:
        table[col] = table[col].apply(lambda x: 0 if x == 0 else 1)

    return table

# scramble df


def dfShuffle(df):
    """scramble dataframe"""
    return df.sample(frac=1).reset_index(drop=True)

def directionFilter(table):
    """filter for positive directional changes in activity"""

    table = table.loc[table["US_DIF"]<0] #where US_DIF = ctrl - US, and US_dif <0 means US has more activity than control
    
    return table

In [8]:
def directionFilter(table, col):
    """filter for positive directional changes in activity"""

    table = table.loc[table[col]<0] #where US_DIF = ctrl - US, and US_dif <0 means US has more activity than control
    
    return table

# Main

## load data

In [9]:
collection_dict = {}

In [12]:
CLS = ["hob", "hepg2", "k562", "bj"]
for CL in CLS:

    PREFIX = "reg.all.8tasks"

    PATH = "/wynton/group/ahituv/data/US-MPRA/ATAC-seq/Diffbind_results"
    DATA_PATH = f"/wynton/home/ahituv/fongsl/EMF/US/ml_emf/data/deepstarr/deseq2/{CL}/{PREFIX}"

    TEST_CHR = ["chr8"]
    VAL_CHR = "chr12"

    # peak information
    PEAK_SIZE = 270
    MIN_CONC = 2

    # genome information
    HG38 = "/wynton/group/ahituv/data/dna/hg38/hg38.chrom.sizes"
    FA_HG38 = "/wynton/group/ahituv/data/dna/hg38/hg38.fa"

    cols = ['coor.type', "seq", "ctrl", "US", "US_DIF", "fold_num"]
    fa_cols = ['coor.type', "seq"]

    if os.path.exists(DATA_PATH) is False:
        if os.path.dirname(DATA_PATH) is False:
            os.mkdir(os.path.dirname(DATA_PATH))
        os.mkdir(DATA_PATH)

    # base files

    # files

    # DESEQ2 information
    DIFF = f'./diffbind_results/{CL}_deseq2.csv'
    NODIFF = f'./diffbind_results/{CL}_deseq2-nodiff.csv'
    DIFF_BED = "./" + DIFF.strip(".csv") + ".bed"

    # write training, test files
    FULL = f"{CL}_deseq2-nondiff.trimmed.full.csv"

    # base config

    section = f"{CL}-ATAC-DESEQ2"
    crw.check(config, section)

    config[section]["path"] = PATH

    config[section]["nondiff_bind_results"] = "%(path)s/" + NODIFF
    config[section]["diff_bind_results"] = "%(path)s/" + DIFF
    config[section]["diff_bind_results_bed"] = "%(path)s/" + DIFF_BED

    # deepstarr files

    # file inputs for training, testing
    # x
    TRAIN_FA = f"{PREFIX}.Sequences_Train.fa"
    VAL_FA = f"{PREFIX}.Sequences_Val.fa"
    TEST_FA = f"{PREFIX}.Sequences_Test.fa"

    # y
    TRAIN_TARGET = f"{PREFIX}.Sequences_activity_Train.txt"
    VAL_TARGET = f"{PREFIX}.Sequences_activity_Val.txt"
    TEST_TARGET = f"{PREFIX}.Sequences_activity_Test.txt"

    # rules

    # deepstarr config

    section = f"{CL}.atac.deseq2.deepstarr"
    crw.check(config, section)

    # deepstar config
    config[section]["data_path"] = DATA_PATH
    config[section]["held_out_chr"] = ",".join(TEST_CHR)
    config[section]["val_chr"] = VAL_CHR

    # deepstarr+prefix config

    # rules for data
    QUANTILE_FILTER, MIN_CONC_FILTER, JOINT_ACCESSIBLE, CLASS_LABEL = setPrefixRules(
        PREFIX)

    DIRECTION_FILTER = True

    # deepstarr prefix config
    section = f"Hepg2.atac.deseq2.deepstarr.{PREFIX}"
    crw.check(config, section)

    config[section]["train_fa"] = TRAIN_FA
    config[section]["val_fa"] = VAL_FA
    config[section]["test_fa"] = TEST_FA

    config[section]["train_target"] = TRAIN_TARGET
    config[section]["tval_target"] = VAL_TARGET
    config[section]["test_target"] = TEST_TARGET

    config[section]["filter_MIN_CONC"] = str(MIN_CONC_FILTER)
    if MIN_CONC_FILTER is True:
        config[section]["MIN_CONC"] = str(MIN_CONC)

    config[section]["filter_quantile"] = str(QUANTILE_FILTER)
    config[section]["filter_jointaccessible"] = str(JOINT_ACCESSIBLE)
    config[section]["classlabel"] = str(CLASS_LABEL)

    crw.write(config, cfn)

    # deal with data tables
    os.chdir(PATH)
    table = pd.read_csv(FULL, sep='\t').drop_duplicates()

    table["coor.type"] = table["type"] + "|" + table["seq.id"]

    if "Conc_Ultrasound_dif" not in list(table):
        table["Conc_Ultrasound_dif"] = table["Conc_Control"] - \
            table["Conc_Ultrasound"]

    # rename columns w cl annotated
    us_col, ctrl_col, dif_col = f"US.{CL}", f"ctrl.{CL}", f"US_DIF.{CL}"
    
    table.rename(columns={"Conc_Ultrasound": us_col,
                          "Conc_Control": ctrl_col,
                          "Conc_Ultrasound_dif": dif_col}, inplace=True)

    print(table.shape)

    # min read depth filter

    # mean read depth filter
    if MIN_CONC_FILTER is True:
        table = minReadDepthFilter(table, MIN_CONC, us_col, ctrl_col)

    # quantile filter
    if QUANTILE_FILTER is True:

        table = quantileFilter(table, quantile=0.99)

    # joint accessible filter
    if JOINT_ACCESSIBLE is False:
        table = jointAccessibleFilter(table)

    # apply class label
    if CLASS_LABEL is True:
        table = classLabel(table)
        print(table.groupby([us_col, ctrl_col, dif_col])['#chr'].count())

    if DIRECTION_FILTER is True:
        table = directionFilter(table, dif_col)

    # train on all atac peaks

    os.chdir(DATA_PATH)

    table = dfShuffle(table)  # shuffle the table before splitting

    train, val, test = splitTrainTestVal(
        table, val_chr_list=[VAL_CHR], test_chr_list=TEST_CHR)

    writeFa(test[fa_cols], TEST_FA)
    writeFa(train[fa_cols], TRAIN_FA)
    writeFa(val[fa_cols], VAL_FA)

    cols = ["coor.type",
            # us_col,
            ctrl_col, 
            dif_col 
           ]
    test[cols].to_csv(TEST_TARGET, sep='\t', index=False)
    train[cols].to_csv(TRAIN_TARGET, sep='\t', index=False)
    val[cols].to_csv(VAL_TARGET, sep='\t', index=False)

    print(val.shape, train.shape, test.shape)

    collection_dict[CL] = table

(80892, 10)
before read depth filter: (80892, 10)
after: (80892, 10)
(2823, 10) (52606, 10) (2472, 10)
(27488, 10)
before read depth filter: (27488, 10)
after: (27488, 10)
(551, 10) (9920, 10) (372, 10)
(55241, 10)
before read depth filter: (55241, 10)
after: (55241, 10)
(1449, 10) (27766, 10) (1119, 10)
(41641, 10)
before read depth filter: (41641, 10)
after: (41641, 10)
(751, 10) (13320, 10) (639, 10)


# combine all cell line ATAC data together

In [26]:
table = pd.concat(collection_dict.values()).fillna(0)
table

Unnamed: 0,#chr,start_trim,end_trim,type,seq.id,US.hob,ctrl.hob,US_DIF.hob,seq,coor.type,US.hepg2,ctrl.hepg2,US_DIF.hepg2,US.k562,ctrl.k562,US_DIF.k562,US.bj,ctrl.bj,US_DIF.bj
0,chrX,48492419,48492690,hob.146430,chrX:48492419-48492690,4.985740,4.964249,-0.021490,CCAGAGACAATGTGGCCAGGCTCCGGAGGGCTGGGAAGATGAGCAA...,hob.146430|chrX:48492419-48492690,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000
1,chr20,4076708,4076979,hob.47070,chr20:4076708-4076979,4.637884,4.347412,-0.290472,ttttgatgaattccaatttattgatttcgatgttgtggggacaaga...,hob.47070|chr20:4076708-4076979,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000
2,chr17,17591665,17591936,hob.142196,chr17:17591665-17591936,8.482293,8.470101,-0.012192,CAGCCACGCGCCCCCGGAACCGGACCTATAGAGCCGGGTAAGTGCC...,hob.142196|chr17:17591665-17591936,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000
3,chr6,106746723,106746994,hob.128551,chr6:106746723-106746994,4.958441,4.903755,-0.054687,AGATGACTGCTTACAATACACAGCTTTCATATAGGGAGTGAGTGTC...,hob.128551|chr6:106746723-106746994,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000
4,chr5,177476824,177477095,hob.11791,chr5:177476824-177477095,5.505728,5.103355,-0.402373,TTCTGCATGCAGGGGTGAGGTGGGCTGGAGTCTGATCAGAAGTTGC...,hob.11791|chr5:177476824-177477095,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14705,chr14,58864931,58865202,bj.47457,chr14:58864931-58865202,0.000000,0.000000,0.000000,TTTCCTGGAAGTTGAGAAGCGTCTTTAGTCTACTGTGTTGACAGCA...,bj.47457|chr14:58864931-58865202,0.0,0.0,0.0,0.0,0.0,0.0,4.281995,3.928428,-0.353567
14706,chr5,1523913,1524184,bj.29040,chr5:1523913-1524184,0.000000,0.000000,0.000000,cgcgggtctcggggcgcgggccgaggatgcgcggcggctggagcgg...,bj.29040|chr5:1523913-1524184,0.0,0.0,0.0,0.0,0.0,0.0,4.836365,4.626367,-0.209998
14707,chr9,91887496,91887767,bj.51645,chr9:91887496-91887767,0.000000,0.000000,0.000000,ccagtcaataaatatatcacaagttttttcccactgttctgtgggt...,bj.51645|chr9:91887496-91887767,0.0,0.0,0.0,0.0,0.0,0.0,4.329817,4.137622,-0.192196
14708,chr11,46347175,46347446,bj.66233,chr11:46347175-46347446,0.000000,0.000000,0.000000,GGGATCGGGGGAGGAAAGATGCGCGTCTGGATGCGCGCAGTGCGAG...,bj.66233|chr11:46347175-46347446,0.0,0.0,0.0,0.0,0.0,0.0,5.883983,5.794700,-0.089284


In [27]:
PREFIX = "reg.all.8task"
DATA_PATH = f"/wynton/home/ahituv/fongsl/EMF/US/ml_emf/data/deepstarr/deseq2/all/{PREFIX}"

if os.path.exists(DATA_PATH) is False:
    os.mkdir(DATA_PATH)
    
os.chdir(DATA_PATH)

table = dfShuffle(table)  # shuffle the table before splitting

train, val, test = splitTrainTestVal(
    table, val_chr_list=[VAL_CHR], test_chr_list=TEST_CHR)

In [28]:
# FIles for all. 

TRAIN_FA = f"{PREFIX}.Sequences_Train.fa"
VAL_FA = f"{PREFIX}.Sequences_Val.fa"
TEST_FA = f"{PREFIX}.Sequences_Test.fa"

# y
TRAIN_TARGET = f"{PREFIX}.Sequences_activity_Train.txt"
VAL_TARGET = f"{PREFIX}.Sequences_activity_Val.txt"
TEST_TARGET = f"{PREFIX}.Sequences_activity_Test.txt"

writeFa(test[fa_cols], TEST_FA)
writeFa(train[fa_cols], TRAIN_FA)
writeFa(val[fa_cols], VAL_FA)

cols = ["coor.type",
        'ctrl.hob',
        'US_DIF.hob',
        'ctrl.hepg2',
        'US_DIF.hepg2',
        'ctrl.k562',
        'US_DIF.k562',
        'ctrl.bj',
        'US_DIF.bj'
        ]
test[cols].to_csv(TEST_TARGET, sep='\t', index=False)
train[cols].to_csv(TRAIN_TARGET, sep='\t', index=False)
val[cols].to_csv(VAL_TARGET, sep='\t', index=False)

print(val.shape, train.shape, test.shape)

(5574, 19) (103612, 19) (4602, 19)
