In [2]:
from itertools import count
import json
import argparse
import os
import random
import pandas as pd
import numpy as np
import pyBigWig

NARROWPEAK_SCHEMA = ["chr", "start", "end", "1", "2", "3", "4", "5", "6", "summit"]

peak_regions_df = pd.read_csv("/mnt/lab_data2/vir/tf_chr_atlas/temp/docker_modelling/ENCSR142IGM_peaks_inliers.bed.gz", sep='\t', names=NARROWPEAK_SCHEMA)
peak_regions_df['group']='peak'
peak_regions_df['ind']=range(len(peak_regions_df))
nonpeak_regions_df = pd.read_csv("/mnt/lab_data2/vir/tf_chr_atlas/temp/docker_modelling/ENCSR142IGM_gc_neg_only.bed.gz", sep='\t', names=NARROWPEAK_SCHEMA)
nonpeak_regions_df['group']='nonpeak'
nonpeak_regions_df['ind']=range(len(nonpeak_regions_df))


all_regions_df = pd.concat([peak_regions_df,nonpeak_regions_df])
all_regions_df['pos']=all_regions_df['start']+all_regions_df['summit']
all_regions_df.sort_values(by=['chr', 'pos'], inplace=True)
all_regions_df=all_regions_df.reset_index(drop=True)

print("Creating Splits")

group_dict = {}

inputlen=2114
max_jitter=32

cur_chrom = ''
cur_group = ''
last_pos = 0
for index,row in all_regions_df.iterrows():
    if cur_chrom != '':
        if row['chr'] != cur_chrom:
            cur_chrom = row['chr']
            cur_group += 1
            group_dict[cur_group] = [row]
        else:
            if row['pos'] <= int(last_pos) + int(inputlen) + int(2 * max_jitter):
                group_dict[cur_group].append(row)
            else:
                cur_group += 1
                group_dict[cur_group] = [row]
    else:
        cur_chrom = row['chr']
        cur_group = 0
        group_dict[cur_group] = [row]
    last_pos = row['pos']
    
groups = []
group_counts = []
bigwig = '/mnt/lab_data2/vir/tf_chr_atlas/temp/docker_modelling/ENCSR142IGM_plus.bigWig'
bw = pyBigWig.open(bigwig)

for group in group_dict:
    groups.append(group)
    sum = 0
    for element in group_dict[group]:
        labels = bw.values(element['chr'], int(element['pos'] - (inputlen // 2)), int(element['pos'] + (inputlen // 2)))
        labels = np.array(labels)
        labels = np.nan_to_num(labels)
        labels = np.sum(labels)
        sum += labels
    group_counts.append(sum)
group_df = pd.DataFrame({'groups': groups, 'group_counts': group_counts})
group_df.sort_values(by='group_counts', inplace=True)



Creating Splits


In [4]:
group_fold_dict

{'fold4': []}

In [5]:
group_fold_dict = {}
number_of_folds=5
for fold in range(number_of_folds):
    group_fold_dict[f"fold{fold}"]=[]

count = 0
valid_used = []

for index,row in group_df.iterrows():
    if index % 10000 == 0:
        print(index)
    if count % 2 == 0:
        test_or_valid = 'valid'
    else:
        test_or_valid = 'test'
    test_or_valid_fold = random.choice([i for i in range(number_of_folds) if i not in valid_used])
    for fold in range(number_of_folds):
        if fold != test_or_valid_fold:
            group_fold_dict[f"fold{fold}"].append('train')
        else:
            group_fold_dict[f"fold{fold}"].append(test_or_valid)
    count += 1
    valid_used.append(test_or_valid_fold)
    if len(valid_used) == number_of_folds:
        valid_used = []


for fold in range(number_of_folds):
    group_df['fold' + str(fold)] = group_fold_dict['fold' + str(fold)]

print("Saving Splits")
for fold in range(number_of_folds):
    for split in ['valid','train','test']:
        temp_lst = [group_dict.get(key) for key in group_df['groups'][group_df[f"fold{fold}"]==split]] 
        peak_indices = [i['ind'] for b in map(lambda x:[x] if not isinstance(x, list) else x, temp_lst) for i in b if i['group']=='peak']
        nonpeak_indices = [i['ind'] for b in map(lambda x:[x] if not isinstance(x, list) else x, temp_lst) for i in b if i['group']=='nonpeak']
        f = open(f"{args.output_path}/loci_{split}_indices_fold{fold}.txt", "w")
        for items in peak_indices:
            f.writelines(str(items)+'\n')
        f.close()
        f = open(f"{args.output_path}/background_{split}_indices_fold{fold}.txt", "w")
        for items in nonpeak_indices:
            f.writelines(str(items)+'\n')
        f.close()

210000
0
80000
200000
150000
140000
60000
40000
120000
90000
70000
170000
20000
130000
110000
50000
10000
160000
190000
180000
30000
100000
Saving Splits


NameError: name 'args' is not defined

In [None]:


peak_regions = pd.read_csv(args.peaks, sep='\t', names=NARROWPEAK_SCHEMA)
nonpeak_regions = pd.read_csv(args.nonpeaks, sep='\t', names=NARROWPEAK_SCHEMA)

print("Loading Data")

peak_chroms = peak_regions['chr']
peak_pos = peak_regions['start'] + peak_regions['summit']

nonpeak_chroms = nonpeak_regions['chr']
nonpeak_pos = nonpeak_regions['start'] + nonpeak_regions['summit']

all_chroms = peak_chroms.tolist() + nonpeak_chroms.tolist()
all_pos = peak_pos.tolist() + nonpeak_pos.tolist()

print("Creating Splits")

all_df = pd.DataFrame({'chr': all_chroms, 'pos': all_pos})
all_df.sort_values(by=['chr', 'pos'], inplace=True)

group_dict = {}

cur_chrom = ''
cur_group = ''
last_pos = 0
for index,row in all_df.iterrows():
    if cur_chrom != '':
        if row['chr'] != cur_chrom:
            cur_chrom = row['chr']
            cur_group += 1
            group_dict[cur_group] = [(row['chr'], row['pos'])]
        else:
            if row['pos'] <= int(last_pos) + int(args.inputlen) + int(2 * args.max_jitter):
                group_dict[cur_group].append((row['chr'], row['pos']))
            else:
                cur_group += 1
                group_dict[cur_group] = [(row['chr'], row['pos'])]
    else:
        cur_chrom = row['chr']
        cur_group = 0
        group_dict[cur_group] = [(row['chr'], row['pos'])]
    last_pos = row['pos']

groups = []
group_counts = []

bw = pyBigWig.open(args.bigwig)

for group in group_dict:
    groups.append(group)
    sum = 0
    for element in group_dict[group]:
        labels = bw.values(element[0], int(element[1] - (args.inputlen // 2)), int(element[1] + (args.inputlen // 2)))
        labels = np.array(labels)
        labels = np.nan_to_num(labels)
        labels = np.sum(labels)
        sum += labels
    group_counts.append(sum)

group_df = pd.DataFrame({'groups': groups, 'group_counts': group_counts})
group_df.sort_values(by='group_counts', inplace=True)
group_fold_dict = {'fold0': [], 'fold1': [], 'fold2': [], 'fold3': [], 'fold4': []}

count = 0
valid_used = []

for index,row in group_df.iterrows():
    if index % 10000 == 0:
        print(index)
    if count % 2 == 0:
        test_or_valid = 'valid'
    else:
        test_or_valid = 'test'
    test_or_valid_fold = random.choice([i for i in range(5) if i not in valid_used])
    for fold in range(5):
        if fold != test_or_valid_fold:
            group_fold_dict['fold' + str(fold)].append('train')
        else:
            group_fold_dict['fold' + str(fold)].append(test_or_valid)
    count += 1
    valid_used.append(test_or_valid_fold)
    if len(valid_used) == 5:
        valid_used = []

for fold in range(5):
    group_df['fold' + str(fold)] = group_fold_dict['fold' + str(fold)]

all_dict = {'chr': [], 'pos': [], 'fold0': [], 'fold1': [], 'fold2': [], 'fold3': [], 'fold4': []}

for index,row in group_df.iterrows():
    for element in group_dict[row['groups']]:
        all_dict['chr'].append(element[0])
        all_dict['pos'].append(element[1])
        all_dict['fold0'].append(row['fold0'])
        all_dict['fold1'].append(row['fold1'])
        all_dict['fold2'].append(row['fold2'])
        all_dict['fold3'].append(row['fold3'])
        all_dict['fold4'].append(row['fold4'])

splits_df = pd.DataFrame(all_dict)

print("Saving Splits")

splits_df.to_csv(args.output_prefix + '.splits.tsv', sep='\t')