<a href="https://colab.research.google.com/github/pandemic-tracking/viz-gen/blob/main/omicron2_seqtimeline_asof20211216.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np

from datetime import datetime, timedelta, date
import pytz

import altair as alt
from altair import datum
alt.data_transformers.disable_max_rows()

from pathlib import Path

pd.set_option("display.precision", 4)

now_est = datetime.now().astimezone(pytz.timezone("US/Eastern"))

now_est_time = now_est.strftime("%Y-%m-%d, %H:%M:%S ET")
now_est_date = now_est.strftime("%Y-%m-%d")
now_est_timestamp = now_est.strftime("%Y%m%d_%H%M%S")
now_utc_timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
print(now_est_time, now_est_date, now_est_timestamp, now_utc_timestamp)

In [None]:
# adapting from https://towardsdatascience.com/consistently-beautiful-visualizations-with-altair-themes-c7f9f889602

def ptc_theme():
    axisColor = "#808080"
    gridColor = "#DEDDDD"
    markColor = "#000000"
    font = 'Arial'
    labelFont = 'Arial'
    # Colors
    # main_palette = ["#1696d2", 
    #                 "#d2d2d2",
    #                 "#000000", 
    #                 "#fdbf11", 
    #                 "#ec008b", 
    #                 "#55b748", 
    #                 "#5c5859", 
    #                 "#db2b27", 
    #                ]
    # sequential_palette = ["#cfe8f3", 
    #                       "#a2d4ec", 
    #                       "#73bfe2", 
    #                       "#46abdb", 
    #                       "#1696d2", 
    #                       "#12719e", 
    #                      ]
    return {
          "config": {
              "title": {
                  "fontSize": 20,
                  "font": font,
                  "anchor": "start", # equivalent of left-aligned.
                  "fontColor": "#000000",
                  "fontWeight": "bold",
              },
              "text": {
                  "font": font,
                  "labelFont": labelFont,
              },
              "header": {
                  "font": font,
                  "labelFont": labelFont,
                  "titleFont": font,
              },
              "axisX": {
                  "domain": False,
                  "domainColor": axisColor,
                  "labelColor": axisColor,
                  "domainWidth": 1,
                  "grid": False,
                  "labelFont": labelFont,
                  "labelFontSize": 5,
                  "labelAngle": 0, 
                  "labelPadding": 5,
                  "tickColor": axisColor,
                  "tickSize": 5, # default, including it just to show you can change it
                  "titleFont": font,
                  "titleFontSize": 12,
                  "titlePadding": 10, # guessing, not specified in styleguide
                  "title": "X Axis Title (units)", 
              },
              "axisY": {
                  "domain": False,
                  "grid": True,
                  "gridColor": gridColor,
                  "gridWidth": 1,
                  "labelFont": labelFont,
                  "labelColor": axisColor,
                  "labelFontSize": 12,
                  "labelAngle": 0,
                  "labelAnchor": "end",
                  "labelAlign": "right",  
                  "ticks": False, # even if you don't have a "domain" you need to turn these off.
                  "titleFont": font,
                  "titleFontSize": 12,
                  "titlePadding": 10, # guessing, not specified in styleguide
                  "title": "Y Axis Title (units)", 
                  # titles are by default vertical left of axis so we need to hack this 
                  "titleAngle": 0, # horizontal
                  "titleY": -10, # move it up
                  "titleX": 18, # move it to the right so it aligns with the labels 
              },
            #   "range": {
            #       "category": main_palette,
            #       "diverging": sequential_palette,
            #   },
              "legend": {
                  "labelFont": labelFont,
                  "labelFontSize": 12,
                  "symbolType": "stroke", # just 'cause
                  "symbolSize": 100, # default
                  "symbolStrokeWidth": 5,
                  "titleFont": font,
                  "titleFontSize": 12,
                  "title": "", # set it to no-title by default
                  "orient": "right", # so it's right next to the y-axis
                  "offset": 30, # literally right next to the y-axis.
              },
              "view": {
                  "stroke": "transparent", # altair uses gridlines to box the area where the data is visualized. This takes that off.
              },
        }
    }

alt.themes.register("my_custom_theme", ptc_theme)
alt.themes.enable("my_custom_theme")

In [None]:
from google.colab import auth
auth.authenticate_user()

import gspread
from oauth2client.client import GoogleCredentials

gc = gspread.authorize(GoogleCredentials.get_application_default())

In [None]:
date_str = '2021_12_16_17'
# df1 = pd.read_csv(f'/content/gisaid_hcov-19_2021_12_11_00.tsv', sep='\t')
# df1['lineage'] = 'BA.1'
# df2 = pd.read_csv(f'/content/gisaid_hcov-19_{date_str}_17_BA2.tsv', sep='\t')
# df2['lineage'] = 'BA.2'
df3 = pd.read_csv(f'/content/gisaid_hcov-19_{date_str}.tsv', sep='\t')
df3['lineage'] = 'B.1.1.529+BA.*'

update_date = 'December 16, 2021 12PM EST'

In [None]:
df = df3#pd.concat([df1, df2, df3])

In [None]:
seq_count = df.shape[0]
df.shape

In [None]:
df.columns

In [None]:
df.Location.value_counts()

In [None]:
df['collect_date'] = pd.to_datetime(df['Collection date'])
df['submit_date'] = pd.to_datetime(df['Submission date'])

In [None]:
def get_weekstartdate(dt_value):
    start = dt_value - timedelta(days=dt_value.weekday())
    return start

def titlecase_location(location_name, exceptions=['and', 'or', 'the', 'a', 'of', 'in', "d'Ivoire"]):
    word_list = [word if word in exceptions else word.capitalize() for word in location_name.split(' ')]
    return ' '.join(word_list)

def correct_location_names(gisaid_df):
    gisaid_df.loc[:,'country'] = gisaid_df['country'].apply(titlecase_location)
    gisaid_df.loc[gisaid_df['country'].fillna('').str.contains('USA', case=False), 'country'] = 'United States'
    gisaid_df.loc[gisaid_df['country'] == 'Puerto Rico', 'country'] = 'United States'
    gisaid_df.loc[gisaid_df['country'] == 'Guam', 'country'] = 'United States'
    gisaid_df.loc[gisaid_df['country'] == 'Northern Mariana Islands', 'country'] = 'United States'
    gisaid_df.loc[gisaid_df['country'] == 'U.s. Virgin Islands', 'country'] = 'United States'
    # gisaid_df.loc[gisaid_df['country'] == 'Czech Republic', 'country'] = 'Czech Republic'
    gisaid_df.loc[gisaid_df['country'] == 'Antigua', 'country'] = 'Antigua and Barbuda'
    gisaid_df.loc[gisaid_df['country'] == 'Democratic Republic of the Congo', 'country'] = 'Democratic Republic of Congo'
    gisaid_df.loc[gisaid_df['country'] == 'Republic of the Congo', 'country'] = 'Congo'
    gisaid_df.loc[gisaid_df['country'] == 'Faroe Islands', 'country'] = 'Faeroe Islands'
    gisaid_df.loc[gisaid_df['country'] == 'Guinea Bissau', 'country'] = 'Guinea-Bissau'
    gisaid_df.loc[gisaid_df['country'] == 'Niogeria', 'country'] = 'Nigeria'
    gisaid_df.loc[gisaid_df['country'] == 'Bosni and Herzegovina', 'country'] = 'Bosnia and Herzegovina'
    gisaid_df.loc[gisaid_df['country'] == 'England', 'country'] = 'United Kingdom'
    gisaid_df.loc[gisaid_df['country'] == 'The Bahamas', 'country'] = 'Bahamas'
    gisaid_df.loc[gisaid_df['country'] == 'Hong Kong', 'country'] = 'Hong Kong SAR (China)'
    gisaid_df.loc[gisaid_df['country'] == 'Reunion', 'country'] = 'Reunion (France)'
    return gisaid_df

def annotate_sequences(gisaid_df):
    gisaid_df['region'] = gisaid_df.Location.apply(lambda x: x.split('/')[0].strip())
    gisaid_df['country'] = gisaid_df.Location.apply(lambda x: x.split('/')[1].strip())
    gisaid_df['division'] = gisaid_df.Location.apply(
        lambda x: x.split('/')[2].strip() if len(x.split('/'))>2 else '')

    # replace 'USA' string with 'United States' etc in location, to match OWID location name
    gisaid_df = correct_location_names(gisaid_df)

    gisaid_df['collect_date'] = pd.to_datetime(gisaid_df['Collection date'])
    gisaid_df['submit_date'] = pd.to_datetime(gisaid_df['Submission date'])

    gisaid_df['lag_days'] = gisaid_df['submit_date'] - gisaid_df['collect_date']
    gisaid_df['lag_days'] = gisaid_df['lag_days'].dt.days.astype('int')

    # using ISO 8601 year and week (Monday as the first day of the week. Week 01 is the week containing Jan 4)
    gisaid_df['collect_yearweek'] = gisaid_df['collect_date'].apply(lambda x: datetime.strftime(x, "%G-W%V"))
    gisaid_df['submit_yearweek'] = gisaid_df['submit_date'].apply(lambda x: datetime.strftime(x, "%G-W%V"))

    gisaid_df['collect_weekstartdate'] = gisaid_df['collect_date'].apply(get_weekstartdate)
    gisaid_df['submit_weekstartdate'] = gisaid_df['submit_date'].apply(get_weekstartdate)

    return gisaid_df

In [None]:
df = annotate_sequences(df)

In [None]:
df.Location.value_counts()

In [None]:
df

In [None]:
today_str = date.today().strftime('%Y-%m-%d')

# full date string, between Dec 2019 and today (no future samples allowed)
def is_legit_date(collection_date):
    return len(collection_date) == 10 and \
        collection_date > '2019-12-01' and collection_date <= today_str

df['date_filter'] = df['Collection date'].apply(is_legit_date)

In [None]:
df['country_count'] = df['country'].apply(lambda x: x+" ("+str(df['country'].value_counts().to_dict()[x])+")")
countries_count = len(df['country'].unique())
countries_count

In [None]:
df[df['Accession ID'].isin(exclude_seqs_list)]['country'].unique()

In [None]:
# exclude Senegal and other anomalous seqs
exclude_seqs_list = ['EPI_ISL_7400617', # Senegal sample colleced 11/9
                # 'EPI_ISL_7543999', # S Africa sample collected 6/17
                'EPI_ISL_7547731', # Nigeria sample collected 10/17
                'EPI_ISL_7605742', # S Africa Eastern Cape sample collected 10/24
]

exclude_seqs_countries = list(df[df['Accession ID'].isin(exclude_seqs_list)]['country'].unique())
print(len(exclude_seqs_list), exclude_seqs_countries)

df = df[~df['Accession ID'].isin(exclude_seqs_list)]

In [None]:
first_ids = df[df['date_filter']==True][['country','Collection date','Submission date','Accession ID']].sort_values(['Collection date','Submission date']).drop_duplicates(subset='country')['Accession ID'].to_list()
first_ids

In [None]:
first_ids_sub = df[df['date_filter']==True][['country','Collection date','Submission date','Accession ID']].sort_values(['Submission date','Collection date']).drop_duplicates(subset='country')['Accession ID'].to_list()
first_ids_sub

In [None]:
first_date_dict = df[df['date_filter']==True][['country_count','Collection date']].sort_values('Collection date').drop_duplicates(subset='country_count').set_index('country_count').to_dict()['Collection date']
first_date_dict

In [None]:
df['country_count_firstdate'] = df[df['date_filter']==True]['country_count'].apply(lambda x: first_date_dict[x].replace('2021-','').replace('-','/')+' - '+x)
df['country_count_firstdate']

In [None]:
sorted_legend = df[df['date_filter']==True][['collect_date','country_count_firstdate']].sort_values('collect_date').drop_duplicates(subset='country_count_firstdate')['country_count_firstdate'].to_list()
sorted_legend

In [None]:
df['country_division'] = df['country'] + ' / ' + df['division']

In [None]:
df['country_division']

In [None]:
sh = gc.open_by_key('12XHSXZiyriSAotB6bBVCn9tbQfUzyMwLvIudXEIg5Pw')
rows = sh.worksheet('Sheet1').get_all_values()
print(rows)
news_df = pd.DataFrame.from_records(rows)

In [None]:
news_df.columns = news_df.iloc[0]
news_df = news_df.iloc[1:]
news_df['collect_date'] = pd.to_datetime(news_df['date'])
news_df['event_order'] = pd.to_numeric(news_df['event_order'])

In [None]:
news_df

In [None]:
news_df.dropna(inplace=True)

In [None]:
combined_df = pd.concat([df[df['date_filter']==True], news_df])

In [None]:
combined_df['filter'] = combined_df['Accession ID'].isna().astype('int')

In [None]:
combined_df.sort_values('collect_date', ascending=True)

In [None]:
combined_df.collect_date.min()

In [None]:
combined_df['first_sample'] = combined_df['Accession ID'].apply(lambda x: int(x in first_ids))
combined_df['first_seqsub'] = combined_df['Accession ID'].apply(lambda x: int(x in first_ids_sub))

In [None]:
combined_df['first_sample']

In [None]:
legend_df = pd.read_csv('https://github.com/covid-tracking-collab/gisaid-variants/raw/main/data/gisaid_owid_country_lineage_cases_2021_07_15_weekly.csv')
continents_dict = legend_df[['gisaid_country','owid_continent']].drop_duplicates().set_index('gisaid_country').to_dict()['owid_continent']

In [None]:
continents_dict['Hong Kong SAR (China)'] = 'Asia'
continents_dict['Czech Republic'] = 'Europe'
continents_dict['Reunion (France)'] = 'Africa'

In [None]:
continents_dict

In [None]:
combined_df['continent'] = combined_df.country.apply(lambda x: continents_dict[x])

In [None]:
combined_df[combined_df['Location'].fillna('').str.contains('Senegal')]

In [None]:
combined_df.sort_values('collect_date', inplace=True)

In [None]:
for country in combined_df['country'].unique():
  combined_df.loc[combined_df['country']==country, 'earliest_date'] =  combined_df.loc[combined_df['country']==country, 'collect_date'].min()
  combined_df.loc[combined_df['country']==country, 'latest_date'] =  combined_df.loc[combined_df['country']==country, 'submit_date'].max()

In [None]:
unplotted_seqs = df[df['date_filter']==False][['country']]
unplotted_seqs

In [None]:
viz_height = 550

base_viz = alt.Chart(combined_df).encode(
    y = alt.Y('country_count_firstdate', axis=None,
                sort = alt.EncodingSortField(
                      field='region',  
                      order="ascending"
              )),
    x = alt.X('collect_date:T', axis=alt.Axis(orient='top', labelFontSize=10, tickMinStep=5, format='%m/%-d'),
                scale = alt.Scale(domain=[combined_df.collect_date.min(),combined_df.submit_date.max()])
              ),
    color=alt.Color('region', scale=alt.Scale(scheme='dark2')),
)

base_points = base_viz.mark_point(filled=True, opacity=0.9).encode(
    size=alt.Size('count(Accession ID)', legend=alt.Legend(direction='horizontal', symbolType='circle'), scale=alt.Scale(domain=[1,500], range=[10,120])),
    color=alt.Color('region', scale=alt.Scale(scheme='dark2'),
                    ),
).transform_filter(alt.datum.filter == 0)

points_rule = base_viz.mark_rule(opacity=1).encode(
    x = 'earliest_date:T',
    x2 = 'latest_date:T',
    size = alt.value(0.5),
).properties(height=viz_height).transform_filter(alt.datum.first_seqsub==1)


points_location_text = base_viz.mark_text(size=10, dx=-15, fontWeight=600, align='right', baseline='middle').encode(
    text = 'country_count',
  ).properties(height=viz_height).transform_filter(alt.datum.filter == 0).transform_filter(alt.datum.first_sample==1)

global_timeline_viz = (base_points+points_rule+points_location_text
                       ).properties(width=1000, height=viz_height,
                                    title={
                                      "text": ["Omicron Sequences Shared Via GISAID Per Country By Sample Collection Date"], 
                                      "subtitle": [f"{countries_count} countries have submitted {seq_count} Omicron sequences via GISAID as of {update_date}",
                                                  f"{unplotted_seqs.shape[0]+len(exclude_seqs_list)} sequences from {', '.join(sorted(set(list(unplotted_seqs.country.unique())+exclude_seqs_countries)))} have incomplete dates or data quality issues and are not shown.",
                                                  ""],
                                      "subtitleFontSize": 14
                                    },
                                   ).configure_axisY(grid=False, domain=False, ticks=False, labels=False).configure_axisX(grid=True, ticks=False)
global_timeline_viz

In [None]:
unplotted_seqs.country.value_counts()

In [None]:
set(list(unplotted_seqs.country.unique())+exclude_seqs_countries)

In [None]:
combined_df[combined_df['date_filter']==True].shape[0] + unplotted_seqs.shape[0] + len(exclude_seqs_list)

In [None]:
first_df = df[df['date_filter']==True][['country_count','Collection date','Submission date']].sort_values('Collection date').drop_duplicates(subset='country_count')#df[['collect_date','country_count']].sort_values('collect_date').drop_duplicates(subset='country_count')
first_df[first_df['Collection date']=='2021-11-26']

In [None]:
combined_df[['Accession ID','Location','Submission date','Collection date']].sort_values('Collection date').head(15)

In [None]:
combined_df[combined_df['submit_date']<='2021-11-23'][['Location','Submission date','Collection date']].sort_values('Submission date').head(15)

# Save to Drive

In [None]:
# this is for saving altair charts to png and svg, based on https://colab.research.google.com/github/altair-viz/altair_saver/blob/master/AltairSaver.ipynb#scrollTo=ZiTDBCAM_Ni8
!pip install -q altair_saver
!npm install --silent vega-lite vega-cli canvas

In [None]:
from pathlib import Path
from altair_saver import save

SAVE_PATH = Path('assets')
SAVE_PATH.mkdir(exist_ok=True)

# Import PyDrive and associated libraries.
# This only needs to be done once per notebook.
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
gdrive = GoogleDrive(gauth)

In [None]:
def assets_to_gdrive(folder_name, localdir_path = SAVE_PATH, parentdir_id='17Kx2uZbQv1r5U1M9x_OXS4lpMU5c6Ym8'):
  # search gdrive for snapshot folder and save assets there if it already exists. 
  folder_id = ''
  file_list = gdrive.ListFile({'q': f"'{parentdir_id}' in parents and mimeType = 'application/vnd.google-apps.folder' and trashed=false"}).GetList()
  for file1 in file_list:
      if file1['title'] == folder_name: 
        folder_id = file1['id']
        print(f'Found pre-existing gdrive folder named "{folder_name}" at',folder_id)
  # if not, create new folder
  if folder_id == '':
    folder = gdrive.CreateFile(metadata={'title': folder_name,
                                      'parents':[{'id': parentdir_id}],
                                      "mimeType": "application/vnd.google-apps.folder"
                                      })
    folder.Upload()
    folder_id = folder.get('id')
    print(f'Created new gdrive folder named "{folder_name}" at',folder_id)
  
  # upload all files within SAVE_PATH to snapshot folder
  for asset_file in localdir_path.iterdir():
    file_path = localdir_path/asset_file.name
    file1 = gdrive.CreateFile(metadata={'title':asset_file.name,
                                        'parents':[{'id': folder_id}],
                                        })
    file1.SetContentFile(file_path)
    file1.Upload()
    print('Saved file: ',asset_file.name)


In [None]:
def save_vizassets(chart, save_path, filename, fmts=['html','png','svg',#'json','pdf'
                                                     ]):
  for fmt in fmts:
    save(chart, f'{save_path}/{filename}.{fmt}')

In [None]:
# put your stuff (i.e. dataframes, altair charts, input data files) to save here

In [None]:
!cp gisaid_hcov-19_2021_12_16_17.tsv assets/

In [None]:
df.to_csv(SAVE_PATH/'df.csv')
combined_df.to_csv(SAVE_PATH/'combined_df.csv')
save_vizassets(global_timeline_viz, SAVE_PATH, f'global_timeline_viz_{now_utc_timestamp}')

In [None]:
# get the colab filename
from requests import get
nb_name = get('http://172.28.0.2:9000/api/sessions').json()[0]['name'].replace('.ipynb','')
nb_id = get('http://172.28.0.2:9000/api/sessions').json()[0]['notebook']['path'].replace('fileId=','')

print(SAVE_PATH, nb_name, now_utc_timestamp, nb_id)

# create a snapshot of this currently running notebook and save to SAVE_PATH
downloaded_nb = gdrive.CreateFile({'id':nb_id})   # replace the id with id of file you want to access
downloaded_nb.GetContentFile(SAVE_PATH/f'{nb_name}_{now_utc_timestamp}.ipynb')

In [None]:
# upload everything to gdrive
assets_to_gdrive(folder_name=f'{nb_name}_{now_utc_timestamp}')