## Summary

---

## Imports

In [None]:
import codecs
import gzip
import io
import subprocess
import threading
import time
import uuid
from enum import Enum
from pathlib import Path

import ipywidgets as widgets
import numpy as np
from IPython.display import HTML, Javascript, display
from ipywidgets import Layout

from kmbio import PDB

## Parameters

In [None]:
JUPYTER_DATA_DIR = Path(subprocess.check_output(["jupyter", "--data-dir"], universal_newlines=True).strip()).resolve(strict=True)

In [None]:
JUPYTER_DATA_DIR

In [None]:
STATIC_DATA_DIR = JUPYTER_DATA_DIR.joinpath("voila", "templates", "mytemplate", "static").resolve(strict=True)

In [None]:
STATIC_DATA_DIR

## Global variables

In [None]:
reference_sequence = []
target_sequence = []
generated_sequences = []
proteinsolver_thread = None

## Helper functions

In [None]:
AMINO_ACIDS = [
    "-",
    "A",
    "B",
    "C",
    "D",
    "E",
    "F",
    "G",
    "H",
    "I",
    "J",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
    "V",
    "W",
    "X",
    "Y",
    "Z",
]

In [None]:
class ProteinSolverThread(threading.Thread):
    def __init__(self, value, msa_view, progress_bar, download_button):
        super().__init__(daemon=True)
        self.value = value
        self.msa_view = msa_view
        self.msa_view.clear_output()
        self.progress_bar = progress_bar
        self.download_button = download_button
        self._stop_event = threading.Event()

    def run(self):
        self.progress_bar.value = 0
        self.progress_bar.bar_style = ""
        for i in range(self.progress_bar.max):
            if self.stopped():
                self.progress_bar.bar_style = "danger"
                return
            time.sleep(1)
            self.msa_view.append_stdout(self.value + "\n")
            self.progress_bar.value += 1
        self.progress_bar.bar_style = "success"
        enable_download_button(self.download_button)

    def stop(self):
        self._stop_event.set()

    def stopped(self):
        return self._stop_event.is_set()

In [None]:
def generate_random_sequence(length=80, seed=None):
    amino_acids = np.array(list("GVALICMFWPDESTYQNKRH"))
    if seed is None:
        choice = np.random.choice
    else:
        choice = np.random.RandomState(seed).choice
    return "".join(choice(amino_acids, length))

In [None]:
generate_random_sequence(80, 42)

In [None]:
def sequences_to_fasta(sequences, line_width=80):
    sequence_string = ""
    for sequence in sequences:
        sequence_string += f">{sequence['id']}|{sequence['name']}|{sequence['proba']}\n"
        for start in range(0, len(sequence["seq"]), line_width):
            sequence_string += sequence["seq"][start : start + line_width] + "\n"
    return sequence_string

In [None]:
print(sequences_to_fasta([{"id": 1, "name": "reference", "proba": 1.0, "seq": generate_random_sequence(160, 42)}]))

In [None]:
def populate_generated_sequences():
    global generated_sequences
    
    sequences = []
    for i in range(20_000):
        sequence = {"id": i + 1, "name": f"gen-{i:05d}", "proba": 1.0, "seq": generate_random_sequence(162)}
        sequences.append(sequence)

In [None]:
populate_generated_sequences()

In [None]:
def save_sequences():
    sequences = [
        {"id": 1, "name": "reference", "proba": None, "seq": "".join(reference_sequence)},
        {"id": 2, "name": "target", "proba": None, "seq": "".join(target_sequence)},
        *generated_sequences,
    ]
    sequences_fasta = sequences_to_fasta(sequences)
    sequences_fasta_gz = gzip.compress(sequences_fasta.encode("utf-8"))

    output_file = STATIC_DATA_DIR.joinpath(f"{uuid.uuid4()}.fasta.gz")
    with output_file.open("wb") as fout:
        fout.write(sequences_fasta_gz)

    return output_file

In [None]:
out = widgets.Output()

In [None]:
out

In [None]:
structure = None


def handle_upload(change):
    with out:
        # keep only the last file
        # TODO: check if this should be fixed in FileUpload widget
        # when multiple=False
        last_item = list(change["new"].values())[-1]

        filename = last_item["metadata"]["name"]
        structure_id = filename.split(".")[0]
        suffix = filename.split(".")[-1]

        data = codecs.decode(last_item["content"], encoding="utf-8")
        buf = io.StringIO()
        buf.write(data)
        buf.seek(0)
        parser = PDB.get_parser(suffix)
        structure = parser.get_structure(buf, structure_id=structure_id)


uploader = widgets.FileUpload(accept=".pdb,.cif,.mmcif", multiple=False)
uploader.observe(handle_upload, names="value")

In [None]:
uploader

In [None]:
show_uploader()

In [None]:
def show_examples():
    example_folder = "./examples"
    examples = [f for f in os.listdir(example_folder) if f.endswith('.gpx')]
    
    def create_example(name):
        filename = os.path.join(example_folder, name)
        
        @out.capture()
        def on_example_clicked(change):
            out.clear_output()
            with open(filename) as f:
                with out:
                    plot_gpx(f)
    
        button = Button(description=os.path.splitext(name)[0])
        button.on_click(on_example_clicked)
        return button

    
    buttons = [create_example(example) for example in examples]
    line = HBox(buttons, layout=Layout(flex_flow='row', align_items='center'))
    display(line)

In [None]:
reference_sequence = list(generate_random_sequence(120))
target_sequence = ["-"] * len(reference_sequence)

## CSS

In [None]:
%%html
<style>
.myheading {
    font-size: large;
    margin-bottom: 1rem
}

.mytext > .widget-label {
    font-family: monospace;
    font-size: small;
    width: 60px;
}

.mysequence > .widget-label {
    line-height: 1rem;
}
</style>

## Widgets

### Target sequence preference

In [None]:
reference_sequence_ta = widgets.Textarea(
    value="".join(reference_sequence),
    placeholder="AAAAA...",
    description="<em>Reference</em><br>sequence:",
    disabled=True,
    layout=widgets.Layout(width="auto"),
)
_ = reference_sequence_ta.add_class("mysequence")

In [None]:
target_sequence_ta = widgets.Textarea(
    value="".join(target_sequence),
    placeholder="AAAAA...",
    description="<em>Target</em><br>sequence:",
    disabled=True,
    layout=widgets.Layout(width="auto"),
)
_ = target_sequence_ta.add_class("mysequence")

In [None]:
def update_target_sequence(change):
    residue_idx = int(change["owner"].description.split(" ")[0])
    target_sequence[residue_idx] = change["new"]
    target_sequence_ta.value = "".join(target_sequence)


target_sequence_selections = [
    widgets.Dropdown(
        options=AMINO_ACIDS,
        value=aa_target,
        description=f"{i} ({aa_ref})",
#                 style={},
        layout=widgets.Layout(width="120px"),
        style={"font_family": "monospace", "font_weight": "bold"},
    )
    for i, (aa_ref, aa_target) in enumerate(zip(reference_sequence, target_sequence))
]
for button in target_sequence_selections:
    button.observe(update_target_sequence, names="value")
    button.add_class("mytext")

In [None]:
target_sequence_selection_box = widgets.HBox(
    target_sequence_selections,
    layout=widgets.Layout(width="100%", flex_direction="row", flex_wrap="wrap", flex_flow="row wrap"),
)

In [None]:
widgets.VBox(
    [
        target_sequence_selection_box,
        widgets.VBox([reference_sequence_ta, target_sequence_ta], layout=widgets.Layout(margin="20px 0px 0px 0px")),
    ]
)

### MSA view

In [None]:
progress_bar = widgets.IntProgress(
    value=0,
    min=0,
    max=100,
    step=1,
    bar_style="",  # 'success', 'info', 'warning', 'danger' or ''
    orientation="horizontal",
    layout=widgets.Layout(width="auto", height="10px"),
)

In [None]:
msa_alignment_view = widgets.Output(layout=Layout(width="auto"))

In [None]:
widgets.VBox([progress_bar, msa_alignment_view])

### Run ProteinSolver

In [None]:
number_of_sequences_input = widgets.BoundedIntText(
    value=100,
    min=1,
    max=20_000,
    step=1,
    description="Number of sequences:",
    disabled=False,
    style={"description_width": "initial"},
    layout=widgets.Layout(width="200px"),
)

In [None]:
class State(Enum):
    GENERATE: str = "Run ProteinSolver!"
    CANCEL: str = "Cancel"


button_meta = {
    State.GENERATE: {"icon": "check", "button_style": "", "tooltip": "Generate new sequences!"},
    State.CANCEL: {"icon": "ban", "button_style": "danger", "tooltip": "Cancel!"},
}


def disable_download_button(b):
    b.description = State.CANCEL.value
    b.icon = button_meta[State.CANCEL]["icon"]
    b.button_style = button_meta[State.CANCEL]["button_style"]
    b.tooltip = button_meta[State.CANCEL]["tooltip"]

    # Change value before changing max to prevent jitter in progress bar
    progress_bar.value = 0
    progress_bar.max = number_of_sequences_input.value
    number_of_sequences_input.disabled = True


def enable_download_button(b):
    b.description = State.GENERATE.value
    b.icon = button_meta[State.GENERATE]["icon"]
    b.button_style = button_meta[State.GENERATE]["button_style"]
    b.tooltip = button_meta[State.GENERATE]["tooltip"]

    number_of_sequences_input.disabled = False

def on_button_clicked(b):
    global proteinsolver_thread

    if b.description == State.GENERATE.value:
        disable_download_button(b)

        if proteinsolver_thread is not None and not proteinsolver_thread.stopped():
            proteinsolver_thread.stop()
        proteinsolver_thread = ProteinSolverThread("hello world", msa_alignment_view, progress_bar, b)
        proteinsolver_thread.start()
    elif b.description == State.CANCEL.value:
        enable_download_button(b)

        proteinsolver_thread.stop()
    else:
        raise Exception


run_proteinsolver_button = widgets.Button(
    description=State.GENERATE.value,
    icon=button_meta[State.GENERATE]["icon"],
    button_style=button_meta[State.GENERATE]["button_style"],  # 'success', 'info', 'warning', 'danger' or ''
    tooltip=button_meta[State.GENERATE]["tooltip"],
    disabled=False,
    layout=widgets.Layout(width="200px"),
)
run_proteinsolver_button.on_click(on_button_clicked)

In [None]:
widgets.VBox([number_of_sequences_input, run_proteinsolver_button])

### Generate download link

In [None]:
generate_download_link_output = widgets.Output(layout=widgets.Layout(width="200px"))

In [None]:
def generate_download_link(b):
    b.description = "Generating..."
    b.icon = "running"
    b.button_style = "info"  # 'success', 'info', 'warning', 'danger' or ''
    b.disabled = True

    generate_download_link_output.clear_output()
    output_file = save_sequences()
    with generate_download_link_output:
        display(
            HTML(
                f'<a href="./voila/static/{output_file.name}" download={output_file.stem[:8]}{output_file.suffix}><i class="fa fa-download"></i> Download sequences</a>'
            )
        )

    b.description = "Update download link"
    b.icon = ""  # check
    b.button_style = "success"
    b.disabled = False


generate_download_link_button = widgets.Button(
    description="Generate download link",
    tooltip="Generate download link",
    button_style="success",
    disabled=False,
    layout=widgets.Layout(width="200px"),
)

generate_download_link_button.on_click(generate_download_link)

In [None]:
widgets.VBox([generate_download_link_button, generate_download_link_output])

## Final dashboard

In [None]:
section1_heading = """
<p class=myheading>
1. Load a reference protein structure, or use one of the provided examples.
</p>
"""

display(HTML(section1_heading))

In [None]:
display(
    HTML(
        """
<p class="myheading">
2. (Optional) Specify target amino acids at specific positions (or enter '-' to leave the position open for design).
</p>
"""
    )
)
display(
    widgets.VBox(
        [
            target_sequence_selection_box,
            widgets.VBox([reference_sequence_ta, target_sequence_ta], layout=widgets.Layout(margin="20px 0px 0px 0px")),
        ]
    )
)

In [None]:
display(
    HTML(
        """
<p class="myheading">
3. Run ProteinSolver to generate new designs.
</p>
"""
    )
)
display(
    widgets.HBox(
        [
            widgets.VBox(
                [
                    widgets.VBox([number_of_sequences_input, run_proteinsolver_button]),
                    widgets.VBox(
                        [generate_download_link_button, generate_download_link_output],
                        layout=Layout(margin="30px 0px 0px 0px"),
                    ),
                ],
                layout=Layout(width="230px", flex="0 0 auto"),
            ),
            widgets.VBox([progress_bar, msa_alignment_view], layout=Layout(align_items="stretch", flex="1 1 auto")),
        ],
        layout=Layout(align_items="stretch", flex="flex-grow"),
    )
)