In [13]:
import os
from pathlib import Path
import ipywidgets as widgets
import traitlets
import pandas as pd
import polars as pl
from IPython.display import display, clear_output
from ipyaggrid import Grid
from src.post_processing import PathWrangler

In [14]:
# os.environ.get("")
from hydra import initialize, compose
with initialize(version_base=None, config_path="../conf/filepaths"):
    cfg = compose(config_name="filepaths")

study = Path(cfg.results) /  "bottle25"
known = Path(cfg.known)
pw = PathWrangler(study=study, known=known)

In [15]:
# UI widgets and layout

# Input
starter_options, target_options = pw.starters, pw.targets
evidence_options = [elt.value for elt in pw.enzyme_existence]
sort_options = [
    ("Mean Rxn Sim", 'mean_max_rxn_sim'),
    ("Min Rxn Sim", 'min_max_rxn_sim'),
    # ("Max-min driving force", 'mdf'),
    ("Feasibility", 'feasibility_frac')
]
kwargs_input_selector = dict(
    disabled=False,
    continuous_update=False,
    rows=len(starter_options),
    layout=widgets.Layout(
        height='100px',
        width='auto',
        flex='1 1 auto',
    ),
    style=dict(
        description_width='100px',
    ),
)

input_selectors = widgets.HBox(
    children=[
        starter_selector := widgets.SelectMultiple(
            description='Starters:',
            options=starter_options,
            value=starter_options[:1],
            **kwargs_input_selector,
        ),
        target_selector := widgets.SelectMultiple(
            description='Targets:',
            options=target_options,
            value=target_options[:1],
            **kwargs_input_selector,
        ),
        evidence_selector := widgets.SelectMultiple(
            description='Enzyme LOE:',
            options=evidence_options,
            value=evidence_options[:1],
            **kwargs_input_selector,
        ),
        sort_by_radio_buttons := widgets.RadioButtons(
            description='Sort paths by:',
            options=sort_options,
            value='mean_max_rxn_sim',
            **kwargs_input_selector,
        ),
    ],
    layout=widgets.Layout(
        flex_flow='row wrap',
    ),
)

# Search button
search_button = widgets.Button(
    description='Load paths',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Load paths',
    icon='flask' # (FontAwesome names without the `fa-` prefix)
)

# output 

kwargs_output_selector = {
    **kwargs_input_selector,
    'layout': widgets.Layout(
        width='auto',
        flex='0 1 auto',
    ),
}

output_paths = widgets.VBox(
    children=[
        paths_label := widgets.Label(
            value='Make sorting & filtering selections above',
        ),
        paths_selector := widgets.Dropdown(
            description='Path:',
            **kwargs_output_selector,
        ),
        paths_viewer := widgets.Output(),
    ],
)

# UI overall

ui = widgets.VBox(
    children=[
        input_selectors,
        search_button,
        output_paths,
    ],
    layout=widgets.Layout(
        width='99%',
        justify_content='flex-start',
    ),
)

In [16]:
js_uniprot_id_renderer = '''
    function(params){
        if (params.value !== undefined && params.value !== null){
            return `<a href="https://www.uniprot.org/uniprotkb/${params.value}/entry" target="_blank" rel="noopener noreferrer">${params.value}</a>`;
        }
        return ""
    }
'''

def build_enzymes(enzymes: pl.DataFrame) -> Grid:
    df = (
        enzymes.to_pandas()
        .loc[:, ['id', 'ec', 'organism', 'name', 'existence', 'reviewed', 'sequence']]
        .pipe(lambda df: df.set_index(pd.Index(range(1, len(df) + 1), name='idx')))
    )
    grid = Grid(
        grid_data=df,
        grid_options={
            'columnDefs': [
                {'headerName': '', 'field': df.index.name, 'width': 40},
                {'headerName': 'UniProt ID ⤴', 'field': 'id', 'cellRenderer': js_uniprot_id_renderer, 'width': 120},
                {'headerName': 'EC', 'field': 'ec', 'width': 100},
                {'headerName': 'Organism', 'field': 'organism', 'width': 250},
                {'headerName': 'Name', 'field': 'name', 'width': 300},
                {'headerName': 'Existence', 'field': 'existence', 'width': 200},
                {'headerName': 'Reviewed', 'field': 'reviewed', 'width': 100},
                {'headerName': 'Sequence', 'field': 'sequence', 'width': 500},
            ],
            'enableSorting': True,
            'enableFilter': True,
            'enableColResize': True,
            'enableRangeSelection': True,
        },
        index=True,
        theme='ag-theme-balham',
        quick_filter=True,
        height=190,
        width=900,
    )
    grid.unsync = True
    return grid

def display_predicted_reaction(rxn_step: int, img: Path, feasibility: int):
    html = widgets.HTML(f'<b><u>Step #{rxn_step + 1} | Reaction feasibilty: {bool(feasibility)}</u></b>')
    svg = widgets.Image.from_file(img)
    return widgets.VBox([html, svg])

def widget_path_view(batch: dict[str, pl.DataFrame], idx: int, top_k_analogues: int = 10):
    path = batch["paths"].row(idx, named=True)
    header = widgets.HTML(f"""
    <h3>{len(path["reactions"])}-step path from {path["starter"].upper()} to {path["target"].upper()}<br>
    Max-min driving force: {round(path["mdf"], 2) if path["mdf"] else 'N/A'} kJ/mol<br>
    Path feasibility: {path["feasibility_frac"]:.2f}<br>
    ID: {path["id"]}
    </h3>
    """)
    rows = [header]
    pred_rxns = batch["predicted_reactions"].filter(
        pl.col("id").is_in(path["reactions"])
    ).sort(
        pl.col("id").replace({id: idx for idx, id in enumerate(path["reactions"])})
    ).select(
        pl.col("id"),
        pl.col("dxgb_label"),
        pl.col("rxn_sims"),
        pl.col("analogue_ids"),
        pl.col("image"),
    )
    for i, row in enumerate(pred_rxns.iter_rows(named=True)):
        pr_elt = display_predicted_reaction(
            rxn_step=i,
            img=Path(row["image"]),
            feasibility=row["dxgb_label"]
        )

        krid_to_sim = dict(zip(row['analogue_ids'], row["rxn_sims"]))
        analogues = batch["known_reactions"].filter(
            pl.col("id").is_in(row["analogue_ids"])
        ).with_columns(
            pl.col("id").replace(krid_to_sim).alias("rxn_sim").cast(pl.Float32),
        ).sort(
            pl.col("rxn_sim"),
            descending=True
        ).slice(0, top_k_analogues)

        enzymes = batch['enzymes'].filter(
            pl.col("id").is_in(set(analogues["enzymes"].explode()))
        )
        
        kr_elt = widget_analogues_enzymes(analogues, enzymes)

        row = widgets.GridBox(
            children=[pr_elt,  kr_elt],
            layout=widgets.Layout(
                border='1px solid black',
                height='280px',
                grid_template_rows='auto',
                grid_template_columns='55% 45%',
            )
        )
        rows.append(row)

    return widgets.VBox(rows)

def display_analogue(img: str, rxn_sim: float):
    html = widgets.HTML(f'<b><u>{rxn_sim * 100:.2f}% similar to predicted reaction</u></b>')
    svg = widgets.Image.from_file(img)
    return widgets.VBox([html, svg])


def widget_analogues_enzymes(analogues: pl.DataFrame, enzymes: pl.DataFrame):
    kr_elts = []
    enzyme_elts = []
    for row in analogues.iter_rows(named=True):
        kr_elts.append(display_analogue(row["image"], row["rxn_sim"]))
        enz = enzymes.filter(pl.col("id").is_in(row["enzymes"]))
        enzyme_elts.append(build_enzymes(enz))

    kr_default_idx = 0 if len(kr_elts) > 0 else None
    enz_default_idx = 0 if len(enzyme_elts) > 0 else None
    kr_selector = widgets.Dropdown(
        options=[(i + 1, i) for i in range(len(kr_elts))],
        value=kr_default_idx,
        description="Analogue: "
    )
    kr_stack = widgets.Stack(kr_elts, selected_index=kr_default_idx)
    kr_sel_disp = widgets.VBox([kr_selector, kr_stack])
    enzyme_stack = widgets.Stack(enzyme_elts, selected_index=enz_default_idx)
    _link_kr = widgets.jslink((kr_selector, 'index'), (kr_stack, 'selected_index'))
    _link_enz = widgets.jslink((kr_selector, 'index'), (enzyme_stack, 'selected_index'))

    return widgets.Tab(
        titles=['Known Analogues', 'Enzymes'],
        children=[kr_sel_disp, enzyme_stack],
    )

In [17]:
# event handlers and wiring

def update_paths_selector(_change: traitlets.Bunch):
    batch = pw.get_paths(
        starters=starter_selector.value,
        targets=target_selector.value,
        filter_by_enzymes={'existence': evidence_selector.value},
        sort_by=sort_by_radio_buttons.value,
    )
    path_stack = []
    for idx in range(len(batch["paths"])):
        path_stack.append((idx + 1, widget_path_view(batch, idx)))
    paths_selector.disabled = not path_stack
    paths_selector.options = path_stack
    if path_stack:
        # avoid setting `value` when the options are empty
        # otherwise, traitlet validations go berserk
        paths_selector.value = path_stack[0][1]
        paths_label.value = f'Loaded {len(path_stack)} paths'

def loading_paths(change_: traitlets.Bunch):
    paths_label.value = 'Loading paths...'

def render_paths(change: traitlets.Bunch):
    with paths_viewer:
        clear_output()
        display(change.new)

search_button.on_click(loading_paths)
search_button.on_click(update_paths_selector)
paths_selector.observe(render_paths, names=['value'])

In [18]:
# display UI & kick off a selection
display(ui)
starter_selector.value = starter_options[:2]

VBox(children=(HBox(children=(SelectMultiple(description='Starters:', index=(0,), layout=Layout(flex='1 1 auto…