# Week 13: Implementing the UPGMA algorithm

In [1]:
import numpy as np

Some functions that we might want to use: 

In [2]:
def LCS(seq1, seq2):
    
    # base case 1
    if seq1 == "":
        return 0
    
    # base case 2
    elif seq2 == "":
        return 0
    
    # make a match while we can
    elif seq1[0] == seq2[0]:
        return 1 + LCS(seq1[1:], seq2[1:])
    
    # recurse on the two possible sub-problems
    else:
        lose1 = LCS(seq1, seq2[1:])
        lose2 = LCS(seq1[1:], seq2)
        
        return max(lose1, lose2)   

In [3]:
def differences(seq1, seq2):
    
    count = 0
    
    for i in range(len(seq1)):
        if seq1[i] != seq2[i]:
            count += 1
            
    return count

The sequences we want to compare:

In [4]:
human =   "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHF.DLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"
chimp =   "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHF.DLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"
gorilla = ".VLSPADKTNVKAAWGKVGAHAGDYGAEALERMFLSFPTTKTYFPHF.DLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"
cow =     "MVLSAADKGNVKAAWGKVGGHAAEYGAEALERMFLSFPTTKTYFPHF.DLSHGSAQVKGHGAKVAAALTKAVEHLDDLPGALSELSDLHAHKLRVDPVNFKLLSHSLLVTLASHLPSDFTPAVHASLDKFLANVSTVLTSKYR"
horse =   "MVLSAADKTNVKAAWSKVGGHAGEYGAEALERMFLGFPTTKTYFPHF.DLSHGSAQVKAHGKKVGDALTLAVGHLDDLPGALSNLSDLHAHKLRVDPVNFKLLSHCLLSTLAVHLPNDFTPAVHASLDKFLSSVSTVLTSKYR"  
donkey =  "MVLSAADKTNVKAAWSKVGGNAGEFGAEALERMFLGFPTTKTYFPHF.DLSHGSAQVKAHGKKVGDALTLAVGHLDDLPGALSNLSDLHAHKLRVDPVNFKLLSHCLLSTLAVHLPNDFTPAVHASLDKFLSTVSTVLTSKYR" 
rabbit =  ".VLSPADKTNIKTAWEKIGSHGGEYGAEAVERMFLGFPTTKTYFPHF.DFTHGSZQIKAHGKKVSEALTKAVGHLDDLPGALSTLSDLHAHKLRVDPVNFKLLSHCLLVTLANHHPSEFTPAVHASLDKFLANVSTVLTSKYR"
carp =    "MSLSDKDKAAVKGLWAKISPKADDIGAEALGRMLTVYPQTKTYFAHWADLSPGSGPVKKHGKVIMGAVGDAVSKIDDLVGGLAALSELHAFKLRVDPANFKILAHNVIVVIGMLYPGDFPPEVHMSVDKFFQNLALALSEKYR"

In [5]:
animals = [carp, cow, donkey, horse, human, gorilla, rabbit]
nodes = ["carp", "cow", "donkey", "horse", "human", "gorilla", "rabbit"]

## Step 1: Make a dictionary of differences

Note that I've modified this code to include each combination only once. That is,`('human', 'gorilla')` appears in the dictionary, but `('gorilla', 'human')` does not.

In [6]:
diff_dict = {}

for i in range(len(animals)):
    for j in range(i, len(animals)):
        if i != j: 
            diff = differences(animals[i], animals[j])
            diff_dict[(nodes[i], nodes[j])] = diff
        
diff_dict

{('carp', 'cow'): 70,
 ('carp', 'donkey'): 71,
 ('carp', 'horse'): 71,
 ('carp', 'human'): 72,
 ('carp', 'gorilla'): 72,
 ('carp', 'rabbit'): 76,
 ('cow', 'donkey'): 19,
 ('cow', 'horse'): 17,
 ('cow', 'human'): 17,
 ('cow', 'gorilla'): 19,
 ('cow', 'rabbit'): 26,
 ('donkey', 'horse'): 3,
 ('donkey', 'human'): 20,
 ('donkey', 'gorilla'): 22,
 ('donkey', 'rabbit'): 26,
 ('horse', 'human'): 17,
 ('horse', 'gorilla'): 19,
 ('horse', 'rabbit'): 24,
 ('human', 'gorilla'): 2,
 ('human', 'rabbit'): 26,
 ('gorilla', 'rabbit'): 26}

## Step 2: Write a function that finds the two closest related organisms (i.e. a pair group)


Your function should return a tuple with the most closely related organisms from the dictionary.

In [7]:
def pair_group(diff_dict):
    """ given a matrix of differences, returns the indices of the closest two related organisms"""

    min_diff = np.inf
    
    for key in diff_dict.keys():
        if diff_dict[key] < min_diff:
            min_diff = diff_dict[key]
            min_key = key
            
    return min_key

If this function works correctly, it should return the pair group `('human', 'gorilla')`.

In [8]:
pg = pair_group(diff_dict)
pg

('human', 'gorilla')

## Step 3: With a pair-group in hand, let's update our differences dictionary and nodes

Below, we will work through each step needed to update our dictionary and list of nodes accounting for the fact that we are now making the human/gorilla pair a node. 

### 3a: Remove members of the pair group from the list of nodes

We no longer want to consider the nodes `human` or `gorilla` since we have now determined that these two nodes are most closely related. So we will remove these nodes from the node list, leaving just `['carp', 'cow', 'donkey', 'horse', 'rabbit']`.

In [10]:
# remove the nodes of the pair group from the nodes list
nodes.remove(pg[0])
nodes.remove(pg[1])

# check out list of nodes
nodes

['carp', 'cow', 'donkey', 'horse', 'rabbit']

### 3b: Update the distances in the dictionary relative to this new node

For each of the remaining organisms (carp, cow, donkey, horse, rabbit), we want to find their distance to this new `('human', 'gorilla')` node and add this to the dictionary. This will amount to averaging the existing distance to the human and the distance to gorilla for the given organism. *Note that a more careful treatment will use the **weighted** average based on the size of the two nodes, but for now we will just take the average directly.*

A wrinkle here is that you might not know if the relevant entry in the dictionary is `('gorilla', 'rabbit')` or `('rabbit', 'gorilla')` for example. So you'll want to search for both keys to find the distance between rabbit and gorilla. You can ask the boolean question `('gorilla', 'rabbit') in diff_dict` to see if a certain key exists in the dictionary.

If you've done this part correctly, you should see the following entries at the bottom of your dictionary now:

`('carp', ('human', 'gorilla')): 72.0,`

`('cow', ('human', 'gorilla')): 18.0,`

`('donkey', ('human', 'gorilla')): 21.0,`

`('horse', ('human', 'gorilla')): 18.0,`

`('rabbit', ('human', 'gorilla')): 26.0`


In [None]:
# for the remaining node in the nodes list, update their distance to be weighted average of the two other nodes in the pair group
for node in nodes:
    
    # find distance between node and first member of pair group
    if (node, pg[0]) in diff_dict:
        dist1 = diff_dict[(node,pg[0])]
    else:
        dist1 = diff_dict[(pg[0],node)]
     
    # find distance between node and second member of pair group
    if (node, pg[1]) in diff_dict:
        dist2 = diff_dict[(node,pg[1])]
    else:
        dist2 = diff_dict[(pg[1],node)]
        
    # add average distance to the newly created node
    diff_dict[(node, pg)] = (dist1 + dist2) / 2
    
# check out updated dictionary
diff_dict

{('carp', 'cow'): 70,
 ('carp', 'donkey'): 71,
 ('carp', 'horse'): 71,
 ('carp', 'human'): 72,
 ('carp', 'gorilla'): 72,
 ('carp', 'rabbit'): 76,
 ('cow', 'donkey'): 19,
 ('cow', 'horse'): 17,
 ('cow', 'human'): 17,
 ('cow', 'gorilla'): 19,
 ('cow', 'rabbit'): 26,
 ('donkey', 'horse'): 3,
 ('donkey', 'human'): 20,
 ('donkey', 'gorilla'): 22,
 ('donkey', 'rabbit'): 26,
 ('horse', 'human'): 17,
 ('horse', 'gorilla'): 19,
 ('horse', 'rabbit'): 24,
 ('human', 'rabbit'): 26,
 ('gorilla', 'rabbit'): 26,
 ('carp', ('human', 'gorilla')): 72.0,
 ('cow', ('human', 'gorilla')): 18.0,
 ('donkey', ('human', 'gorilla')): 21.0,
 ('horse', ('human', 'gorilla')): 18.0,
 ('rabbit', ('human', 'gorilla')): 26.0}

### 3c: Remove the original members of the newly created pair group from the dictionary

Now that we've calculated the new distances, we want to remove all references to just `human` (i.e. `pg[0]`) and just `gorilla` (i.e. `pg[1]`) from the dictionary. That is, now that we've created a new pair group, we no longer want any references human or gorilla, as we are moving forward with the human/gorilla unit as a node..

Again another wrinkle here is that you can't edit the dictionary as you are looping through it, so I recommend making a list of keys to remove as you loop through the dictionary. Once you have this list established, you can then remove these keys in turn. 

To make sure you've written your code correctly, look at your `diff_dict` and make sure that there are no entries remaining that contain `'human'` or `'gorilla'` as nodes.

In [21]:
keys_to_remove = []

# find keys to remove
for key in diff_dict.keys():
    if pg[0] in key or pg[1] in key:
        keys_to_remove.append(key)
        
# remove keys
for key in keys_to_remove:
    diff_dict.pop(key)

# checkout updated dictioanry
diff_dict

{('carp', 'cow'): 70,
 ('carp', 'donkey'): 71,
 ('carp', 'horse'): 71,
 ('carp', 'rabbit'): 76,
 ('cow', 'donkey'): 19,
 ('cow', 'horse'): 17,
 ('cow', 'rabbit'): 26,
 ('donkey', 'horse'): 3,
 ('donkey', 'rabbit'): 26,
 ('horse', 'rabbit'): 24,
 ('carp', ('human', 'gorilla')): 72.0,
 ('cow', ('human', 'gorilla')): 18.0,
 ('donkey', ('human', 'gorilla')): 21.0,
 ('horse', ('human', 'gorilla')): 18.0,
 ('rabbit', ('human', 'gorilla')): 26.0}

### 3d: Add pair group to the nodes list

Lastly, we want to add the new pair group, `('human', 'gorilla')`, to our list of nodes, so that we can proceed with treating the human/gorilla node as a unit. This is just a single line that I've written for you below:

In [22]:
# add new pair group to the nodes
nodes.append(pg)

nodes

['carp', 'cow', 'donkey', 'horse', 'rabbit', ('human', 'gorilla')]

## Part 4: Writing an `update_diff_dict` function

Above, we've worked through one full round of finding a new pair group and updating the nodes and dictionary. Let's put this all in a function so that we can repeatedly do this process to build up our phylogenetic tree, layer by layer. The function will take in a `diff_dict`, a list of `nodes`, and the pair group (`pg`) to incorporate. It will return an updated version of the `diff_dict` and `nodes` as a result of incorporating the pair group.

In [23]:
def update_diff_dict(diff_dict, nodes, pg):
    
    # remove the nodes of the pair group from the nodes list
    nodes.remove(pg[0])
    nodes.remove(pg[1])
    
    # for the remaining node in the nodes list, update their distance to be weighted average of the two other nodes in the pair group
    for node in nodes:

        # find distance between node and first member of pair group
        if (node, pg[0]) in diff_dict:
            dist1 = diff_dict[(node,pg[0])]
        else:
            dist1 = diff_dict[(pg[0],node)]

        # find distance between node and second member of pair group
        if (node, pg[1]) in diff_dict:
            dist2 = diff_dict[(node,pg[1])]
        else:
            dist2 = diff_dict[(pg[1],node)]

        # add average distance to the newly created node
        diff_dict[(node, pg)] = (dist1 + dist2) / 2
        
    keys_to_remove = []

    # find keys to remove
    for key in diff_dict.keys():
        if pg[0] in key or pg[1] in key:
            keys_to_remove.append(key)

    # remove keys
    for key in keys_to_remove:
        diff_dict.pop(key)
        
    # add new pair group to the nodes
    nodes.append(pg)
    
    # return updated version 
    return diff_dict, nodes
    
    

To check that this works, I reinstantiate the original nodes list and dictionary below, and run this function to incorporate the `('human', 'gorilla')` node. Look at the outputs of the `update_diff_dict` function: have the dictionary and list of nodes been update correctly?

In [28]:
nodes = ["carp", "cow", "donkey", "horse", "human", "gorilla", "rabbit"]

# remake the original dictioary
diff_dict = {}
for i in range(len(animals)):
    for j in range(i, len(animals)):
        if i != j: 
            diff = differences(animals[i], animals[j])
            diff_dict[(nodes[i], nodes[j])] = diff
            
# testing the update function
update_diff_dict(diff_dict, nodes, ('human', 'gorilla'))

({('carp', 'cow'): 70,
  ('carp', 'donkey'): 71,
  ('carp', 'horse'): 71,
  ('carp', 'rabbit'): 76,
  ('cow', 'donkey'): 19,
  ('cow', 'horse'): 17,
  ('cow', 'rabbit'): 26,
  ('donkey', 'horse'): 3,
  ('donkey', 'rabbit'): 26,
  ('horse', 'rabbit'): 24,
  ('carp', ('human', 'gorilla')): 72.0,
  ('cow', ('human', 'gorilla')): 18.0,
  ('donkey', ('human', 'gorilla')): 21.0,
  ('horse', ('human', 'gorilla')): 18.0,
  ('rabbit', ('human', 'gorilla')): 26.0},
 ['carp', 'cow', 'donkey', 'horse', 'rabbit', ('human', 'gorilla')])

## Part 5: Iterating to make a tree!

We'll use the following approach to repeatedly build up our tree:

- Use our `pair_group` function to find the two most closely related nodes
- With the pair group in hand, use our `update_diff_dict` function to get updated versions of our dictionary and our list of nodes
- Repeat until the `nodes` list has a length of 1, meaning that all organisms have been added to the tree

In [29]:
while len(nodes) > 1:
    pg = pair_group(diff_dict)
    diff_dict, nodes = update_diff_dict(diff_dict, nodes, pg)

## Part 6: intepreting the results

Take a look at your final `nodes` list. How does the tree your created match what we saw in class (Lesson 22)? Specifically, does the cow get placed correctly?

In [30]:
nodes

[('carp', ('rabbit', (('donkey', 'horse'), ('cow', ('human', 'gorilla')))))]

## Extra credit: Taking an weighted average

The way we've written the code so far is not fully correct as we should take a *weighted average* when updating the differences, weighting by the number of organisms in the given nodes. This is actually a fairly difficult problem since a node like `('cow', ('human', 'gorilla')` has three organisms in it, but has a length of two. To figure out the true number of organisms in a given node, you may want to implement a *recursive approach* to repeatedly peel back all the layers. 

For extra credit, implement these changes to correctly update the differences based on a weighted average. 