In [None]:
import pandas as pd
import matplotlib as mpl
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import networkx as nx
from utils import timeseries_formatting, basic_formatting, dec_to_date, hpd, _toYearFraction

prop = mpl.font_manager.FontProperties('Roboto')
mpl.rcParams['font.sans-serif'] = prop.get_name()
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.weight']=300
mpl.rcParams['axes.labelweight']=300
mpl.rcParams['font.size']=16

COLOR = '#343434'
mpl.rcParams['text.color'] = COLOR
mpl.rcParams['axes.labelcolor'] = COLOR
mpl.rcParams['xtick.color'] = COLOR
mpl.rcParams['ytick.color'] = COLOR
mpl.rcParams["axes.facecolor"] = "white"

cmap = {
    "South" : "#fbb4ae",
    "Midwest" : "#b3cde3",
    "NorthEast" : "#ccebc5",
    "West" : "#decbe4"
}

cmap_unsat = {
    "South" : "#e41a1c",
    "Midwest" : "#377eb8",
    "NorthEast" : "#4daf4a",
    "West" : "#984ea3",
}

### Figure X: Regional differences in West Nile virus transition rate
Using the discrete state reconstruction, we use the estimated transition rates to identify regional clustering of West Nile virus.

First we load in the US map which will be used for plotting purposes.

In [None]:
trait_counts = pd.read_csv( snakemake.input.traits, sep="\t")["state"].value_counts()
trait_counts.head()

regions = {
    "NorthEast": ["Connecticut", "Massachusetts", "RhodeIsland", "NewJersey", "NewYork", "Pennsylvania",
                  "NewHampshire", "Vermont", "Maine" ],
    "South": ["DistrictofColumbia", "Delaware", "Florida", "Louisiana", "Georgia", "Maryland", "NorthCarolina", "SouthCarolina", "Virginia",
              "Alabama", "Kentucky", "WestVirginia", "Mississippi", "Tennessee", "Arkansas", "Oklahoma", "Texas",
              "VirginIslands"],
    "Midwest": ["Illinois", "Michigan", "Ohio", "Wisconsin", "Iowa", "Kansas", "Minnesota", "Missouri",
                "Nebraska", "Indiana", "NorthDakota", "SouthDakota"],
    "West": ["Arizona", "Idaho", "Colorado", "Montana", "Nevada", "NewMexico", "California", "Oregon",
             "Washington", "Utah", "Wyoming",]
}

us = gpd.read_file( snakemake.params.us_map )
us = us[["NAME", "geometry"]]
us["NAME"] = us["NAME"].str.replace( " ", "" )
for reg, states in regions.items():
    us.loc[us["NAME"].isin( states ),"region"] = reg
us = us.loc[~us["region"].isna()]
us = us.to_crs( "EPSG:2163" )

# Identify centroids of each state and region for plotting purposes
us["centroid.x"] = us["geometry"].centroid.x
us["centroid.y"] = us["geometry"].centroid.y
georeg = us.dissolve( "region" )
georeg["centroid.x"] = georeg["geometry"].centroid.x
georeg["centroid.y"] = georeg["geometry"].centroid.y

# Create dictionary mapping each state and region to its centroid. Necessary for networkX plotting
# Manually specify the location of Other, Midwest, and NorthEast so they don't clash with states centroids.
pos_dict = dict()
for _, row in us.iterrows():
    pos_dict[row["NAME"]] = (row["centroid.x"], row["centroid.y"] )
for reg, row in georeg.iterrows():
    pos_dict[reg] = (row["centroid.x"], row["centroid.y"] )
pos_dict["Other"] = (-635174.467, -1611168.826)
pos_dict["Midwest"] =  (259909.909, -252712.408)
pos_dict["NorthEast"] = (2080618.217, -242742.572)

# To indicate states without any sequencing data.
present = pd.read_csv( snakemake.input.traits, sep="\t" )["taxon"].apply( lambda x: x.split( "|" )[2].split( "-" )[0] ).unique()
us["mask"] = 1
us.loc[~us["NAME"].isin(present),"mask"] = 0

us.head()

We next load in the posterior distribution of discrete transitions.

In [None]:
log = pd.read_csv( snakemake.input.log, header=3, sep="\t" )
log = log.loc[log["state"]>snakemake.params.burnin]
log.head()

Accordingly, we mask transition rates that were not included in the likelihood calculate for that particular draw of the MCMC. These are transition rates with an indicator of 0. To determine which rates are significant for the final network, we calculate their Bayes Factor for inclusion.

In [None]:
pairs = log.columns[log.columns.str.startswith( "Location.rates" )].str.slice(15)
for pair in pairs:
    log.loc[~log[f"Location.indicators.{pair}"].astype( bool ),f"Location.rates.{pair}"] = None

rates = log[log.columns[log.columns.str.startswith( "Location.rates" )]]
rates = rates.describe( percentiles=[0.025, 0.5, 0.975])
rates = rates.T.reset_index()
rates[["start", "end"]] = rates["index"].str.slice(15).str.split( ".", n=1, expand=True)
rates = rates.drop( columns=["index", "mean", "std", "min", "max"] )
rates.sort_values( "count", ascending=False  ).head(10)

# Calculate inverse rate for closeness centrality.
rates["inv"] = 1 / rates["50%"]

# Calculate Bayes factor for each transitions rate.
n = np.log(2)
k = len( trait_counts )
qk = (n + k - 1)/(k * (k-1) / 2)
draws = log.shape[0]
rates["bayes_factor"] = ((rates["count"] / draws) * (1-qk)) / ((1 - (rates["count"] / draws)) * qk )
rates.head()

We next generate a network, where regions and states and nodes, and they're connected by edges representing transitions rates. Only transitions rates with a Bayes Factor > {snakemake.params.minimum_BF} are included.

We plot this network on top of the map. This generates Figure X.

In [None]:
g95 = nx.from_pandas_edgelist( rates.loc[rates["bayes_factor"]>snakemake.params.minimum_BF], source="start", target="end", edge_attr=["50%","inv"] )

fig, ax = plt.subplots( dpi=200, figsize=(8,6) )
for reg, shapes in us.groupby( "region" ):
    shapes.plot( color=cmap[reg], edgecolor="white", ax=ax, zorder=1, linewidth=1 )
    shapes.loc[shapes["mask"]==0].plot( color=cmap[reg], edgecolor="white", hatch="//////", linewidth=0.5, zorder=2, ax=ax )

nodes = nx.draw_networkx_nodes( g95, pos_dict, nodelist=[i for i in g95.nodes if i not in cmap], node_size=50, node_color="white", edgecolors="black", linewidths=0.75, ax=ax )
nodes.set_zorder(10)
nodes = nx.draw_networkx_nodes( g95, pos_dict, nodelist=cmap.keys(), node_shape="s", node_size=50, node_color="white", edgecolors="black", linewidths=0.75, ax=ax )
nodes.set_zorder(10)

rad = -0.3
for source, target in g95.edges():
    ax.annotate(
        "",
        xy=pos_dict[source],
        xytext=pos_dict[target],
        zorder=9,
        arrowprops={
            "lw" : g95.edges[(source, target)]["50%"]+1,
            "arrowstyle" : "-",
            "color" : "black",
            "connectionstyle" : f"arc3,rad={rad}",
            "linestyle" : '-',
            "alpha" : 0.75
        }
    )
    rad *= -1

legend = [
    Line2D([0], [0], linestyle='-', marker=None, color="black", label="1", markersize=0, linewidth=2 ),
    Line2D([0], [0], linestyle='-', marker=None, color="black", label="2", markersize=0, linewidth=3 ),
    Line2D([0], [0], linestyle='-', marker=None, color="black", label="3", markersize=0, linewidth=4 ),
    Line2D([0], [0], linestyle='-', marker=None, color="black", label="4", markersize=0, linewidth=5 ),
]

ax.legend(title="Transitions / year", handles=legend, loc="lower left", handletextpad=0.5, frameon=False, title_fontsize=8, fontsize=8, handlelength=1 )
[ax.spines[j].set_visible( False ) for j in ax.spines]

fig.tight_layout()
fig.savefig( snakemake.output.map_figure )
plt.show()

Next we calculate, for each state, their closeness centrality and total transition rate. Closeness centrality is the mean distance between a state and all other states. This metric is useful because the graph is incomplete, and closeness centrality will find the shortest path between each node. Distance in this case is the inverse of the transition rate (estimated number of years between transitions). Total transition rate is just the sum of transition rates including the state that are included in the model. Both these metrics are calculated across the entire posterior distribution.

In [None]:
cc = list()
rs = list()
for row, df in log.iterrows():
    temp = df.reset_index()
    temp.columns = ["var", "value"]
    temp = temp.loc[temp["var"].str.startswith( "Location.rates")]
    temp[["start", "end"]] = temp["var"].str.slice(15).str.split( ".", n=1, expand=True)
    temp = temp.dropna()
    temp["inv"] = 1 / temp["value"]
    tempG = nx.from_pandas_edgelist( temp, source="start", target="end", edge_attr=["value","inv"] )

    temp_rates = dict()
    for node in tempG.nodes():
        temp_rates[node] = sum( tempG.edges[edge]["value"] for edge in tempG.edges( node ) )

    tempCC = nx.closeness_centrality( tempG, distance="inv" )
    cc.append( pd.DataFrame( tempCC, index=[row] ) )
    rs.append( pd.DataFrame( temp_rates, index=[row] ) )

cc = pd.concat( cc )
rs = pd.concat( rs )
cc.head()

First we plot each locations' closeness centrality. This will be figure X.

In [None]:
plot_df = cc.describe( percentiles=[0.025, 0.5, 0.975] ).T
plot_df = plot_df.loc[plot_df.index!="Other"]
plot_df = plot_df.sort_values( "50%" ).reset_index()
for reg, states in regions.items():
    plot_df.loc[plot_df["index"].isin( states ),"region"] = reg
plot_df.loc[plot_df["region"].isna(),"region"] = plot_df.loc[plot_df["region"].isna(),"index"]

fig, ax = plt.subplots( dpi=200, figsize=(4,6) )
for reg, df in plot_df.groupby( "region" ):
    ax.scatter( df["50%"], df.index, color=cmap_unsat[reg], s=70, zorder=10, edgecolor="black", linewidth=1 )
    ln = ax.hlines( df.index, df["2.5%"], df["97.5%"], zorder=5, color=cmap_unsat[reg], linewidth=3, alpha=0.5 )
    ln.set_capstyle( "round" )
ax.set_yticks(plot_df.index)
ax.set_yticklabels( plot_df["index"])

for i in range(1,max(plot_df.index), 2):
    ax.axhspan( i-0.5,i+0.5, color="black", alpha=0.04, edgecolor=None, linewidth=0 )

basic_formatting( ax, spines=["bottom", "left"], which="x", xlabel="Closeness centrality", xlims=(0,1), ylims=(-0.5, plot_df.index.max() + 0.5 ), xsize=10, ysize=10 )

plt.tight_layout()
fig.savefig( snakemake.output.cc_figure )
plt.show()

Next we plot each locations' total transition rate.

In [None]:
plot_df = rs.describe( percentiles=[0.025, 0.5, 0.975] ).T
plot_df = plot_df.loc[plot_df.index!="Other"]
plot_df = plot_df.sort_values( "50%" ).reset_index()
for reg, states in regions.items():
    plot_df.loc[plot_df["index"].isin( states ),"region"] = reg
plot_df.loc[plot_df["region"].isna(),"region"] = plot_df.loc[plot_df["region"].isna(),"index"]

fig, ax = plt.subplots( dpi=200, figsize=(4,6) )
for reg, df in plot_df.groupby( "region" ):
    ax.scatter( df["50%"], df.index, color=cmap_unsat[reg], s=70, zorder=10, edgecolor="black", linewidth=1 )
    ln = ax.hlines( df.index, df["2.5%"], df["97.5%"], zorder=5, color=cmap_unsat[reg], linewidth=3, alpha=0.5 )
    ln.set_capstyle( "round" )
ax.set_yticks(plot_df.index)
ax.set_yticklabels( plot_df["index"])

for i in range(1,max(plot_df.index), 2):
    ax.axhspan( i-0.5,i+0.5, color="black", alpha=0.04, edgecolor=None, linewidth=0 )

basic_formatting( ax, spines=["bottom", "left"], which="x", xlabel="Total transition rate", ylims=(-0.5, plot_df.index.max() + 0.5 ), xsize=10, ysize=10 )
ax.set_xlim(0)

plt.tight_layout()
fig.savefig( snakemake.output.rates_figure )
plt.show()