In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

x = np.linspace(-3, 3, 100) # x domain and resolution
y = np.linspace(-3, 3, 100) # y domain and resolution
X, Y = np.meshgrid(x, y) # datapoints along X and Y


items = [
    [2.66,  1.37, 2.62, -0.71, 2.62, -2.44, 2.62, -4.21],
    [4.09, -2.57, 3.91, -1.19, 3.91, -4.12, 3.91, -6.98],
    [1.97, -1.88, 1.64, -0.67, 1.64, -1.89, 1.64, -3.48],
    [1.45, -1.47, 1.05, -0.31, 1.05, -1.20, 1.05, -1.88],
    [4.63, -3.13, 3.62, -1.60, 3.62, -4.02, 3.62, -6.82],
    [2.92, -2.88, 2.24, -1.30, 2.24, -3.00, 2.24, -4.44],
    [1.39, -2.28, 1.05, -0.34, 1.05, -1.18, 1.05, -1.44],
]

# labels for each item
item_labels = [
    'Lack of strength (energy)',
    'An overall sense of fatigue',
    'Sleepiness or drowsiness',
    'Difficulty concentrating',
    'Increased fatigue during the course of the day',
    'Fatigue in the morning when getting up',
    'Difficulty staying attentive',
]

fig = make_subplots(rows=1, cols=2,
                    specs=[[{'is_3d': True}, {'is_3d': True}]],
                    subplot_titles=item_labels,
                    )

for index, (a1, c1, a2, c2, a3, c3, a4, c4) in enumerate(items[0:2]):
    q1 = lambda x: np.exp(-(a1*x + c1))
    q2 = lambda y: np.exp(-(a2*y + c2))
    q3 = lambda y: np.exp(-(a3*y + c3))
    q4 = lambda y: np.exp(-(a4*y + c4))

    p1 = lambda x, y: q1(x) / (1 + q1(x))
    p2 = lambda x, y: q2(y) / ((1 + q1(x)) * (1 + q2(y)))
    p3 = lambda x, y: q3(y) / ((1 + q1(x)) * (1 + q2(y)) * (1 + q3(y)))  
    p4 = lambda x, y: q4(y) / ((1 + q1(x)) * (1 + q2(y)) * (1 + q3(y)) * (1 + q4(y)))
    p5 = lambda x, y: 1 / ((1 + q1(x)) * (1 + q2(y)) * (1 + q3(y)) * (1 + q4(y)))

    row, col = (index // 2 + 1, index % 2 + 1)
    shared_viz = dict(row=row, col=col) # surfaces share the same subplot

    fig.add_trace(go.Surface(z=p1(X, Y), x=X, y=Y, colorscale=["red", "red"]), **shared_viz)
    fig.add_trace(go.Surface(z=p2(X, Y), x=X, y=Y, colorscale=["green", "green"]), **shared_viz)
    fig.add_trace(go.Surface(z=p3(X, Y), x=X, y=Y, colorscale=["blue", "blue"]), **shared_viz)
    fig.add_trace(go.Surface(z=p4(X, Y), x=X, y=Y, colorscale=["magenta", "magenta"]), **shared_viz)
    fig.add_trace(go.Surface(z=p5(X, Y), x=X, y=Y, colorscale=["cyan", "cyan"]), **shared_viz)
  
# remove automatic color bar
fig.update_traces(showscale=False)

# change axis titles on all plots: Working but manual #TODO check for a better way
axis_titles = dict(xaxis_title='Presence', yaxis_title='Severity',zaxis_title='Probability')
fig.update_layout(scene=axis_titles, scene2=axis_titles)

# display figure
fig.show()