In [1]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Cassiopeia, install this version: https://github.com/pjb7687/Cassiopeia
# The orignal version is available at https://cassiopeia-lineage.readthedocs.io/en/latest/
from cassiopeia.TreeSolver.Node import Node
from cassiopeia.TreeSolver.lineage_solver import lineage_solver as ls

# ForceAtlas2, available at https://pypi.org/project/ForceAtlas2
from fa2 import ForceAtlas2

  import pandas.util.testing as pdt


In [None]:
# Assuming that the "gisaid_cov2020_sequences.fasta" file exists at the working directory.
# According to the EULA, this file should be downloaded directly from GISAID website.

# FAMSA, available at https://github.com/refresh-bio/FAMSA
# GPU acceleration doesn't seem to work, but it is still quite fast
!famsa-gpu-1.3.2-linux gisaid_cov2020_sequences.fasta aligned.fasta

FAMSA (Fast and Accurate Multiple Sequence Alignment) ver. 1.3.2 CPU and GPU
  by S. Deorowicz, A. Debudaj-Grabysz, A. Gudys (2020-02-21)



In [None]:
with open("aligned.fasta") as f, open("aligned_nospaces.fasta", "w") as fo:
    for line in f:
        if line[0] == '>':
            line = line.replace(" ", "_")
        fo.write(line)

In [None]:
# SNP-sites, available at https://github.com/sanger-pathogens/snp-sites
!snp-sites -v -c -o variants.vcf aligned_nospaces.fasta

In [None]:
var_vecs = []
var_ids = []
with open("variants.vcf") as f:
    for line in f:
        if line[0] == "#":
            if line[1] != "#":
                headers = np.array(line.strip().replace("|", "@").split('\t')[9:]) # bar sign can confuse Cassiopeia
            continue
        entries = line.strip().split('\t')
        if "," in entries[4]:
            print("Found multiple SNVs at the same site!")
            continue # ignore multiple SNVs (so far this has never happened yet)
        var_ids.append("_".join([entries[1], entries[3], entries[4]]))
        var_vecs.append([int(i) for i in entries[9:]])
var_mat = np.array(var_vecs).T

In [None]:
# Dedupicate the same strains
vecs_stringified = np.array(['|'.join(map(str, v)) for v in var_mat])
uniq_vecs, uniq_indices = np.unique(vecs_stringified, return_index = True)
uniq_headers = ['\n'.join(headers[vs == vecs_stringified]) for vs in np.unique(vecs_stringified)]

var_mat_uniq = var_mat[uniq_indices, :]

In [None]:
nodes = [Node(str(node_id), node_vec) for node_id, node_vec in enumerate(var_mat_uniq)]
tree = ls.solve_lineage_instance(nodes, method="ilp")[0]
net = tree.get_network().to_undirected()

In [None]:
unknown_idx = 1
for n in net:
    try:
        n.name = "Node_" + str(list(vecs_stringified[uniq_indices]).index(n.name.split('_')[0]))
    except:
        n.name = "Unknown_" + str(unknown_idx)
        unknown_idx += 1

In [None]:
plt.figure(figsize=[20, 20])
nx.draw(net, with_labels=True)
pass

In [None]:
forceatlas2 = ForceAtlas2(
                        # Behavior alternatives
                        outboundAttractionDistribution=True,  # Dissuade hubs
                        linLogMode=False,  # NOT IMPLEMENTED
                        adjustSizes=False,  # Prevent overlap (NOT IMPLEMENTED)
                        edgeWeightInfluence=1.0,

                        # Performance
                        jitterTolerance=1.0,  # Tolerance
                        barnesHutOptimize=True,
                        barnesHutTheta=1.2,
                        multiThreaded=False,  # NOT IMPLEMENTED

                        # Tuning
                        scalingRatio=0.5,
                        strongGravityMode=False,
                        gravity=2,

                        # Log
                        verbose=True)

In [None]:
pos_fa2 = forceatlas2.forceatlas2_networkx_layout(net, pos=nx.spring_layout(net), iterations=10000)

In [None]:
plt.figure(figsize=[20, 20])
nx.draw(net, pos=pos_fa2, with_labels=True)
pass

In [None]:
#!wget https://datahub.io/core/world-cities/r/world-cities.csv
import pandas as pd
class div_to_country:
    def __init__(self):
        self.df = pd.read_csv("world-cities.csv").dropna()
        self.countries = np.unique(self.df['country'])
        self.remaining_dict = {
            "NetherlandsL": "Netherlands",
            "United States": "USA",
            "Korea": "South Korea",
            "Tianmen": "China",
            "Wuhan-Hu-1": "China",
            "England": "United Kingdom",
            "Wales": "United Kingdom",
        }
    def __getitem__(self, div):
        if div in self.countries:
            return div
        
        if div in self.remaining_dict:
            return self.remaining_dict[div]
        
        for i, city in enumerate(self.df['name']):
            if div == city:
                return self.df['country'][i]
            
        for i, subcountry in enumerate(self.df['subcountry']):
            if div == subcountry or div + ' Sheng' == subcountry:
                return self.df['country'][i]

        return div

In [None]:
# Export to Json for visualization with vis.js
import json
from datetime import datetime

dtc = div_to_country()
# print(list(np.unique([h.split('/')[1].replace('_' , ' ') + "/" + dtc[h.split('/')[1].replace('_' , ' ')] for h in headers])))
nodes = []
unique_countries = list(np.unique([dtc[h.split('/')[1].replace('_' , ' ')] for h in headers]))
node_details = {}
node_countries = {}
for k in net.nodes:
    if not "Unknown" in k.name:
        node_details[k.name] = []
        node_countries[k.name] = []
        for strain in uniq_headers[int(k.name.split("_")[-1])].split("\n"):
            entries = strain.split("@")
            node_details[k.name].append(entries)
            node_countries[k.name].append(dtc[entries[0].split('/')[1].replace('_' , ' ')])
    nodes.append({"id": k.name})
for k in node_details:
    node_details[k] = sorted(node_details[k], key=lambda e: (e[0], e[2]))
for k in node_countries:
    node_countries[k] = list(np.unique(node_countries[k]))
edges = [{"from": i.name, "to": j.name} for i, j in net.edges]

current_datetime = datetime.now().strftime("%d %b %Y, %H:%M (CET)")
s = """
var updated_datetime = "%s";

var unique_countries = %s;

var node_countries = %s;

var nodes = new vis.DataSet(%s);

var edges = new vis.DataSet(%s);

var node_details = %s;
"""%(current_datetime, json.dumps(unique_countries), json.dumps(node_countries), json.dumps(nodes), json.dumps(edges), json.dumps(node_details))
with open("network_data.js", "w") as f:
    f.write(s)

In [None]:
!git pull
!git add covid19.ipynb network-visjs.html network_data.js
!git commit -m "update data"
!git push