In [1]:
import numpy as np
from scipy import stats

import plotly.offline as py
import plotly.graph_objs as go

In [2]:
# Plot layout
layout = go.Layout(
	scene = dict(
    	camera = dict(
        	up=dict(x=0, y=0, z=1),
        	center=dict(x=0, y=0, z=0),
        	eye=dict(x=1.8, y=0, z=0.1)
    	),
    	xaxis = dict(
        	title='',
        	showgrid=False,
        	zeroline=False,
        	mirror=False,
        	showline=False,
        	ticks='',
        	showticklabels=False
    	),
    	yaxis = dict(
        	title='',
        	showgrid=False,
        	zeroline=False,
        	mirror=False,
        	showline=False,
        	ticks='',
        	showticklabels=False
    	),
    	zaxis = dict(
        	title='',
        	showgrid=False,
        	zeroline=False,
        	showline=False,
        	ticks='',
        	showticklabels=False
    	)
	)
)

In [68]:
n = 100

r = 10
x = np.linspace(-1.8, 1.8, n)
y = np.linspace(-1.8, 1.8, n)

X, Y = np.meshgrid(x, y)

XY = np.empty((n * n, 2))
XY[:, 0] = X.flatten()
XY[:, 1] = Y.flatten()

#------------------------------------------------------------------------

# Z_native
cov = np.array([[0.08, 0.002],
               [0.005, 0.08]])
dist = stats.multivariate_normal(np.array([0.5, -1.3]), cov)
Z_native = dist.pdf(XY).reshape((n, n)) * 0.8

cov = np.eye(2) * 0.2
dist = stats.multivariate_normal(np.array([-0.6, -0.6]), cov)
Z_native += dist.pdf(XY).reshape((n, n)) * 0.3

cov = np.array([[0.05, 0.002],
               [0.003, 0.05]])
dist = stats.multivariate_normal(np.array([-0.5, -1.0]), cov)
Z_native += dist.pdf(XY).reshape((n, n)) * 0.15

#-------------------------------------------------------------------------

# Z_new
cov = np.eye(2) * 0.1
dist = stats.multivariate_normal(np.array([-1.2, 1.2]), cov)
Z_new = dist.pdf(XY).reshape((n, n)) * 1

cov = np.eye(2) * 0.1
dist = stats.multivariate_normal(np.array([-0.3, 0.7]), cov)
Z_new += dist.pdf(XY).reshape((n, n)) * 0.6

cov = np.eye(2) * 0.10
dist = stats.multivariate_normal(np.array([0.52, 0.34]), cov)
Z_new += dist.pdf(XY).reshape((n, n)) * 0.4

cov = np.eye(2) * 0.06
dist = stats.multivariate_normal(np.array([0.5, -0.4]), cov)
Z_new += dist.pdf(XY).reshape((n, n)) * 0.12

cov = np.eye(2) * 0.09
dist = stats.multivariate_normal(np.array([1.0, -0.8]), cov)
Z_new += dist.pdf(XY).reshape((n, n)) * 0.15

# If spatially separated
#Z_native += 2

In [78]:
# Some simple black to color gradients:

# Blue used in native protein: #3273BB
blues = [[0, '#111118'], [0.02, '#111133'], [1.0, '#0090FF']]
# Red used in evolved protein: ##A01F1D
reds = [[0, '#181111'], [0.02, '#331111'], [1.0, '#B01F1D']]

#greens = [[0, '#111811'], [0.02, '#113311'], [1.0, '#70CC00']]
#oranges = [[0, '#222222'], [0.02, '#666666'], [1.0, '#DC6600']]


# Check the plot
data = [
    go.Surface(z=Z_native, opacity=1, colorscale=blues),
    go.Surface(z=Z_new, opacity=1, colorscale=reds)
]

fig = go.Figure(data=data, layout=layout)

py.plot(fig, filename='final_landscape.html')

'file:///Users/Patrick/git/fitness_landscape/final_landscape.html'

In [85]:
# Try to find intersection

intersect = np.empty_like(Z_native)

for x in range(len(Z_native[0])):
    for value in Z_native[x, :]:
        if value >= 0.01:
            thresh_up = value * 1.05
            thresh_low = value * 0.95
            
            if (value <= thresh_up) and (value >= thresh_low):
                intersect[x] = value # this does the whole column... need to adjust
            else:
                intersect[x] = 0
                
intersect

array([[1.37592766e-02, 1.37592766e-02, 1.37592766e-02, ...,
        1.37592766e-02, 1.37592766e-02, 1.37592766e-02],
       [1.25330637e-02, 1.25330637e-02, 1.25330637e-02, ...,
        1.25330637e-02, 1.25330637e-02, 1.25330637e-02],
       [1.10664385e-02, 1.10664385e-02, 1.10664385e-02, ...,
        1.10664385e-02, 1.10664385e-02, 1.10664385e-02],
       ...,
       [8.58793242e-09, 1.06465303e-08, 1.31116154e-08, ...,
        4.13795835e-13, 2.70136239e-13, 1.75189557e-13],
       [5.60641641e-09, 6.95032043e-09, 8.55958943e-09, ...,
        2.70136239e-13, 1.76351672e-13, 1.14368111e-13],
       [3.63588984e-09, 4.50744246e-09, 5.55109039e-09, ...,
        1.75189557e-13, 1.14368111e-13, 7.41703472e-14]])

In [87]:
# test
test = [[0, '#000000'], [0.001, '#DC6600'], [1.0, '#DC6600']]

data = [
    go.Surface(z=Z_native, opacity=0.5, colorscale=blues),
    go.Surface(z=Z_new, opacity=0.5, colorscale=reds),
    go.Surface(z=intersect, opacity=1, colorscale=test)
]

fig = go.Figure(data=data, layout=layout)

py.plot(fig, filename='test.html')

'file:///Users/Patrick/git/fitness_landscape/test.html'

In [None]:
# For screenshot: 400, 260; 620, 570