In [1]:
import gzip
import treeswift
import parsimony_pb2
import tqdm
def preorder_traversal(node):
    yield node
    for clade in node.children:
        yield from preorder_traversal(clade)


def preorder_traversal_iter(node):
    return iter(preorder_traversal(node))


class UsherMutationAnnotatedTree:
    def __init__(self, tree_file):
        self.data = parsimony_pb2.data()
        self.data.ParseFromString(tree_file.read())
        self.condensed_nodes_dict = self.get_condensed_nodes_dict(
            self.data.condensed_nodes)
        self.tree = treeswift.read_tree(self.data.newick, schema="newick")
        self.data.newick = ''

        self.annotate_mutations()
        self.expand_condensed_nodes()

    def annotate_mutations(self):
        for i, node in enumerate(preorder_traversal(self.tree.root)):
            node.nuc_mutations = self.data.node_mutations[i].mutation

    def set_branch_lengths(self):
        for i, node in enumerate(preorder_traversal(self.tree.root)):
            node.edge_length = len(node.nuc_mutations.mutation)

    def expand_condensed_nodes(self):
        for i, node in tqdm.tqdm(enumerate(self.tree.traverse_leaves()),
                                 desc="Expanding condensed nodes",
                                 miniters=100000,mininterval=10):

            if node.label and node.label in self.condensed_nodes_dict:

                for new_node_label in self.condensed_nodes_dict[node.label]:
                    new_node = treeswift.Node(label=new_node_label)
                    new_node.nuc_mutations = []
                    new_node.aa_subs = []
                    node.add_child(new_node)
                node.label = ""
            else:
                pass

    def get_condensed_nodes_dict(self, condensed_nodes_dict):
        output_dict = {}
        for condensed_node in tqdm.tqdm(condensed_nodes_dict,
                                        desc="Reading condensed nodes dict",
                                        miniters=100000,mininterval=10):
            output_dict[
                condensed_node.node_name] = condensed_node.condensed_leaves
        return output_dict

f = open("/mnt/data/gisaid_data/optimised_trimmed.pb", "rb")


mat = UsherMutationAnnotatedTree(f)



Reading condensed nodes dict: 100%|██████████| 588400/588400 [00:01<00:00, 379124.05it/s]
Expanding condensed nodes: 5509021it [00:29, 186935.68it/s]


In [21]:
leaves = mat.tree.traverse_leaves()
num_leaves = 10000
top_leaves = []
for i, leaf in enumerate(leaves):
    if i < num_leaves:
        top_leaves.append(leaf)
    else:
        break


In [61]:

url = "https://raw.githubusercontent.com/W-L/ProblematicSites_SARS-CoV2/master/archived_vcf/problematic_sites_sarsCov2.2021-10-14-11%3A49.vcf"
import urllib
import gzip
# iterate over lines
masked_positions = set()
with urllib.request.urlopen(url) as response:
        for line in response:
            line = line.decode('utf-8')

            if line.startswith('#'):
                continue
            # parse line
            print
            fields = line.split('\t')
            chrom = fields[0]
            pos = int(fields[1])
            ref = fields[3]
            alt = fields[4]
            qual = fields[5]
            filt = fields[6]
            if filt=="mask":
                masked_positions.add(pos)

masked_pos_to_real_pos = {}

masked_i = 0
for i in range(max(list(masked_positions))):
    masked_pos_to_real_pos[masked_i] = i
    if i+1 not in masked_positions:
        masked_i+=1


In [62]:
letters = "ACGT"
def get_info(leaf):
    print("---")
    positions_mutated = set()
    cur_node = leaf
    starting_nucs = {}
    ending_nucs = {}
    while not cur_node.is_root():
        
        print("a",cur_node.label)
       
        for mutation in cur_node.nuc_mutations:
            #print(mutation)
            positions_mutated.add(mutation.position)
            if mutation.position not in ending_nucs:
                ending_nucs[mutation.position] = mutation.mut_nuc[0]
            starting_nucs[mutation.position] = mutation.ref_nuc
        cur_node = cur_node.get_parent()
    print(positions_mutated)
    filtered_positions = [x for x in positions_mutated if starting_nucs[x] != ending_nucs[x]]
    
    full_muts = [f"{letters[starting_nucs[x]]}{masked_pos_to_real_pos[x]}{letters[ending_nucs[x]]}" for x in sorted(filtered_positions)]
    print(full_muts)

for leaf in top_leaves[9004:9005]:
    if True or leaf.label == "EPI_ISL_509302":
        get_info(leaf)

---
a EPI_ISL_515606
a None
a None
a None
a None
a None
a None
a None
a None
a None
{23264, 23781, 9769, 2973, 26384, 1329, 11825, 184, 22549, 27414, 117, 3576, 2875, 636, 21469, 14302}
['G174T', 'C241T', 'T694A', 'C1387T', 'C2939A', 'C3037T', 'A3643C', 'C9857T', 'T11920C', 'C14408T', 'G21600T', 'C22687T', 'A23403G', 'A23920G', 'A26530G', 'G27561T']


In [None]:

for leaf in tqdm.tqdm(mat.tree.traverse_leaves()):



        cur_node = cur_node.parent
    mutation_string = ",".join(str(pos) for pos in sorted(positions_mutated) if starting_nucs[pos] != ending_nucs[pos])
    print(f"{leaf.label}:{mutation_string}")