In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.cm as cm
import matplotlib.colors as colors
from utils import setup_plotting_standards, dec_to_date, get_black
import baltic as bt
import geopandas as gpd

setup_plotting_standards()

COLOR = get_black()

### Figure X. Maximum clade credibility tree
Visualize the output MCC tree from our bayesian phylogeographic analysis.

Load the US shapefile. We'll color taxa by their longitude

In [None]:
us = gpd.read_file( snakemake.params.us_map )
us = us[["NAME", "geometry"]]
us = us.loc[~us["NAME"].isin(["Alaska", "Hawaii", "Puerto Rico"])]
us = us.to_crs( "EPSG:2163" )
us["NAME"] = us["NAME"].str.replace( " ", "" )

# Calculate color for each state
us["longitude"] = us["geometry"].centroid.x
cNorm  = colors.Normalize( vmin=us["longitude"].min(), vmax=us["longitude"].max() )
smap_state = cm.ScalarMappable(norm=cNorm, cmap=cm.viridis )
us["color"] = us["longitude"].apply( smap_state.to_rgba )
us["color_hex"] = us["color"].apply( lambda x: "#{:02x}{:02x}{:02x}".format( int(x[0]*255), int(x[1]*255), int(x[2]*255)) )
state_dict = us.set_index("NAME")["color_hex"].to_dict()
state_dict["Other"] = "#A8A8A8"

Load the tree from file. I use baltic (H/t to @evogytis) because dendropy tends to mangle the tree otherwise. Tree was generated using TreeAnnotator, discarding the first 100 trees as the burn-in (see rule `beast_analysis.construct_mcc_tree`).

In [None]:
t = bt.loadNexus( snakemake.input.tree )
t.treeStats()

Plot the tree, labeling taxa by their longitude.

In [None]:
fig, ax = plt.subplots( dpi=200, figsize=(6.5,14.5) )
x_attr = lambda k: mdates.date2num( dec_to_date( k.absoluteTime ) )
c_func = lambda k: state_dict.get( k.name.split( "|" )[2].split( "-" )[0], state_dict["Other"] )
t.plotTree( ax, x_attr=x_attr, colour=state_dict["Other"], linewidth=1 )
t.plotPoints( ax, x_attr=x_attr, size=15, colour=c_func, zorder=100 )
t.plotPoints( ax, x_attr=x_attr, target=lambda k : k.is_node() & (k.traits.get( "posterior", 0) > 0.5), color=COLOR, size=5 )

ax.set_yticks([])
ax.set_yticklabels([])
ax.tick_params( axis="x", bottom=False, which="both", labelbottom=True, rotation=90, labelsize=10 )
ax.xaxis.set_major_locator( mdates.YearLocator() )
ax.xaxis.set_major_formatter( mdates.DateFormatter( '%Y' ) )

ax.grid( which="both", axis="x", linewidth=1, color="#F1F1F1", zorder=1 )
[ax.spines[loc].set_visible(False) for loc in ax.spines]
ax.set_ylim(-20,t.ySpan+5)

plt.tight_layout()
plt.savefig( snakemake.output.tree_figure )
plt.show()

Plot the map legend

In [None]:
plt.figure( dpi=200, figsize=(5,4) )
ax = plt.subplot()
us.plot( ax=ax, color=us["color_hex"], zorder=1, edgecolor="white", linewidth=1 )
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
[ax.spines[j].set_visible(False) for j in ax.spines]

plt.tight_layout()
plt.savefig( snakemake.output.tree_legend )
plt.show()