# Install dependencies
### (might need to reload Jupyter for widgets to show properly)

In [None]:
!pip install py3Dmol
!pip install ipywidgets

## Helper functions

In [None]:
import py3Dmol
import json
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

labels = [
    "0: Nanobody",
    "1: Heavy Chain + Light Chain",
    "2: Riboswitch + nucleoside precursor",
    "3: tRNA",
    "4: Enzyme catalytic domain + Nanobody",
    "5: Enzyme + small molecule inhibitor",
    "6: Bacterial enzyme",
    "7: Nanobody + Protein",
    "8: Fab + antibody-binding epitope"
]

# Display
def display_complex_cif(file_path, width=800, height=600):
    with open(file_path, 'r') as f:
        cif_data = f.read()
    
    view = py3Dmol.view(width=width, height=height)
    view.addModel(cif_data, "cif")

    view.setStyle({'chain': 'A0'}, {'cartoon': {'color': '#f5ae4c'}})
    view.setStyle({'chain': 'B0'}, {'cartoon': {'color': '#58db9a'}})
    view.setStyle({'chain': 'C0'}, {'cartoon': {'color': '#a887de'}})
    
    # Highlight ligands / small molecules
    view.setStyle({'hetflag': True}, {'stick': {'colorscheme': 'greenCarbon'}, 'sphere': {'radius': 0.3}})

    view.zoomTo()
    return view._make_html()

# Widget callback
def on_select(change):
    with viewer_output:
        clear_output(wait=True)
        html = display_complex_cif(label_to_cif[change['new']])
        display(HTML(html))

viewer_output = widgets.Output()

dropdown = widgets.Dropdown(
    options=labels,
    value=labels[0],
    description='Structure:',
    layout=widgets.Layout(width='80%')
)

dropdown.observe(on_select, names='value')

## Load inference results

In [None]:
with open('examples/example.json', 'r') as f:
    jobs = json.load(f)

# Construct labels and file paths
results = [job['name'] for job in jobs]

# Map labels to file paths
label_to_cif = {
    label: f'scripts/outputs/example-1234/{name}/seed_1234/predictions/{name}_seed_1234_sample_0.cif'
    for label, name in zip(labels, results)
}

## Display results

In [None]:
# 0: Nanobody
# 1: Heavy Chain + Light Chain
# 2: Riboswitch + nucleoside precursor
# 3: tRNA
# 4: Enzyme catalytic domain + Nanobody
# 5: Enzyme + small molecule inhibitor
# 6: Bacterial enzyme
# 7: Nanobody + Protein
# 8: Fab + antibody-binding epitope

display(dropdown)
with viewer_output:
    html = display_complex_cif(label_to_cif[labels[0]])
    display(HTML(html))
display(viewer_output)
# If you see a warning "'3Dmol.js failed to load", 
# just select another structure then come back to the first one.