<a href="https://colab.research.google.com/github/theosanderson/SGDTimeTree/blob/main/TimeTree_working_well.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!unzip metadata.zip
#!unzip named.tree_global.nwk.zip

In [2]:
#!pip install Bio

In [3]:
from Bio import Phylo
import jax.numpy as jnp
import numpy as onp
import pandas as pd
import tqdm
import gzip
import jax

In [4]:
#!unzip metadata.zip
#!unzip named.tree_global.nwk.zip

In [5]:
tree = Phylo.read("named.tree_global.nwk","newick")

substitutions_per_site_per_year = 1e-3
metadata = pd.read_table("metadata.tsv")
import datetime
lookup = {}
for i,row in tqdm.tqdm(metadata.iterrows()):
    try:
        lookup[row['strain']] = datetime.datetime.strptime(row['date'],'%Y-%m-%d')
    except:
        pass
del metadata

  if self.run_code(code, result):
506964it [00:51, 9870.41it/s]


In [6]:
def target_dates():
    output = {}
    for terminal in tree.root.get_terminals():
        if terminal.name in lookup:
            date = lookup[terminal.name]
            diff = (date - lookup['Wuhan/WH04/2020']).days
            output[terminal.name] = diff
    return output

In [7]:
the_target_dates = target_dates()
terminal_names = the_target_dates.keys()
terminal_targets = [float(the_target_dates[x]) for x in terminal_names]
terminal_targets_array = jnp.asarray(terminal_targets)

In [8]:
terminal_name_to_pos = {x:i for i,x in enumerate(terminal_names)}

In [9]:
def assign_paths(tree):
  for node in tqdm.tqdm(tree.get_nonterminals()):
      for clade in node.clades:
          if node==tree.root:
              clade.path=[node]
          else:
              clade.path = node.path+[node,]
assign_paths(tree)

100%|██████████| 275440/275440 [00:03<00:00, 76758.06it/s]


In [10]:
substitutions_per_site_per_year = 1e-3
initial_branch_lengths = {}
for i,node in enumerate(tree.root.find_clades()):
    initial_branch_lengths[node.name]=365*node.branch_length/substitutions_per_site_per_year
names_init = initial_branch_lengths.keys()
values_init = [initial_branch_lengths[x] for x in names_init]
name_to_pos = {x:i for i,x in enumerate(names_init)}

In [11]:
branch_lengths_array = jnp.array(values_init)

In [12]:
rows = []
cols = []


for i,node in enumerate(tree.root.get_terminals()):

    if node.name in terminal_name_to_pos:
      for item in node.path + [node,]:
        rows.append(terminal_name_to_pos[node.name])
        cols.append(name_to_pos[item.name])


In [None]:
rows = jnp.asarray(rows)
cols = jnp.asarray(cols)

In [None]:
@jax.partial(jax.jit, static_argnums=(2))
def sp_matmul(A, B, shape):
    """
    Arguments:
        A: (N, M) sparse matrix represented as a tuple (indexes, values)
        B: (M,K) dense matrix
        shape: value of N
    Returns:
        (N, K) dense matrix
    """
    assert B.ndim == 2
    indexes, values = A
    rows, cols = indexes
    in_ = B.take(cols, axis=0)
    prod = in_*values[:, None]
    res = jax.ops.segment_sum(prod, rows, shape)
    return res

In [None]:
num= len(terminal_name_to_pos)

@jax.jit
def calc_dates(branch_lengths_array):
  A= ((rows,cols),jnp.ones_like(cols))
  B=branch_lengths_array.reshape((640476,1))
  calc_dates = sp_matmul(A,B,num).squeeze()
  return calc_dates



@jax.jit
def get_loss(branch_lengths_array):
  calced_dates =calc_dates(branch_lengths_array)
  loss= jnp.sum((terminal_targets_array-calced_dates)**2)*10
  loss+= jnp.sum((initial_branch_lengths_array-branch_lengths_array)**2)


  # Alternative idea for 'self calibration':
  #ratio = jnp.sum(branch_lengths_array) / jnp.sum(initial_branch_lengths_array)
  #loss+= 100 * jnp.var(initial_branch_lengths_array*ratio - branch_lengths_array )

  
  # Penalise negative branch lengths
  loss += 70*jnp.sum( jnp.maximum(float(-2), -branch_lengths_array)**2  )
  return loss


grad_get_loss = jax.jit(jax.grad(get_loss))


In [None]:
cur_branch_lengths_array = jnp.array(branch_lengths_array)
initial_branch_lengths_array = branch_lengths_array


In [None]:
#8805934.0
21025694000.0
7669196300.0
7156918000.0
6866454000.0
6667853300.0
6519100400.0
6400666600.0
6302144000.0
6217548000.0

In [None]:
from jax.experimental import optimizers
opt_init, opt_update, get_params = optimizers.rmsprop_momentum(1e-3)
opt_state = opt_init(cur_branch_lengths_array)

def step(step, opt_state):
  value, grads = jax.value_and_grad(get_loss)(get_params(opt_state))
  jitted_update = jax.jit(opt_update)
  opt_state = jitted_update(step, grads, opt_state)
  return value, opt_state



In [None]:
for i in range(10000000):
  value, opt_state = step(i, opt_state)
  if(i%100==0):
    print(value)

In [None]:

params = get_params(opt_state)

In [None]:
calc_dates(params)

In [None]:
terminal_targets[-5:]

In [None]:
plt.scatter(calc_dates(params),terminal_targets,alpha=0.002)

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.hist(params[params<80],bins=150)

In [None]:
params.min()

In [None]:
plt.scatter( initial_branch_lengths_array,params,alpha=0.002)