# GISAID and INSDC data

In [1]:
import pandas as pd
md_file = "/mnt/d/metadata_2023-05-09_00-46 (1).tsv.gz"
cols_needed = ['strain','date','country','length','gisaid_epi_isl', 'age']
df = pd.read_csv(md_file, sep='\t', usecols=cols_needed)
df.set_index('gisaid_epi_isl', inplace=True)

def get_year_from_date(date):
    if pd.isna(date):
        return None
    else:
        return date.split('-')[0]

df['year'] = df['date'].apply(get_year_from_date)



In [2]:
import bte
tree_file = "/mnt/d/gisaidAndPublic.2023-05-11.masked.gisaidNames.pb.gz"
tree = bte.MATree(tree_file)

Finished 'from_pb' in 81.5333 seconds


In [3]:
from tqdm import tqdm

def to_newick(node, pbar=None):
    """
    Convert a tree to a Newick format string.
    """
    if pbar is not None:
        pbar.update()  # Update the progress bar only if it exists

    if len(node.children) == 0:
        # This is a leaf node, output its label and the branch length
        return f"{node.id}:{len(node.mutations)}"
    else:
        # This is an internal node, recursively convert its children
        children_newick = ", ".join(to_newick(child, pbar) for child in node.children)
        return f"({children_newick}){node.id}:{len(node.mutations)}"

def to_newick_with_progress(root, total_nodes=None):
    """
    Convert a tree to a Newick format string with a progress bar.
    """
    pbar = tqdm(total=total_nodes) if total_nodes is not None else None
    newick_string = to_newick(root, pbar)
    if pbar is not None:
        pbar.close()
    return newick_string


In [4]:
tree.root

id: node_1
level: 1
parent: None
children: ['Yunnan/0306-466/2020|EPI_ISL_429239|2020-03-06', 'Japan/DP0803/2020|EPI_ISL_416630|2020-02-17', 'node_2', 'node_3', 'node_11', 'node_17', 'node_21', 'node_23', 'node_56', 'node_57', 'node_59', 'Japan/DP0687/2020|EPI_ISL_416614|2020-02-17', 'node_64', 'USA/CA-QDX-2597/2020|EPI_ISL_604886|2020-03-16', 'Greece/213_34051/2020|EPI_ISL_437882|2020-03-14', 'node_67', 'node_78', 'node_95', 'node_110', 'node_111', 'node_112', 'node_113', 'node_116', 'Japan/DP0290/2020|EPI_ISL_416591|2020-02-16', 'Japan/DP0724/2020|EPI_ISL_416620|2020-02-17', 'USA/CA-CZB-1757/2020|EPI_ISL_476774|2020-03-19', 'USA/WV-QDX-1284/2020|EPI_ISL_572053|2020-03-17', 'Japan/DP0158/2020|EPI_ISL_416579|2020-02-15', 'Japan/DP0027/2020|EPI_ISL_416566|2020-02-15', 'node_121', 'node_123', 'Taiwan/CGMH-CGU-22/2020|EPI_ISL_444275|2020-03-18', 'England/BRIS-125699/2020|EPI_ISL_443694|2020-03-28', 'node_124', 'node_140', 'node_142', 'node_145', 'Japan/Hu_DP_Kng_19-031/2020|EPI_ISL_420889

In [5]:
id_to_epi = {}

for node in tree.depth_first_expansion(reverse=True):
    if node.is_leaf():
        parts = node.id.split('|')
        epi_isl = None
        for part in parts:
            if part.startswith('EPI_ISL_'):
                epi_isl = part
                break
        if epi_isl:
            id_to_epi[node.id] = epi_isl

In [6]:
id_epi_df = pd.DataFrame.from_dict(id_to_epi, orient='index', columns=['gisaid_epi_isl'])
# flip the index and columns
id_epi_df = id_epi_df.reset_index().set_index('gisaid_epi_isl')
id_epi_df.columns = ['strain']

In [7]:
#import importlib
#importlib.reload(bte_extended)

prepare_chronumental = False
if prepare_chronumental:
    nwk_format = tree.write_newick(retain_original_branch_len =False)
    
    # write nwk to dist.nwk.gz
    import gzip
    with gzip.open('dist.nwk.gz', 'wb') as f:
        f.write(nwk_format.encode('utf-8'))
    # create a df with just strain and date
    df_chron = df[['date']]
    # merge with id_to_epi_df on index
    df_chron = df_chron.merge(id_epi_df, left_index=True, right_index=True)

    # exclude non full dates
    df_chron = df_chron[df_chron['date'].str.len() == 10]
    #
    print(df_chron.shape)
    df_chron.to_csv('dist.csv.gz', index=False, compression='gzip')

    # Now run
    # chronumental --tree dist.nwk.gz --dates dist.csv.gz --dates_out dates.tsv.gz --tree_out chronumental_timetree_dist.nwk.gz --steps 5000




In [8]:
if not prepare_chronumental:
    # read in distance newick and confirm it matches
    import gzip
    with gzip.open('dist.nwk.gz', 'rb') as f:
        nwk_format_2 = f.read().decode('utf-8')
        nwk_format_1 = tree.write_newick(retain_original_branch_len =False)
        assert nwk_format_1 == nwk_format_2
        del nwk_format_1, nwk_format_2
        

In [9]:
import treeswift
if not prepare_chronumental:
    # read in the chronumental output
    chronumental_file = "chronumental_timetree_dist.nwk.gz"
    chronumental_tree = treeswift.read_tree_newick(chronumental_file)
    # create dict of node label to edge length
    node_to_length = {}
    for node in chronumental_tree.traverse_preorder():
        node_to_length[node.label] = node.edge_length
    del chronumental_tree
    # import dates from dates.tsv.gz
   


In [10]:
dates =  "dates.tsv.gz"
df_dates = pd.read_csv(dates, sep='\t')
df_dates.set_index('strain', inplace=True)

# predicted_date is currently datetime as string, get just the date
df_dates['predicted_date'] = df_dates['predicted_date'].apply(lambda x: x.split(' ')[0])



In [11]:
dates_dict = df_dates.to_dict()
del df_dates

In [12]:
num_descendants = {}
for node in tree.depth_first_expansion(reverse=True):
    if node.is_leaf():
        num_descendants[node.id] = 1
       
    else:
        num_descendants[node.id] = sum(num_descendants[child.id]
                                       for child in node.children)


In [13]:


import re, tqdm
from collections import Counter, defaultdict
import random

country_counts = defaultdict(Counter)
for node in tqdm.tqdm(tree.depth_first_expansion(reverse=True)):
    if node.is_leaf():
        try:
            country = df.loc[id_to_epi[node.id], 'country']
            country_counts[node.id][country] = 1
        except KeyError:
            # for 1 in 1000, print the key error
            if random.random() < 0.00:
                print(node.id, 'not found in metadata')
            pass
    else:
        for child in node.children:
            for country, count in country_counts[child.id].items():
                country_counts[node.id][country] += count

100%|██████████| 17849624/17849624 [03:14<00:00, 91887.64it/s] 


In [14]:
# consensus country is the most common country if that country is >90% of the total
# otherwise, consensus country is "?"
consensus_countries = {}
for node_id, counts in tqdm.tqdm(country_counts.items()):
    total = sum(counts.values())
    if total == 0:
        consensus_countries[node_id] = "?"
    else:
        most_common_country, count = counts.most_common(1)[0]
        if count / total > 0.9:
            consensus_countries[node_id] = most_common_country
        else:
            consensus_countries[node_id] = "?"

100%|██████████| 17849624/17849624 [00:29<00:00, 604578.60it/s]


In [15]:
# most common consensus country
most_common_countries = Counter(consensus_countries.values()).most_common(10)
most_common_countries

[('USA', 5198531),
 ('United Kingdom', 3052804),
 ('?', 1644992),
 ('Germany', 1130459),
 ('Japan', 687789),
 ('Denmark', 682638),
 ('Canada', 597568),
 ('France', 558948),
 ('India', 272267),
 ('Brazil', 267783)]

In [16]:
del country_counts

In [17]:
# delete country column from df
del df['country']

In [18]:
year_counts = defaultdict(Counter)
for node in tqdm.tqdm(tree.depth_first_expansion(reverse=True)):
    if node.is_leaf():
        try:
            year = df.loc[id_to_epi[node.id], 'year']
            year_counts[node.id][year] = 1
        except KeyError:
            pass
    else:
        for child in node.children:
            for year, count in year_counts[child.id].items():
                year_counts[node.id][year] += count

consensus_years = {}
for node_id, counts in tqdm.tqdm(year_counts.items()):
    total = sum(counts.values())
    if total == 0:
        consensus_years[node_id] = "?"
    else:
        most_common_year, count = counts.most_common(1)[0]
        if count / total > 0.99:
            consensus_years[node_id] = most_common_year
        else:
            consensus_years[node_id] = "?"

100%|██████████| 17849624/17849624 [03:35<00:00, 82781.41it/s] 
100%|██████████| 17849624/17849624 [00:32<00:00, 555230.33it/s]


In [19]:
del year_counts


In [20]:
import tqdm
from collections import defaultdict
import random

# Define a dictionary to store the sum of ages and the number of descendants with valid ages for each node
age_info = defaultdict(lambda: {'sum_ages': 0, 'count': 0})

for node in tqdm.tqdm(tree.depth_first_expansion(reverse=True)):
    if node.is_leaf():
        try:
            age = df.loc[id_to_epi[node.id], 'age']
            # Check if the age is a valid number
            if age != "?":
                age = float(age)
                age_info[node.id]['sum_ages'] = age
                age_info[node.id]['count'] = 1
        except KeyError:
            # for 1 in 1000, print the key error
            if random.random() < 0.000:
                print(node.id, 'not found in metadata')
            pass
    else:
        for child in node.children:
            age_info[node.id]['sum_ages'] += age_info[child.id]['sum_ages']
            age_info[node.id]['count'] += age_info[child.id]['count']
        

# Calculate the average age for each node
average_age = {}
for node_id, info in age_info.items():
    if info['count'] > 0:
        average_age[node_id] = info['sum_ages'] / info['count']
    else:
        average_age[node_id] = None  # No valid ages available for this node

# The dictionary 'average_age' now contains the average age for each node
del age_info

100%|██████████| 17849624/17849624 [02:18<00:00, 129161.84it/s]


In [21]:
del df
#translations = tree.translate("/mnt/d/cov.gtf", "/mnt/d/ref.fasta")


In [22]:
def mut_to_class(s):
    # A123G becomes A>G
    return re.sub(r"(\w)(\d+)(\w)", r"\1>\3", s)


def get_mut_classes(muts):
    # Count the number of transitions and transversions
    muts = [mut_to_class(m) for m in muts]
    counts = Counter(muts)
    return counts


def get_transition_percentage(mut_classes):
    #transitions are G>A, A>G, C>T, T>C

    transitions = sum(mut_classes[k] for k in mut_classes
                      if k in ["G>A", "A>G", "C>T", "T>C"])
    total = sum(mut_classes.values())
    return transitions / total

all_possible_classes = []
for start in list("ACGT"):
    for end in list("ACGT"):
        if start != end:
            all_possible_classes.append(start + ">" + end)

In [23]:
nsps = {# from uniprot P0DTD1 orf1ab numbering of nsps
  'nsp1': 1,
  'nsp2': 181,
  'nsp3': 819,
  'nsp4': 2764,
  'nsp5 (Mpro)': 3264,
  'nsp6': 3570,
  'nsp7': 3860,
  'nsp8': 3943,
  'nsp9': 4141,
  'nsp10': 4254,
  'nsp12 (RdRp)': 4393,
  'nsp13': 5325,
  'nsp14': 5926,
  'nsp15': 6453,
  'nsp16': 6799
}

# make sorted list of nsp names by value
nsp_names = sorted(nsps, key=nsps.get)
# reverse the list
nsp_names = nsp_names[::-1]

def orf1ab_to_nsp_str(aa_mut):
    orf1ab_index = aa_mut.aa_index
    for nsp_name in nsp_names:
        nsp_index = nsps[nsp_name]
        if orf1ab_index > nsp_index:
            new_index = orf1ab_index - nsp_index +1
            return f"{nsp_name}:{aa_mut.original_aa}{new_index}{aa_mut.alternative_aa}"
    

In [24]:
dates_dict = dates_dict['predicted_date']

In [25]:
import xopen


output = xopen.xopen("all_nodes.tsv.gz", "wt")

headers = ["node_id","num_descendants","consensus_country","consensus_year","date","date_length","age"] + all_possible_classes 
output.write("\t".join(headers) + "\n")

for node in tqdm.tqdm(tree.depth_first_expansion()):
    date = dates_dict[node.id] if node.id in dates_dict else "?"
    date_length = node_to_length[node.id] if node.id in node_to_length else -1
    classes = get_mut_classes(node.mutations)
    all_counts = [classes[k] for k in all_possible_classes]
    age = average_age[node.id] if node.id in average_age else "?"
    #aa_muts = translations[node.id] if node.id in translations else []
    #syn = sum(aa_mut.is_synonymous() for aa_mut in aa_muts)
    #nonsyn = sum(not aa_mut.is_synonymous() for aa_mut in aa_muts)
    output.write("\t".join([str(x) for x in [node.id, num_descendants[node.id], consensus_countries[node.id], consensus_years[node.id], 
                                             #syn, nonsyn,
                                               date, date_length, age] + all_counts]) + "\n")
    #for aa_mut in aa_muts:
        
        #output_mut.write("\t".join([str(x) for x in [node.id, aa_mut.nt_index, aa_mut.original_nt, aa_mut.alternative_nt,  aa_mut.gene, aa_mut.aa_index, aa_mut.original_aa, aa_mut.alternative_aa, aa_mut.is_synonymous(), aa_mut.mutation_type, aa_mut.aa_string()]]) + "\n")


output.close()

    



100%|██████████| 17849624/17849624 [05:27<00:00, 54537.24it/s]


In [26]:
del dates_dict, node_to_length, consensus_countries, consensus_years, average_age
translations = tree.translate("/mnt/d/cov.gtf", "/mnt/d/ref.fasta")

In [27]:
output_mut = xopen.xopen("all_node_muts.tsv.gz", "wt")
output_mut.write("\t".join(["node_id", "nt_index", "original_nt", "alternative_nt", "gene", "aa_index", "original_aa", "alternative_aa", "is_synonymous", "mutation_type", "aa_string"]) + "\n")
for node in tqdm.tqdm(tree.depth_first_expansion()):
 
    aa_muts = translations[node.id] if node.id in translations else []
   
    for aa_mut in aa_muts:
        
        output_mut.write("\t".join([str(x) for x in [node.id, aa_mut.nt_index, aa_mut.original_nt, aa_mut.alternative_nt,  aa_mut.gene, aa_mut.aa_index, aa_mut.original_aa, aa_mut.alternative_aa, aa_mut.is_synonymous(), aa_mut.mutation_type, aa_mut.aa_string()]]) + "\n")
output_mut.close()

100%|██████████| 17849624/17849624 [00:54<00:00, 326802.03it/s]


: 

In [28]:

# write out to parenthood.tsv.gz
with xopen.xopen("parenthood.tsv.gz", "wt") as output:
    output.write("child\tparent\n")
    for node in tree.depth_first_expansion():
        if node.parent:
            output.write(f"{node.id}\t{node.parent.id}\n")
            
                 