In [1]:
#pairwise-bootstrapping?
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import subprocess
import ete3

import os 

alphabet = {"-":0, "A": 1, "R": 2, "N":3, "D":4, "C":5, "Q":6, "E":7, "G":8,
            "H":9, "I":10, "L":11, "K":12, "M":13, "F":14, "P":15, 
            "S":16, "T":17, "W":18, "Y":19, "V":20, "X":21}

import math
INF = math.inf

import random

from Bio import Align

from IPython.display import clear_output

In [None]:
def file_to_seq_dict(fasta_path):
    seqDict = {}
    with open(fasta_path) as f:
        cur = ""
        for line in f:
            if line[0] == ">":
                cur = line[1:].strip()
                seqDict[cur] = ""
            else:
                seqDict[cur] += line.strip().replace("X","-")
    return seqDict
                

In [1]:
def seq_dict_to_matrix(seqDict):
    num_seqs = len(seqDict)
    alignment_len = len(list(seqDict.values())[0])
    labels = np.empty((num_seqs, 1), dtype=object)
    sequence_matrix = np.empty((num_seqs, alignment_len), dtype=object)
    i = 0
    for label in seqDict:
        labels[i, 0] = label
        seq = seqDict[label]
        sequence_matrix[i] = np.array([x.strip() for x in seq])
        i += 1
    return sequence_matrix, labels


In [1]:
def makeShell(name, commands, nodes='1', cores='1', mem='10', days='7', out=None):
    if out == None:
        out = name
    file = open(name + '.sh', 'w')
    file.write( "#!/bin/bash\n\n"
                "#SBATCH -p sched_mit_g4nier\n"                                                                          
                "#SBATCH -t " + days + "-00:00:00\n"
                "#SBATCH -N " + nodes + "\n"
                "#SBATCH -n " + cores + "\n"
                "#SBATCH --mem=" + mem + "G\n"
                "#SBATCH -J " + name + "\n"
                "#SBATCH -o " + out + ".out\n" +
                '\n'.join(commands))
    file.close()
    return name + '.sh'


In [4]:
def make_tree(alignmentpath, name):
    commands = ["module add engaging/iqtree/1.6.3",
                "iqtree -s " + alignmentpath + 
                " -nt 5 -bb 1000 -alrt 1000 -m MFP -mset WAG,LG,JTT -msub nuclear"]
    iqtreeSh = makeShell(name, commands, mem = "10", cores="5")
    subprocess.run(["sbatch", iqtreeSh])

In [3]:
def get_supports(tree_path, metric=None):
    t = ete3.Tree(tree_path, format=1)
    tot = []
    for node in t.traverse():
        if node.is_leaf() or node.name=="":
            continue
        #print(node.name)
        alrt, bb = node.name.split("/")
        if metric=="alrt":
            tot.append(float(alrt))
        elif metric=="bb":
            tot.append(float(bb))
    return tot

In [2]:
def make_unique(items):
    unique_items = []
    for item in items:
        if item not in unique_items:
            unique_items.append(item)
    return unique_items

def vector_angle(v1, v2):
    v1_norm, v2_norm = v1/np.linalg.norm(v1), v2/np.linalg.norm(v2)
    return 180 * np.arccos(np.clip(np.dot(v1_norm, v2_norm), -1.0, 1.0)) / np.pi

def cos_vector_angle(v1, v2):
    v1_norm, v2_norm = v1/np.linalg.norm(v1), v2/np.linalg.norm(v2)
    return np.clip(np.dot(v1_norm, v2_norm), -1.0, 1.0)

def cos_vector_angle_exclude(v1, v2, exclude=0):
    #print(v1.shape, v2.shape)
    if v1[exclude] == 1 or v2[exclude] == 1:
        return cos_vector_angle(v1, v2)
    new_v1, new_v2 = np.concatenate((v1[:exclude],v1[exclude+1:])), np.concatenate((v2[:exclude],v2[exclude+1:]))
    return cos_vector_angle(new_v1, new_v2)

def cos_vector_angle_dilute(v1, v2, dilute=0):
    #print(v1.shape, v2.shape)
    if v1[dilute] == 1 or v2[dilute] == 1:
        return cos_vector_angle(v1, v2)
    v1_gaps, v2_gaps = v1[dilute], v2[dilute]
    redistribute_v1, redistribute_v2 = v1.copy(), v2.copy()
    redistribute_v1 = redistribute_v1 * (((1-v1_gaps)+(v1_gaps/2))/(1-v1_gaps))
    redistribute_v2 = redistribute_v2 * (((1-v2_gaps)+(v2_gaps/2))/(1-v2_gaps))
    redistribute_v1[dilute] = v1_gaps/2
    redistribute_v2[dilute] = v2_gaps/2
    return cos_vector_angle(redistribute_v1, redistribute_v2)

def id_poor_sites(alignmentFrequencyMatrix, sequenceFrequencyMatrix, gapDistribution, 
                  metric=cos_vector_angle, threshold=0.2):
    scores = get_site_fit_metrics(alignmentFrequencyMatrix, sequenceFrequencyMatrix, gapDistribution, metric=metric)
    #print(scores, len(scores))
    return np.argwhere(scores<threshold)[:,0]

def get_site_fit_metrics(alignmentFrequencyMatrix, sequenceFrequencyMatrix, gapDistribution, metric=vector_angle):
    alphabet_size, alignment_len = alignmentFrequencyMatrix.shape
    alphabet_size, sequence_len = sequenceFrequencyMatrix.shape
    site = 0
    scores = []
    for i in range(alignment_len):
        if gapDistribution[i] == "-":
            scores.append(INF)
        #print(gapDistribution[i])
        else:
            scores.append(metric(sequenceFrequencyMatrix[:, site], alignmentFrequencyMatrix[:, i]))
            site += 1
    #print([x for x in scores if x != INF])
    return np.array(scores)

def get_aa_frequencies_by_site(alignmentMatrix, alphabet=alphabet):
    num_seqs, alignment_len = alignmentMatrix.shape
    alphabet_size = len(alphabet)
    frequencyMatrix = np.zeros((alphabet_size, alignment_len))
    for i in range(alignment_len):
        for j in range(num_seqs):
            frequencyMatrix[alphabet[alignmentMatrix[j,i]], i] += 1
    return frequencyMatrix / num_seqs

def modifyAlignmentBasedOnPairwiseFrequencies(alignmentMatrix, pairwiseFrequencies, 
                                              delete_bad=False, metric=cos_vector_angle, threshold=0.2,
                                              max_deletions_per_col=1, max_deletions_per_seq=1, 
                                              max_gaps_per_col=1):
    alignmentFrequencyMatrix = get_aa_frequencies_by_site(alignmentMatrix)
    updated_alignmentMatrix = np.empty(alignmentMatrix.shape, dtype=object)
    for i in range(len(pairwiseFrequencies)):
        sequenceFrequencyMatrix = pairwiseFrequencies[i].T
        sequence = alignmentMatrix[i,:]
        #print(alignmentFrequencyMatrix.shape, sequence.shape, sequenceFrequencyMatrix.shape, len(sequence[sequence!="-"]))
        poor_sites = id_poor_sites(alignmentFrequencyMatrix, sequenceFrequencyMatrix, 
                                   sequence, metric=metric, threshold=threshold)
        new_sequence = sequence.copy()
        for site in poor_sites:
            if delete_bad:
                new_sequence[site] = "X"
            else:
                #print(new_sequence[site])
                new_sequence[site] = new_sequence[site].lower()
        #print(new_sequence.tolist())
        updated_alignmentMatrix[i,:] = new_sequence
    return updated_alignmentMatrix

In [3]:
def generate_alignment_score_matrix(seq, align_freq, alphabet=alphabet, 
                                    substitution_matrix=None, exclude=[0]):
    if substitution_matrix == None:
        substitution_matrix = Align.substitution_matrices.load("BLOSUM62")
    seq_len = len(seq)
    align_len = align_freq.shape[1]
    score_matrix = np.zeros((seq_len, align_len))
    inv_alphabet = {v: k for k, v in alphabet.items()}
    for s in range(seq_len):
        site_in_seq = alphabet[seq[s]]
        partner_score_dict = substitution_matrix[site_in_seq]
        if (site_in_seq in exclude):
            score_matrix[s,a] = np.min(substitution_matrix)
            continue
        site_in_seq = inv_alphabet[site_in_seq]
        #print(site_in_seq)
        for a in range(align_len):
            if (s > a) or ((align_len - a) < (seq_len - s)):
                score_matrix[s,a] = np.min(substitution_matrix)
                continue

            site_freqs_in_align = [align_freq[x, a] for x in range(align_freq.shape[0]) if x not in exclude]
            partner_scores = [partner_score_dict[x] for x in range(align_freq.shape[0]) if x not in exclude]
            #print(site_in_align, site_in_seq, a,  s)
            site_freqs_in_align = site_freqs_in_align / np.sum(site_freqs_in_align)
            #print(np.sum(site_freqs_in_align))
            #print(partner_scores, site_freqs_in_align, np.sum([partner_scores[x] * site_freqs_in_align[x] for x in range(len(partner_scores))]) + 5)
            score_matrix[s,a] = (np.sum([partner_scores[x] * site_freqs_in_align[x] for x in range(len(partner_scores))]) - np.min(substitution_matrix))
        site_max = np.max(score_matrix[s,a])
        if site_max != site_max:
            score_matrix[s,a] = score_matrix[s,a] / site_max
        #print(score_matrix[s,:])
        #score_matrix[s,:] = (score_matrix[s,:] - np.min(score_matrix)) / (np.max(partner_score_dict) - np.min(partner_score_dict))
        #score_matrix[s,:] = (score_matrix[s,:] - np.min(score_matrix)) / (np.max(partner_score_dict) - np.min(partner_score_dict))
        #print(score_matrix[s,:])
    #print(np.max(score_matrix), np.min(score_matrix))
    print(score_matrix.shape)
    score_matrix = (score_matrix - np.min(score_matrix[score_matrix!=np.min(substitution_matrix)])) / (np.max(score_matrix) - np.min(score_matrix[score_matrix!=np.min(substitution_matrix)]))
    print(score_matrix.shape)
    score_matrix[score_matrix==np.min(score_matrix)] = 0
    print(score_matrix.shape)
    return score_matrix

In [4]:
def matrix_to_dictionary(matrix, labels):
    num_seqs, num_sites = matrix.shape
    dictionary = {}
    for i in range(num_seqs):
        dictionary[labels[i][0]] = "".join(matrix[i,:])
    return dictionary

def get_match_matrix(frequencyMatrix, alignmentFrequencies, metric=cos_vector_angle):
    alphabet_size, align_length = alignmentFrequencies.shape
    seq_length, alphabet_size = frequencyMatrix.shape
    match_matrix = np.zeros((seq_length, align_length))
    #print(match_matrix.shape)
    for s in range(seq_length):
        for a in range(align_length):
            if s < a and (align_length - a) > (seq_length - s):
                score = metric(frequencyMatrix[s,:], alignmentFrequencies[:, a])
                match_matrix[s, a] = score
    return match_matrix

In [5]:
def random_pathfind(match_matrix, true_random=False, fixed_sites=None):
    seq_length, align_length = match_matrix.shape
    cur = np.random.choice(list(range(align_length - seq_length)), 
                           p=match_matrix[0,:align_length-seq_length]/np.sum(match_matrix[0,:align_length-seq_length]))
    scores = []
    inds = []
    #print(fixed_sites, check_monotonic(fixed_sites))
    if true_random:
        sites = np.random.choice(list(range(align_length)), size=seq_length, replace=False)
        sites = [x for x in np.sort(sites)]
        scores = [match_matrix[s, sites[s]] for s in range(seq_length)]
        return scores, np.sum(scores), sites
    
    if fixed_sites != None:
        sites = []
        last_s, last_a = -1, -1
        for site in fixed_sites:
            s, a = site
            if last_a+1 != a:
                region_sites = np.random.choice(list(range(last_a+1, a)), replace=False, size=s-last_s-1)#, p=site_strengths)
                if s-last_s-1 == 1:
                    region_sites = [int(region_sites)]
            else:
                region_sites = []
            region_sites = list(np.sort(region_sites)) + [a]
            last_s, last_a = s, a 
            sites += list(region_sites)
        if last_a != align_length-1:
            region_sites = [x for x in np.random.choice(list(range(last_a+1, align_length)), replace=False, size=seq_length-last_s-1)]
            region_sites = np.sort(region_sites)
            sites += list(region_sites)
        scores = [match_matrix[s, sites[s]] for s in range(seq_length)]
        #print([(x[0], x[1], sites[x[0]]) for x in fixed_sites], len(sites), seq_length)
        return scores, np.sum(scores), sites

    for s in range(seq_length):
        good_indices = np.argwhere([match_matrix[s,:]>np.random.random() * 0.5])[:,1]
        good_indices = good_indices[good_indices>s]
        #adjacent = match_matrix[s,cur+1]
        best = (None, 0)
        root = 0.8 * ((np.random.random()) ** 0.5)
        candidates = good_indices[good_indices>cur]
        for a in range(min([int(np.random.randint(0, 25)**root), len(candidates)])):
            next_ind = good_indices[good_indices>cur][a]
            #print(min((next_ind - cur), 10))
            upgrade = match_matrix[s,next_ind] * (root ** min((next_ind - cur - 1), 1))
            if upgrade > best[1]:
                best = (next_ind, upgrade)
        if best[0] == None:
            cur = cur + 1
        elif best[0] < np.random.random() * 0.5:
            cur = good_indices[good_indices>cur][0]
        else:
            cur = best[0]
        scores.append(match_matrix[s,cur])
        inds.append(cur)
    return scores, np.sum(scores), inds
        
#No idea if this works. All my datasets are way too big for this to finish running anyway
def recursive_pathfind(match_matrix):
    seq_length, align_length = match_matrix.shape
    if seq_length == 1 and align_length == 1:
        return match_matrix[0][0]
    elif seq_length == 0 and align_length == 0:
        return 0
    good_candidates = np.argwhere([match_matrix>0.8])
    #print(good_candidates[np.random.randint(len(good_candidates))])
    _, s, a = good_candidates[np.random.randint(len(good_candidates))]
    if s == seq_length - 1 or a == align_length - 1:
        return recursive_pathfind(match_matrix[:s,:a]) + match_matrix[s,a]
    elif s == 0 or a == 0:
        return match_matrix[s,a] + recursive_pathfind(match_matrix[s+1:,a+1:])
    else:
        print(s, a, match_matrix.shape)
        print(match_matrix[s,a])
        return recursive_pathfind(match_matrix[:s,:a]) + match_matrix[s,a] + recursive_pathfind(match_matrix[s+1:,a+1:])

    
def depth_first_pathfinding(match_matrix, best_score=-INF):
    stack = [[]]
    best_final = [best_score, []]
    seq_length, align_length = match_matrix.shape
    while stack:
        state = stack.pop()
        if len(state) < align_length:
            remainder_alignment = align_length - len(state)
            remainder_sequence = seq_length - sum(state)
            print(len(stack), len(state), best_final[0])
            clear_output(wait=True)
            #print(evaluate_score(match_matrix, state) + remainder)
            if evaluate_score(match_matrix, state) + remainder_sequence > best_final[0]:
                if np.random.random() < 0.5:
                    if remainder_alignment > remainder_sequence:
                        stack.append(state.copy() + [0])
                    if remainder_sequence > 0:
                        stack.append(state.copy() + [1])
                else:
                    if remainder_sequence > 0:
                        stack.append(state.copy() + [1])
                    if remainder_alignment > remainder_sequence:
                        stack.append(state.copy() + [0])

        else:
            score = evaluate_score(match_matrix, state)
            if score > best_final[0]:
                best_final = [score, state]
    return best_final
            
def evaluate_score(match_matrix, place_skip_record):
    s = 0
    score = 0
    for a in range(len(place_skip_record)):
        place_skip = place_skip_record[a]
        if place_skip == 1:
            score += match_matrix[s, a]
            s += 1
    return score

def amalgamate(paths, scores, consensus=0.8):
    score_matrix = np.vstack(scores)
    path_matrix = np.vstack(paths)
    best = [max([sum(x) for x in scores]), max([x for x in range(len(scores))], key=lambda x: sum(scores[x]))]
    #print(best, paths[best[1]])

def perturb(match_matrix, paths, scores, mu, fixed_sites=[]):
    paths, scores = paths.copy(), scores.copy()
    seq_length, align_length = match_matrix.shape
    fixed_sites = {s:a for s,a in fixed_sites}
    for p in range(len(paths)):
        path = paths[p]
        score = scores[p]
        scrambled = list(range(len(path)))
        np.random.shuffle(scrambled)
        for s in scrambled:
            new_match = path[s]
            if s == 0:
                if path[1] == 1:
                    continue
                if np.random.random() < mu:
                    new_match = np.random.randint(0, path[1])
                    #new_match = np.random.choice(range(0, path[1]), 1, match_matrix[s][0, path[1]])
            elif s == len(path) - 1:
                if path[-2] == align_length - 2:
                    continue
                if np.random.random() < mu:
                    new_match = np.random.randint(path[-2] + 1, align_length)
            else:
                try:
                    if np.random.random() < mu:
                        new_match = np.random.randint(path[s-1] + 1, path[s+1])
                except:
                    a=1
                    #print(s, new_match, path, match_matrix.shape)
            try:
                if match_matrix[s, new_match] > match_matrix[s, path[s]]:# or np.random.random() > 0.99:
                    if s not in fixed_sites:
                        path[s] = new_match
                        score[s] = match_matrix[s, path[s]]
            except:
                a=1
                #print(s, new_match, path, match_matrix.shape)
        paths[p] = path
        scores[p] = score
    return paths, scores

In [15]:

# starts is a collection of the best starts at particular indices of the sequence
# ends is the set of best matches to the corresponding start in starts
def get_best_start_ends(population, match_matrix, prev_dicts=None):
    population=population.copy()
    match_matrix=match_matrix.copy()
    seq_length, align_length = match_matrix.shape
    if prev_dicts == None or prev_dicts[0] == None or prev_dicts[1] == None:
        starts={x:[] for x in range(seq_length+1)}
        ends={x:[] for x in range(seq_length+1)}
    else:
        starts = prev_dicts[0]
        ends = prev_dicts[1]
    pop_size = len(population)
    for s in range(seq_length+1):
        #print(s)
        #clear_output(wait=True)
        starts_s = []
        ends_s = []
        for i in range(pop_size):
            start_scores = [match_matrix[j,population[i][:s][j]] for j in range(s)]
            end_scores = [match_matrix[s+j,population[i][s:][j]] for j in range(seq_length-s)]
            start = population[i][:s]
            end = population[i][s:]
            #if (start, start_scores) not in starts[s] + starts_s:
            starts_s.append((start, start_scores))
            #if (end, end_scores) not in ends[s] + ends_s:
            ends_s.append((end, end_scores))
        if starts[s] == None or ends[s] == None:
            starts[s], ends[s] = [], []
        starts[s] += starts_s
        ends[s] += ends_s
        starts[s] = sorted(starts[s], key=lambda x: np.sum(x[1]))
        ends[s] = sorted(ends[s], key=lambda x: np.sum(x[1]))
        starts[s].reverse()
        ends[s].reverse()
        
        starts[s] = make_unique(starts[s])
        ends[s] = make_unique(ends[s])
        ends_s = []
        for start in starts[s]:
            if (len(start[0]) == 0) or (len(start[0]) == seq_length):
                compatible = ends[s]
            else:  
                #print([end[0][0] for end in ends[s]])
                #print(start)
                compatible = [end for end in ends[s] if end[0][0] > start[0][-1]]
                #print([(start[0][-1], end[0][0]) for end in ends[s] if end[0][0] > start[0][-1]])
            #print(len(compatible))
            if len(compatible) > 0:
                ends_s.append(compatible[0])
        ends[s] = ends_s
        #print(ends[s])
        starts[s] = starts[s][:min(len(starts[s]), pop_size)]
        ends[s] = ends[s][:min(len(ends[s]), pop_size)]
    return starts, ends

def generate_new_pop_from_starts_ends(starts, ends, pop_size):
    new_pop = []
    starts = starts.copy()
    ends = ends.copy()
    seq_len = len(starts)
    #random_inds = [np.random.randint(1, seq_len - 1) for x in range(pop_size // 2)]
    #random_guys = [(starts[random_inds[x]][0][0] + ends[random_inds[x]][0][0], 
    #                starts[x][0][1] + ends[x][0][1]) for x in range(pop_size // 2)]
    #print(ends)
    #print(starts)
    sorted_by_score = [(starts[x][0][0] + ends[x][0][0], 
                        starts[x][0][1] + ends[x][0][1]) for x in range(1, seq_len - 1)]
    sorted_by_score = sorted(sorted_by_score, key=lambda x: np.sum(x[1]))
    sorted_by_score.reverse()
    sorted_by_score = make_unique(sorted_by_score)
    #print([x[1] for x in sorted_by_score])
    #combined = sorted_by_score[:pop_size - (pop_size // 2)] + random_guys
    combined = sorted_by_score[:min(pop_size, len(sorted_by_score))]
    combined = [[sum(scores), scores, path] for path, scores in combined]
    return combined



In [17]:
def find_diverse_initial_population(heatmap, pop_size=50, update_freq=100, max_rounds=1000, verbose=False, fixed_sites=None):
    best_scores_per_site = [[0] * heatmap.shape[0]]
    starts, ends = None, None
    prev = 0
    mean = 0
    best = 0
    min_score = 0
    good = []
    all_seen = []
    fixed_sites = fixed_sites.copy()
    site_ranges = []
    for s in range(heatmap.shape[0]):
        if s in [x[0] for x in fixed_sites]:
            limit = {x[0]:x[1] for x in fixed_sites}[s]
            site_ranges.append((limit, limit))
        else:
            upper_limits = [x for x in fixed_sites if x[0] > s] + [[heatmap.shape[0], heatmap.shape[1]]]
            lower_limits = [x for x in fixed_sites if x[0] < s] + [[0, 0]]
            site_ranges.append((max([x[1] for x in lower_limits]), min([x[1] for x in upper_limits])))
    #print(site_ranges)
    best_a_by_s = []
    for s in range(heatmap.shape[0]):
        candidates = heatmap[s, site_ranges[s][0]:site_ranges[s][1]]
        if candidates != []:
            best_a_by_s.append(int(np.argmax(candidates)) + site_ranges[s][0])
        else:
            best_a_by_s.append(site_ranges[s][0])
    best_a_by_s = [(x, best_a_by_s[x]) for x in range(heatmap.shape[0])]
    #print(best_a_by_s)
    #print(len(fixed_sites))
    for i in range(max_rounds + 1):
        #print(best_a_by_s, fixed_sites)
        random_fixed = get_mutually_compatible(heatmap, best_a_by_s, fixed_sites)
        #print(random_fixed)#, best_a_by_s)
        np.random.shuffle(random_fixed)
        #print(random_fixed)
        #random_fixed = random_fixed[:np.random.randint(len(random_fixed)-1)]
        fixed_round = make_unique(sorted(fixed_sites + random_fixed, key=lambda x:x[0]))
        #print(len(random_fixed), len(fixed_sites))
        #print(check_monotonic(fixed_round))
        scores, score, inds = random_pathfind(heatmap, true_random=False, fixed_sites=fixed_round)
        #score = recursive_pathfind(heatmap)
        #score, path = depth_first_pathfinding(heatmap, best_score=np.sum([heatmap[s, seq_to_align[s]] for s in range(seq_length)]))
        #print(i)
        if (i % update_freq == 0) or (i == 1):
            if i != 0:
                mean = round(np.mean([x for x,_,_ in good]), 2)
                #paths = [g[2] for g in good]
                #print(paths)
                #if starts == None:
                #    starts, ends = get_best_start_ends(all_seen, heatmap, prev_dicts=None)
                #    all_seen = []
                #else:
                #    starts, ends = get_best_start_ends(all_seen, heatmap, prev_dicts=(starts, ends))
                #    all_seen = []
                if mean != np.nan and ((mean - prev) / prev) < (0.001) and i != update_freq:
                    prev = mean
                    print("Converged on round", i, "\nBest:", round(best, 2))
                    starts, ends = get_best_start_ends(all_seen, heatmap, prev_dicts=(starts, ends))
                    while len(good) > pop_size:
                        min_ind = min([x for x in range(len(good))], key=lambda x: sum(good[x][1]))
                        good.pop(min_ind)

                    return good, starts, ends
                else:
                    best = round(max([x for x,_,_ in good]), 2)
                    if verbose:
                        clear_output(wait=True)
                        #print("Score to beat:", alignment_score, "/", seq_length, 
                        #      "... Done!" if best > alignment_score else "...")
                        print("Not converged after", i, "rounds\nBest:", round(best, 2),
                              "/", heatmap.shape[0], "Improvement:", 
                              round(100 * (mean - prev) / mean, 2), "%", "Mean: ", mean)
                        print(len(all_seen), "paths found")
                    prev = mean


        if score:
            for i in range(heatmap.shape[0]):
                best_scores_per_site.append([max(scores[x], best_scores_per_site[-1][x]) for x in range(len(scores))])
            if inds in all_seen or score < best * 0.95:
                continue
            else:
                all_seen.append(inds)
            if score > best:
                #if verbose and best < alignment_score and score > alignment_score:
                #    print("Alignment beaten!")
                best = score
                #print(round(best,2), len(good), round(np.mean([x for x,_,_ in good]), 2))
                if len(good) > 0:
                    good.pop(min_ind)
                good.append([score, scores, inds])
                min_score = min([x for x,_,_ in good])
                min_ind = min([x for x in range(len(good))], key=lambda x: sum(good[x][1]))
            elif len(good) < pop_size:
                good.append([score, scores, inds])
                min_score = min([x for x,_,_ in good])
                min_ind = min([x for x in range(len(good))], key=lambda x: sum(good[x][1]))
            elif score > min_score:
                #print(score)
                good.pop(min_ind)
                good.append([score, scores, inds])
                min_score = min([x for x,_,_ in good])
                min_ind = min([x for x in range(len(good))], key=lambda x: sum(good[x][1]))
    good = sorted(good, key=lambda x:x[0])
    good.reverse()
    good = good[:min(pop_size,len(good))]
    #while len(good) > pop_size:
    #    min_ind = min([x for x in range(len(good))], key=lambda x: sum(good[x][1]))
    #    good.pop(min_ind)

    starts, ends = get_best_start_ends(all_seen, heatmap, prev_dicts=None)
    return good, starts, ends


In [9]:
def random_refine(match_matrix, population, max_generations=200, verbose=False, fixed_sites=None):
    if verbose:
        print("Refining...")
    paths = [x[2].copy() for x in population]
    scores = [x[1].copy() for x in population]
    upgraded = []
    prev_j = 0
    starting = max([sum(x) for x in scores])
    for i in range(max_generations):
        p = [x.copy() for x in paths]
        s = [x.copy() for x in scores]
        num_rounds = np.random.randint(prev_j * 2 + 3)
        if num_rounds > 100:
            num_rounds = 3
        for j in range(num_rounds):
            if fixed_sites == None:
                p, s = perturb(match_matrix, p, s, 0.9)
            else:
                p, s = perturb(match_matrix, p, s, 0.9, fixed_sites=fixed_sites)
            round_best = max([sum(x) for x in s])
            round_best_ind = max([x for x in range(len(s))], key=lambda x: sum(s[x]))
            round_best_value = max([sum(s[x]) - sum(scores[x]) for x in range(len(s))])
            round_best_value_ind = max([x for x in range(len(s))], key=lambda x: sum(s[x]) - sum(scores[x]))
            #round_worst = min([sum(x) for x in s])
            round_worst_ind = min([x for x in range(len(s))], key=lambda x: sum(s[x]))
            if round_best>sum(scores[round_best_ind]):
                pre = sum(scores[round_best_ind])
                paths[round_worst_ind] = paths[round_best_ind].copy()
                scores[round_worst_ind] = scores[round_best_ind].copy()
                paths[round_best_ind] = p[round_best_ind].copy()
                scores[round_best_ind] = s[round_best_ind].copy()
                clear_output(wait=True)
                upgraded.append(round_best_ind)
                if verbose:
                    print("generation:", i, "round:", j, "of", num_rounds, 
                          round(len(set(upgraded)) / len(paths)*100,2), "% upgraded")
                    print(round(pre, 2), "->", round(round_best, 2), "best:",
                          round(np.max([sum(x) for x in scores]), 2), "mean:",
                          round(np.mean([sum(x) for x in scores]), 2), "with", len(paths), "paths")
                prev_j = (j + prev_j + 1)//2
            if round_best_value > 0:
                #print(round_best_value, sum(s[round_best_value_ind]), sum(scores[round_best_value_ind]))
                paths[round_best_value_ind] = p[round_best_value_ind].copy()
                scores[round_best_value_ind] = s[round_best_value_ind].copy()
            else:
                #print("crank up the j")
                if j > prev_j / 2:
                    prev_j += 1
                    break
    return [(np.sum(scores[x]), scores[x], paths[x]) for x in range(len(paths))]

In [10]:
def get_seq_to_align(seq, sequence_matrix):
    sequence = sequence_matrix[seq, :]
    seq_length, align_length = sequence_matrix.shape
    seq_to_align = []
    cur = 0 
    for i in range(align_length):
        if "-" not in sequence[i]:
            seq_to_align.append(i)
    return seq_to_align

In [8]:
def align_sequence_to_matrix(match_matrix, pre_alignment=[], verbose=False):
    seq_length, align_length = match_matrix.shape
    alignment_score = np.sum([match_matrix[s, pre_alignment[s]] for s in range(seq_length)])
    alignment_scores = [match_matrix[s, pre_alignment[s]] for s in range(seq_length)]
    max_scores = np.max(match_matrix,axis=1)
    norm_scores = alignment_scores / max_scores
    fixed = [(x, pre_alignment[x]) for x in range(seq_length) if norm_scores[x]==1]
    #print(pre_alignment)
    if verbose:
        print("Score to beat:", round(alignment_score, 2), "/", seq_length)
    pop_size = 50
    
    # Generating starting set
    good, starts, ends = find_diverse_initial_population(match_matrix, pop_size=20, update_freq=50, 
                                                         max_rounds=100, verbose=verbose, fixed_sites=fixed)
    
    # Optimize weirdly
    paths = [g[2] for g in good.copy()]
    if verbose:
        print("(1/2) Finding best start-end pairs...")
    starts_new, ends_new = get_best_start_ends(paths, match_matrix, prev_dicts=(starts.copy(), ends.copy()))
    if verbose:
        print("(1/2) Generating optimal paths from start-end pairs...")
    new_pop = generate_new_pop_from_starts_ends(starts_new, ends_new, len(paths))
    paths = make_unique(paths)
    prev_len = 0
    if verbose:
        print("Best paired:", round(max(new_pop, key=lambda x:x[0])[0], 2))
    while len(new_pop) > 1:
        if len(paths) == prev_len:
            break
        prev_len = len(paths)
        paths = [g[2] for g in new_pop.copy()]# + [g[2] for g in good.copy()] #+ [pre_alignment]
        paths = make_unique(paths)
        starts_new, ends_new = get_best_start_ends(paths, match_matrix, prev_dicts=(starts_new.copy(), ends_new.copy()))
        new_pop = generate_new_pop_from_starts_ends(starts_new, ends_new, len(good))
        if verbose:
            print("Best paired:", round(max(new_pop, key=lambda x:x[0])[0], 2))
    
    # Optimize randomly or something
    full_pop = good.copy() + new_pop.copy() + [(np.sum(alignment_scores), alignment_scores, pre_alignment)]
    refined_pop = random_refine(match_matrix, make_unique(full_pop), verbose=verbose, max_generations=30, fixed_sites=fixed)

    # Optimize weirdly again
    paths = [g[2] for g in refined_pop.copy()]
    if verbose:
        print("(2/2) Finding best start-end pairs...")
    starts_new, ends_new = get_best_start_ends(paths, match_matrix, prev_dicts=(starts.copy(), ends.copy()))
    if verbose:
        print("(2/2) Generating optimal paths from start-end pairs...")
    new_refined_pop = generate_new_pop_from_starts_ends(starts_new, ends_new, len(good))
    paths = [g[2] for g in new_refined_pop.copy()] + [g[2] for g in good.copy()] + [pre_alignment]
    paths = make_unique(paths)
    prev_len = 0
    if verbose:
        print("Best paired:", round(max(new_refined_pop, key=lambda x:x[0])[0], 2))
    while len(new_refined_pop) > 1:
        #print(len(paths))
        if len(paths) == prev_len:
            break
        prev_len = len(paths)
        paths = [g[2] for g in new_refined_pop.copy()]# + [g[2] for g in good.copy()] + [pre_alignment]
        paths = make_unique(paths)
        starts_new, ends_new = get_best_start_ends(paths, match_matrix, prev_dicts=(starts_new.copy(), ends_new.copy()))
        #print(len(good))
        new_refined_pop = generate_new_pop_from_starts_ends(starts_new, ends_new, len(good))
        if verbose:
            print("Best paired:", round(max(new_refined_pop, key=lambda x:x[0])[0], 2))
    new_refined_pop = make_unique(new_refined_pop)
    new_refined_pop = [x for x in new_refined_pop if x[0] >= alignment_score]
    #print(alignment_score, [x[0] for x in new_refined_pop])
    return new_refined_pop

    

In [16]:
def get_mutually_compatible(match_matrix, candidates, population):
    population = [(c[0], c[1]) for c in population.copy()]
    candidates = [(c[0], c[1]) for c in candidates.copy() if c not in population]
    added = []
    while len(candidates) != 0:
        ind = np.random.randint(len(candidates))
        s, a = candidates[ind]
        upper_limits = [x for x in population if x[0] > s] + [[match_matrix.shape[0], match_matrix.shape[1]]]
        lower_limits = [x for x in population if x[0] < s] + [[0, 0]]
        site_range = (max([x[1] for x in lower_limits]), min([x[1] for x in upper_limits]))
        if (site_range[0] > site_range[1]) or (a not in list(range(site_range[0], site_range[1]))) or (a in [x[1] for x in population]):
            #print(a, list(range(site_range[0], site_range[1])))
            candidates.pop(ind)
        elif site_range[1]-a < min(x[0] for x in upper_limits)-s or a-site_range[0] < s-max([x[0] for x in lower_limits]):
            candidates.pop(ind)
        else:
            #print(a, list(range(site_range[0], site_range[1])))
            population.append(candidates[ind])
            added.append(candidates.pop(ind))
    return added

In [None]:
def check_monotonic(site_list):
    prev_s, prev_a = site_list[0]
    for i in range(1, len(site_list)):
        s, a = site_list[i]
        if s <= prev_s or a <= prev_a:
            return False
        prev_s, prev_a = s, a
    return True