In [105]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import sys
sys.path.append('../creme/')
import custom_model
import shuffle
import creme
import utils
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pyranges as pr
import shutil
import pickle
import kipoiseq
from itertools import combinations
import scipy
import os
import gene as bgene
import shuffle
from scipy.stats import pearsonr
import tensorflow_hub as hub

import copy
from pymemesuite.fimo import FIMO
from pymemesuite.common import MotifFile
import Bio.SeqIO
from pymemesuite.common import Sequence

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [132]:
model = custom_model.Enformer(track_index=[4824, 5110, 5111], bin_index=[447, 448])

In [72]:
data_dir = '../data/'
genome_path = f'{data_dir}/GRCh38.primary_assembly.genome.fa'
seq_parser = utils.SequenceParser(genome_path)
model_name = 'enformer'

result_dir = utils.make_dir(f"{utils.make_dir(f'../results/XSTREME/')}/FIMO/")
csv_dir = f'../results/summary_csvs/{model_name}/'
cre_df_path = f'{csv_dir}/sufficient_CREs.csv'
cre_set = pd.read_csv(cre_df_path)

In [173]:
cre_set = cre_set[cre_set['cell_line']=='K562']

In [203]:
[g for g in cre_set['seq_id'].values if 'KAT6A' in g]

[]

In [207]:
np.sort([g.split('_')[0] for g in cre_set['seq_id'].values])

array(['ABO', 'ACSL6', 'ACSL6', 'ADGRG1', 'ADGRG1', 'ADGRG1', 'AK7',
       'AK7', 'AK7', 'AKAP13', 'ALDOC', 'ALKBH2', 'ATG101', 'BBS5',
       'CA13', 'CA13', 'CA13', 'CCDC80', 'CCDC88C', 'CCPG1', 'CFAP20DC',
       'CHAF1B', 'CHAF1B', 'CHST11', 'CHST11', 'CLUAP1', 'CNTLN', 'CNTLN',
       'COMMD6', 'CPNE3', 'CREG1', 'CXCL2', 'CXXC5', 'CXXC5', 'CYB5D1',
       'DAB2', 'DNASE1L1', 'DYSF', 'DYSF', 'DYSF', 'EEF1A2', 'EEF1A2',
       'ENSG00000257355', 'ENSG00000268133', 'ENSG00000269026',
       'ENSG00000269533', 'ENSG00000269533', 'ENSG00000273734',
       'ENSG00000276302', 'ENSG00000284931', 'ENSG00000288208',
       'ENSG00000288208', 'ENSG00000288208', 'ENSG00000288208',
       'ENSG00000288208', 'ENSG00000288796', 'ENSG00000290263',
       'ENSG00000290263', 'ETV3', 'FAN1', 'FMNL1', 'FOXO3B', 'FZD6',
       'G0S2', 'G0S2', 'G0S2', 'G0S2', 'G0S2', 'G0S2', 'GABRE', 'GABRE',
       'GAD1', 'GADD45B', 'GADD45B', 'GADD45GIP1', 'GATA2', 'GEMIN7',
       'GFI1B', 'GFI1B', 'GPR146', 'HACD

In [174]:
row = cre_set.iloc[0]
row

Unnamed: 0                                         19
(MUT - CONTROL) / WT                         0.333129
(MUT - CONTROL) / CONTROL                   15.364138
seq_id                       CHST11_chr12_104456947_+
control                                      3.605964
wt                                          166.30945
mut                                          59.00849
tile_start                                     100804
tile_end                                       105804
context                                     enhancing
cell_line                                        K562
Normalized CRE effect                        0.333129
tile class                                   Enhancer
Name: 98, dtype: object

In [175]:
meme_path = '../results/XSTREME/K562_enhancers.meme'

In [176]:
N_shuffles = 10
tile_size = 5000
# load all motifs from meme file
motif_file = MotifFile(meme_path)
motifs = []
for motif in motif_file:
    motifs.append(motif)

fimo = FIMO(both_strands=True)

fasta_dir = utils.make_dir(f'{result_dir}/todo/')

In [177]:


['ACGT'[i] for i in np.array(motif.frequencies).argmax(axis=1)]


['C', 'A', 'A', 'A', 'G', 'T', 'G', 'C', 'T', 'G', 'G', 'G', 'A', 'T', 'T']

In [178]:
chrom, start, strand = row['seq_id'].split('_')[1:]

In [179]:

seq_wt = seq_parser.extract_seq_centered(chrom, int(start), strand, model_seq_length, onehot=False).upper()  # wt str whole seq
seq_onehot_wt = utils.one_hot_encode(seq_wt)  # wt onehot whole seq



In [180]:
cre_start, cre_end = row['tile_start'], row['tile_end']  # coordinate of tile start
cre_wt = seq_wt[cre_start:cre_end]  # wt CRE str seq
cre_onehot_wt = (seq_onehot_wt[cre_start:cre_start + tile_size]).copy()  # wt CRE onehot



In [181]:
fasta_path = f'{fasta_dir}/{row["seq_id"]}_{cre_start}_{cre_end}.fa'  # one seq fasta file
result_path = fasta_path.split('.')[-2] + '.pickle'
result_path

'/results/XSTREME//FIMO//todo//CHST11_chr12_104456947_+_100804_105804.pickle'

In [182]:
with open(fasta_path, 'w') as f:
    f.write(f'>{row["seq_id"]}_{cre_start}_{cre_end} \n')
    f.write(cre_wt)

In [183]:
sequences = [Sequence(str(record.seq), name=record.id.encode())
             for record in Bio.SeqIO.parse(fasta_path, "fasta")]  # weird seq format for fimo

In [184]:
mask = np.zeros((len(motifs), tile_size))
for i, motif in enumerate(motifs):
    pattern = fimo.score_motif(motif, sequences, motif_file.background)
    for m in pattern.matched_elements:
        mask[i, m.start:m.stop] = 1

In [185]:
index_bool = [True if m else False for m in np.max(mask, axis=0)]  # mask showing motif positions
non_index_bool = [not e for e in index_bool]  # mask showing where motifs are not
random_mask = copy.copy(index_bool)
np.random.shuffle(random_mask)

In [186]:
motif_search_res = utils.read_pickle(f'../results/motifs_500,50_batch_1,10_shuffle_10_thresh_0.9,0.7/{row["cell_line"]}/{row["seq_id"]}_{cre_start}_{cre_end}.pickle')

In [187]:
control_sequences = motif_search_res['control_sequences']

In [188]:
result_summary = {}
for label, mask in {"motifs": index_bool, "non-motifs": non_index_bool, "random": random_mask, "all": [True for _ in range(5000)]}.items(): 
    # take shuffled CRE from controlseqs
    cre_motifs = control_sequences[:,cre_start: cre_end,:].copy() # start with shuffled CRE
    cre_motifs[:, mask, :] = cre_onehot_wt[mask] # embed 
    # embed CRE with just the motifs back into seqs
    test_seqs = control_sequences.copy()
    test_seqs[:, cre_start: cre_end,:] = cre_motifs
    
    test_pred = model.predict(test_seqs)
    result_summary[label] = test_pred.mean(axis=0).mean(axis=0)
    

In [189]:
result_summary

{'motifs': array([ 9.566814, 50.804127,  6.219971], dtype=float32),
 'non-motifs': array([ 6.897575 , 41.955494 ,  4.8685007], dtype=float32),
 'random': array([ 6.068802, 31.875778,  3.377602], dtype=float32),
 'all': array([ 11.413242, 144.04369 ,  60.437805], dtype=float32)}

In [139]:
test_pred.mean(axis=0).mean(axis=0)

array([0.12163265, 0.4629516 , 1.4371657 ], dtype=float32)

In [138]:
motif_search_res

{'wt': 52.86592,
 'mut': 32.34838,
 'control': 0.1775859,
 500: {'scores': [1.6889794,
   1.7525357,
   1.6812135,
   1.5986915,
   1.4148751,
   1.2027652,
   0.9831877,
   0.5926066],
  'bps': [4500.0, 4000.0, 3500.0, 3000.0, 2500.0, 2000.0, 1500.0, 1000.0],
  'all_removed_tiles': array([[101804., 102304.],
         [103804., 104304.],
         [101304., 101804.],
         [105304., 105804.],
         [100804., 101304.],
         [104804., 105304.],
         [102304., 102804.],
         [102804., 103304.]]),
  'insert_coords': [array([102804, 103304]),
   array([103304, 103804]),
   array([104304, 104804])]},
 50: {'scores': [1.1931223, 0.30012962],
  'bps': [1000.0, 500.0],
  'all_removed_tiles': array([[102804., 102854.],
         [103754., 103804.],
         [102854., 102904.],
         [104354., 104404.],
         [104304., 104354.],
         [102954., 103004.],
         [103554., 103604.],
         [103604., 103654.],
         [103704., 103754.],
         [103104., 103154.],
   