In [2]:
import pandas as pd
# only load strain, country,date cols. Import them as strings
metadata = pd.read_csv("~/Desktop/metadata.tsv.gz", sep="\t", usecols=["strain", "country", "date"], dtype=str)


In [3]:
def date_to_month(date):
    components = date.split("-")
    if len(components) >=2:
        return components[0] + "-" + components[1]
    else:
        return "?"
# coerce to string to avoid errors
metadata["date"] = metadata["date"].astype(str)
metadata["month"] = metadata["date"].apply(date_to_month)
# filter out rows with unknown month
metadata = metadata[metadata["month"] != "?"]

In [4]:
metadata.set_index("strain", inplace=True)

In [9]:
strain_to_countrymonth = {}
for strain, row in metadata.iterrows():
    strain_to_countrymonth[strain] = row["country"] + "|" + row["month"]

In [20]:
alignment_file =  "/Users/theosanderson/Desktop/aligned.fasta.zst"
import zstandard as zstd
from Bio import SeqIO

import lzma, tqdm, random

from collections import defaultdict, Counter
total = 7e6

def identify_changing_sites():
    # loop over the alignment file, but randomly filter to just 0.1% of the sequences
    # to speed things up
    # count the characters at each site for these 
    site_counts = defaultdict(Counter)
    with zstd.open(alignment_file, "rt") as handle:
        for count, record in enumerate(tqdm.tqdm(SeqIO.parse(handle, "fasta"), total=total)):
            if random.random() < 0.001:
                for i in range(len(record.seq)):
                    character = record.seq[i]
                    site_counts[i][character] += 1
    
    # for each site, ignore Ns, then see if the top-most character is less than 95% of the total
    # if so, then we have a changing site
    changing_sites = []
    for site, counts in site_counts.items():
        counts['N'] = 0
        total = sum(counts.values())
        top_count = max(counts.values())
        if top_count < total * 0.95:
            changing_sites.append(site)
    return changing_sites


changing_sites = identify_changing_sites()

14518753it [11:05, 21803.06it/s]


In [22]:
changing_sites
# filter out <= 134 and >= 29732
changing_sites = [site for site in changing_sites if site > 134 and site < 29732]

In [25]:

def do_it():

    def get_a_defaultdict_of_counters():
        return defaultdict(Counter)

    alread_seen = set()

    by_countrymonth = defaultdict(get_a_defaultdict_of_counters)
    with zstd.open(alignment_file, "rt") as handle:
        for count, record in enumerate(tqdm.tqdm(SeqIO.parse(handle, "fasta"), total=7e6)):
            if record.id in alread_seen:
                continue
            alread_seen.add(record.id)
            if record.id in strain_to_countrymonth:
                countrymonth = strain_to_countrymonth[record.id]
                for i in changing_sites:
                    character = record.seq[i]
                    by_countrymonth[countrymonth][i][character] += 1
    return by_countrymonth
                
by_countrymonth = do_it()

14518753it [17:11, 14076.08it/s]                               


In [38]:
keys = list(by_countrymonth.keys())
uk_keys = sorted([key for key in keys if key.startswith("United Kingdom")])

from collections import defaultdict

states = {}
top_value = {}
num_transitions = defaultdict(int)

def prop_to_state(prop):
    if prop<0.4:
        return "low"
    elif prop<0.6:
        return "medium"
    else:
        return "high"


cur_states = defaultdict(str)
cur_tops = defaultdict(str)
total_transitions = defaultdict(int)
transitions = defaultdict(list)
for key in uk_keys:
    # get the prop from the top 
    counts = by_countrymonth[key]
    # get proportions for each site in descending order
    props = {site: sorted([(char, count/sum(counts[site].values())) for char, count in counts[site].items()], key=lambda x: x[1], reverse=True) for site in counts}
    # get the top prop for each site
    top_props = {site: props[site][0][1] for site in props}
    # get the state for each site
    states = {site: prop_to_state(top_props[site]) for site in top_props}
    # get the top character for each site
    top_chars = {site: props[site][0][0] for site in props}
    
    # if either the state or the top character has changed, then we have a transition
    for site in states:
        if states[site] != cur_states[site] or top_chars[site] != cur_tops[site]:
            if states[site]=="high":
                total_transitions[site] += 1
                transitions[site].append((key, states[site], top_chars[site]))
            cur_states[site] = states[site]
            cur_tops[site] = top_chars[site]
            




In [39]:
# list the sites with the most transitions
sorted(total_transitions.items(), key=lambda x: x[1], reverse=True)

[(21764, 6),
 (21765, 6),
 (21766, 6),
 (21767, 6),
 (21768, 6),
 (21769, 6),
 (28880, 6),
 (28881, 6),
 (28882, 6),
 (21991, 5),
 (21992, 5),
 (21993, 5),
 (11287, 4),
 (11288, 4),
 (11289, 4),
 (11290, 4),
 (11291, 4),
 (11292, 4),
 (11293, 4),
 (11294, 4),
 (11295, 4),
 (21617, 4),
 (21845, 4),
 (21986, 4),
 (22916, 4),
 (23062, 4),
 (23603, 4),
 (209, 3),
 (912, 3),
 (2831, 3),
 (3266, 3),
 (4180, 3),
 (5385, 3),
 (5387, 3),
 (5985, 3),
 (6401, 3),
 (6512, 3),
 (6513, 3),
 (6514, 3),
 (6953, 3),
 (7123, 3),
 (7850, 3),
 (8392, 3),
 (8985, 3),
 (9052, 3),
 (9865, 3),
 (11200, 3),
 (11284, 3),
 (11285, 3),
 (11286, 3),
 (11331, 3),
 (11536, 3),
 (13194, 3),
 (14675, 3),
 (15239, 3),
 (15278, 3),
 (15450, 3),
 (16175, 3),
 (16465, 3),
 (19219, 3),
 (21761, 3),
 (21987, 3),
 (21988, 3),
 (21989, 3),
 (21990, 3),
 (21994, 3),
 (22028, 3),
 (22029, 3),
 (22030, 3),
 (22031, 3),
 (22032, 3),
 (22033, 3),
 (22193, 3),
 (22194, 3),
 (22195, 3),
 (22672, 3),
 (22785, 3),
 (22897, 3),
 (23039

In [34]:
transitions[21769]

[('United Kingdom|2020-01', 'high', 'G'),
 ('United Kingdom|2021-01', 'high', '-'),
 ('United Kingdom|2021-06', 'high', 'G'),
 ('United Kingdom|2022-01', 'high', '-'),
 ('United Kingdom|2022-03', 'high', 'G'),
 ('United Kingdom|2022-06', 'high', '-')]

In [None]:
28880

In [35]:
transitions[28880]

[('United Kingdom|2020-01', 'high', 'G'),
 ('United Kingdom|2020-05', 'high', 'A'),
 ('United Kingdom|2020-09', 'high', 'G'),
 ('United Kingdom|2021-01', 'high', 'A'),
 ('United Kingdom|2021-06', 'high', 'T'),
 ('United Kingdom|2022-01', 'high', 'A')]