In [None]:
import tensorflow as tf
import os
from joblib import Parallel,delayed
import tensorflow_hub as hub
import joblib
import gzip
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import biomart
from scipy.stats import zscore
from pandas import HDFStore
import h5py
import itertools
import argparse
import sys
import datetime
model_path = "/wynton/home/hernandez/shirondru/pollard_lab/enformer"
fasta_file = '/wynton/home/hernandez/shirondru/pollard_lab/data/hg38_genome.fa'
clinvar_vcf = '/wynton/home/hernandez/shirondru/pollard_lab/data/clinvar.vcf.gz'


# cols = ['chrom','txStart','txEnd','ENST','strand',,'cdsStart','cdsEnd','exonCount','exonStarts','exonEnds','ENST_y','A','B','C','D','Gene','E','F']
gene_annotations = pd.read_csv("/wynton/home/hernandez/shirondru/pollard_lab/data/knownGene.tsv",sep = '\t') #from ucsc genome browser hg38


# Download targets from Basenji2 dataset 
# Cite: Kelley et al Cross-species regulatory sequence activity prediction. PLoS Comput. Biol. 16, e1008050 (2020).
df_targets = pd.read_csv("/wynton/home/hernandez/shirondru/pollard_lab/data/enformer_df_targets.csv")

# In[2]:


# @title `Enformer`, `EnformerScoreVariantsNormalized`, `EnformerScoreVariantsPCANormalized`,
SEQUENCE_LENGTH = 393216

class Enformer:

  def __init__(self, tfhub_url):
    self._model = hub.load(tfhub_url).model

  def predict_on_batch(self, inputs):
    predictions = self._model.predict_on_batch(inputs)
    return {k: v.numpy() for k, v in predictions.items()}

  @tf.function
  def contribution_input_grad(self, input_sequence,
                              target_mask, output_head='human'):
    input_sequence = input_sequence[tf.newaxis]

    target_mask_mass = tf.reduce_sum(target_mask)
    with tf.GradientTape() as tape:
      tape.watch(input_sequence)
      prediction = tf.reduce_sum(
          target_mask[tf.newaxis] *
          self._model.predict_on_batch(input_sequence)[output_head]) / target_mask_mass

    input_grad = tape.gradient(prediction, input_sequence) * input_sequence
    input_grad = tf.squeeze(input_grad, axis=0)
    return tf.reduce_sum(input_grad, axis=-1)


class EnformerScoreVariantsRaw:

  def __init__(self, tfhub_url, organism='human'):
    self._model = Enformer(tfhub_url)
    self._organism = organism
  
  def predict_on_batch(self, inputs):
    ref_prediction = self._model.predict_on_batch(inputs['ref'])[self._organism]
    alt_prediction = self._model.predict_on_batch(inputs['alt'])[self._organism]

    return alt_prediction.mean(axis=1) - ref_prediction.mean(axis=1)


class EnformerScoreVariantsNormalized:

  def __init__(self, tfhub_url, transform_pkl_path,
               organism='human'):
    assert organism == 'human', 'Transforms only compatible with organism=human'
    self._model = EnformerScoreVariantsRaw(tfhub_url, organism)
    with tf.io.gfile.GFile(transform_pkl_path, 'rb') as f:
      transform_pipeline = joblib.load(f)
    self._transform = transform_pipeline.steps[0][1]  # StandardScaler.
    
  def predict_on_batch(self, inputs):
    scores = self._model.predict_on_batch(inputs)
    return self._transform.transform(scores)


class EnformerScoreVariantsPCANormalized:

  def __init__(self, tfhub_url, transform_pkl_path,
               organism='human', num_top_features=500):
    self._model = EnformerScoreVariantsRaw(tfhub_url, organism)
    with tf.io.gfile.GFile(transform_pkl_path, 'rb') as f:
      self._transform = joblib.load(f)
    self._num_top_features = num_top_features
    
  def predict_on_batch(self, inputs):
    scores = self._model.predict_on_batch(inputs)
    return self._transform.transform(scores)[:, :self._num_top_features]


# TODO(avsec): Add feature description: Either PCX, or full names.


# In[51]:


# @title `variant_centered_sequences`

# @title `variant_centered_sequences`

class FastaStringExtractor:
    
    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}

    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),
                                    )
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream

    def close(self):
        return self.fasta.close()


def variant_generator(vcf_file, gzipped=False,skip_lines = 0,max_lines = np.inf):
  """Yields a kipoiseq.dataclasses.Variant for each row in VCF file."""
  def _open(file):
    return gzip.open(vcf_file, 'rt') if gzipped else open(vcf_file)
    
  with _open(vcf_file) as f:
    for idx,line in enumerate(f):
      if line.startswith('#') or line.startswith('CHROM'):
        continue
      
      chrom, pos, id, ref, alt_list = line.split('\t')[:5]

        # Split ALT alleles and return individual variants as output.
      for alt in alt_list.split(','):
        yield kipoiseq.dataclasses.Variant(chrom=chrom, pos=pos,
                                           ref=ref, alt=alt, id=id)
       



def one_hot_encode(sequence):
  return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)


def variant_centered_sequences(vcf_file, sequence_length, gzipped=False,
                               chr_prefix=''):
  seq_extractor = kipoiseq.extractors.VariantSeqExtractor(
    reference_sequence=FastaStringExtractor(fasta_file))

  for variant in variant_generator(vcf_file, gzipped=gzipped):
    interval = Interval(chr_prefix + variant.chrom,
                        variant.pos, variant.pos)
    interval = interval.resize(sequence_length)
    center = interval.center() - interval.start

    reference = seq_extractor.extract(interval, [], anchor=center)
    alternate = seq_extractor.extract(interval, [variant], anchor=center)

    yield {'inputs': {'ref': one_hot_encode(reference),
                      'alt': one_hot_encode(alternate)},
           'metadata': {'chrom': chr_prefix + variant.chrom,
                        'pos': variant.pos,
                        'id': variant.id,
                        'ref': variant.ref,
                        'alt': variant.alt}}
# In[63]:


model = Enformer(model_path)
fasta_extractor = FastaStringExtractor(fasta_file)


# In[55]:



#take variable number of MAF>0.05 1KG variants and get model(alt) - model(ref) predictions
#take sum or max along sequence axis to get variant score for each track. Save these scores + variant position and allele metadata




In [6]:
vcf_ALS = "/wynton/home/hernandez/shirondru/pollard_lab/GWASPredictions/PsychENCODE_GWAS_Predictions/Sei/TEST/PsychENCODE_GWASVariants_TEST_SeiNoHeader.vcf00.vcf/PsychENCODE_GWASVariants_TEST_SeiNoHeader.vcf00.vcf"

In [31]:
# indel is an addition that adds 4 nt
# and we correctly see that the addition of 4nt to the alt allele frameshifts the sequence to the right by 4nt
#resulting in the GTGT missing from the end of the alt sequence, and only the 4 preceeding AACT appearing
print(alt[-4:])
print(ref[-8:])


#however, the alt and ref seqs correctly have same sequence at the beginning, before the indel. 
print(alt[0:4])
print(ref[0:4])

AACT
AACTGTGT
CTTT
CTTT


In [50]:
seq_extractor = kipoiseq.extractors.VariantSeqExtractor(
    reference_sequence=FastaStringExtractor(fasta_file))
variant = kipoiseq.dataclasses.Variant(chrom='chr1', pos=84732076, ref='T', alt='TGGGC', id='')

interval = Interval('' + variant.chrom,
                        variant.pos, variant.pos)
interval = interval.resize(seq_length)
center = interval.center() - interval.start

reference = seq_extractor.extract(interval, [], anchor=center)
ref_N_composition = get_N_composition(reference)

alternate = seq_extractor.extract(interval, [variant], anchor=center)

In [53]:
window = len(reference)

#first T nucleotide in ref seq is the reference allele
#the TGGGC in the alternate seq is the alt allele. GGGC is the addition
#printing the 100 nt to the right of the ref allele shows what exists after the indel in the ref seq
#doing the same with the alt seq shows the frameshift is implemented correctly according to section 5 of:
# http://samtools.github.io/hts-specs/VCFv4.2.pdf
print(reference[(window//2)-1:(window//2)+100])
print(alternate[(window//2)-1:(window//2)+104])



TTTTTTTTTTTTGAGACAGAGTCTCACTCTGTCACCCAGGCTGGAGTGCAGTGGCACGATCTTGGCTCACTGCAAGCTCCACCTCCCAGGTTCACACTATT
TGGGCTTTTTTTTTTTGAGACAGAGTCTCACTCTGTCACCCAGGCTGGAGTGCAGTGGCACGATCTTGGCTCACTGCAAGCTCCACCTCCCAGGTTCACACTATT


In [59]:
# indel is an addition that adds 4 nt
# and we correctly see that the addition of 4nt to the alt allele frameshifts the sequence to the right by 4nt
#resulting in the GTGT missing from the end of the alt sequence, and only the 4 preceeding AACT appearing
print(alternate[-4:])
print(reference[-8:])


#however, the alt and ref seqs correctly have same sequence at the beginning, before the indel. 
print(alternate[0:4])
print(reference[0:4])

AACT
AACTGTGT
CTTT
CTTT


In [58]:
len(reference[center:])

524288