In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import os
import json
import subprocess
import shutil
os.environ["CUDA_VISIBLE_DEVICES"] = '-1' ### run on CPU

import tensorflow as tf
print(tf.__version__)
if tf.__version__[0] == '1':
    tf.compat.v1.enable_eager_execution()

import numpy as np
import pandas as pd
import pysam
import matplotlib.pyplot as plt
from cooltools.lib.numutils import set_diag
import sys
sys.path.append("/wynton/home/hernandez/shirondru/pollard_lab/basenji")
from basenji import dataset, dna_io, seqnn

import cooler
import cooltools
from cooltools.lib.numutils import observed_over_expected
from cooltools.lib.numutils import adaptive_coarsegrain
from cooltools.lib.numutils import interpolate_bad_singletons
from cooltools.lib.numutils import interp_nan, set_diag
from cooltools.lib.plotting import *


import kipoiseq
from kipoiseq import Interval
import pyfaidx

from collections import Counter
import argparse
import sys




# In[2]:


### load params, specify model ###

model_dir = '/wynton/home/hernandez/shirondru/pollard_lab/'
params_file = model_dir+'params.json'
model_file  = model_dir+'model_best.h5'
with open(params_file) as params_open:
    params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

seqnn_model = seqnn.SeqNN(params_model)


# In[3]:


### restore model ###
# note: run %%bash get_model.sh 
# if you have not already downloaded the model
seqnn_model.restore(model_file)
print('successfully loaded')


# In[4]:


### names of targets ###
data_dir =  os.path.join(model_dir,"basenji/manuscripts/akita/data/")

hic_targets = pd.read_csv(data_dir+'/targets.txt',sep='\t')
hic_file_dict_num = dict(zip(hic_targets['index'].values, hic_targets['file'].values) )
hic_file_dict     = dict(zip(hic_targets['identifier'].values, hic_targets['file'].values) )
hic_num_to_name_dict = dict(zip(hic_targets['index'].values, hic_targets['identifier'].values) )

# read data parameters
data_stats_file = '%s/statistics.json' % data_dir
with open(data_stats_file) as data_stats_open:
    data_stats = json.load(data_stats_open)
seq_length = data_stats['seq_length']
target_length = data_stats['target_length']
hic_diags =  data_stats['diagonal_offset']
target_crop = data_stats['crop_bp'] // data_stats['pool_width']
target_length1 = data_stats['seq_length'] // data_stats['pool_width']


# In[5]:


fasta_file = '/wynton/home/hernandez/shirondru/pollard_lab/data/hg38_genome.fa'


### for converting from flattened upper-triangluar vector to symmetric matrix  ###

def from_upper_triu(vector_repr, matrix_len, num_diags):
    z = np.zeros((matrix_len,matrix_len))
    triu_tup = np.triu_indices(matrix_len,num_diags)
    z[triu_tup] = vector_repr
    for i in range(-num_diags+1,num_diags):
        set_diag(z, np.nan, i)
    return z + z.T


def preprocess_from_cool(myseq_str, genome_hic_cool):
    print("Seq-str: ", myseq_str)
    num_counts= np.sum(genome_hic_cool.matrix(balance=False).fetch(myseq_str))
    seq_hic_obs = genome_hic_cool.matrix(balance=True).fetch(myseq_str)
    seq_hic_smoothed =  adaptive_coarsegrain(
                     seq_hic_obs,  
                     genome_hic_cool.matrix(balance=False).fetch(myseq_str),  
                     cutoff=3, max_levels=8)
    seq_hic_nan = np.isnan(seq_hic_smoothed)
    seq_hic_obsexp = observed_over_expected(seq_hic_smoothed, ~seq_hic_nan)[0]
    seq_hic_obsexp = np.log(seq_hic_obsexp)
    seq_hic_obsexp = np.clip(seq_hic_obsexp,-2,2)
    seq_hic_obsexp_init = np.copy(seq_hic_obsexp)
    seq_hic_obsexp = interp_nan(seq_hic_obsexp)
    seq_hic_obsexp = np.nan_to_num(seq_hic_obsexp)
    seq_hic = np.clip(seq_hic_obsexp,-2,2)
    for i in [-1,0,1]: set_diag(seq_hic, 0,i)
        
    from astropy.convolution import Gaussian2DKernel
    from astropy.convolution import convolve
    kernel = Gaussian2DKernel(x_stddev=1,x_size=5)

    seq_hic = convolve(seq_hic, kernel)
    return seq_hic, num_counts, seq_hic_obs


def get_expt(region_chr, region_start, region_stop):
    myseq_str = "{}:{}-{}".format(region_chr, region_start, region_stop)
    expt, num_counts, expt_obs = preprocess_from_cool(myseq_str, genome_hic_cool)
    new_start = int((target_length - target_length_cropped)/2)
    new_end = int(target_length-new_start)
    expt = expt[new_start:target_length-new_start, new_start:target_length-new_start]
    return(expt)


# @title `variant_centered_sequences`

class FastaStringExtractor:
    
    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}

    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),
                                    )
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream

    def close(self):
        return self.fasta.close()


def variant_generator(vcf_file, gzipped=False):
  """Yields a kipoiseq.dataclasses.Variant for each row in VCF file."""
  def _open(file):
    return gzip.open(vcf_file, 'rt') if gzipped else open(vcf_file)
    
  with _open(vcf_file) as f:
    for line in f:
      if line.startswith('#') or line.startswith('CHROM'):
        continue
      chrom, pos, id, ref, alt_list = line.split('\t')[:5]
      # Split ALT alleles and return individual variants as output.
      for alt in alt_list.split(','):
        yield kipoiseq.dataclasses.Variant(chrom=chrom, pos=pos,
                                           ref=ref, alt=alt, id=id)


def one_hot_encode(sequence):
  return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)


def get_N_composition(seq: str):
    """
    Get % of N's in input sequence
    
    Input: 
        seq: string of sequence
    Returns: % of Ns in input sequence
    """
    count = Counter(seq)
    
    for key, value in count.items():
        count[key] = round(value/len(seq)*100,2)
#     if 'N' in count.keys():
    if count['N'] > 0:
        return count['N']
    else:
        return 0 

def variant_centered_sequences(vcf_file, sequence_length, gzipped=False,
                               chr_prefix=''):
  seq_extractor = kipoiseq.extractors.VariantSeqExtractor(
    reference_sequence=FastaStringExtractor(fasta_file))

  for variant in variant_generator(vcf_file, gzipped=gzipped):
    interval = Interval(chr_prefix + variant.chrom,
                        variant.pos, variant.pos)
    interval = interval.resize(sequence_length)
    center = interval.center() - interval.start

    reference = seq_extractor.extract(interval, [], anchor=center)
    ref_N_composition = get_N_composition(reference)
    
    alternate = seq_extractor.extract(interval, [variant], anchor=center)
    alt_N_composition = get_N_composition(alternate)
    yield {'inputs': {'ref': reference,
                      'alt': alternate},
           'metadata': {'chrom': chr_prefix + variant.chrom,
                        'pos': variant.pos,
                        'id': variant.id,
                        'ref': variant.ref,
                        'alt': variant.alt,
                        'ref_N_composition':ref_N_composition,
                        'alt_N_composition':alt_N_composition}}
    
    
def msd(alternate_prediction, reference_prediction):
    
    #returns Mean squared difference between alt and ref predictions for each cell line
    return np.nanmean(np.square(alternate_prediction - reference_prediction),axis = 1).reshape(-1)

def max_diff(alternate_prediction, reference_prediction):
    #returns max difference between absolute value of alt and ref predictions for each cell line


    return np.max(abs(alternate_prediction - reference_prediction),axis = 1).reshape(-1)


# In[7]:








2.4.1
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
sequence (InputLayer)           [(None, 1048576, 4)] 0                                            
__________________________________________________________________________________________________
stochastic_reverse_complement ( ((None, 1048576, 4), 0           sequence[0][0]                   
__________________________________________________________________________________________________
stochastic_shift (StochasticShi (None, 1048576, 4)   0           stochastic_reverse_complement[0][
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 1048576, 4)   0           stochastic_shift[0][0]           
______________________________________________________________________________________

successfully loaded


In [2]:
vcf_ALS = "/wynton/home/hernandez/shirondru/pollard_lab/GWASPredictions/PsychENCODE_GWAS_Predictions/Sei/TEST/PsychENCODE_GWASVariants_TEST_SeiNoHeader.vcf00.vcf/PsychENCODE_GWASVariants_TEST_SeiNoHeader.vcf00.vcf"

In [None]:
vcf 

In [None]:
seq_extractor = kipoiseq.extractors.VariantSeqExtractor(
    reference_sequence=FastaStringExtractor(fasta_file))
variant = kipoiseq.dataclasses.Variant(chrom='chr1', pos=84732076, ref='T', alt='TGGGC', id='')

interval = Interval('' + variant.chrom,
                        variant.pos, variant.pos)
interval = interval.resize(seq_length)
center = interval.center() - interval.start

reference = seq_extractor.extract(interval, [], anchor=center)
ref_N_composition = get_N_composition(reference)

alternate = seq_extractor.extract(interval, [variant], anchor=center)

In [31]:
# indel is an addition that adds 4 nt
# and we correctly see that the addition of 4nt to the alt allele frameshifts the sequence to the right by 4nt
#resulting in the GTGT missing from the end of the alt sequence, and only the 4 preceeding AACT appearing
print(alt[-4:])
print(ref[-8:])


#however, the alt and ref seqs correctly have same sequence at the beginning, before the indel. 
print(alt[0:4])
print(ref[0:4])

AACT
AACTGTGT
CTTT
CTTT


In [50]:
seq_extractor = kipoiseq.extractors.VariantSeqExtractor(
    reference_sequence=FastaStringExtractor(fasta_file))
variant = kipoiseq.dataclasses.Variant(chrom='chr1', pos=84732076, ref='T', alt='TGGGC', id='')

interval = Interval('' + variant.chrom,
                        variant.pos, variant.pos)
interval = interval.resize(seq_length)
center = interval.center() - interval.start

reference = seq_extractor.extract(interval, [], anchor=center)
ref_N_composition = get_N_composition(reference)

alternate = seq_extractor.extract(interval, [variant], anchor=center)

In [53]:
window = len(reference)

#first T nucleotide in ref seq is the reference allele
#the TGGGC in the alternate seq is the alt allele. GGGC is the addition
#printing the 100 nt to the right of the ref allele shows what exists after the indel in the ref seq
#doing the same with the alt seq shows the frameshift is implemented correctly according to section 5 of:
# http://samtools.github.io/hts-specs/VCFv4.2.pdf
print(reference[(window//2)-1:(window//2)+100])
print(alternate[(window//2)-1:(window//2)+104])



TTTTTTTTTTTTGAGACAGAGTCTCACTCTGTCACCCAGGCTGGAGTGCAGTGGCACGATCTTGGCTCACTGCAAGCTCCACCTCCCAGGTTCACACTATT
TGGGCTTTTTTTTTTTGAGACAGAGTCTCACTCTGTCACCCAGGCTGGAGTGCAGTGGCACGATCTTGGCTCACTGCAAGCTCCACCTCCCAGGTTCACACTATT


In [59]:
# indel is an addition that adds 4 nt
# and we correctly see that the addition of 4nt to the alt allele frameshifts the sequence to the right by 4nt
#resulting in the GTGT missing from the end of the alt sequence, and only the 4 preceeding AACT appearing
print(alternate[-4:])
print(reference[-8:])


#however, the alt and ref seqs correctly have same sequence at the beginning, before the indel. 
print(alternate[0:4])
print(reference[0:4])

AACT
AACTGTGT
CTTT
CTTT


In [58]:
len(reference[center:])

524288