In [None]:
import pickle
import glob
import re
import os
import copy
import subprocess
import time
import datetime as dt
from datetime import datetime, timedelta, date
from collections import defaultdict, Counter, OrderedDict
from functools import reduce

import vcf
import pysam
import numpy as np
import pandas as pd
import dask.dataframe as dd
from Bio import SeqIO
from Bio.SeqUtils import seq1
from Bio.Seq import Seq

import random

In [None]:
import json
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.ticker as mtick
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

In [None]:
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
import matplotlib.colors as mcolors

In [None]:
import matplotlib.gridspec as gridspec

In [None]:
import matplotlib.dates as mdates

In [None]:
import vcf
from operator import add

In [None]:
from scipy.stats import fisher_exact

# Helper Function

## Remove Problematic Sites

In [None]:
def filter_crykey_output(population_df):
    """
    This function filters the output of crykey by
    1. removing cryptic lineage found in control samples
    2. removing cryptic lineage containing variants of problematic sites
    3. removing cryptic lineage located within 100 bp of the SARS-CoV-2 reference genome
    The function returns a filtered dataframe.
    The filtered dataframe is stored at: fixed_results/crykey_wastewater_houston.csv
    """
    output = "/home/Users/yl181/wastewater/quaid/output_incl_recombinant"
    dbs = '/home/Users/yl181/wastewater/quaid/quarc_dbs_01102023_incl_recombinant'
    
    sites = population_df['WWTP'].unique()
    merged_df = pd.read_csv(os.path.join(output, 'merged_df.csv'))
    
    merged_df['Site'] = merged_df['Site'].str.upper()
    merged_df = merged_df[merged_df.Site.isin(sites)]
    
    merged_df['Date'] = merged_df['Date'].apply(str)
    merged_df['Date'] = merged_df['Date'].apply(lambda x: x.zfill(8))
    merged_df['Date'] = merged_df['Date'].apply(lambda x: f'{x[0:2]}/{x[2:4]}/{x[4:]}')
    merged_df['Date'] = pd.to_datetime(merged_df['Date'])
    
    exclude_end_bp = 100
    exclude_sites = [187, 1059, 2094, 3037, 3130, 6990, 8022, 10323, 10741, 11074, 13408,
                     14786, 19684, 20148, 21137, 24034, 24378, 25563, 26144, 26461, 26681,
                     28077, 28826, 28854, 29700]
    exclude_sites = exclude_sites + list(np.arange(1,1+exclude_end_bp)) + list(np.arange(29903-exclude_end_bp+1, 29903+1))
    
    keep_records = []
    for idx, row in merged_df.iterrows():
        nt_mutations = row['Nt Mutations'].split(";")
        keep = True
        for mut in nt_mutations:
            pos = int(mut[1:-1])
            if pos in exclude_sites:
                keep = False
                break

        if keep:
            keep_records.append(row)
            
    merged_df = pd.DataFrame(keep_records)
    return merged_df

## Coverage Calculation

In [None]:
def coverage_date_dataframe(coverage_dir):
    """
    Sample collection dates to week starts
    Check whether we have duplicate samples for the same week
    Returns a mapping between sample collection date and week start
    """
    coverage_dates = []
    for coverage_date in os.listdir(coverage_dir):
        date = pd.to_datetime(coverage_date[3:], format='%m%d%Y')
        coverage_dates.append(date)
    coverage_dates.sort()
    coverage_df = pd.DataFrame(coverage_dates, columns=['Collection Date'])
    coverage_df['Week Start'] = coverage_df['Collection Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)
    try:
        assert len(coverage_df['Collection Date'].unique()) == len(coverage_df['Week Start'].unique())
    except AssertionError:
        weekstart2collection = defaultdict(list)
        for idx, row in coverage_df.iterrows():
            weekstart2collection[row['Week Start']].append(row['Collection Date'])
        for i in weekstart2collection:
            if len(weekstart2collection[i]) != 1:
                print(i, weekstart2collection[i])
    return coverage_df

In [None]:
def get_missing_weeks(coverage_dir):
    """
    Calculate missing weeks
    Returns a mapping between sample collection date and week start
    """
    coverage_df = coverage_date_dataframe(coverage_dir)
    
    total_week_count = int((coverage_df['Week Start'].max() - coverage_df['Week Start'].min()).days/7) + 1
    missing_week_count = total_week_count - len(coverage_df['Week Start'].unique())
    print("Total Weeks:\t", total_week_count)
    print("Missing Weeks:\t", missing_week_count)
    
    start_date = coverage_df['Week Start'].min()
    days_in_a_week = timedelta(days = 7)
    missing_weeks = []
    for i in range(total_week_count):
        if not start_date + i*days_in_a_week in coverage_df['Week Start'].unique():
            missing_weeks.append(start_date + i*days_in_a_week)
    
    return missing_weeks, coverage_df

In [None]:
def build_coverage_dataframe(coverage_date_df, population_df, missing_weeks):
    sites = list(population_df['WWTP'].unique())
    record_list = []
    index_list = []
    count = 0
    for idx, row in coverage_date_df.iterrows():
        date_dir = "HHD"+row['Collection Date'].strftime('%m%d%Y')
        for site in sites:
            coverage_f = os.path.join(coverage_dir, date_dir, f"{site}-1", f"{site}-1.clean.sorted.coverage.txt")
            if not os.path.exists(coverage_f):
                count += 1
            else:
                try:
                    temp_df = pd.read_csv(coverage_f, header=None, sep='\t', usecols = [1,2], names=['POS', 'DEP'])
                    if not temp_df.empty:
                        temp_df = pd.pivot_table(temp_df, index=['POS'], values=['DEP']).transpose()
                        index_list.append((row['Collection Date'], site))
                        # temp_df['Collection Date'] = row['Collection Date']
                        # temp_df['Site'] = site
                        record_list.append(temp_df.iloc[0].to_dict())
                    else:
                        count += 1
                except:
                    print(coverage_f)
    print("Missing Samples:", count + len(missing_weeks)*39)
    
    for missing_week in missing_weeks:
        record_list.append({0:0})
        index_list.append((missing_week, 'SS'))
    
    index = pd.MultiIndex.from_tuples(index_list, names=["Collection Date", "Site"])
    coverage_df = pd.DataFrame(record_list, index=index, dtype=pd.Int64Dtype())
    columns = coverage_df.columns.to_list()
    positions = []
    for element in columns:
        if isinstance(element, int):
            positions.append(element)
    positions.sort()
    coverage_df = coverage_df[positions].fillna(0)
    coverage_df = coverage_df.reset_index()
    coverage_df = coverage_df.sort_values(['Collection Date', 'Site'], axis=0)
    return coverage_df

In [None]:
def build_has_sample_dataframe(coverage_date_df, population_df, missing_weeks):
    sites = list(population_df['WWTP'].unique())
    record_list = []
    index_list = []
    for idx, row in coverage_date_df.iterrows():
        date_dir = "HHD"+row['Collection Date'].strftime('%m%d%Y')
        for site in sites:
            coverage_f = os.path.join(coverage_dir, date_dir, f"{site}-1", f"{site}-1.clean.sorted.coverage.txt")
            index_list.append((row['Collection Date'], site))
            if (os.path.exists(coverage_f)) and (os.path.getsize(coverage_f) > 0):
                record_list.append({'Has Sample': True})
            else:
                record_list.append({'Has Sample': False})

    
    for missing_week in missing_weeks:
        for site in sites: 
            index_list.append((missing_week, site))
            record_list.append({'Has Sample': False})
    
    index = pd.MultiIndex.from_tuples(index_list, names=["Collection Date", "Site"])
    has_sample_df = pd.DataFrame(record_list, index=index, dtype='boolean')
    has_sample_df = has_sample_df.reset_index()
    has_sample_df = has_sample_df.sort_values(['Collection Date', 'Site'], axis=0)
    
    return has_sample_df

In [None]:
def run_build_coverage_dataframe():
    coverage_dir = '/home/Users/yl181/wastewater/processed_data/Coverage'
    missing_weeks, coverage_date_df = get_missing_weeks(coverage_dir)
    coverage_df = build_coverage_dataframe(coverage_date_df, population_df, missing_weeks)
    has_sample_df = build_has_sample_dataframe(coverage_date_df, vcf_dir, population_df, missing_weeks)
    coverage_df.to_csv(os.path.join(fixed_results_dir, 'houston_coverage.csv'))
    has_sample_df.to_csv(os.path.join(fixed_results_dir, 'houston_has_sample_df.csv'))

In [None]:
gisaid_total_count = 12988494

In [None]:
max_gisaid_occurance = 0.0001 * gisaid_total_count
max_gisaid_occurance

# Load Crykey Output

In [None]:
def filter_cryptic_lineage_df(mutation_rarity_df, merged_df, gisaid_total_count):
    min_site = 2
    max_gisaid_occurance = 0.0001 * gisaid_total_count
    print('Minumum Detected Site:', "\t", min_site)
    print('Max GISAID Occurance:', "\t", max_gisaid_occurance)

    selected_data = mutation_rarity_df[(mutation_rarity_df['Site'] >= min_site) & (mutation_rarity_df['GISAID Count'] <= max_gisaid_occurance)].copy()
    selected_data = selected_data.rename({'Site': 'Mutation Occurance'}, axis=1)
    print('# the cryptic lineage after filtering:', "\t", len(selected_data))
    
    present_weeks = []
    present_sites = []
    mean_freqs = []
    max_freqs = []
    max_durations = []
    aa_mutations = []

    for idx, row in selected_data.iterrows():
        mutation_df = merged_df[merged_df['Nt Mutations'] == idx]
        present_weeks.append(len(mutation_df['Date'].unique()))    
        present_sites.append(len(mutation_df['Site'].unique()))
        mean_freqs.append(mutation_df['Combined Freq'].mean())
        max_freqs.append(mutation_df['Combined Freq'].max())
        max_durations.append(int((mutation_df['Date'].unique().max()-mutation_df['Date'].unique().min()).astype('timedelta64[D]')/np.timedelta64(1, 'D') + 1))
        aa_mutations.append(mutation_df['AA Mutations'].unique()[0])

    selected_data['AA_Mutation'] = aa_mutations
    selected_data['Present Weeks'] = present_weeks
    selected_data['Present Sites'] = present_sites
    selected_data['Mean Site Occurance'] = selected_data['Mutation Occurance']/selected_data['Present Weeks']
    selected_data['Mean Allele Freq'] = mean_freqs
    selected_data['Max Allele Freq'] = max_freqs
    selected_data['Max Durations'] = max_durations
    
    selected_data = selected_data.sort_values(by=['Mutation Occurance', 'GISAID Count'], ascending=False)
    
    return selected_data

In [None]:
def build_gisaid_rarity_dataframe(fixed_results_dir):
    file_to_read = open(os.path.join(fixed_results_dir, "query_result.pkl"), "rb")
    query_result = pickle.load(file_to_read)
    
    count_rarity_dict = dict()
    for key in query_result:
        count_rarity_dict[key] = sum(query_result[key].values())
    gisaid_count_df = pd.DataFrame.from_dict(count_rarity_dict, columns=['GISAID Count'], orient='index')
    mutation_rarity_df = site_count_df.merge(gisaid_count_df, left_index=True, right_index=True, validate='one_to_one')
    
    return mutation_rarity_df

In [None]:
fixed_results_dir = "/home/Users/yl181/wastewater/quarc_figures/fixed_results"

## Samtools Coverage DataFrame

In [None]:
%%time
coverage_df = pd.read_csv(os.path.join(fixed_results_dir, 'houston_coverage.csv'), index_col=False)
coverage_df['Collection Date'] = coverage_df['Collection Date'].apply(str)
coverage_df['Collection Date'] = pd.to_datetime(coverage_df['Collection Date'], format='%Y-%m-%d')
coverage_df = coverage_df.drop(['Unnamed: 0'], axis=1)

## Has Sample DataFrame

In [None]:
def build_has_sample_dataframe(coverage_date_df, vcf_dir, population_df, missing_weeks):
    sites = list(population_df['WWTP'].unique())
    record_list = []
    index_list = []
    for idx, row in coverage_date_df.iterrows():
        date_dir = "HHD"+row['Collection Date'].strftime('%m%d%Y')
        for site in sites:
            coverage_f = os.path.join(coverage_dir, date_dir, f"{site}-1", f"{site}-1.clean.sorted.coverage.txt")
            index_list.append((row['Collection Date'], site))
            if (os.path.exists(coverage_f)) and (os.path.getsize(coverage_f) > 0):
                record_list.append({'Has Sample': True})
            else:
                record_list.append({'Has Sample': False})

    
    for missing_week in missing_weeks:
        for site in sites: 
            index_list.append((missing_week, site))
            record_list.append({'Has Sample': False})
    
    index = pd.MultiIndex.from_tuples(index_list, names=["Collection Date", "Site"])
    has_sample_df = pd.DataFrame(record_list, index=index, dtype='boolean')
    has_sample_df = has_sample_df.reset_index()
    has_sample_df = has_sample_df.sort_values(['Collection Date', 'Site'], axis=0)
    
    return has_sample_df

In [None]:
has_sample_df = pd.read_csv(os.path.join(fixed_results_dir, 'houston_has_sample_df.csv'), index_col=False)
has_sample_df['Collection Date'] = has_sample_df['Collection Date'].apply(str)
has_sample_df['Collection Date'] = pd.to_datetime(has_sample_df['Collection Date'], format='%Y-%m-%d')
has_sample_df = has_sample_df.drop(['Unnamed: 0'], axis=1)

## Merged DataFrame

In [None]:
merged_df = pd.read_csv(os.path.join(fixed_results_dir, 'crykey_wastewater_houston.csv'), index_col=0)
merged_df['Date'] = merged_df['Date'].apply(str)
merged_df['Date'] = pd.to_datetime(merged_df['Date'], format='%Y-%m-%d')
merged_df['Week Start'] = merged_df['Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)
print("# the cryptic lineage before filtering:", len(merged_df['Nt Mutations'].unique()))

In [None]:
site_count_df = pd.DataFrame(merged_df.groupby('Nt Mutations')['Site'].count())

In [None]:
"""
Used by
    Supplementary Figure 1
    Supplementary Figure 5
"""

mutation_rarity_df = build_gisaid_rarity_dataframe(fixed_results_dir)

In [None]:
mutation_rarity_df

In [None]:
filtered_cryptic_df = filter_cryptic_lineage_df(mutation_rarity_df, merged_df, gisaid_total_count)

# Load Viral Load and Population

In [None]:
viral_load_df = pd.read_csv(os.path.join(fixed_results_dir, 'houston_viral_load.csv')).dropna()
viral_load_df['date']= pd.to_datetime(viral_load_df['date'])

In [None]:
population_df = pd.read_csv(os.path.join(fixed_results_dir, 'population.csv')).dropna()
population_df = population_df[['WWTP', 'pop']]

In [None]:
valid_WWTPs = list(population_df['WWTP'].values)

# Figure 1 - Workflow

In [None]:
'https://drive.google.com/file/d/1zxmsrMhb2vXBJgy-yQiQkBlpO-9FjK8J/view?usp=sharing'

# Figure 3 - Genomic Distribution of CLs in Wastewater Samples

In [None]:
def add_colorbar(mappable):
    last_axes = plt.gca()
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3%", pad=0.05)
    cbar = fig.colorbar(mappable, cax=cax)
    cbar.ax.get_yaxis().labelpad = 5
    cbar.ax.set_ylabel('Rarity in GISAID', rotation=90, fontsize=fontsize)
    plt.sca(last_axes)
    return cbar

In [None]:
def synonymous_mutations(filtered_cryptic_df):
    selected_data = filtered_cryptic_df.copy().rename({'AA_Mutation': 'AA Mutation'}, axis=1)
    #selected_data = selected_data.reset_index().rename({'index': 'Nt Mutation'}, axis=1)
    non_synonyms_exclusive_list = []
    for idx, row in selected_data.iterrows():
        aa_mutation_set = row['AA Mutation']
        ns_count = 0
        for aa_muts in aa_mutation_set.split(';'):
            temp = aa_muts.split(":")[-1]
            if temp[0] == temp[-1]:
                ns_count += 1
        if ns_count > 0:
            non_synonyms_exclusive_list.append(False)
        else:
            non_synonyms_exclusive_list.append(True)

    selected_data['Non-synonyms Only'] = non_synonyms_exclusive_list
    
    return selected_data

In [None]:
def get_gene_label(selected_data):
    gene_list = []
    snp_locations = []
    s_gene_count = 0
    n_gene_count = 0
    orf1a_count = 0
    orf1b_count = 0

    for idx, row in selected_data.iterrows():
        snps = idx.split(";")
        snp_locations.append(int(np.array([int(i[1:-1]) for i in snps]).mean()))
        aa_muts = row['AA Mutation'].split(';')
        gene_set = set()
        for aa_mut in aa_muts:
            gene = aa_mut.split(':')[0]
            gene_set.add(gene)

        gene_list.append(";".join(list(gene_set)))

        if 'S' in gene_set:
            s_gene_count += 1
        if 'N' in gene_set:
            n_gene_count += 1
        if 'ORF1a' in gene_set:
            orf1a_count += 1
        if 'ORF1b' in gene_set:
            orf1b_count += 1      

    selected_data['Gene'] = gene_list
    selected_data['Nt Location'] = snp_locations
    
    print('Total CL #:', '\t', selected_data.shape[0])
    
    print("S Gene CL #:", '\t', s_gene_count)
    print("S Gene CL %:", '\t', s_gene_count/selected_data.shape[0])
    print("N Gene CL #:", '\t', n_gene_count)
    print("N Gene CL %:", '\t', n_gene_count/selected_data.shape[0])
    
    print("Minor AF (<0.5) CL #:", '\t', selected_data[selected_data['Mean Allele Freq']<0.5].shape[0])
    print("Minor AF (<0.5) CL %:", '\t', selected_data[selected_data['Mean Allele Freq']<0.5].shape[0]/selected_data.shape[0])
    
    print("Consensus AF (>=0.5) CL %:", '\t', selected_data[selected_data['Mean Allele Freq']>=0.5].shape[0]/selected_data.shape[0])
    
    return selected_data

In [None]:
reference = SeqIO.read("SARS-CoV-2-reference.gb", "genbank")

In [None]:
selected_data = synonymous_mutations(filtered_cryptic_df)
selected_data = get_gene_label(selected_data)

In [None]:
len(selected_data[selected_data['Mean Allele Freq'] < 0.2])/selected_data.shape[0]

In [None]:
fig, axes  = plt.subplots(3, 1, figsize=(14, 10),
                          sharex=True,
                          gridspec_kw={'height_ratios': [1, 2.5, 1], 'hspace':0.1})

fontsize = 12
alpha = 0.5

cmap = mpl.cm.inferno_r

bin_width = 400

selected_data['Bin Index'] = selected_data['Nt Location']/bin_width
selected_data['Bin Index'] = selected_data['Bin Index'].astype(int)
selected_data['Bin Index'] = selected_data['Bin Index'] * bin_width + bin_width/2

hist_df = pd.DataFrame(selected_data.groupby(['Bin Index']).count())

bar_x = hist_df.index.to_list()
bar_y1 = hist_df['AA Mutation'].values

hist_df = pd.DataFrame(selected_data.groupby(['Bin Index'])['Non-synonyms Only'].sum())
bar_y2 = hist_df['Non-synonyms Only'].values

ax = axes[1]

rarity = -np.log10((selected_data['GISAID Count'].values+1)/gisaid_total_count)
main_plot = ax.scatter(x=selected_data['Nt Location'].values, 
                       y=selected_data['Mean Allele Freq'].values, 
                       c=rarity, 
                       s=selected_data['Present Weeks'].values*10,
                       cmap=cmap,
                       alpha=alpha)

ax.set_xlim(0, 29903+1)
#ax.set_xticks(np.arange(0, 29903+1, 5000))

ax.set_ylim(0,1.1)
ax.set_yticks(np.arange(0, 11, 2)/10)
ax.set_ylabel('Mean AF of CRs', fontsize=fontsize)

ax.xaxis.set_tick_params(labelsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)

add_colorbar(main_plot)

handles, labels = main_plot.legend_elements(prop="sizes", alpha=alpha, num=4, func = lambda x: x/10)
legend = ax.legend(handles, labels, loc="upper left", title="Weeks of detection", ncol=4, framealpha=1, borderpad=0.6)#, bbox_to_anchor=(0.5,1))
legend.get_frame().set_facecolor('none')

ax.text(-0.07, 1.05, 'a', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')

ax = axes[0]
colors = plt.get_cmap('Set3').colors

gene_label_h = 0.5
gene_label_y = -gene_label_h/2
label_top = gene_label_h/2*1.5
counter = 0

for seq_feature in reference.features:
    if seq_feature.type == 'gene':
        
        gene_name = seq_feature.qualifiers['gene'][0]
        
        rectangle = Rectangle((seq_feature.location.start, gene_label_y), seq_feature.location.end-seq_feature.location.start, gene_label_h,
                              fill=True,
                              facecolor=colors[counter],
                              edgecolor='black')
        ax.add_patch(rectangle)
        
        if counter != 0:
            rx, ry = rectangle.get_xy()
            cx = rx + rectangle.get_width()/2.0
            cy = ry + rectangle.get_height()/2.0 + label_top
            if label_top > 0:
                va = 'top'
            else:
                va = 'bottom'
            ax.annotate(gene_name, (cx, cy), color='black', weight='normal', fontsize=fontsize-2, ha='center', va=va, rotation=90)
        else:
            rx, ry = rectangle.get_xy()
            cx = rx + rectangle.get_width()/2.0
            cy = ry + rectangle.get_height()/2.0 - label_top
            ax.annotate(gene_name, (cx, cy), color='black', weight='normal', fontsize=fontsize-2, ha='center', va='bottom', rotation=90)
        
        counter += 1
        label_top = label_top * -1
        

ax.set_ylim(-label_top*3,label_top*3)
comb_pos = -2.5

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="3%", pad=0.05)

ax.axis('off')
cax.axis('off')

ax = axes[2]

ax.bar(bar_x, bar_y1, width=bin_width*0.85, color="black", alpha=0.6, label='Including Synonymous Mutations')
ax.bar(bar_x, bar_y2, width=bin_width*0.85, color="C1", alpha=1, label='Non-synonymous Mutations Only')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="3%", pad=0.05)

#ax.axis('off')
cax.axis('off')
ax.set_xticks(np.arange(0, 29903+1, 5000))
ax.set_xlabel('Position on Reference Genome', fontsize=fontsize)
ax.set_ylim(0,100)
ax.set_ylabel('CR Count', fontsize=fontsize)

ax.xaxis.set_tick_params(labelsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)

ax.legend(loc='upper left', borderpad=0.6)
ax.text(-0.07, 1.05, 'b', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')
#fig.show()

fig.savefig('/home/Users/yl181/wastewater/quarc_figures/pdf/figure_3.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Figure 4 - Longitude Distribution of CLs in Wastewater Samples

In [None]:
def get_first_occ_df(merged_df, filtered_cryptic_df):
    first_occ_df = filtered_cryptic_df.merge(merged_df.sort_values('Date')[['Date', 'Nt Mutations']].drop_duplicates('Nt Mutations'), left_index=True, right_on='Nt Mutations').set_index('Nt Mutations')
    return first_occ_df

In [None]:
def load_breadth_coverage_df():
    coverage_df = pd.read_csv(os.path.join(fixed_results_dir, 'houston_breadth_coverage_df_v2.csv'))
    coverage_df['Date'] = pd.to_datetime(coverage_df['Date'])
    min_coverage = coverage_df[coverage_df['Date'] >= pd.Timestamp('2021-05-01')]['Breadth_Coverage'].min()
    max_coverage = coverage_df[coverage_df['Date'] >= pd.Timestamp('2021-05-01')]['Breadth_Coverage'].max()
    print(min_coverage, max_coverage)
    coverage_df['Width'] = (coverage_df['Breadth_Coverage']-min_coverage)/(max_coverage-min_coverage)*0.9 + 0.1
    #coverage_df['Width'] = coverage_df['Breadth_Coverage'] * 1.6
    coverage_df[coverage_df['Date'] >= pd.Timestamp('2021-05-01')]['Width'].max()
    
    return coverage_df

In [None]:
tx_voc_df = pd.read_csv(os.path.join(fixed_results_dir, 'texas_gisaid_voc.csv'), index_col=False)
tx_voc_df['Week'] = pd.to_datetime(tx_voc_df['Week'])

In [None]:
first_occ_df = get_first_occ_df(merged_df, filtered_cryptic_df)

In [None]:
first_occ_df = synonymous_mutations(first_occ_df)

In [None]:
breadth_coverage_df = load_breadth_coverage_df()

In [None]:
assert breadth_coverage_df[breadth_coverage_df['Date'] >= pd.Timestamp('2021-05-01')]['Width'].max() == 1

In [None]:
time_df = pd.DataFrame(first_occ_df.groupby(['Date'])['AA Mutation'].count())

bar_x = time_df.index.to_list()
bar_y1 = time_df['AA Mutation'].values
time_min = time_df.index.min()
time_max = time_df.index.max()

time_df = pd.DataFrame(first_occ_df.groupby(['Date'])['Non-synonyms Only'].sum())
time_df = time_df.merge(breadth_coverage_df, left_index=True, right_on=['Date'])
bar_y2 = time_df['Non-synonyms Only'].values
bar_w = time_df['Width'].values

fontsize = 14
fig, axes  = plt.subplots(2, 1, figsize=(14, 8), sharex=True,
                         gridspec_kw={'height_ratios': [2, 1], 'hspace':0.15})

ax = axes[0]

ax.bar(bar_x, bar_y1, width=6*bar_w, label='Including Synonymous Mutations', color="black", alpha=0.7)
ax.bar(bar_x, bar_y2, width=6*bar_w, label='Non-synonymous Mutations Only', color='C1')

ax.set_xlim(pd.Timestamp('2021-05-01')-timedelta(days=3), time_max+timedelta(days=3))
ax.set_ylim(0,70)
ax_load = ax.twinx()
ax_load.plot(viral_load_df['date'], viral_load_df['Spline_WW_Percent_10'], 'k--')
ax_load.set_ylim(0,1000)

ax.tick_params(axis='both', which='major', labelsize=fontsize)

ax.set_ylabel('Count of CRs', fontsize=fontsize)
ax_load.yaxis.set_major_formatter(mtick.PercentFormatter())
ax_load.set_ylabel('Viral Load', fontsize=fontsize)
ax_load.tick_params(axis='y', which='major', labelsize=fontsize)
ax.legend(fontsize=fontsize)
ax.text(-0.12, 1.05, 'a', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')

ax.set_title('Newly detected CRs in Houston wastewater', fontsize=fontsize)

ax = axes[1]

voc_df = pd.pivot_table(tx_voc_df, values='Accession ID', index='Week', columns=['VOC'], aggfunc='count', fill_value=0)
ax.stackplot(voc_df.index, voc_df['Delta'], voc_df['Omicron'], voc_df['BA.2'], voc_df['BA.5'], voc_df['Other'],
            labels = ['Delta', 'Omicron', 'BA.2', 'BA.5', 'Others'])


ax.legend()
ax.tick_params(axis='both', which='major', labelsize=fontsize)
ax.set_title('Texas sequences in GISAID EpiCoV', fontsize=fontsize)

ax.set_ylabel('Count of sequences', fontsize=fontsize)
ax.set_ylim(0,6000)
ax.legend(fontsize=fontsize, ncol=1, loc=2, bbox_to_anchor=(1,1))
ax.text(-0.12, 1.05, 'b', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')

ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))

fig.savefig('/home/Users/yl181/wastewater/quarc_figures/pdf/figure_4.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Figure 6 - Long lasting CR12

In [None]:
def select_single_cryptic(target_mutation, merged_df):
    selected_data = merged_df[merged_df['Nt Mutations'] == target_mutation].copy()
    selected_data = selected_data.merge(population_df, left_on='Site', right_on='WWTP')
    date_mean_freqs = selected_data.groupby(['Date'])['Combined Freq'].mean()
    date_site_count = selected_data.groupby(['Date'])['Site'].count()
    date_pop_sum = selected_data.groupby(['Date'])['pop'].sum()
    aa_mut_label = selected_data['AA Mutations'].unique()[0]
    
    selected_data['Week Start'] = selected_data['Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)
    assert len(selected_data['Date'].unique()) == len(selected_data['Week Start'].unique())
    
    return selected_data, date_site_count, aa_mut_label

In [None]:
def build_select_coverage_df(coverage_df, target_mutation, start_date, end_date, min_dp=10):
    positions = []
    for mut in target_mutation.split(";"):
        positions.append(mut[1:-1])
        
    selected_coverage_df = coverage_df[['Collection Date', 'Site'] + positions].copy()
    selected_coverage_df['Week Start'] = selected_coverage_df['Collection Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)
    
    days_in_a_week = timedelta(days = 7)
    total_week_count = int((end_date - start_date).days/7) + 1
    selected_dates = []
    for i in range(total_week_count):
        selected_dates.append(str(start_date + i*days_in_a_week))
    selected_coverage_df = selected_coverage_df[selected_coverage_df['Week Start'].isin(selected_dates)]
    
    valid_coverage = []
    for idx, row in selected_coverage_df.iterrows():
        is_valid = True
        for pos in positions:
            is_valid = is_valid and (row[pos] >= min_dp)
        valid_coverage.append(is_valid)
    selected_coverage_df['Not Valid'] = np.invert(valid_coverage)
        
    return selected_coverage_df

In [None]:
def build_select_missing_sample_df(has_sample_df, start_date, end_date):  
    selected_has_sample_df = has_sample_df[['Collection Date', 'Site', 'Has Sample']].copy()
    selected_has_sample_df['Missing Sample'] = np.invert(selected_has_sample_df['Has Sample'])
    selected_has_sample_df['Week Start'] = selected_has_sample_df['Collection Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)
    
    days_in_a_week = timedelta(days = 7)
    total_week_count = int((end_date - start_date).days/7) + 1
    selected_dates = []
    for i in range(total_week_count):
        selected_dates.append(str(start_date + i*days_in_a_week))
    selected_has_sample_df = selected_has_sample_df[selected_has_sample_df['Week Start'].isin(selected_dates)]
    
    selected_has_sample_df = selected_has_sample_df.pivot_table(values=['Missing Sample'], index=['Site'], columns=['Week Start'])
    selected_has_sample_df.columns = selected_has_sample_df.columns.droplevel()
    return selected_has_sample_df

In [None]:
def build_figure5_data(target_mutation, merged_df, coverage_df, has_sample_df, start_date=date(2021, 4, 15), end_date=date(2022, 10, 1)):
    selected_cr_df, date_site_count, aa_mut_label = select_single_cryptic(target_mutation, merged_df)

    start_date = start_date - timedelta(days=start_date.weekday())
    end_date = end_date - timedelta(days=end_date.weekday())

    total_week_count = int((end_date - start_date).days/7) + 1

    selected_coverage_df = build_select_coverage_df(coverage_df, target_mutation, start_date, end_date)
    coverage_mask = selected_coverage_df.pivot_table(values=['Not Valid'], index=['Site'], columns=['Week Start']).fillna(True)
    coverage_mask.columns = coverage_mask.columns.droplevel()
    selected_cr_heatmap_array = selected_cr_df.pivot_table(values=['Combined Freq'], index=['Site'], columns=['Week Start'])
    selected_cr_heatmap_array.columns = selected_cr_heatmap_array.columns.droplevel()
    for i in list(set(coverage_mask.columns) - set(selected_cr_heatmap_array.columns)):
        selected_cr_heatmap_array[i] = 0
    selected_cr_heatmap_array = selected_cr_heatmap_array[coverage_mask.columns].fillna(0)
    sorted_index = selected_cr_heatmap_array.astype(bool).sum(axis=1).sort_values(ascending=False).index
    sorted_index = selected_cr_heatmap_array.loc[sorted_index].ne(0).idxmax(axis=1).sort_values(ascending=False).index
    bar_plot_df = pd.DataFrame(selected_cr_heatmap_array.astype(bool).sum(axis=0), 
                                   columns=['Count'])
    
    missing_sample_df = build_select_missing_sample_df(has_sample_df, start_date, end_date)
    
    missing_sample_df = missing_sample_df.loc[sorted_index]
    sorted_heatmap_array = selected_cr_heatmap_array.loc[sorted_index]
    sorted_mask = coverage_mask.loc[sorted_index]

    try:
        assert np.logical_and(sorted_mask.values, sorted_heatmap_array.astype(bool).values).sum() == False
    except:
        print("Inconsistency between coverage data between cryptic lineage data.")
        
    return selected_cr_df, date_site_count, aa_mut_label, sorted_heatmap_array, sorted_mask, missing_sample_df, bar_plot_df, selected_coverage_df

In [None]:
selected_cryptic_mutations = ['A26530G;C26577G;G26634A',
                              'C6402T;G6456A',
                              'A26530G;C26533T;C26542A;T26545G',
                              'A26530G;T26545G',
                              'A27259C;C27335T;A27344T;A27345T',
                              'T29029C;A29039T',
                              'A26530G;C26577G;C26625A',
                              'C10449A;T10459C',
                              'T15682A;T15685A',
                              'A24966T;C25000T',
                              'A27344T;A27345T;A27354G',
                              'A29039T;G29049A']

In [None]:
filtered_cryptic_df[filtered_cryptic_df['Present Weeks'] > 10]

In [None]:
target_mutation = selected_cryptic_mutations[12-1]

In [None]:
selected_cr_df, date_site_count, aa_mut_label, selected_cr_heatmap_array, sorted_mask, missing_sample_df, bar_plot_df, selected_coverage_df = \
build_figure5_data(target_mutation, merged_df, coverage_df, has_sample_df)
depth_df = selected_coverage_df[selected_coverage_df['Not Valid'] == False].groupby('Week Start')[['29039', '29049']].mean()
depth_df = depth_df.merge(bar_plot_df, how='right', left_index=True, right_index=True)
depth_df['depth_total'] = (depth_df['29039'] + depth_df['29049'])/2

In [None]:
# Customized Colormap, where 0-0.1 is light blue.
vmin = 0
vmax = 0.15
v_lower_threshold = 0.01
total_colors = 300
v_lower_num_colors = int((v_lower_threshold/vmax)*total_colors)
skip_min_colors = 0.1

cmap_reds = plt.get_cmap('YlOrRd')
colors = ['lightblue']*v_lower_num_colors +\
[cmap_reds((i+total_colors/(1-skip_min_colors)*skip_min_colors+v_lower_num_colors) / total_colors/(1-skip_min_colors)) for i in range(0, total_colors-v_lower_num_colors)]
cmap = LinearSegmentedColormap.from_list('', colors, total_colors)

# Master Layout
fontsize = 12
fig, axes  = plt.subplots(2, 2, figsize=(15, 10),
                          sharex=False,
                          gridspec_kw={'height_ratios': [4, 1],
                                       'width_ratios': [37, 1],
                                       'wspace': 0.04,
                                       'hspace': 0.15})


# Heatmap
ax = axes[0][0]
ax1 = sns.heatmap(selected_cr_heatmap_array, 
                  mask=sorted_mask,
                  ax=axes[0,0],
                  cbar_ax=axes[0,1],
                  linewidths=0.2,
                  cmap=cmap,
                  cbar_kws={'label': 'Allele Frequency'
                            },
                  square=False,
                  vmin=vmin,
                  vmax=vmax)

sns.heatmap(missing_sample_df, 
            cmap=ListedColormap(['white']), 
            linecolor='white', 
            linewidths=0.2,
            cbar=False, mask=(missing_sample_df != 1),
            ax=axes[0,0])

ax1.axhline(y = 0, color='k',linewidth = 2)
ax1.axhline(y = selected_cr_heatmap_array.shape[0], color = 'k', linewidth = 2)
ax1.axvline(x = 0, color = 'k',linewidth = 2)
ax1.axvline(x = selected_cr_heatmap_array.shape[1], color = 'k', linewidth = 2)

ax.set_title(f'{target_mutation} ({aa_mut_label}) Detection in Houston Wastewater', fontsize=fontsize)

axes[0,1].set_ylabel("Allele Frequency",size=12);

axes[0,0].set_facecolor(mcolors.CSS4_COLORS['lightgrey'])
axes[0,0].set_ylabel("WWTP")
axes[0,0].set_xlabel("")
axes[0,0].set(xticklabels=[],  xlabel=None)
ax1.tick_params(left=False, bottom=False)

# Colormap for the heatmap
axes[0,1].set_facecolor("w")
axes[0,1].xaxis.label.set_size(fontsize)

ax.text(-0.07, 1.05, 'a', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')

# Customized Legend for the heatmap
legend_elements = [Line2D([0], [0], marker='s', color='w', label='Not Detected',
                          markerfacecolor=cmap(0), markersize=8, markeredgecolor=cmap(0)),
                   Line2D([0], [0], marker='s', color='w', label='No Coverage',
                          markerfacecolor=mcolors.CSS4_COLORS['lightgrey'], markersize=8, markeredgecolor=mcolors.CSS4_COLORS['lightgrey']),
                   Line2D([0], [0], marker='s', color='w', label='Missing Sample',
                          markerfacecolor='w', markersize=8, markeredgecolor='black')]
legend = axes[0,0].legend(handles=legend_elements, fontsize=fontsize, loc='upper center', bbox_to_anchor=(0.5, 0), ncol=len(legend_elements), frameon=False)
frame = legend.get_frame()
frame.set_facecolor('w')

# Bar plot
ax = axes[1][0]

x_ticklabels = bar_plot_df.index.strftime('%Y-%m-%d').to_list()
x_tick_pos = [i + 0.5 for i in range(len(bar_plot_df.index))]

ax.bar(x_tick_pos, bar_plot_df['Count'].values, width=0.8, color='g')

ax_twinx = ax.twinx()
ax_twinx.plot(x_tick_pos, depth_df['depth_total'].fillna(0).values, color='b', linestyle='dashed')
ax_twinx.set_ylim(0, depth_df['depth_total'].fillna(0).values.max()*1.2)
ax_twinx.set_ylabel('Mean Coverage', fontsize=fontsize)

ax.set_xlim([0, len(bar_plot_df.index)])
ax.set_xticks(x_tick_pos[::4])
ax.set_xticklabels(x_ticklabels[::4], rotation=90, fontsize=fontsize)

ax.set_yticks(np.arange(0,20,5))
ax.set_yticklabels(['   0', '   5', '  10', '  15'], fontsize=fontsize)
ax.set_ylabel('Site Count', fontsize=fontsize)
ax.set_ylim(0, bar_plot_df['Count'].values.max()*1.2)

ax.text(-0.07, 1.15, 'b', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')

# Remove extract plot
axes[1,1].set_visible(False)

fig.savefig('/home/Users/yl181/wastewater/quarc_figures/pdf/figure_6.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Figure 7 - CR3 & CR 8

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
from itertools import chain

In [None]:
def padding_date(date_string):
    return str(date_string).zfill(8)

## VCFs

In [None]:
def parse_wastewater_vcf(has_sample_df, valid_WWTPs):
    vcf_metadata = pd.read_csv('/home/Users/yl181/crykey_bu/input_metadata.tsv', sep='\t')
    vcf_metadata['Date'] = vcf_metadata['Sample_Collection_Date'].map(padding_date)
    vcf_metadata['Date'] = pd.to_datetime(vcf_metadata['Date'], format='%m%d%Y')
    vcf_metadata['Week Start'] = vcf_metadata['Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)
    
    #print(vcf_metadata.columns)
    
    records = []
    for idx, row in vcf_metadata.iterrows():
        try:
            if row['WWTP'] in valid_WWTPs and has_sample_df[(has_sample_df['Collection Date'] == row['Date']) & (has_sample_df['Site'] == row['WWTP'])]['Has Sample'].values[0]:
                vcf_path = row['VCF']
                if os.path.exists(vcf_path):
                    for record in vcf.Reader(open(vcf_path, 'r')):
                        ref = str(record.REF)
                        pos = str(record.POS)
                        alt = str(record.ALT[0])

                        if len(ref) == 1 and len(alt) == 1:
                            mut = ref+pos+alt
                            records.append({'Week Start': row['Week Start'],
                                            'WWTP': row['WWTP'],
                                'Nt Mutation': mut,
                                'Depth': int(record.INFO['DP']),
                                'AF': float(record.INFO['AF'])})
                    records.append({'Week Start': row['Week Start'],
                                'WWTP': row['WWTP'],
                                'Nt Mutation': 'Valid_VCF',
                                'Depth': 0,
                                'AF': 0})
        except:
            continue
    vcf_df = pd.DataFrame(records)
    return vcf_df

In [None]:
wastewater_vcf_df = parse_wastewater_vcf(has_sample_df, valid_WWTPs)

In [None]:
def parse_clinical_vcf():
    clinical_metadata_df = pd.read_csv('/home/Users/yl181/wastewater/quarc_clinical_sampling/PRJNA764181.filtered.csv')
    clinical_metadata_df['Date'] = pd.to_datetime(clinical_metadata_df['Collection_Date'], format='%Y-%m-%d')
    clinical_metadata_df['Week Start'] = clinical_metadata_df['Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)
    
    valid_samples = []
    records = []
    for idx, row in clinical_metadata_df.iterrows():
        run_id = row['Run']
        vcf_path = f'/home/Users/yl181/wastewater/quarc_clinical_sampling/Harvest_Variant_Outputs_Houston/{idx}_out/vcf_files_filtered/{run_id}.vcf'
        if os.path.exists(vcf_path):
            valid_samples.append(True)
            for record in vcf.Reader(open(vcf_path, 'r')):
                ref = str(record.REF)
                pos = str(record.POS)
                alt = str(record.ALT[0])

                if len(ref) == 1 and len(alt) == 1:
                    mut = ref+pos+alt
                    records.append({'Week Start': row['Week Start'],
                                    'WWTP': run_id,
                        'Nt Mutation': mut,
                        'Depth': int(record.INFO['DP']),
                        'AF': float(record.INFO['AF'])})
            records.append({'Week Start': row['Week Start'],
                    'WWTP': run_id,
                    'Nt Mutation': 'Valid_VCF',
                    'Depth': 0,
                    'AF': 0})
        else:
            valid_samples.append(False)
            #print(vcf_path)
    vcf_df = pd.DataFrame(records)
    clinical_metadata_df['Valid_VCF'] = valid_samples
    return vcf_df, clinical_metadata_df

In [None]:
%%time
clinical_vcf_df, clinical_metadata_df = parse_clinical_vcf()

## Individual SNV DataFrame CR12

In [None]:
def filtering_clinical_samples(clinical_cryptic_df, strand_bias=False):
    crykey_calls = []
    if strand_bias:
        p_values = []
        sb_values = []
    
    for idx, row in clinical_cryptic_df.iterrows():
        supp_dp = row['Support DP']
        total_dp = row['Total DP']
            
        if strand_bias:
            dp1, dp2, dp3, dp4 = row['DP1'], row['DP2'], row['DP3'], row['DP4']
            if dp3+dp4 > 0 and (dp3/(dp3+dp4) > 0.85 or dp3/(dp3+dp4) < 1-0.85):
                p_value = fisher_exact([[dp1, dp2], [dp3, dp4]])[1]
                try:
                    sb_value = abs((dp3/(dp1+dp3)) - (dp4/(dp2+dp4)))/((dp3+dp4)/(dp1+dp2+dp3+dp4))
                except ZeroDivisionError:
                    sb_value = 0
            else:
                p_value = 1
                sb_value = 0  
            p_values.append(p_value)
            sb_values.append(sb_value)
        
        if row['Crykey'] == True:
            crykey_call = (supp_dp >= 5) and (supp_dp/total_dp >= 0.02)
            if strand_bias:
                crykey_call = crykey_call and (sb_value < 1) and (p_value > 0.05)
                
        else:
            crykey_call = False
        crykey_calls.append(crykey_call)
            
    clinical_cryptic_df['Crykey'] = crykey_calls
    if strand_bias:
        clinical_cryptic_df['SB_p_value'] = p_values
        clinical_cryptic_df['SB'] = sb_values
    
    return clinical_cryptic_df

In [None]:
def parse_wastewater_independent_snv_df(merged_df, valid_WWTPs, from_source=False):
    if from_source:
        independent_snv_df = pd.read_csv('/home/Users/yl181/wastewater/crykey_wastewater_split/ww_20_split_af_all_dates.csv')
        independent_snv_df['Date'] = independent_snv_df['Date'].map(padding_date)
        independent_snv_df['Date'] = pd.to_datetime(independent_snv_df['Date'], format='%m%d%Y')
        independent_snv_df['Week Start'] = independent_snv_df['Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)

        independent_snv_df = independent_snv_df[independent_snv_df.Site.isin(valid_WWTPs)]

        crykey_calls = []
        for idx, row in independent_snv_df.iterrows():
            if merged_df[(merged_df['Nt Mutations'] == row['Nt Mutations']) & (merged_df['Week Start'] == row['Week Start']) & (merged_df['Site'] == row['Site'])].shape[0] == 1:
                crykey_calls.append(True)
            else:
                crykey_calls.append(False)
        independent_snv_df['Crykey'] = crykey_calls
    else:
        independent_snv_df = pd.read_csv(os.path.join(fixed_results_dir, 'wastewater_12crs_independent_snvs.csv'), index_col=False)
        independent_snv_df['Date'] = pd.to_datetime(independent_snv_df['Date'], format='%Y-%m-%d')
        independent_snv_df['Week Start'] = pd.to_datetime(independent_snv_df['Week Start'], format='%Y-%m-%d')
        independent_snv_df = independent_snv_df.drop(['Unnamed: 0'], axis=1)
    return independent_snv_df

In [None]:
def parse_clinical_independent_snv_df(from_source=False):
    if from_source:
        clinical_metadata_df = pd.read_csv('/home/Users/yl181/wastewater/quarc_clinical_sampling/PRJNA764181.filtered.csv')
        clinical_metadata_df = clinical_metadata_df[['Run', 'Collection_Date']].copy().set_index('Run')

        clinical_independent_snv_df = pd.read_csv('/home/Users/yl181/wastewater/crykey_wastewater_split/houston_20_split_af.csv')
        clinical_independent_snv_df['Date'] = clinical_independent_snv_df['Site'].map(clinical_metadata_df['Collection_Date'].to_dict())
        clinical_independent_snv_df['Date'] = pd.to_datetime(clinical_independent_snv_df['Date'], format='%Y-%m-%d')
        clinical_independent_snv_df['Week Start'] = clinical_independent_snv_df['Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)

        crykey_calls = []
        for idx, row in clinical_independent_snv_df.iterrows():
            if row['Support DP'] >= 5 and row['Combined Freq'] >= 0.01:   
                crykey_filter = True
            else:
                crykey_filter = False

            if crykey_filter:
                vcf_all_called = True
                for nt in row['Nt Mutations'].split(';'):
                    vcf_all_called = clinical_vcf_df[(clinical_vcf_df['Run'] == row['Site']) & (clinical_vcf_df['Nt Mutation'] == nt)].shape[0] * vcf_all_called
                if vcf_all_called:
                    crykey_calls.append(True)
                else:
                    crykey_calls.append(False)
            else:
                crykey_calls.append(False)

        clinical_independent_snv_df['Crykey'] = crykey_calls
    else:
        clinical_independent_snv_df = pd.read_csv(os.path.join(fixed_results_dir, 'clinical_12crs_independent_snvs.csv'), index_col=False)
        clinical_independent_snv_df['Date'] = pd.to_datetime(clinical_independent_snv_df['Date'], format='%Y-%m-%d')
        clinical_independent_snv_df['Week Start'] = pd.to_datetime(clinical_independent_snv_df['Week Start'], format='%Y-%m-%d')
        clinical_independent_snv_df = clinical_independent_snv_df.drop(['Unnamed: 0'], axis=1)
        
    return clinical_independent_snv_df

In [None]:
wastewater_isnv_df = parse_wastewater_independent_snv_df(merged_df, valid_WWTPs)

In [None]:
clinical_isnv_df = parse_clinical_independent_snv_df()

In [None]:
%%time
clinical_isnv_df = filtering_clinical_samples(clinical_isnv_df, strand_bias=True)

## Lineage

In [None]:
def group_lineage(idx):
    if idx.startswith('AY.'):
        return 'Delta'
    elif idx == 'BA.1.15' or idx.startswith('BA.1.15.'):
        return 'BA.1.15'
    elif idx == 'BA.1.17' or idx.startswith('BA.1.17.'):
        return 'BA.1.17'
    elif idx == 'BA.1.18' or idx.startswith('BA.1.18.'):
        return 'BA.1.18'
    elif idx == 'BA.1.20' or idx.startswith('BA.1.20.'):
        return 'BA.1.20'
    elif idx == 'BA.1.1' or idx.startswith('BA.1.1.'):
        return 'BA.1.1'
    else:
        if idx.startswith('BA') or idx == 'B.1.1.529':
            return 'Omicron'
        else:
            print('Error', idx)

In [None]:
houston_lineage_df = pd.read_csv(os.path.join(fixed_results_dir, 'houston_lineage_report.csv'))
houston_lineage_df = houston_lineage_df[['taxon', 'lineage']]
houston_lineage_dict = houston_lineage_df.set_index('taxon')['lineage'].to_dict()

In [None]:
def lineage_plot(clinical_isnv_df, target_mutation, ax):
    category_names = ['Delta', 'BA.1.1', 'BA.1.15', 'BA.1.17', 'BA.1.18', 'BA.1.20', 'Omicron']
    category_colors = plt.get_cmap('RdYlGn')(np.linspace(0.15, 0.85, len(category_names)))

    clinical_support_df = clinical_isnv_df[clinical_isnv_df['Crykey'] == True]
    clinical_support_df['Lineage'] = clinical_support_df['Site'].map(houston_lineage_dict)
    clinical_support_df['Lineage Label'] = clinical_support_df['Lineage'].map(group_lineage)
    lineage_counter = Counter(clinical_support_df[clinical_support_df['Nt Mutations'] == target_mutation]['Lineage Label'].values)
    
    ax.invert_yaxis()
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    ax.set_xlim(0, sum(lineage_counter.values()))
    starts = 0
    for i, (colname, color) in enumerate(zip(category_names, category_colors)):
        widths = lineage_counter[colname]
        rects = ax.barh(0, widths, left=starts, height=0.5,
                        label=colname, color=color)
        starts = starts+widths

    ax.legend(title='Lineages of CR supported clinical samples', ncol=len(category_names), bbox_to_anchor=(0.5, 1),
              loc='lower center')
    ax.set_ylim(-0.2,0.2)
    return lineage_counter

## Generate Plot Data

In [None]:
def get_isnv_plot_data(target_mutation, vcf_df, valid_samples, min_depth=10, start_date=date(2021, 3, 1), end_date=date(2022, 11, 3)):
    """
    For individual SNV plot
    """
    def get_weekly_stats(isnv, weekly_vcf_df, valid_samples, af_min=0.02):
        afs = np.array(weekly_vcf_df[(weekly_vcf_df['Nt Mutation'] == isnv) & (weekly_vcf_df['AF'] >= af_min) & (weekly_vcf_df['WWTP'].isin(valid_samples))]['AF'].values)
        valid_sample_count = len(valid_samples)
        try:
            prevalence = len(afs)/valid_sample_count
            if prevalence>1:
                print(isnv, len(afs), valid_sample_count)
        except ZeroDivisionError:
            prevalence = np.nan
            
        if len(afs) > 0:
            min_af = abs(np.nanmin(afs) - afs.mean())
            max_af = abs(np.nanmax(afs) - afs.mean())
        else:
            min_af = np.nan
            max_af = np.nan
            
        return afs.mean(), min_af, max_af, afs.std(), prevalence
    
    start_date = start_date - timedelta(days=start_date.weekday())
    end_date = end_date - timedelta(days=end_date.weekday())
    total_week_count = int((end_date - start_date).days/7) + 1
    week_offset = timedelta(days = 7)
    
    isnv_mean_dict = defaultdict(list)
    isnv_min_dict = defaultdict(list)
    isnv_max_dict = defaultdict(list)
    isnv_std_dict = defaultdict(list)
    isnv_prev_dict = defaultdict(list)
    for i in range(0, total_week_count):
        week_start = start_date + i*week_offset        
        weekly_vcf_df = vcf_df[vcf_df['Week Start'] == pd.to_datetime(week_start)]
        for isnv in target_mutation.split(";"):
            mean, min_af, max_af, std, prevalence = get_weekly_stats(isnv, weekly_vcf_df, valid_samples[i], af_min=0.02)
            isnv_mean_dict[isnv].append(mean)
            isnv_min_dict[isnv].append(min_af)
            isnv_max_dict[isnv].append(max_af)
            isnv_std_dict[isnv].append(std)
            isnv_prev_dict[isnv].append(prevalence)
    return isnv_mean_dict, isnv_min_dict, isnv_max_dict, isnv_std_dict, isnv_prev_dict

In [None]:
def get_cr_plot_data(target_mutation, isnv_df, min_depth=10, start_date=date(2021, 3, 1), end_date=date(2022, 11, 3)):
    """
    For CR plots
    """
    start_date = start_date - timedelta(days=start_date.weekday())
    end_date = end_date - timedelta(days=end_date.weekday())
    total_week_count = int((end_date - start_date).days/7) + 1
    week_offset = timedelta(days = 7)
    
    selected_isnv_df = isnv_df[isnv_df['Nt Mutations'] == target_mutation].copy()

    means = []
    mins = []
    maxs = []
    stds = []
    coverages = []
    prevalences = []
    
    total_sample_counts = []
    valid_sample_counts = []
    valid_samples = []
    detected_sample_counts = []
    
    for i in range(0, total_week_count):
        week_start = start_date + i*week_offset 
        
        weekly_selected_isnv_df = selected_isnv_df[selected_isnv_df['Week Start'] == pd.to_datetime(week_start)]
        total_sample_count = weekly_selected_isnv_df.shape[0]
        weekly_selected_isnv_df = weekly_selected_isnv_df[isnv_df['Total DP'] >= min_depth]
        valid_sample_count = weekly_selected_isnv_df.shape[0]
        valid_sample = weekly_selected_isnv_df['Site'].values
        
        afs = list(weekly_selected_isnv_df[(weekly_selected_isnv_df['Crykey'] == True)]['Combined Freq'].values)
        mean = np.nanmean(afs)
        std = np.nanstd(afs)
        valid_sample_count = weekly_selected_isnv_df.shape[0]
        try:
            prevalence = len(afs)/valid_sample_count
        except ZeroDivisionError:
            prevalence = np.nan
        coverage = weekly_selected_isnv_df['Total DP'].mean()
        
        means.append(mean)
        try:
            mins.append(abs(np.nanmin(afs)-mean))
            maxs.append(abs(np.nanmax(afs)-mean))
        except ValueError:
            mins.append(np.nan)
            maxs.append(np.nan)
        stds.append(std)
        prevalences.append(prevalence)
        coverages.append(coverage)
        
        total_sample_counts.append(total_sample_count)
        valid_sample_counts.append(valid_sample_count)
        valid_samples.append(valid_sample)
        detected_sample_counts.append(len(afs))
        
    return means, mins, maxs, stds, prevalences, coverages, total_sample_counts, valid_sample_counts, detected_sample_counts, valid_samples

## Figure

In [None]:
def date_labels(start_date=date(2021, 3, 1), end_date=date(2022, 11, 3)):
    date_labels = []
    start_date = start_date - timedelta(days=start_date.weekday())
    end_date = end_date - timedelta(days=end_date.weekday())
    total_week_count = int((end_date - start_date).days/7) + 1
    week_offset = timedelta(days = 7)
    
    for i in range(0, total_week_count):
        start = start_date + i*week_offset
        date_labels.append(start.strftime('%Y-%m-%d'))
        
    return date_labels

In [None]:
date_labels = date_labels()

In [None]:
target_mutation = selected_cryptic_mutations[-1]

In [None]:
def plot3axes(axes, cr_no, vcf_df, isnv_df, data_type, width, total_weeks):
    total_width = 0.8
    means, mins, maxs, stds, prevalences, coverages, total_sample_counts, valid_sample_counts, detected_sample_counts, valid_samples = get_cr_plot_data(selected_cryptic_mutations[cr_no-1], isnv_df)
    
    ax = axes[1]
    positions = np.arange(total_weeks)
    ax.errorbar(positions, 
                means, 
                [mins, maxs], 
                marker='o', markersize=4,
                linestyle='dotted', 
                linewidth=2, alpha=1, color=f'C0', 
                label=f'CR{cr_no} AF')
    ax.legend(loc='upper left', framealpha=0.4)
    ax.set_ylim(0,1.05)
    ax.set_ylabel(f'AF\n({data_type})')
    
    isnv_mean_dict, isnv_min_dict, isnv_max_dict, isnv_std_dict, isnv_prev_dict = get_isnv_plot_data(selected_cryptic_mutations[cr_no-1], vcf_df, valid_samples)
    
    ax = ax.twinx()
    prevelence_barplot = ax.bar(x=positions,
                                height=prevalences,
                                alpha=0.3, color=f'C0',
                                label=f'CR{cr_no} Prevalence')
    ax.legend(loc='upper right', framealpha=0.4)
    ax.set_ylim(0,1.05)
    ax.set_ylabel(f'Prevalence\n({data_type})')
    
    ax = axes[0]
    
    for i, isnv in enumerate(isnv_prev_dict):
        positions = np.arange(total_weeks)-total_width/2+(i)*width+0.5*width
        ax.errorbar(positions, 
                    isnv_mean_dict[isnv], 
                    [isnv_min_dict[isnv], isnv_max_dict[isnv]], 
                    marker='o', markersize=4,
                    linestyle='dotted', 
                    linewidth=2, alpha=1, color=f'C{i+1}',
                    label=f'{isnv} AF')
    ax.legend(loc='upper left', framealpha=0.4)
    ax.set_ylim(0,1.05)
    ax.set_ylabel(f'AF\n({data_type})')

    ax = ax.twinx()
    for i, isnv in enumerate(isnv_prev_dict):    
        # Prevelence Bars
        positions = np.arange(total_weeks)-total_width/2+(i)*width+0.5*width
        prevelence_barplot = ax.bar(x=positions,
                             height=isnv_prev_dict[isnv],
                             alpha=0.5,
                             width=width, color=f'C{i+1}', 
                             label=f"{isnv} Prevalence")
    ax.legend(loc='upper right', framealpha=0.4)
    ax.set_ylim(0,1.05)
    ax.set_ylabel(f'Prevalence\n({data_type})')

    ax = axes[2]
    coverage_barplot = ax.bar(x=np.arange(total_weeks),
                              height=coverages,
                              alpha=0.8, color='black',
                              label=f'Mean Coverage at CR{cr_no} Positions')
    
    ax.legend(loc='upper right', framealpha=0.4)
    ax.set_ylabel(f'Coverage\n({data_type})')

In [None]:
def cr_plot(cr_no, cr_label, save_fig=False):
    fontsize = 12
    fig, axes  = plt.subplots(7, 1, figsize=(15, 10), sharex=True, constrained_layout=True, gridspec_kw={'height_ratios': [2, 2, 1, 2, 2, 1, 1], 'wspace':0.2})

    total_weeks = len(date_labels)
    total_width = 0.8
    width = total_width*1/(len(selected_cryptic_mutations[cr_no-1].split(";")))
    linewidth = 2

    plot3axes(axes[0:3], cr_no, wastewater_vcf_df, wastewater_isnv_df, "Wastewater", width, total_weeks)
    plot3axes(axes[3:6], cr_no, clinical_vcf_df, clinical_isnv_df, "Clinical", width, total_weeks)

    ax = axes[-1]
    positions = np.arange(total_weeks)

    _, _, _, _, _, _, total_sample_counts, valid_sample_counts, detected_sample_counts, _ = get_cr_plot_data(selected_cryptic_mutations[cr_no-1], wastewater_isnv_df)

    ax.bar(positions, total_sample_counts, color=mcolors.CSS4_COLORS['lightgrey'], alpha=1, label='Insufficient Coverage')
    ax.bar(positions, valid_sample_counts, color=cmap(0), label=f'CR{cr_no} Not Detected')
    ax.bar(positions, detected_sample_counts, color=cmap(100), label=f'CR{cr_no} Detected')
    ax.legend()
    xticks = []
    xticklabels = []
    for i in range(total_weeks):
        if i%4 == 0:
            xticks.append(i)
            xticklabels.append(date_labels[i])

    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels, rotation=90)

    ax.set_xlim(-0.6, total_weeks-0.4)
    ax.set_ylim(0, 39)
    ax.set_ylabel('WWTP\nCount')

    for ax in axes[:-1]:
        ax.xaxis.set_tick_params(labelbottom=False)
        ax.tick_params(bottom=False)

    if save_fig:
        fig.savefig(f'/home/Users/yl181/wastewater/quarc_figures/high_dpi/{cr_label}_{cr_no}.png', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")
    else:
        fig.show()

In [None]:
def inner(cr_no, subfigure, fig, fig_label):
    inner = gridspec.GridSpecFromSubplotSpec(10, 1,
                                             height_ratios=[2, 2, 1.5, 1.5, 1.5, 2, 2, 1.5, 0.6, 0.5], 
                                             wspace=0.2,
                                             subplot_spec=subfigure)
    axes = []
    ax0 = fig.add_subplot(inner[0])
    axes.append(ax0)
    for j in range(1,10,1):
        if j != 9:
            ax = fig.add_subplot(inner[j], sharex=ax0)
        else:
            ax = fig.add_subplot(inner[j])
        axes.append(ax)
    
    axes[4].set_visible(False)
    axes[-2].set_visible(False)

    total_weeks = len(date_labels)
    total_width = 0.8
    width = total_width*1/(len(selected_cryptic_mutations[cr_no-1].split(";")))
    linewidth = 2

    plot3axes(axes[0:3], cr_no, wastewater_vcf_df, wastewater_isnv_df, "Wastewater", width, total_weeks)
    plot3axes(axes[5:8], cr_no, clinical_vcf_df, clinical_isnv_df, "Clinical", width, total_weeks)
    
    lineage_plot(clinical_isnv_df, selected_cryptic_mutations[cr_no-1], axes[-1])
    
    ax = axes[3]
    positions = np.arange(total_weeks)

    _, _, _, _, _, _, total_sample_counts, valid_sample_counts, detected_sample_counts, _ = get_cr_plot_data(selected_cryptic_mutations[cr_no-1], wastewater_isnv_df)

    ax.bar(positions, total_sample_counts, color=mcolors.CSS4_COLORS['lightgrey'], alpha=1, label='Insufficient Coverage')
    ax.bar(positions, valid_sample_counts, color=cmap(0), label=f'CR{cr_no} Not Detected')
    ax.bar(positions, detected_sample_counts, color=cmap(100), label=f'CR{cr_no} Detected')
    ax.legend(loc='upper left', framealpha=0.4)
    xticks = []
    xticklabels = []
    for i in range(total_weeks):
        if i%4 == 0:
            xticks.append(i)
            xticklabels.append(date_labels[i])

    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels, rotation=90)

    ax.set_xlim(-0.6, total_weeks-0.4)
    ax.set_ylim(0, 39)
    ax.set_ylabel('WWTP\nCount')

    for idx, ax in enumerate(axes):
        if idx != 3:
            ax.xaxis.set_tick_params(labelbottom=False, bottom=False)
        if idx == 5:
            ax.xaxis.set_tick_params(labelbottom=False, bottom=False, top=True)
            ax.set_xticks(xticks)
            
    for idx, ax in enumerate(axes):
        ax.yaxis.set_label_coords(-0.032, 0.5)
    
    if fig_label != 'x':
        ax0.text(-0.06, 1.05, fig_label, transform=ax0.transAxes,
                fontsize=20, fontweight='bold', va='top')
    else:
        labels = ['a', 'b', 'c', 'd', '', 'e', 'f', 'g', '', 'h']
        for idx, ax in enumerate(axes):
            ax.text(-0.065, 1.05, labels[idx], transform=ax.transAxes,
                fontsize=16, fontweight='bold', va='top')

# Figure 8

In [None]:
fig = plt.figure(figsize=(20, 12))
outer = gridspec.GridSpec(1, 1, wspace=0.15, hspace=0.05)

cr_nos = [8]
fig_labels = ['x']
for i in range(len(cr_nos)):
    inner(cr_nos[i], outer[i], fig, fig_labels[i])
fig.savefig(f'/home/Users/yl181/wastewater/quarc_figures/pdf/Figure_8.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Supplemenatary 3

In [None]:
fig = plt.figure(figsize=(20, 12))
outer = gridspec.GridSpec(1, 1, wspace=0.15, hspace=0.05)

cr_nos = [3]
fig_labels = ['x']
for i in range(len(cr_nos)):
    inner(cr_nos[i], outer[i], fig, fig_labels[i])
fig.savefig(f'/home/Users/yl181/wastewater/quarc_figures/pdf/supp_figure_3.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Supplemenatary 4

In [None]:
fig = plt.figure(figsize=(20, 12))
outer = gridspec.GridSpec(1, 1, wspace=0.15, hspace=0.05)

cr_nos = [5]
fig_labels = ['x']
for i in range(len(cr_nos)):
    inner(cr_nos[i], outer[i], fig, fig_labels[i])
fig.savefig(f'/home/Users/yl181/wastewater/quarc_figures/pdf/supp_figure_4.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Supplemenatary 5

In [None]:
fig = plt.figure(figsize=(20, 12))
outer = gridspec.GridSpec(1, 1, wspace=0.15, hspace=0.05)

cr_nos = [12]
fig_labels = ['x']
for i in range(len(cr_nos)):
    inner(cr_nos[i], outer[i], fig, fig_labels[i])
fig.savefig(f'/home/Users/yl181/wastewater/quarc_figures/pdf/supp_figure_5.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

In [None]:
cr_plot(1, '_', save_fig=False)

# Figure 7 - Clinical Results

In [None]:
def parse_us_data(from_source=False):
    if from_source:
        clinical_metadata_df = pd.read_csv('/home/Users/yl181/wastewater/quarc_clinical_sampling/filtered_us_df.csv', index_col=0)
        clinical_metadata_df['Date'] = pd.to_datetime(clinical_metadata_df['Collection_Date'], format='%Y-%m-%d')
        clinical_metadata_df['Week Start'] = clinical_metadata_df['Date'].dt.to_period('W-SUN').apply(lambda r: r.start_time)

        runidx2weekstart = clinical_metadata_df[['Run Index', 'Week Start']].drop_duplicates().set_index('Run Index')['Week Start'].to_dict()

        runidx2accession = clinical_metadata_df[['Run Index', 'SRA Accession ID']].drop_duplicates().set_index('Run Index')['SRA Accession ID'].to_dict()

        valid_sample_dict = dict()
        records = []
        for run_id in clinical_metadata_df['Run Index'].unique():
            vcf_path = f'/home/Users/yl181/wastewater/quarc_clinical_sampling/Harvest_Variant_Outputs_US/{run_id}_out/vcf_files_filtered/{runidx2accession[run_id]}.vcf'
            if os.path.exists(vcf_path):
                valid_sample_dict[run_id] = True
                for record in vcf.Reader(open(vcf_path, 'r')):
                    ref = str(record.REF)
                    pos = str(record.POS)
                    alt = str(record.ALT[0])

                    if len(ref) == 1 and len(alt) == 1:
                        mut = ref+pos+alt
                        records.append({'Week Start': runidx2weekstart[run_id],
                                        'WWTP': runidx2accession[run_id],
                            'Nt Mutation': mut,
                            'Depth': int(record.INFO['DP']),
                            'AF': float(record.INFO['AF'])})
                records.append({'Week Start': runidx2weekstart[run_id],
                        'WWTP': runidx2accession[run_id],
                        'Nt Mutation': 'Valid_VCF',
                        'Depth': 0,
                        'AF': 0})
            else:
                valid_sample_dict[run_id] = False
        vcf_df = pd.DataFrame(records)

        assert sum(valid_sample_dict.values()) == len(runidx2accession)

        crykey_calls = []
        for idx, row in clinical_metadata_df.iterrows():
            if row['Support DP'] >= 5 and row['Combined Freq'] >= 0.01:   
                crykey_filter = True
            else:
                crykey_filter = False

            if crykey_filter:
                vcf_all_called = True
                for nt in row['Nt Mutations'].split(';'):
                    vcf_all_called = vcf_df[(vcf_df['WWTP'] == row['SRA Accession ID']) & (vcf_df['Nt Mutation'] == nt)].shape[0] * vcf_all_called
                if vcf_all_called:
                    crykey_calls.append(True)
                else:
                    crykey_calls.append(False)
            else:
                crykey_calls.append(False)

        clinical_metadata_df['Crykey'] = crykey_calls

        return clinical_metadata_df
    else:
        clinical_metadata_df = pd.read_csv(os.path.join(fixed_results_dir, 'us_crykey.csv'), index_col=0)
        clinical_metadata_df['Collection_Date'] = pd.to_datetime(clinical_metadata_df['Date'], format='%Y-%m-%d')
        clinical_metadata_df['Week Start'] = pd.to_datetime(clinical_metadata_df['Week Start'], format='%Y-%m-%d')
        clinical_metadata_df = clinical_metadata_df.rename({'SRA Accession ID': 'Site'}, axis=1)
        return clinical_metadata_df

In [None]:
def get_heatmap_data(target_mutation, isnv_df, regions, min_depth=10, start_date=date(2021, 12, 6), end_date=date(2022, 1, 30), exclude=False):
    """
    For CR plots
    """
    start_date = start_date - timedelta(days=start_date.weekday())
    end_date = end_date - timedelta(days=end_date.weekday())
    total_week_count = int((end_date - start_date).days/7) + 1
    week_offset = timedelta(days = 7)
    
    selected_isnv_df = isnv_df[isnv_df['Nt Mutations'] == target_mutation].copy()
    
    if not exclude:
        for region in regions:
            selected_isnv_df = selected_isnv_df[selected_isnv_df['Region'] == region]
    else:
        for region in regions:
            selected_isnv_df = selected_isnv_df[selected_isnv_df['Region'] != region]

    means = []
    mins = []
    maxs = []
    stds = []
    coverages = []
    prevalences = []
    
    total_sample_counts = []
    valid_sample_counts = []
    valid_samples = []
    detected_sample_counts = []
    
    for i in range(0, total_week_count):
        week_start = start_date + i*week_offset 
        
        weekly_selected_isnv_df = selected_isnv_df[selected_isnv_df['Week Start'] == pd.to_datetime(week_start)]
        total_sample_count = weekly_selected_isnv_df.shape[0]
        weekly_selected_isnv_df = weekly_selected_isnv_df[isnv_df['Total DP'] >= min_depth]
        valid_sample_count = weekly_selected_isnv_df.shape[0]
        valid_sample = weekly_selected_isnv_df['Site'].values
        
        afs = list(weekly_selected_isnv_df[(weekly_selected_isnv_df['Crykey'] == True)]['Combined Freq'].values)
        mean = np.nanmean(afs)
        std = np.nanstd(afs)
        valid_sample_count = weekly_selected_isnv_df.shape[0]
        try:
            prevalence = len(afs)/total_sample_count
        except ZeroDivisionError:
            prevalence = np.nan
        coverage = weekly_selected_isnv_df['Total DP'].mean()
        
        means.append(mean)
        try:
            mins.append(abs(np.nanmin(afs)-mean))
            maxs.append(abs(np.nanmax(afs)-mean))
        except ValueError:
            mins.append(np.nan)
            maxs.append(np.nan)
        stds.append(std)
        prevalences.append(prevalence)
        coverages.append(coverage)
        
        total_sample_counts.append(total_sample_count)
        valid_sample_counts.append(valid_sample_count)
        valid_samples.append(valid_sample)
        detected_sample_counts.append(len(afs))
        
    return means, mins, maxs, stds, prevalences, coverages, total_sample_counts, valid_sample_counts, detected_sample_counts, valid_samples

In [None]:
def get_heatmap_array(mut_label, us_clinical_isnv_df):
    states = ['Maryland', 'Massachusetts', 'California', 'Colorado', 'Utah']
    heatmap_array = []

    for state in states:
        means, mins, maxs, stds, prevalences, coverages, total_sample_counts, valid_sample_counts, detected_sample_counts, valid_samples = get_heatmap_data(mut_label, us_clinical_isnv_df, [state])
        heatmap_array.append(prevalences)

    means, mins, maxs, stds, prevalences, coverages, total_sample_counts, valid_sample_counts, detected_sample_counts, valid_samples = get_heatmap_data(mut_label, us_clinical_isnv_df, states, exclude=True)
    heatmap_array.append(prevalences)
    
    return heatmap_array, states

In [None]:
def get_boxplot_data(target_mutation, isnv_df, min_depth=10, start_date=date(2021, 12, 6), end_date=date(2022, 1, 30)):
    """
    For CR plots
    """
    start_date = start_date - timedelta(days=start_date.weekday())
    end_date = end_date - timedelta(days=end_date.weekday())
    total_week_count = int((end_date - start_date).days/7) + 1
    week_offset = timedelta(days = 7)
    
    selected_isnv_df = isnv_df[isnv_df['Nt Mutations'] == target_mutation].copy()

    means = []
    mins = []
    maxs = []
    stds = []
    coverages = []
    prevalences = []
    weekly_afs = []
    
    total_sample_counts = []
    valid_sample_counts = []
    valid_samples = []
    detected_sample_counts = []
    
    for i in range(0, total_week_count):
        week_start = start_date + i*week_offset 
        
        weekly_selected_isnv_df = selected_isnv_df[selected_isnv_df['Week Start'] == pd.to_datetime(week_start)]
        total_sample_count = weekly_selected_isnv_df.shape[0]
        weekly_selected_isnv_df = weekly_selected_isnv_df[isnv_df['Total DP'] >= min_depth]
        valid_sample_count = weekly_selected_isnv_df.shape[0]
        valid_sample = weekly_selected_isnv_df['Site'].values
        
        afs = list(weekly_selected_isnv_df[(weekly_selected_isnv_df['Crykey'] == True)]['Combined Freq'].values)
        weekly_afs.append(afs)
        mean = np.nanmean(afs)
        std = np.nanstd(afs)
        valid_sample_count = weekly_selected_isnv_df.shape[0]
        try:
            prevalence = len(afs)/valid_sample_count
        except ZeroDivisionError:
            prevalence = np.nan
        coverage = weekly_selected_isnv_df['Total DP'].mean()
        
        means.append(mean)
        try:
            mins.append(abs(np.nanmin(afs)-mean))
            maxs.append(abs(np.nanmax(afs)-mean))
        except ValueError:
            mins.append(np.nan)
            maxs.append(np.nan)
        stds.append(std)
        prevalences.append(prevalence)
        coverages.append(coverage)
        
        total_sample_counts.append(total_sample_count)
        valid_sample_counts.append(valid_sample_count)
        valid_samples.append(valid_sample)
        detected_sample_counts.append(len(afs))
        
    return means, mins, maxs, stds, prevalences, coverages, total_sample_counts, valid_sample_counts, detected_sample_counts, valid_samples, weekly_afs

In [None]:
def get_xtick_wastewater_data(target_mutation, merged_df, start_date=date(2021, 12, 6), end_date=date(2022, 1, 30)):
    start_date = start_date - timedelta(days=start_date.weekday())
    end_date = end_date - timedelta(days=end_date.weekday())
    total_week_count = int((end_date - start_date).days/7) + 1
    week_offset = timedelta(days = 7)
    
    ww_df = merged_df[merged_df['Nt Mutations'] == target_mutation]
    
    x_ticklabels = []
    ww_counts = []

    for i in range(0, total_week_count):
        start = start_date + i*week_offset
        end = start_date + (i+1)*week_offset

        ww_count = ww_df[ww_df['Week Start'] == pd.to_datetime(start)].shape[0]
        time_label = start.strftime('%Y-%m-%d')
        x_ticklabels.append(time_label)
        if ww_count > 0:
            ww_counts.append(2)
        else:
            ww_counts.append(np.nan)
    
    x_ticklabels.append(end.strftime('%Y-%m-%d'))
    
    return x_ticklabels, ww_counts

In [None]:
def get_lineage_values(df, target_mutation, state):
    lineage_mutation_df = pd.pivot_table(df, values='Site', index='lineage', columns='Nt Mutations', 
                   aggfunc='count').fillna(0)
    lineage_mutation_df.index.unique()

    group_labels = []
    for idx, _ in lineage_mutation_df.iterrows():
        if idx.startswith('AY.'):
            group_labels.append('Delta')
        elif idx == 'BA.1.15' or idx.startswith('BA.1.15.'):
            group_labels.append('BA.1.15')
        elif idx == 'BA.1.17' or idx.startswith('BA.1.17.'):
            group_labels.append('BA.1.17')
        elif idx == 'BA.1.18' or idx.startswith('BA.1.18.'):
            group_labels.append('BA.1.18')
        elif idx == 'BA.1.20' or idx.startswith('BA.1.20.'):
            group_labels.append('BA.1.20')
        elif idx == 'BA.1.1' or idx.startswith('BA.1.1.'):
            group_labels.append('BA.1.1')
        else:
            group_labels.append('Omicron')
    lineage_mutation_df['Label'] = group_labels
    
    lineage_mutation_df = lineage_mutation_df.groupby(lineage_mutation_df['Label']).sum()
    
    lineage_mutation_df = lineage_mutation_df.rename({target_mutation: state}, axis=1).transpose()
    return lineage_mutation_df
    
    

In [None]:
def get_lineage_plot_data(target_mutation, houston_clinical_isnv_df, us_clinical_isnv_df):
    states = ['Houston', 'Maryland', 'Massachusetts', 'California', 'Colorado', 'Utah'][::-1]
    lineage_array = []
    for state in states:
        if state == 'Houston':
            houston_clinical_isnv_df['lineage'] = houston_clinical_isnv_df['Site'].map(houston_lineage_dict)
            houston_df = houston_clinical_isnv_df[(houston_clinical_isnv_df['Crykey'] == True) & (houston_clinical_isnv_df['Nt Mutations'] == target_mutation)].copy()
            lineage_mutation_df = get_lineage_values(houston_df, target_mutation, state).copy()
        else:
            temp_df = us_clinical_isnv_df[(us_clinical_isnv_df['Region'] == state) & (us_clinical_isnv_df['Crykey'] == True) & (us_clinical_isnv_df['Nt Mutations'] == target_mutation)]
            lineage_mutation_df = get_lineage_values(temp_df, target_mutation, state).copy()
        
        if not lineage_mutation_df.empty:
            lineage_array.append(lineage_mutation_df)
    
    return lineage_array

In [None]:
def clinical_plot(cr_no, cr_label, save_fig=False):
    fig, axes  = plt.subplots(3, 1, figsize=(8, 10), sharex=False, constrained_layout=True,
                             gridspec_kw={'height_ratios': [5, 3, 3], 'hspace':0.35})
    target_mutation = selected_cryptic_mutations[cr_no-1]

    fontsize = 12
    # Barplot Prevalence Rate in Houston

    x_ticklabels, ww_counts = get_xtick_wastewater_data(target_mutation, merged_df)
    _, _, _, _, p_rates, _, _, valid_sample_counts, _, _, freqs = get_boxplot_data(target_mutation, houston_clinical_isnv_df)

    ax = axes[0]
    x = np.arange(8)
    ax.bar(x, ww_counts, color='gray', alpha=0.3, width=1)
    bar = ax.bar(x, p_rates, color='c')

    ax.set_ylim(0,1)
    ax.set_ylabel('Prevalence Rate in\nHouston Samples', fontsize=fontsize)
    ax.set_title(", ".join(target_mutation.split(';')))

    # Boxplot of AF in Houston
    ax_twin = ax.twinx()
    ax_twin.boxplot(freqs, 
                    positions=x,
                    showmeans=False)
    ax_twin.set_ylim(0, 0.6)
    ax_twin.set_ylabel('Intra-host\nAllele Frequency', fontsize=fontsize)

    ax.set_xlim(-0.5,7.5)
    ax.set_xticks(np.arange(9)-0.5)
    ax.set_xticklabels(x_ticklabels, rotation=90, fontsize=12)
    ax.xaxis.set_ticks_position('none')
    ax.set_xlabel(' ', fontsize=fontsize)

    ax.text(-0.42, 0.98, f'CR{cr_no}',
                fontsize=20, fontweight='bold', va='top')

    ax.text(-0.25, 1.1, f'{cr_label}', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')
    # Heatmap

    heatmap_array, state_labels = get_heatmap_array(target_mutation, us_clinical_isnv_df)
    state_labels.append('Other States')

    sns.heatmap(heatmap_array, linewidth=2, 
                ax=axes[1],
                vmin=0,
                vmax=1,
                cmap='flare',
                yticklabels=state_labels,
                xticklabels=[],
                annot=True, fmt=".3f",
                cbar_kws={'label': 'Prevalence Rate\nin non-Texas Samples'},)

    axes[1].figure.axes[-1].yaxis.label.set_size(fontsize)

    axes[1].xaxis.set_ticks_position('none')
    axes[1].yaxis.set_ticks_position('none')
    axes[1].set_yticklabels(state_labels, fontsize=fontsize)
    axes[1].set_xlabel(' ', fontsize=fontsize)

    # bar
    df = pd.concat(get_lineage_plot_data(target_mutation, houston_clinical_isnv_df, us_clinical_isnv_df), axis=0).fillna(0)
    df = df.div(df.sum(axis=1), axis=0)
    df = df[['BA.1.1', 'BA.1.15', 'BA.1.17', 'BA.1.18', 'BA.1.20', 'Omicron']]

    ax = axes[2]
    y = df.index.to_list()
    lineages = df.columns.to_list()

    category_names = ['Delta', 'BA.1.1', 'BA.1.15', 'BA.1.17', 'BA.1.18', 'BA.1.20', 'Omicron']
    category_colors = plt.get_cmap('RdYlGn')(np.linspace(0.15, 0.85, len(category_names)))

    for idx, lineage in enumerate(lineages):
        color_idx = category_names.index(lineage)
        if idx == 0: 
            ax.barh(y, df[lineage].to_list(), align='center', height=.5, label=lineage, edgecolor='black', color=category_colors[color_idx])
            sum_array = df[lineage].to_list()
        else:
            ax.barh(y, df[lineage].to_list(), align='center', height=.5, label=lineage, left=sum_array, edgecolor='black', color=category_colors[color_idx])
            sum_array = list(map(add, sum_array, df[lineage].to_list())) 
    ax.set_yticks(y)
    ax.set_yticklabels(y, fontsize=fontsize)
    ax.set_xlabel('Consensus Pango Lineage of CR Supported Samples', fontsize=fontsize)

    ax.set_xticks([])

    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)

    ax.set_xlim(0,1)
    ax.legend(loc=4, bbox_to_anchor=(1.25, 0, 0, 0))

    if save_fig:
        fig.savefig(f'/home/Users/yl181/wastewater/quarc_figures/pdf/supp_figure_6{cr_label}.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")
    else:
        fig.show()

## Results

In [None]:
houston_clinical_isnv_df = clinical_isnv_df.copy()
us_clinical_isnv_df = parse_us_data()

In [None]:
%%time
us_clinical_isnv_df = filtering_clinical_samples(us_clinical_isnv_df, strand_bias=True)

## Figure

In [None]:
category_names = ['Delta', 'BA.1.1', 'BA.1.15', 'BA.1.17', 'BA.1.18', 'BA.1.20', 'Omicron']
category_colors = plt.get_cmap('RdYlGn')(np.linspace(0.15, 0.85, len(category_names)))
    
fontsize = 14
num = int(len(selected_cryptic_mutations)/2)
fig, axes  = plt.subplots(num, 4, figsize=(12, num*2), sharex=False, constrained_layout=True, gridspec_kw={'width_ratios': [5, 1, 5, 1], 'wspace':0.35})

x = np.arange(8)

plotted_mut = []

for ax_idx, mut_label in enumerate(selected_cryptic_mutations[-num:]):
    ax = axes[ax_idx][2]
    target_mutation = mut_label
    x_ticklabels, ww_counts = get_xtick_wastewater_data(target_mutation, merged_df)
    means, mins, maxs, stds, p_rates, _, _, valid_sample_counts, _, _, freqs = get_boxplot_data(target_mutation, houston_clinical_isnv_df)

    ax.bar(x, ww_counts, color='gray', alpha=0.2, width=1)
    bar = ax.bar(x, p_rates, color='c')
    ax.scatter(x, ww_counts, marker="P", s=60, color='black')
    ax.set_ylim(0,1.3)
    ax.set_title(", ".join(mut_label.split(';')))
    ax.xaxis.set_ticks_position('none')
    plotted_mut.append(mut_label)

    ax.text(0.01, 0.95, f'CR{ax_idx+7}', transform=ax.transAxes,
        fontsize=14, fontweight='bold', va='top')
    ax_twin = ax.twinx()
    ax_twin.errorbar(x, means, [mins, maxs], linestyle='dotted', marker='o', markersize=4)
    ax_twin.set_ylim(0, 0.6)
    ax_twin.set_yticks([0,0.25,0.5])
    ax_twin.set_yticklabels(['0','.25','.50'])
    ax.set_xlim(-0.5,7.5)
    ax.set_xticks(np.arange(9)-0.5)
    ax.set_xticklabels([])
    
    ax2 = axes[ax_idx][3]
    bottom = 0
    df = pd.concat(get_lineage_plot_data(target_mutation, houston_clinical_isnv_df, us_clinical_isnv_df), axis=0).fillna(0)
    df = df.div(df.sum(axis=1), axis=0)
    temp_lineage_dict = df.loc['Houston'].to_dict()
    for idx in temp_lineage_dict:
        color_idx = category_names.index(idx)
        p = ax2.bar(0, temp_lineage_dict[idx], 1, label=idx, bottom=bottom, edgecolor='black', color=category_colors[color_idx])
        bottom += temp_lineage_dict[idx]
        
    
    ax2.set_ylim(0,1)
    ax2.set_xlim(-0.2,0.2)
    ax2.xaxis.set_ticks_position('none')
    ax2.set_yticks([])
    ax2.set_xticklabels([])
    
ax2.set_xticks([0])
ax2.set_xticklabels(['Lineage'], rotation=90, fontsize=12)    
#x_ticklabels.append((sampling_date_start+8*week_offset).strftime('%Y-%m-%d'))
ax.set_xlim(-0.5,7.5)
ax.set_xticks(np.arange(9)-0.5)
ax.set_xticklabels(x_ticklabels, rotation=90, fontsize=12)

patches = []
for j, facecolor in enumerate(category_colors[::-1]):
    patches.append(Patch(facecolor=facecolor, edgecolor='black',
                         label=category_names[::-1][j]))

ax2.legend(handles=patches, loc=4, bbox_to_anchor=(2.75, 0, 0, 0))

for ax_idx, mut_label in enumerate(selected_cryptic_mutations[:num]):
    ax = axes[ax_idx][0]
    target_mutation = mut_label
    x_ticklabels, ww_counts = get_xtick_wastewater_data(target_mutation, merged_df)
    means, mins, maxs, stds, p_rates, _, _, valid_sample_counts, _, _, freqs = get_boxplot_data(target_mutation, houston_clinical_isnv_df)

    ax.bar(x, ww_counts, color='gray', alpha=0.2, width=1)
    bar = ax.bar(x, p_rates, color='c')
    ax.scatter(x, ww_counts, marker="P", s=60, color='black')
    ax.set_ylim(0,1.3)
    ax.set_title(", ".join(mut_label.split(';')))
    ax.xaxis.set_ticks_position('none')
    plotted_mut.append(mut_label)

    ax.text(0.01, 0.95, f'CR{ax_idx+1}', transform=ax.transAxes,
        fontsize=14, fontweight='bold', va='top')
    ax_twin = ax.twinx()
    ax_twin.errorbar(x, means, [mins, maxs], linestyle='dotted', marker='o', markersize=4,)
    ax_twin.set_ylim(0, 0.6)
    ax_twin.set_yticks([0,0.25,0.5])
    ax_twin.set_yticklabels(['0','.25','.50'])
    ax.set_xlim(-0.5,7.5)
    ax.set_xticks(np.arange(9)-0.5)
    ax.set_xticklabels([])
    
    ax2 = axes[ax_idx][1]
    bottom = 0
    df = pd.concat(get_lineage_plot_data(target_mutation, houston_clinical_isnv_df, us_clinical_isnv_df), axis=0).fillna(0)
    df = df.div(df.sum(axis=1), axis=0)
    temp_lineage_dict = df.loc['Houston'].to_dict()
    for idx in temp_lineage_dict:
        color_idx = category_names.index(idx)
        p = ax2.bar(0, temp_lineage_dict[idx], 1, label=idx, bottom=bottom, edgecolor='black', color=category_colors[color_idx])
        bottom += temp_lineage_dict[idx]
    
    ax2.set_ylim(0,1)
    ax2.set_xlim(-0.2,0.2)
    ax2.xaxis.set_ticks_position('none')
    ax2.set_yticks([])
    ax2.set_xticklabels([])
    
ax2.set_xticks([0])
ax2.set_xticklabels(['Lineage'], rotation=90, fontsize=12)
ax.set_xlim(-0.5,7.5)
ax.set_xticks(np.arange(9)-0.5)
ax.set_xticklabels(x_ticklabels, rotation=90, fontsize=12)

fig.text(-0.01, 0.5, 'Prevalence Rate in Clinical Samples', va='center', rotation='vertical', fontsize=12)
fig.text(.815, 0.5, 'Mean Allele Frequency', va='center', rotation='vertical', fontsize=12)

fig.savefig(f'/home/Users/yl181/wastewater/quarc_figures/pdf/Figure_7.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Supplementary Figure 6

In [None]:
def supp_figure_clinical_plot(cr_no, cr_label, fig):
    axes  = fig.subplots(3, 1, sharex=False, constrained_layout=True,
                             gridspec_kw={'height_ratios': [5, 3, 3], 'hspace':0.35})
    target_mutation = selected_cryptic_mutations[cr_no-1]

    fontsize = 12
    # Barplot Prevalence Rate in Houston

    x_ticklabels, ww_counts = get_xtick_wastewater_data(target_mutation, merged_df)
    _, _, _, _, p_rates, _, _, valid_sample_counts, _, _, freqs = get_boxplot_data(target_mutation, houston_clinical_isnv_df)

    ax = axes[0]
    x = np.arange(8)
    ax.bar(x, ww_counts, color='gray', alpha=0.3, width=1)
    bar = ax.bar(x, p_rates, color='c')

    ax.set_ylim(0,1)
    ax.set_ylabel('Prevalence Rate in\nHouston Samples', fontsize=fontsize)
    ax.set_title(", ".join(target_mutation.split(';')))

    # Boxplot of AF in Houston
    ax_twin = ax.twinx()
    ax_twin.boxplot(freqs, 
                    positions=x,
                    showmeans=False)
    ax_twin.set_ylim(0, 0.6)
    ax_twin.set_ylabel('Intra-host\nAllele Frequency', fontsize=fontsize)

    ax.set_xlim(-0.5,7.5)
    ax.set_xticks(np.arange(9)-0.5)
    ax.set_xticklabels(x_ticklabels, rotation=90, fontsize=12)
    ax.xaxis.set_ticks_position('none')
    ax.set_xlabel(' ', fontsize=fontsize)

    ax.text(-0.42, 0.98, f'CR{cr_no}',
                fontsize=20, fontweight='bold', va='top')

    ax.text(-0.25, 1.1, f'{cr_label}', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')
    # Heatmap

    heatmap_array, state_labels = get_heatmap_array(target_mutation, us_clinical_isnv_df)
    state_labels.append('Other States')

    sns.heatmap(heatmap_array, linewidth=2, 
                ax=axes[1],
                vmin=0,
                vmax=1,
                cmap='flare',
                yticklabels=state_labels,
                xticklabels=[],
                annot=True, fmt=".3f",
                cbar_kws={'label': 'Prevalence Rate\nin non-Texas Samples'},)

    axes[1].figure.axes[-1].yaxis.label.set_size(fontsize)

    axes[1].xaxis.set_ticks_position('none')
    axes[1].yaxis.set_ticks_position('none')
    axes[1].set_yticklabels(state_labels, fontsize=fontsize)
    axes[1].set_xlabel(' ', fontsize=fontsize)

    # bar
    df = pd.concat(get_lineage_plot_data(target_mutation, houston_clinical_isnv_df, us_clinical_isnv_df), axis=0).fillna(0)
    df = df.div(df.sum(axis=1), axis=0)
    df = df[['BA.1.1', 'BA.1.15', 'BA.1.17', 'BA.1.18', 'BA.1.20', 'Omicron']]

    ax = axes[2]
    y = df.index.to_list()
    lineages = df.columns.to_list()

    category_names = ['Delta', 'BA.1.1', 'BA.1.15', 'BA.1.17', 'BA.1.18', 'BA.1.20', 'Omicron']
    category_colors = plt.get_cmap('RdYlGn')(np.linspace(0.15, 0.85, len(category_names)))

    for idx, lineage in enumerate(lineages):
        color_idx = category_names.index(lineage)
        if idx == 0: 
            ax.barh(y, df[lineage].to_list(), align='center', height=.5, label=lineage, edgecolor='black', color=category_colors[color_idx])
            sum_array = df[lineage].to_list()
        else:
            ax.barh(y, df[lineage].to_list(), align='center', height=.5, label=lineage, left=sum_array, edgecolor='black', color=category_colors[color_idx])
            sum_array = list(map(add, sum_array, df[lineage].to_list())) 
    ax.set_yticks(y)
    ax.set_yticklabels(y, fontsize=fontsize)
    ax.set_xlabel('Consensus Pango Lineage of CR Supported Samples', fontsize=fontsize)

    ax.set_xticks([])

    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)

    ax.set_xlim(0,1)
    ax.legend(loc=4, bbox_to_anchor=(1.25, 0, 0, 0))

# Supplemenatary 6a

In [None]:
clinical_plot(5, 'a', True)

# Supplemenatary 6b

In [None]:
clinical_plot(8, 'b', True)

# Figure 5 - Duration and Rarity of Cryptic Lineages Found in Houston Wastewater

In [None]:
def create_supp_figure1_data(selected_data, step=100):
    week_count = sorted(selected_data['Present Weeks'].unique())
    site_count_per_week = []
    week_count_height = []

    for week_idx in week_count:
        site_count_per_week.append(selected_data[selected_data['Present Weeks'] == week_idx]['Mean Site Occurance'].values)
        week_count_height.append(len(selected_data[selected_data['Present Weeks'] == week_idx]['Mean Site Occurance'].values))
        
    occ_list = np.arange(0,14,1) * step
    occ_bar_height = []
    for occ_idx in occ_list:
        occ_bar_height.append(selected_data[(selected_data['GISAID Count'] >= occ_idx) & (selected_data['GISAID Count'] < occ_idx+step)].shape[0])

    return week_count, site_count_per_week, week_count_height, occ_list, occ_bar_height

In [None]:
step = 100
week_count, site_count_per_week, week_count_height, occ_list, occ_bar_height = create_supp_figure1_data(filtered_cryptic_df, step)

In [None]:
filtered_cryptic_df.columns

In [None]:
fontsize = 12
fig, axes  = plt.subplots(2, 2, figsize=(9, 9),
                          gridspec_kw={'height_ratios': [3, 1], 'width_ratios': [3, 1], 'hspace':0.08, 'wspace':0.08})

ax = axes[0][0]
scatter = ax.scatter(filtered_cryptic_df['Present Weeks'], filtered_cryptic_df['GISAID Count'],
           s=filtered_cryptic_df['Mean Site Occurance']*40,
           c=filtered_cryptic_df['Mean Allele Freq'],
           cmap='viridis',
           alpha=0.4)

ax.tick_params(axis='both', which='major', labelsize=fontsize)
ax.set_ylabel('Occurrence of CR in GISAID EpiCoV', fontsize=fontsize)
ax.set_ylim(-10,1400)
ax.set_xlim(-2,38)
ax.set_xticklabels([])

legend1 = ax.legend(*scatter.legend_elements(num=7),
                    loc="upper right", title="Mean Allele Freq")
ax.add_artist(legend1)

legend_elements = [Line2D([0], [0], marker='o', color='w', label='2',
                          markerfacecolor='gray', markersize=np.sqrt(80)),
                   Line2D([0], [0], marker='o', color='w', label='4',
                          markerfacecolor='gray', markersize=np.sqrt(160)),
                   Line2D([0], [0], marker='o', color='w', label='6',
                          markerfacecolor='gray', markersize=np.sqrt(240))]
ax.legend(handles=legend_elements, loc='upper center', title="Site Count\nper Week", bbox_to_anchor=(0.64, 1))
                   
ax = axes[1][0]
bars = ax.bar(week_count, week_count_height, log=False)
ax.tick_params(axis='both', which='major', labelsize=fontsize)
ax.set_ylabel('CR Count', fontsize=fontsize)
ax.set_xlabel('Number of Weeks being Detected', fontsize=fontsize)
ax.set_xlim(-2,38)

ax = axes[0][1]
hbars = ax.barh(occ_list, occ_bar_height, align='edge', height=step*0.9, log=False)
ax.set_ylim(-10,1400)
ax.set_xlim(0,500)
ax.set_yticklabels([])
ax.tick_params(axis='both', which='major', labelsize=fontsize)
ax.invert_xaxis()
ax.set_xlabel('CR Count', fontsize=fontsize)

axes[1][1].set_visible(False)

fig.savefig('/home/Users/yl181/wastewater/quarc_figures/pdf/figure_5.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Figure 2 - Performance Benchmark

In [None]:
def create_supp_figure5_boxplot_data(fixed_results_dir):
    """
    Load benchmark results and generate boxplot data
    Benchmark is done by sampling 50 queries from each bin, 
    and bins are defined by the occurance of the query in GISAID
    bins are =0, 1-5, 6-10, 11-25, 26-50, 51-100, 101-250, 251-500, >500
    """
    process_time = []
    with open(os.path.join(fixed_results_dir, 'benchmark_result_50.txt'), 'r') as performance_benchmark_f:
        for line in performance_benchmark_f.readlines():
            process_time_bin = [float(i) for i in line.strip().split(",")]
            process_time.append(process_time_bin)

    mean_query_times = []
    for time_bin in process_time:
        mean_query_times.append(np.mean(time_bin))
    
    return process_time, mean_query_times

In [None]:
def create_supp_figure5_hist_data(gisaid_count_range, mutation_rarity_df):
    """
    Generate Histogram Data by creating bins defined by the occurance of the query in GISAID
    """
    hist_bars = []
    labels = []
    for idx, upper_range in enumerate(gisaid_count_range):
        lower_range = gisaid_count_range[idx-1]
        if lower_range > upper_range:
            lower_range = -1
            labels.append(f"0")
        else:
            labels.append(f"{lower_range+1}-{upper_range}")
        hist_bars.append(len(mutation_rarity_df[(mutation_rarity_df['GISAID Count'] > lower_range)&(mutation_rarity_df['GISAID Count'] <= upper_range)]))
    hist_bars.append(len(mutation_rarity_df[(mutation_rarity_df['GISAID Count'] > gisaid_count_range[-1])]))
    labels.append(f"> {gisaid_count_range[-1]}")

    # cumulative precentage
    total_set_count = sum(hist_bars)

    cum_precentages = []
    for idx, item in enumerate(hist_bars):
        cum_precentages.append(sum(hist_bars[:idx+1])/total_set_count*100)
        
    return hist_bars, labels, total_set_count, cum_precentages

In [None]:
sampling_size = 50
gisaid_count_range = [0, 5, 10, 25, 50, 100, 250, 500]

hist_bars, labels, total_set_count, cum_precentages = create_supp_figure5_hist_data(gisaid_count_range, mutation_rarity_df)
process_time, mean_query_times = create_supp_figure5_boxplot_data(fixed_results_dir)

In [None]:
hist_bars[0]/total_set_count

In [None]:
sum(hist_bars[0:5])/total_set_count

In [None]:
fontsize = 14
fig, axes  = plt.subplots(2, 1, figsize=(12, 9),
                          sharex=False,
                          gridspec_kw={'height_ratios': [3, 2], 'hspace':0.6})

ax = axes[0]
ax.bar(np.arange(len(hist_bars)), hist_bars,
      linewidth = 2,
      edgecolor = 'black')

ax.set_xticks(np.arange(len(hist_bars)))
ax.set_xticklabels(labels, rotation=90)
ax.tick_params(axis='both', which='major', labelsize=fontsize)
ax.set_ylabel('Count of CR candidates', fontsize=fontsize)


ax_twinx = ax.twinx()
ax_twinx.plot(np.arange(len(hist_bars)), cum_precentages, 
              'o-c',
              linewidth=2, markersize=10)
ax_twinx.tick_params(axis='both', which='major', labelsize=fontsize)
ax_twinx.set_ylabel('Cumulative percentage (%)', fontsize=fontsize)
ax_twinx.set_ylim(0,110)

ax.set_title(f'Occurrence distribution in GISAID EpiCoV for CR candidates', fontsize=fontsize)

ax.text(-0.07, 1.08, 'a', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')

ax = axes[1]
bp = ax.boxplot(process_time,
                positions=np.arange(len(hist_bars)),
                showfliers=True,
                meanline=False,
                widths=0.8)

ax.set_xticks(np.arange(len(hist_bars)))
ax.set_xticklabels(labels, rotation=90)
ax.tick_params(axis='both', which='major', labelsize=fontsize)
ax.set_ylabel('Process time (s)', fontsize=fontsize)
ax.set_xlabel(f'Occurrence in GISAID EpiCoV', fontsize=fontsize)
ax.set_title(f'Process time per queries in each bin (sampling size n={sampling_size})', fontsize=fontsize)
ax.set_ylim(0,4)

ax.plot(np.arange(len(hist_bars)), mean_query_times, 
              'o-b',
              linewidth=2, markersize=10)

ax.text(-0.07, 1.08, 'b', transform=ax.transAxes,
            fontsize=20, fontweight='bold', va='top')


fig.tight_layout()

fig.savefig('/home/Users/yl181/wastewater/quarc_figures/pdf/figure_2.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Supplementary Figure 1

In [None]:
houston_clinical_metadata = pd.read_csv(os.path.join(fixed_results_dir, 'PRJNA764181.filtered.csv'))

In [None]:
houston_clinical_metadata['Collection_Date'] = pd.to_datetime(houston_clinical_metadata['Collection_Date'], format='%Y-%m-%d')

In [None]:
sampling_date_start = pd.to_datetime('12/06/2021')
week_offset = timedelta(days = 7)
sample_counts = []
x_ticklabels = []

for i in range(0, 8):
    start = sampling_date_start + i*week_offset
    end = sampling_date_start + (i+1)*week_offset

    weekly_clinical_df = houston_clinical_metadata[(houston_clinical_metadata['Collection_Date'] >= start) & (houston_clinical_metadata['Collection_Date'] < end)]
    weekly_sample_count = weekly_clinical_df.shape[0]
    sample_counts.append(weekly_sample_count)
    time_label = start.strftime('%Y-%m-%d')
    x_ticklabels.append(time_label)
    
x_ticklabels.append((sampling_date_start+8*week_offset).strftime('%Y-%m-%d'))

In [None]:
fontsize = 14
fig, ax  = plt.subplots(1, 1, figsize=(6, 3), sharex=True, constrained_layout=True)

x = np.arange(8)

ax.bar(x, sample_counts, color='black', alpha=1, width=0.8)

ax.set_ylim(0,2000)
ax.set_xlim(-0.5,7.5)
ax.set_xticks(np.arange(9)-0.5)
ax.set_xticklabels(x_ticklabels, rotation=90)

ax.set_title('Houston Clinical Sample Count')

fig.savefig('/home/Users/yl181/wastewater/quarc_figures/pdf/supp_figure_1.pdf', transparent=False, facecolor='white', dpi=300, bbox_inches = "tight")

# Supplementary Figure 2

In [None]:
def plot_supp_figure2():
    non_tx_metadata = pd.read_csv(os.path.join(fixed_results_dir, 'meta_non_tx.tsv'), sep='\t', low_memory=False)
    non_tx_metadata = non_tx_metadata.fillna('None')
    non_tx_metadata = non_tx_metadata[(non_tx_metadata['Region'] != 'Texas') & ~((non_tx_metadata['Country'] == 'USA') & (non_tx_metadata['Region'] == 'None'))]
    non_tx_metadata = non_tx_metadata[(non_tx_metadata['Country'] == 'USA')]

    region_df = pd.DataFrame(pd.pivot_table(non_tx_metadata, values='Repository', index='Region', aggfunc='count'))
    region_df = region_df.rename({'Repository': 'Count'}, axis=1)
    other_count = region_df['Count'].sum() - region_df[region_df['Count'] >= 70]['Count'].sum()
    selected_region_df = region_df[region_df['Count'] >= 70]
    
    fontsize = 14
    fig, ax  = plt.subplots(1, 1, figsize=(6, 4), sharex=True, constrained_layout=True)

    selected_region_df = region_df[region_df['Count'] >= 70]
    x = np.arange(selected_region_df.shape[0] + 1)

    sample_counts = selected_region_df['Count'].to_list() + [other_count]
    x_ticklabels = selected_region_df.index.to_list() + [f'Other {36-12} States']
    ax.bar(x, sample_counts, color='black', alpha=1, width=0.8)

    ax.set_xlim(-0.5,len(sample_counts)-0.5)
    ax.set_xticks(np.arange(len(sample_counts)))
    ax.set_xticklabels(x_ticklabels, rotation=90, fontsize=12)

    fig.savefig('/home/Users/yl181/wastewater/quarc_figures/pdf/supp_figure_2.pdf', dpi=300, facecolor='white', transparent=False)

In [None]:
plot_supp_figure2()