# Figuring Out How to Plot in 3D

This is the notebook I used to organize my thoughts and figure out how to make a 3D scatterplot of the coalescent.

Sections:
1. Import & Simplify
2. 2D Scatterplot
3. How is data stored?
4. Experimenting with Creating Arrays
5. Making lists of parents and children
6. 3D Scatterplot

## Import & Simplify

In [1]:
import pyslim, tskit, msprime
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import SVG, display

ts = pyslim.load("/Users/ARIADNA/Desktop/Attempt3.trees")

extant = ts.individuals_alive_at(0)

extant_nodes = []
for i in extant:
    extant_nodes.extend(ts.individual(i).nodes)
    
sts = ts.simplify(extant_nodes)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/ARIADNA/Desktop/Attempt3.trees'

This simplifies the tree sequence to just the extant nodes in Generation 10 (the last generation), and all of their ancestor nodes. 

NOTE: In my SLiM code, I saved every individual from Generation 1 through 10, which later allows me to get the individual for each node in the tree sequence, and that individual's location.

Not mutating this time around.

In [None]:
print (sts)

from IPython.display import SVG
SVG(sts.draw_svg())

## 
## 2D Scatterplot

In [None]:
%matplotlib inline
## ^ if you use this 'magic' thingy, your plot will print in your notebook but it won't be interactive (which is more for 3D plots)

gen9 = ts.individuals_alive_at(1)
gen9_locs = ts.individual_locations[gen9, : ]

x = gen9_locs[:,0]
y = gen9_locs [:,1]

plt.scatter(x,y, color = 'pink')

plt.title("Generation 9")
plt.show()

^This code could be used to plot the locations of individuals in each generation (either as different plots, or all in one)

## 
## How is data stored?

The following code chunks are just me figuring out how data is stored, and what it looks like by printing it. Could be useful to look through.

In [None]:
final_nodes = sts.nodes()

for node in sts.nodes():
    t = node.time
    print (int(t))

In [None]:
for individual in sts.individuals():
    coordinates = individual.location
    print (coordinates)

In [None]:
for node in final_nodes:
    print (node)

In [None]:
for node in sts.nodes():
    folk = node.individual
    print (folk)

##
## Experimenting with Creating Arrays

In [None]:
##Create array of individuals

ind_array = []

for node in sts.nodes():
    ind_array.append(node.individual)
    
print (ind_array)

In [None]:
## Create array of node times

time_array = []

for node in sts.nodes():
    time_array.append(int(node.time))

print (time_array)

In [None]:
## This iterates through every individual in the simplified tree sequence and prints their info.

for individual in sts.individuals():
    print(individual)

In [None]:
for individual in sts.individuals():
    
    if individual.id in ind_array:
        print ((individual.location))

In [None]:
for individual in sts.individuals():
    x = individual.location[0]
    y = individual.location[1]
    plt.scatter(x,y, color = 'pink')
    
plt.title("All Individuals")
plt.show()

In [None]:
for edge in sts.edges():
    print(edge)
    
## Look at the Edge Table for clarity.

Parent is node 67, child is node 64.

## 
## Making lists of parents and children

In [None]:
parents = []
children = []

for edge in sts.edges():
    parents.append(edge.parent)
    children.append(edge.child)

print (parents)

print (children)

In [None]:
print (len(parents))

print (len(children))

Now to find which individuals these nodes belong to...

In [None]:
## take each parent node and run it through the nodes list to find the individual to which it belongs
p_ind = []

for i in parents:
    for node in sts.nodes():
        if i == node.id:
            p_ind.append(node.individual)
            
print (p_ind)

In [None]:
ch_ind = []

for i in children:
    for node in sts.nodes():
        if i == node.id:
            ch_ind.append(node.individual)
            
print(ch_ind)

## 
## 3D Scatterplot

### Node Array

In [None]:
p_ch_nodes = []

for edge in sts.edges():
    p_ch_nodes.append(edge.parent)
    p_ch_nodes.append(edge.child)

print (p_ch_nodes)
len(p_ch_nodes)

# each parent node is followed by its child node

### Individual Array

In [None]:
p_ch_inds = []

for i in p_ch_nodes:
    for node in sts.nodes():
        if i == node.id:
            p_ch_inds.append(node.individual)
            
print(p_ch_inds)
len(p_ch_inds)

# each parent individual is followed by its child individual

### [Parent, Child] Array

In [None]:
np.column_stack((p_ind, ch_ind))

##this is an array of [parent, child]

## Plot

In [None]:
%matplotlib widget
fig = plt.figure()
ax = plt.axes(projection='3d')

p_ch_x = []
p_ch_y = []
p_ch_z = []

for i in p_ch_inds:
    p_ch_x.append(sts.individual_locations[i][0])
    p_ch_y.append(sts.individual_locations[i][1])
    p_ch_z.append(sts.individual_times[i])

ax.scatter3D(p_ch_x,p_ch_y,p_ch_z, color = 'r')

for i in range(0, len(p_ch_inds), 2):
    ax.plot(p_ch_x[i:i+2], p_ch_y[i:i+2], p_ch_z[i:i+2], linewidth = 0.25, color = 'black')
    ax.text(p_ch_x[i],p_ch_y[i],p_ch_z[i],  '%s' % (str(p_ch_nodes[i])), size=10, zorder=1, color='black')
    
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('Generation')
##ax.grid(False)

plt.show()

In [None]:
len(p_ch_x)