In [None]:
import numpy as np
import pandas as pd
import gzip
import random
import argparse
from os.path import exists
import sys
import itertools


In [None]:
def translate_cds(cds):
    map = {"TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L",
           "TCT": "S", "TCC": "S", "TCA": "S", "TCG": "S",
           "TAT": "Y", "TAC": "Y",
           "TGT": "C", "TGC": "C", "TGG": "W",
           "CTT": "L", "CTC": "L", "CTA": "L", "CTG": "L",
           "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P",
           "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q",
           "CGT": "R", "CGC": "R", "CGA": "R", "CGG": "R",
           "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M",
           "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T",
           "AAT": "N", "AAC": "N", "AAA": "K", "AAG": "K",
           "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R",
           "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V",
           "GCT": "A", "GCC": "A", "GCA": "A", "GCG": "A",
           "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E",
           "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G"}

    aa_seq = ""
    for i in range(int(len(cds) / 3)):
        aa_seq += map[cds.upper()[3 * i:3 * i + 3]]

    return aa_seq


def make_nt_seq(aa_seq, pptness = False):
    import random
    # d = {"A": ["GCA", "GCC", "GCT", "GCG"]}

    d = {"A": ["GCT", "GCC", "GCA", "GCG"], "I": ["ATT", "ATC", "ATA"],
         "R": ["CGT", "CGC", "CGA", "CGG", "AGA", "AGG"], "L": ["CTT", "CTC", "CTA", "CTG", "TTA", "TTG"],
         "N": ["AAT", "AAC"], "K": ["AAA", "AAG"], "D": ["GAT", "GAC"], "M": ["ATG"],
         "F": ["TTT", "TTC"], "C": ["TGT", "TGC"], "P": ["CCT", "CCC", "CCA", "CCG"],
         "Q": ["CAA", "CAG"], "S": ["TCT", "TCC", "TCA", "TCG", "AGT", "AGC"],
         "E": ["GAA", "GAG"], "T": ["ACT", "ACC", "ACA", "ACG"],
         "W": ["TGG"],
         "G": ["GGT", "GGC", "GGA", "GGG"], "Y": ["TAT", "TAC"],
         "H": ["CAT", "CAC"], "V": ["GTT", "GTC", "GTA", "GTG"]}

    seq = ""
    for aa in aa_seq:
        if pptness == False:
            seq += random.choice(d[aa])
        else:
            # Find the one with the most pyrimidines
            potentials = d[aa]
            random.shuffle(potentials)
            best = -1
            for codon in potentials:
                pyr = codon.count("T") + codon.count("C")
                if pyr > best:
                    best = pyr
                    best_codon = codon

            seq += best_codon
    # print(seq)

    return seq

def make_random_seq(l):
    nts = ["A", "C", "G", "T"]

    return ''.join(random.choices(nts, k=l))


def make_ppt(l, frac):
    s = ""
    for _ in range(l):
        if random.uniform(0, 1) <= frac:
            s += random.choice(["C", "T"])
        else:
            s += random.choice(["A", "G"])

    return s


def random_mut(seq, rate, skip_splice_sites = True):
    out = []
    nts = ["A", "C", "G", "T"]

    for i, s in enumerate(seq):
        if i in [0, 1, len(seq)-2, len(seq)-1]:
            out.append(s)
        else:
            if random.uniform(0, 1) <= rate:
                out.append(random.choice(nts))
            else:
                out.append(s)

    return ''.join(out)


def remove_NY(initial_seq, pyrimidine_chance):
    # Convert string to list
    seq = list(initial_seq)
    new_seq = seq

    for i, character in enumerate(seq):
        if character.upper() == "N":
            new_seq[i] = random.choice(["a", "t", "c", "g"])
        elif character.upper() == "Y":
            new_seq[i] = make_ppt(1, pyrimidine_chance).lower()

    return ''.join(new_seq)


def mutate_codons(seq, aa_seq, n, mutable_codons, pptness = False):
    assert len(seq) == 3 * len(aa_seq)
    for _ in range(n):
        aa_to_mut = random.choice(mutable_codons)
        new_codon = make_nt_seq(aa_seq[aa_to_mut], pptness)
        new_seq = seq[0:3 * aa_to_mut] + new_codon + seq[3 * aa_to_mut + 3:]
        assert len(new_seq) == len(seq), "idiot"

        seq = new_seq

    return seq


def mut_cds(old_cds, mut_n, mutable_codons, original_purines, less_purines = True, pptness = False):
    """
    Note that mut_start and mut_end are in nucleotide coordinates
    """
    # Determine which codons can be mutated

    aa_seq = translate_cds(old_cds)

    for _ in range(500):
        new_cds = mutate_codons(old_cds, aa_seq, mut_n, mutable_codons, pptness)
        
        if less_purines == True:
            if sum([new_cds.upper().count(i) for i in purine_4mers]) <= original_purines:
                if new_cds != old_cds:
                    return new_cds
        else:
            if sum([new_cds.upper().count(i) for i in purine_4mers]) > original_purines:
                if new_cds != old_cds:
                    return new_cds

    print("No further CDS mutations.")
#     break_loop = True
    return old_cds

    assert 0 == 1, "Unable to make new sequence"


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--intron_csv", help="A csv of introns. First column = insert position, second = intron sequence. Don't use column headers")
    parser.add_argument("-c", "--initial_cds", help="The CDS sequence. Only lower case letters can be mutated.")
    parser.add_argument("-o", "--output", help="output filename")
    parser.add_argument("--upstream_seq", help="5'utr", default="")
    parser.add_argument("--downstream_seq", help="3'utr", default="")
    parser.add_argument("-n", "--n_iterations", type=int, default=500)
    parser.add_argument("--skip_tf", action="store_true")
    parser.add_argument("--chance_intron_move", default=0.3, type=float)
    parser.add_argument("--max_intron_move", default=15, type=int)
    parser.add_argument("--n_codons_mut", default=5, type=int)
    parser.add_argument("--chance_intron_mut", default=0.3, type=float)
    parser.add_argument("--intron_mut_rate", default=0.05, type=float, help="fraction of intronic bases that get mutated")
    args = parser.parse_args()


    args.aa = translate_cds(args.initial_cds)

    return args


def mut_noncoding(seq, positions_to_mut, n_mut):
    assert len([1 for a in list(seq) if a.islower()]) > 0, "Only lower case bases can be mutated!"

    seq = list(seq)
    new_seq = seq
    choices = random.choices(positions_to_mut, k=n_mut)
    for i in choices:
        new_seq[i] = random.choice(["a", "t", "c", "g"])

    return ''.join(new_seq)

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args("")

cds_file = "/camp/lab/ulej/home/users/farawar/GASR/export_reporter/ppig_deoptimised_4m_with_mSc.txt"
utr5_file = "/camp/lab/ulej/home/users/farawar/GASR/export_reporter/5utr.txt"
utr3_file = "/camp/lab/ulej/home/users/farawar/GASR/export_reporter/3utr.txt"
    
args.intron_csv = "/camp/lab/ulej/home/users/farawar/GASR/export_reporter/best_introns/best_introns.csv"
# args.intron_csv = "/camp/lab/ulej/home/users/farawar/GASR/export_reporter/best_introns/test_introns.csv"
    
with open(cds_file) as file:
    args.initial_cds = file.read().rstrip().lower()

with open(utr5_file) as file:
    args.upstream_seq = file.read().rstrip().lower()

with open(utr3_file) as file:
    args.downstream_seq = file.read().rstrip().lower()

args.output = "/camp/lab/ulej/home/users/farawar/GASR/export_reporter/best_introns/output/reporter"
args.n = 1000
args.skip_tf = False
args.chance_intron_move = 0.3
args.max_intron_move = 20
args.n_codons_mut = 5
args.chance_intron_mut = 0.5
args.intron_mut_rate = 0.05
args.aa = translate_cds(args.initial_cds)

purines = ['A', 'G']

raw_purine_4mers = [''.join(i) for i in itertools.product(purines, repeat = 4)] #Create all purinergic 4-mers
purine_4mers = [mer for mer in raw_purine_4mers[1:] if not "GGG" in mer] #No AAAA, no GGG.
purine_4mers = [i for i in purine_4mers if i.count('G') < 3] #Nothing with 3 Gs fuck that shit bro (G is a cringe purine)
original_purine_4mers = sum([args.initial_cds.upper().count(i) for i in purine_4mers])

In [None]:
if "--skip_tf" in sys.argv:
    load_tf = False
else:
    load_tf = True

if load_tf:
    from keras.models import load_model
    from pkg_resources import resource_filename
    from spliceai.utils import one_hot_encode

    paths = ('models/spliceai{}.h5'.format(x) for x in range(1, 6))
    models = [load_model(resource_filename('spliceai', x)) for x in paths]


    def get_probs(input_sequence):
        context = 10000
        x = one_hot_encode('N' * (context // 2) + input_sequence + 'N' * (context // 2))[None, :]
        y = np.mean([models[m].predict(x) for m in range(5)], axis=0)

        acceptor_prob = y[0, :, 1]
        donor_prob = y[0, :, 2]

        return acceptor_prob, donor_prob
else:
    print("NOT LOADING TENSOR FLOW! THIS IS JUST FOR TESTS")


    def get_probs(input_sequence):
        return [0] * len(input_sequence), [0] * len(input_sequence)



In [None]:
# Read in introns:
intron_d = {}
with open(args.intron_csv) as file:
    for line in file:
        split = line.rstrip().split(",")
        # assert len(split) == 2 or (len(split) == 3 and split[-1] == ""), "Too many columns"

        intron_d[int(split[0])] = split[1]

total_intron_l = sum([len(a) for a in intron_d.values()])


mutable_codons = []
for i in range(int(len(args.initial_cds)/3)):
    this_codon = args.initial_cds[3*i:3*i+3]
    total_lower = sum([1 for a in this_codon if a.islower()])
    assert total_lower in [0, 3], "Codons must either be upper or lower case. Cannot be a mix!"
    if total_lower == 3:
        mutable_codons.append(i)
        
best_score = -10000000
best_switches = {position:0 for position in intron_d.keys()}
best_intron_d = dict(intron_d)
n_purines = original_purine_4mers

In [None]:
for i in range(args.n):
# for i in range(50):

    if i % 50 == 0:
        print("iteration " + str(i))

    if i == 0:
        best_cds = args.initial_cds
        new_cds = args.initial_cds
        new_switches = dict(best_switches)
        new_intron_d = dict(intron_d)
        break_loop = False
    else:
        for _ in range(100):
            new_cds = mut_cds(best_cds, args.n_codons_mut, mutable_codons, n_purines, True)
            new_gc = sum([new_cds.upper().count(i) for i in ['G', 'C']])/len(new_cds)
            if new_gc - .5 < 0.05: # Don't allow GC content to go above 55?
                continue
            else:
                new_cds = best_cds
        
        if random.randint(0,100) < 100*args.chance_intron_mut:
            
            # pick a random key based on how strong the intron is. Add 0.05 to keep things a bit random.
            new_intron_d = dict(best_intron_d)

            all_keys = list(new_intron_d.keys())
            inverse_summed_probs = [(2 - (a + b)) + 0.05 for a, b in zip([donor_prob[a] for a in donors], [acceptor_prob[a] for a in acceptors])]
            choice_weights = inverse_summed_probs/sum(inverse_summed_probs)
            key_index = np.random.choice(range(len(all_keys)), size = 1, p = choice_weights, replace = False)[0]
            key_choice = all_keys[key_index]
            
            seq_list = list(new_intron_d[key_choice])

            #If the expected donors and acceptors are good, we don't want to mutate them.
            current_donor_probs = donor_prob[donors[key_index]:(acceptors[key_index] + 1)]
            current_donor_probs[0] = 0

            current_acceptor_probs = acceptor_prob[donors[key_index]:(acceptors[key_index] + 1)]
            current_acceptor_probs[-1] = 0

            kernel = np.ones(9)/9
            smooth_donors = np.convolve(current_donor_probs, kernel, mode = "same")[1:-1]
            smooth_acceptors = np.convolve(current_acceptor_probs, kernel, mode = "same")[1:-1]
            smooth_probs = smooth_donors + smooth_acceptors + 0.01 #Add a constant to slightly deprioritise splice sites

            #Normalise the probabilities
            smooth_probs = smooth_probs/sum(smooth_probs)

            int_len = len(new_intron_d[key_choice])
            n_mut = round(int_len * args.intron_mut_rate)
            mutation_index = np.random.choice(range(int_len), size = n_mut, p = smooth_probs, replace = False).astype('int')

            replacement_nt = random.choices(["A", "C", "G", "T"], k = n_mut)

            for a, i in enumerate(mutation_index):
                seq_list[i] = replacement_nt[a]
            
            new_intron_d[key_choice] = ''.join(seq_list)
        else:
            new_intron_d = dict(best_intron_d)
            
        if random.randint(0,100) < 100*args.chance_intron_move:
            # move an intron position
            new_switches = dict(best_switches)
            all_keys = list(new_switches.keys())
            key_choice = random.choice(all_keys)
            new_switches[key_choice] = random.randint(-args.max_intron_move, args.max_intron_move)
        else:
            new_switches = dict(best_switches)
    
    new_combined_seq = args.upstream_seq
    p = 0
    for intron_start, intron_seq in new_intron_d.items():
        intron_start2 = intron_start + new_switches[intron_start]
        new_combined_seq += new_cds[p:intron_start2]
        p = intron_start2
        new_combined_seq += intron_seq
    new_combined_seq += new_cds[p:]
    new_combined_seq += args.downstream_seq
    
    donors = []
    acceptors = []
    pos = len(args.upstream_seq)
    prev_intron_start = 0
    for intron_start, intron_seq in new_intron_d.items():
        intron_start2 = intron_start + new_switches[intron_start]
        pos += intron_start2-prev_intron_start
        prev_intron_start = intron_start2
        donors.append(pos-1)  # subtract 1 for donors idk why
        pos += len(intron_seq)
        acceptors.append(pos)

    acceptor_prob, donor_prob = get_probs(new_combined_seq)

    score = sum([acceptor_prob[a] for a in acceptors])
    score += sum([donor_prob[a] for a in donors])
    score += sum([-b for a, b in enumerate(acceptor_prob) if a not in acceptors])
    score += sum([-b for a, b in enumerate(donor_prob) if a not in donors])

    
    if score > best_score:
        best_score = score
        best_cds = new_cds
        n_purines = sum([best_cds.upper().count(i) for i in purine_4mers])
        best_switches = dict(new_switches)
        best_combined = new_combined_seq
        best_donor_prob = donor_prob
        best_acceptor_prob = acceptor_prob
        best_intron_d = dict(new_intron_d)
        best_donors = donors
        best_acceptors = acceptors

        print("")
        print(score)
        print(best_switches)
        #print(best_cds)
        print("donor_scores:")
        print([donor_prob[a] for a in donors])
        print("acceptor_scores:")
        print([acceptor_prob[a] for a in acceptors])
        print("Purine content:")
        print(n_purines)
        print("GC content:")
        print(sum([best_cds.upper().count(i) for i in ['G', 'C']])/len(best_cds))


In [None]:
with open(args.output + ".low_purine.output.csv", 'w') as file:
    file.write("key,value\n")
    file.write("full_sequence," + new_combined_seq + "\n")
    file.write("cds," + best_cds + "\n")
    file.write("score," + str(best_score) + "\n")
    l = len(args.upstream_seq)
    j = 0
    for key, value in best_intron_d.items():
        intron_start2 = best_donors[j]
        j+=1

        file.write("intron_" + str(j) + "_start," +  str(intron_start2) + "\n")
        file.write("intron_" + str(j) + "_end," +  str(intron_start2 + len(value)) + "\n")

with open(args.output + ".low_purine.predictions.csv", 'w') as file:
    file.write("position,donor,acceptor\n")
    yo = 0
    for d, a in zip(list(best_donor_prob), list(best_acceptor_prob)):
        file.write(str(yo) + "," + str(d) + "," + str(a) + "\n")
        yo += 1

In [None]:
breakpoints = [i + j for i, j in best_switches.items()]
midpoints = [round(j - ((j - breakpoints[i])/2)) for i, j in enumerate(breakpoints[1:])]

span = range(-30,30)

overlap_starts = []
overlap_ends = []

for i in midpoints:

    possible_breaks = [best_cds.upper()[((i - 15) + j):((i + 14) + j)] for j in span]
    break_purines = [sum([j.count(i) for i in purine_4mers]) for j in possible_breaks]

    index_max = np.argmax(break_purines)
    overlap_starts.append((i - 15) + span[index_max])
    overlap_ends.append((i + 14) + span[index_max])

overlap_cds = best_cds.lower()

for i, j in zip(overlap_starts, overlap_ends):
    overlap_cds = overlap_cds[:i] + overlap_cds[i:j].upper() + overlap_cds[j:]    

with open(args.output + ".overlap_cds_positions.csv", 'w') as file:
    file.write("overlap_start, overlap_end, overlap_sequence" + "\n")
    for i, j in zip(overlap_starts, overlap_ends):
        file.write(str(i) + ", " + str(j) + ", " + str(overlap_cds[i:j]) + "\n")

deopt_purine_content = sum([overlap_cds.upper().count(i) for i in purine_4mers])
deopt_score = score + 0
# deopt_score = 15.904753000573976
deopt_gc = sum([overlap_cds.upper().count(i) for i in ['G', 'C']])/len(overlap_cds)

mutable_codons = []
for i in range(int(len(overlap_cds)/3)):
    this_codon = overlap_cds[3*i:3*i+3]
    total_lower = sum([1 for a in this_codon if a.islower()])
#     assert total_lower in [0, 3], "Codons must either be upper or lower case. Cannot be a mix!"
    if total_lower > 0: #If there is a single capital letter, protect the whole codon.
        mutable_codons.append(i)

In [None]:
for i in range(args.n):

    if i % 50 == 0:
        print("iteration " + str(i))

    if i == 0:
        best_cds_p = overlap_cds
        new_cds_p = overlap_cds
        best_score = deopt_score
        curent_gc = deopt_gc
    else:
        for _ in range(100):
            new_cds_p = mut_cds(best_cds_p, args.n_codons_mut, mutable_codons, n_purines, False)
            new_gc = sum([new_cds_p.upper().count(i) for i in ['G', 'C']])/len(new_cds_p)
            if abs(new_gc - deopt_gc) < 0.03:
                continue
            else:
                new_cds_p = best_cds_p
    
    new_combined_seq_p = args.upstream_seq
    p = 0
    for intron_start, intron_seq in best_intron_d.items():
        intron_start2 = intron_start + best_switches[intron_start]
        new_combined_seq_p += new_cds_p[p:intron_start2]
        p = intron_start2
        new_combined_seq_p += intron_seq
    new_combined_seq_p += new_cds_p[p:]
    new_combined_seq_p += args.downstream_seq
    
    acceptor_prob, donor_prob = get_probs(new_combined_seq_p)

    score = sum([acceptor_prob[a] for a in acceptors])
    score += sum([donor_prob[a] for a in donors])
    score += sum([-b for a, b in enumerate(acceptor_prob) if a not in acceptors])
    score += sum([-b for a, b in enumerate(donor_prob) if a not in donors])
    
    if not (sum([acceptor_prob[a] < 0.98 for a in acceptors]) & sum([donor_prob[a] < 0.98 for a in donors])):
        if (deopt_score - score) < 0.1:
            best_cds_p = new_cds_p
            n_purines = sum([best_cds_p.upper().count(i) for i in purine_4mers])
            best_combined_p = new_combined_seq_p
            best_donor_prob = donor_prob
            best_acceptor_prob = acceptor_prob
            print("")
            print(score)
            print("Purine content")
            print(n_purines)
            print("GC content")
            print(new_gc)
            print("donor_scores:")
            print([donor_prob[a] for a in donors])
            print("acceptor_scores:")
            print([acceptor_prob[a] for a in acceptors])


In [None]:
with open(args.output + ".high_purine.output.csv", 'w') as file:
    file.write("key,value\n")
    file.write("full_sequence," + new_combined_seq_p + "\n")
    file.write("cds," + best_cds_p + "\n")
    file.write("score," + str(best_score) + "\n")
    l = len(args.upstream_seq)
    j = 0
    for key, value in best_intron_d.items():
        intron_start2 = best_donors[j]
        j+=1

        file.write("intron_" + str(j) + "_start," +  str(intron_start2) + "\n")
        file.write("intron_" + str(j) + "_end," +  str(intron_start2 + len(value)) + "\n")

with open(args.output + ".high_purine.predictions.csv", 'w') as file:
    file.write("position,donor,acceptor\n")
    yo = 0
    for d, a in zip(list(best_donor_prob), list(best_acceptor_prob)):
        file.write(str(yo) + "," + str(d) + "," + str(a) + "\n")
        yo += 1