In [1]:
import sys
import gzip
import argparse
import pickle
from statistics import mean, median
import numpy as np
import pandas as pd
from Bio.Seq import Seq
import pyfastx



chrom = 'chr1'
strand = '-'
junc = [(1395537, 1398233), (1389477, 1390230), (1388743, 1388816), (1388743, 1390230), (1388065, 1390230), (1388743, 1388950), (1398342, 1398597), (1387869, 1387954), (1387582, 1387777), (1391553, 1392679), (1388065, 1388950), (1391575, 1392679), (1398418, 1398597), (1390865, 1392679), (1390865, 1391297), (1393460, 1395394), (1390865, 1393396), (1388510, 1388626), (1388065, 1388492), (1391575, 1392790), (1392803, 1393396), (1390309, 1390315), (1388065, 1388626), (1398671, 1399019), (1390371, 1390459), (1395514, 1398233), (1390865, 1392790), (1389473, 1390230), (1390563, 1390766), (1398671, 1398988), (1389047, 1390230), (1388830, 1390230)]
start_codons = {(1398663, 1398665), (1399304, 1399306), (1390856, 1390858)}
stop_codons = {(1387231, 1387233), (1392782, 1392784)}
gene_name = 'CCNL2'
exonLcutoff=1000
verbose=True
fa = pyfastx.Fasta('genome.fa')

In [9]:
junc_pass, junc_fail, proteins = solve_NMD(chrom, strand, junc, start_codons, stop_codons,gene_name)

Depth 1, Seed L = 2
Depth 2, Seed L = 8
Depth 3, Seed L = 46
Depth 4, Seed L = 116
Depth 5, Seed L = 107
Depth 6, Seed L = 51
Depth 7, Seed L = 53
Depth 8, Seed L = 8
Depth 9, Seed L = 6
junction pass:(1390865, 1393396)
junction pass:(1388065, 1388950)
junction pass:(1389047, 1390230)
junction pass:(1390309, 1390315)
junction pass:(1388065, 1388492)
junction pass:(1388510, 1388626)
junction pass:(1388743, 1390230)
junction pass:(1388830, 1390230)
junction pass:(1388743, 1388816)
junction long_exon:(1388743, 1388816)
junction long_exon:(1390865, 1392790)
junction long_exon:(1390865, 1391297)
junction long_exon:(1391575, 1392679)
junction long_exon:(1391575, 1392790)
junction long_exon:(1390865, 1393396)


In [8]:
def ptc_pos_from_prot(prot, sub):
    to_return = []
    start = 0
    while True:
        start = prot.find(sub, start)
        if start == -1: return to_return
        else:
            to_return.append(start)
        start += 1   
   

def check_utrs(junc,utrs):
    '''
    checks if junction is close or within 100bp of UTRs
    '''
    for s1,s2 in list(utrs):
        if abs(junc[0]-s1) < 100 or abs(junc[1]-s2) < 100:
            return True
    return False

def solve_NMD(chrom, strand, junc, start_codons, stop_codons,gene_name, 
              verbose = True, exonLcutoff = 1000):
    '''
    Compute whether there is a possible combination that uses the junction without
    inducing a PTC. We start with all annotated stop codon and go backwards.
    '''
    
    global fa
    
    seed = []

    junc.sort()
    if strand == "+":
        junc.reverse()
        
    """Quinn Comment: Adds all 'stop codons' to a nested list called seed""" 
    ##in an individual transcript    
    for c in stop_codons:
        if strand == "+":
            seed.append([c[1]])
        else:
            seed.append([c[0]])

    # seed starts with just stop codon and then a possible 3'ss-5'ss junction
    # without introducing a PTC [stop_codon,3'ss, 5'ss, 3'ss, ..., start_codon]

    junc_pass = {'normal':{}, 'long_exon':{}}
    junc_fail = {}
    path_pass = {'normal':[], 'long_exon':[]}
    proteins = []
    short_ptcs = {}

    dic_terminus = {'normal': {}, 'long_exon': {}}

    depth = 0

    """Quinn Comment: while our seed length is greater than 0 - which means we have charted all possible paths through 
    all junctions ending in a stop codon (or there is an exon longer than 1000 bp and we have no complete paths)"""
    while len(seed) > 0:
        new_seed = []
        final_check = []
        depth += 1
        if verbose:
            sys.stdout.write("Depth %s, Seed L = %s\n"%(depth, len(seed)))
        #print(start_codons, [s[-1] for s in seed][-10:], len(junc))
        framepos = {}
                    
        for s in seed:
            # first check that the seed paths are good        
            bool_ptc = False
            leftover = ''
            if len(s) > 0:                
                leftover = Seq("")
                allprot = Seq("")

                """Quinn Comment: loop through the exons, calculating lengths"""
                for i in range(0, len(s)-1, 2):
                    exon_coord = s[i:i+2]
                    exon_coord.sort()
                    exon_coord = tuple(exon_coord)
                    exlen = exon_coord[1]-exon_coord[0]


                    """Quinn Comment: find start position relative to named start of this exon and translate to protein"""
                    startpos = (len(leftover)+exlen+1)%3
                    if strand == '+':
                        seq = Seq(fa.fetch(chrom, (exon_coord[0],exon_coord[1])))+leftover 
                        """Quinn Comment: exon length rule"""
                        if exlen + 1 > 407:
                            prot = seq[startpos:].translate(stop_symbol = '@')
                        else:
                            prot = seq[startpos:].translate()

                            #store ptc position for checking with long_exon PTCs later
                            ptc_pos = ptc_pos_from_prot(prot, '*')
                            ptc_coord = [exon_coord[1] - (x+1)*3 - len(leftover) for x in ptc_pos]
                            
                            #don't want to keep the actual stop codon
                            if i == 0 and len(ptc_coord) > 0: ptc_coord.pop(-1)
                            for k in ptc_coord:
                                short_ptcs[k] = exon_coord
                        leftover = seq[:startpos]                                                                                                               
                        allprot = prot+allprot  
                    else:
                        seq = leftover+Seq(fa.fetch(chrom, (exon_coord[0],exon_coord[1])))
                        if startpos > 0:
                            leftover = seq[-startpos:]
                        else:
                            leftover = Seq("")
                        seq = seq.reverse_complement()

                        if exlen + 1 > 407:
                            prot = seq[startpos:].translate(stop_symbol = '@')
                        else:
                            prot = seq[startpos:].translate()
                            ptc_pos = ptc_pos_from_prot(prot, '*')
                            ptc_coord = [exon_coord[0] + (x+1)*3 - len(leftover) for x in ptc_pos]
                            if i == 0 and len(ptc_coord) > 0: ptc_coord.pop(-1)
                            for k in ptc_coord:
                                short_ptcs[k] = exon_coord
                        
                        allprot = prot+allprot

                    #found a PTC in this transcript if any element but the last is a stop codon    
                    bool_ptc = "*" in allprot[:-1]
                    bool_long_exon = '@' in allprot[:-1]

                    



            """Quinn Comment: if we found a PTC, add all intron coordinate pairs involved in the transcript to junc_fail"""        
            if bool_ptc:
                #This transcript failed
                for i in range(1, len(s)-1, 2):                                                                                                                  
                    j_coord = s[i:i+2]                                                                                                                           
                    j_coord.sort()                                                                                                                             
                    j_coord = tuple(j_coord)                                                                                                                     
                    if j_coord not in junc_fail:                                                                                                                 
                        junc_fail[j_coord] = 0                                                                                                                   
                    junc_fail[j_coord] += 1  

                continue
        
            # passed
            """Quinn Comment: if we don't just have a stop codon, create a terminus for this 
            seed at the last 3' splice site or start codon; terminus is last two coordinates and the reading frame, 
            used for dynamic programming later"""
            if len(s) > 2:
                terminus = (s[-2],s[-1],leftover)
                
                if not bool_long_exon:
                    if terminus in dic_terminus['normal']:
                        dic_terminus['normal'][terminus].append(tuple(s))
                        continue
                    else:
                        dic_terminus['normal'][terminus] = [tuple(s)]
                else:
                    if terminus in dic_terminus['long_exon']:
                        dic_terminus['long_exon'][terminus].append(tuple(s))
                        continue
                    else:
                        dic_terminus['long_exon'][terminus] = [tuple(s)]
            
            last_pos = s[-1]
            
            #remove any termini that rely on a long_exon PTC that is also a short exon PTC
            for k in dic_terminus['long_exon']:
                to_remove = []
                for v in dic_terminus['long_exon'][k]:
                    if len(v) < 2: continue
                    for j in range(0, len(v)-1, 2):
                        exon_coord = [v[j], v[j+1]]
                        exon_coord.sort()
                        exon_coord = tuple(exon_coord)
                        if exon_coord in short_ptcs.values():
                            to_remove.append(v)
                            break
                for rem in to_remove:
                    dic_terminus['long_exon'][k].remove(rem)

            """Quinn Comment: check the last position of our seed to see if it is close to a start codon, within a potential exon's length,
            and add your seed plus this start codon to final_check """
            for start in start_codons:                
                #print("start", start, abs(last_pos-start[0]))
                if strand == "+" and last_pos > start[0] and abs(last_pos-start[0]) < exonLcutoff:
                    final_check.append(s+[start[0]])
                elif strand == "-" and last_pos < start[1] and abs(last_pos-start[1]) < exonLcutoff:
                    final_check.append(s+[start[1]]) 

            """Quinn Comment: add all possible places to go from our last_pos to the seed (nested list)"""
            for j0,j1 in junc:                
                if strand == "+" and last_pos > j1 and abs(last_pos-j1) < exonLcutoff:
                    new_seed.append(s+[j1,j0])
                #print("junction", (j0,j1), abs(last_pos-j0))
                if strand == "-" and last_pos < j0 and abs(last_pos-j0) < exonLcutoff: 
                    new_seed.append(s+[j0,j1])
                    
        """Quinn Comment: Exited from s in seed loop, now we check our final_checks of the full paths, we do not
        eliminate paths based on presence of a PTC, rather we classify full complete paths without PTCs if they exist"""
        # check that the possible final paths are good
        for s in final_check:
            leftover = Seq("")
            allprot = Seq("")
            for i in range(0, len(s)-1, 2):
                exon_coord = s[i:i+2]
                exon_coord.sort()
                exon_coord = tuple(exon_coord)
                exlen = exon_coord[1]-exon_coord[0]
                startpos = (len(leftover)+exlen+1)%3
                if strand == "+":
                    seq = Seq(fa.fetch(chrom, (exon_coord[0],exon_coord[1])))+leftover
                    leftover = seq[:startpos]  
                    if exlen + 1 > 407:
                        prot = seq[startpos:].translate(stop_symbol = '@')
                        #only allow long_exon tag to persist if PTCs introduced are solely present in a long exon 
                        ptc_pos = ptc_pos_from_prot(prot, '@')
                        ptc_coord = [exon_coord[1] - (i+1)*3 - len(leftover) for i in ptc_pos]
                        check = [i in short_ptcs.keys() for i in ptc_coord]
                        if sum(check) > 0:
                            prot = seq[startpos:].translate()
                    else:
                        prot = seq[startpos:].translate()
                    allprot = prot+allprot
                else:
                    seq = leftover+Seq(fa.fetch(chrom, (exon_coord[0],exon_coord[1])))
                    if startpos > 0:                                                                                                    
                        leftover = seq[-startpos:]                                    
                    else:
                        leftover = Seq("")
                    seq = seq.reverse_complement()                                                                                                           
                    if exlen + 1 > 407:
                        prot = seq[startpos:].translate(stop_symbol = '@')
                        ptc_pos = ptc_pos_from_prot(prot, '@')
                        ptc_coord = [exon_coord[0] + (i+1)*3 - len(leftover) for i in ptc_pos]
                        check = [i in short_ptcs.keys() for i in ptc_coord]
                        if sum(check) > 0:
                            prot = seq[startpos:].translate()
                    else:
                        prot = seq[startpos:].translate()                                                                                                       
                    allprot = prot+allprot                    
            bool_ptc = "*" in allprot[:-1]
            bool_long_exon = '@' in allprot[:-1]
        
            """Quinn Comment: Classify seed + start codon as a passing path if no PTCs found in previous block of code"""
            if not bool_ptc:
                # all pass
                proteins.append("\t".join([gene_name,chrom,strand, "-".join([str(x) for x in s]), str(allprot)])+'\n')
                #print("ALL PASS %s"%(s))
                if bool_long_exon:
                    path_pass['long_exon'].append(tuple(s))
                else:
                    path_pass['normal'].append(tuple(s))
                for i in range(1, len(s), 2):
                    j_coord = s[i:i+2]
                    j_coord.sort()
                    j_coord = tuple(j_coord)
                    if not bool_long_exon: 
                        if j_coord not in junc_pass['normal']:
                            junc_pass['normal'][j_coord] = 0
                        junc_pass['normal'][j_coord] += 1
                    else:
                        if j_coord not in junc_pass['long_exon']:
                            junc_pass['long_exon'][j_coord] = 0
                        junc_pass['long_exon'][j_coord] += 1

        seed = new_seed


    """Quinn Comment: OUT OF WHILE LOOP through all possible paths/seeds; 
    check all termini to see if they are part of a full path that has been classified as passing"""
    while True:
        new_paths = []
        for terminus in dic_terminus['normal']:
            terminus_pass = False
            for path_subset in dic_terminus['normal'][terminus]:
                for path in path_pass['normal']:
                    if path[:len(path_subset)] == path_subset:
                        terminus_pass = True
                        break
            #print(terminus, terminus_pass)

            """Quinn Comment: if our terminus is part of a passing path, we want to make sure if is reflected in passing paths and
            add the associate junctions to junc_pass, only if they are not present"""
            if terminus_pass:
                subsets_to_check = dic_terminus['normal'][terminus]
                for path_subset in subsets_to_check:
                    if path_subset in path_pass['normal']: continue
                    new_paths.append(path_subset)
                    path_pass['normal'].append(path_subset)
                    for i in range(1, len(path_subset), 2):
                        j_coord = list(path_subset[i:i+2])
                        j_coord.sort()
                        j_coord = tuple(j_coord)
                        if j_coord not in junc_pass['normal']:
                            junc_pass['normal'][j_coord] = 0
                            if verbose:
                                sys.stdout.write("junction pass:" + str(j_coord) + '\n')

        """Quinn Comment: we could have a new path_pass added, so our while loop checks again to see if there are any new paths 
        that are now going to be passing considering our additions"""
        if len(new_paths) == 0:
            break

    #Now do the same for the long_exon termini and paths
    #We do not care if a terminus is long_exon or not, as long as it ends up leading to a long_exon path
    combined_keys = dic_terminus['long_exon'].keys() | dic_terminus['normal'].keys()
    combined_termini = {key: dic_terminus['long_exon'].get(key, []) + dic_terminus['normal'].get(key, []) for key in combined_keys}  
    while True:
        new_paths = []
        for terminus in combined_termini:
            terminus_pass = False
            for path_subset in combined_termini[terminus]:
                for path in path_pass['long_exon']:
                    if path[:len(path_subset)] == path_subset:
                        terminus_pass = True
                        break
            #print(terminus, terminus_pass)

            """Quinn Comment: if our terminus is part of a passing path, we want to make sure if is reflected in passing paths and
            add the associate junctions to junc_pass, only if they are not present"""
            if terminus_pass:
                subsets_to_check = combined_termini[terminus]
                for path_subset in subsets_to_check:
                    if path_subset in path_pass['long_exon']: continue
                    new_paths.append(path_subset)
                    path_pass['long_exon'].append(path_subset)
                    for i in range(1, len(path_subset), 2):
                        j_coord = list(path_subset[i:i+2])
                        j_coord.sort()
                        j_coord = tuple(j_coord)
                        if j_coord not in junc_pass['long_exon']:
                            junc_pass['long_exon'][j_coord] = 0
                            if verbose:
                                sys.stdout.write("junction long_exon:" + str(j_coord) + '\n')

        """Quinn Comment: we could have a new path_pass added, so our while loop checks again to see if there are any new paths 
        that are now going to be passing considering our additions"""
        if len(new_paths) == 0:
            break
            
    return junc_pass, junc_fail, proteins