In [1]:
import json

split = 'dev'
samples = []
with open(f"data/hotpot-{split}.tsv") as f:
    for line in f:
        if split == 'test':
            q_id, question = line.strip().split('\t')
            samples.append((q_id, (question,)))
        else:
            q_id, question, answer, sp_facts = line.strip().split('\t')
            sp_facts = json.loads(sp_facts)
            samples.append((q_id, (question, answer, sp_facts)))

In [2]:
from agent.client import Agent
from env.client import Env

agent = Agent('10.60.1.79:17101')
env = Env('10.60.1.79:17101')

In [3]:
import hashlib
from ipywidgets import Accordion, BoundedIntText, Button, Combobox, Dropdown, GridspecLayout, HBox, HTML, Label, Layout, Output, RadioButtons, Tab, ToggleButton, VBox
from IPython.display import display
import time

def md5_str(s):
    return hashlib.md5(s.encode('utf-8')).hexdigest()

In [4]:
%%html
<style>
    .widget-radio-box label{
#         margin-top: 5px;
        margin-bottom: 12px !important;
#         width: 80px;
    }
    a{
        color: blue;
    }
</style>

In [5]:
funcs = ['Sparse', 'Dense', 'Link', 'Answer']
styles = ["success", "info", "warning", "danger"]
colors = {"Sparse": "MediumSeaGreen", "Dense": "MediumTurquoise", "Link": "Goldenrod", "Answer": "Salmon"}
implementations = {"Sparse": "BM25", "Dense": "MDR", "Link": "LINK", "Answer": "ANSWER"}
impl2func = {v: k for k, v in implementations.items()}
question_examples = [sample[1][0] for sample in samples]
passage_htmls = dict()
title2id = dict()  # unescaped_title -> p_id

In [6]:
out = Output(layout={'border': '1px solid Silver'})

In [7]:
from html import unescape
import threading

stop_thread = False

def confirm_game(change):
    global horitzon, step, question, obs_id, session_id, q_input, t_input, run_button, obs_html, action_radio, arg_inputs, consult_button, is_tab, out
    if change['new']:
        q_input.disabled = True
        t_input.disabled = True
        confirm_button.description = 'Reset'
        confirm_button.icon = 'user-edit'
        run_button.disabled = False
        
        horitzon, step = t_input.value, 0
        question = q_input.value
        obs_id = None
        session_id = f"{int(time.time())}-{md5_str(question)}"
        
        obs_html.value = question
        action_radio.disabled = False
        for w in arg_inputs.values():
            w.disabled = False
        arg_inputs['Sparse'].options = (question,)
        consult_button.disabled = False
    else:
        q_input.value = ''
        q_input.disabled = False
        t_input.disabled = False
        confirm_button.description = 'Confirm'
        confirm_button.icon = 'user-check'
        run_button.disabled = True
        
        horitzon, step = -1, -1
        question = ''
        obs_id = None
        session_id = None
        passage_htmls.clear()
        title2id.clear()
        
        is_tab.children = [create_page()]
        is_tab._titles = {0: '0'}
        
        out.clear_output()

def check_action(change):
    global arg_inputs, execute_button, horitzon, step
    func = change['new']
    arg = ' '.join(arg_inputs[func].value.split())
    arg_inputs[func].value = arg
    if not arg:
        execute_button.disabled = True
    elif step == horitzon and func != 'Answer':
        execute_button.disabled = True
    else:
        execute_button.disabled = False
        
def update_evidence(change):
    global obs_id, session_id, context_toggles, arg_inputs
    
    toggle_button = change['owner']
    p_id = toggle_button.description_tooltip
    if toggle_button.button_style == 'warning':
        toggle_button.tooltip = f"Click to exclude {p_id}"
        toggle_button.button_style = 'info'
        toggle_button.icon = 'check'
    elif toggle_button.button_style == 'info':
        toggle_button.tooltip = f"Click to include {p_id}"
        toggle_button.button_style = 'warning'
        toggle_button.icon = 'ban'
    elif toggle_button.button_style == 'danger':
        toggle_button.tooltip = f"Click to remove {p_id}"
        toggle_button.button_style = 'success'
        toggle_button.icon = 'check'
    elif toggle_button.button_style == 'success':
        toggle_button.tooltip = f"Click to retain {p_id}"
        toggle_button.button_style = 'danger'
        toggle_button.icon = 'times'
    
    arg_inputs['Dense'].options = ['Q'] + [f"Q + {x.description}" for x in context_toggles.children[1:] if x.value]

def pause(b):
    global run_button, stop_thread
    stop_thread = True

def act_auto(b):
    global horitzon, step, run_button, pause_button, consult_button, execute_button, stop_thread
    b.disabled = True
    pause_button.disabled = False
    stop_thread = False
    while step <= horitzon:
        if stop_thread:
            break
        if not consult_button.disabled:
            act_step(consult_button)
        if step == horitzon:
            action_radio.value = 'Answer'
        if execute_button.disabled:
            break
        execute(execute_button)
    pause_button.disabled = True
    if step < horitzon and not (consult_button.disabled and execute_button.disabled):
        b.disabled = False

# TODO: log output
def act_step(b):
    global step, question, obs_id, session_id, pause_button, context_toggles, action_radio, arg_inputs
    b.disabled = True
    
    cmd = agent.act([session_id], [question], [obs_id], review=(step % 10 == 0))[session_id]  # disable_tqdm=True
    proposals = dict()
    for (impl, arg), arg_conf, action_prob in agent.proposals(session_id, -1):
        if impl == 'MDR':
            arg = (arg[0], None if arg[1] is None else unescape(env.get(arg[1])['title']))
        proposals[impl2func[impl]] = {"arg": arg, "conf": arg_conf, "prob": action_prob}
    evidence = agent.memory(session_id)
    
    print(f"*{step:<3d} {evidence}")
    if context_toggles:
        for x in context_toggles.children[1:]:
            x.value = x.description_tooltip in evidence
    
    for f in proposals.keys():
        arg = proposals[f]['arg']
        if f == 'Dense':
            arg_inputs[f].value = 'Q' if arg[1] is None else f"Q + {arg[1]}"
        elif f == 'Link':
            if arg == 'nolink':
                arg = 'None'
            for opt in arg_inputs[f].options:
                if opt.split(' → ')[-1] == arg:
                    arg_inputs[f].value = opt
                    break
            else:
                arg_inputs[f].value = 'None'
                print(f'cannot find corresponding option from {arg_inputs[f].options} for {arg}')
        elif f == 'Answer':
            if arg == 'noanswer':
                arg_inputs[f].value = 'None'
            elif arg in ['yes', 'no']:
                arg_inputs[f].value = arg.capitalize()
            else:
                arg_inputs[f].value = arg
        else:
            arg_inputs[f].value = arg
        print(f"{proposals[f]['prob']:.3f}  {proposals[f]['conf']:.2f} {f}({arg_inputs[f].value})")
    action_radio.value = max(proposals.keys(), key=lambda f: proposals[f]['prob'])    

@out.capture()
def execute(b):
    global step, question, obs_id, session_id, obs_html, context_accordion, context_toggles, action_radio, arg_inputs, consult_button, execute_button, is_tab
    
    action_radio.disabled = True
    for w in arg_inputs.values():
        w.disabled = True
    consult_button.disabled = True
    execute_button.disabled = True
    if context_accordion:
        context_accordion.selected_index = None
    if context_toggles:
        for x in context_toggles.children[1:]:
            x.disabled = True
            if x.value:
                agent.add_evidence(session_id, x.description_tooltip)
            else:
                agent.del_evidence(session_id, x.description_tooltip)
    
    func = action_radio.value
    arg = ' '.join(arg_inputs[func].value.split())
    if step == 0:
        trajectory_titles = [HTML('<center><b>Step</b></center>', layout=Layout(display="flex", justify_content='center', width='4%', margin='0px')),
                             HTML('<center><b>Observation</b></center>', layout=Layout(width='12%', margin='0px'))]
        for f in funcs:
            trajectory_titles.append(HTML(value=f"<center><b>{f}</b></center>", layout=Layout(width='21%', margin='0px')))
        display(HBox(trajectory_titles, layout=Layout(margin='0px')))
    if obs_id is None:
        obs_desc = 'Q' if step == 0 else ''
    else:
        obs_desc = obs_html.value
    obs_btn = Button(description=obs_desc, tooltip=f'show obs at step {step}', 
                     icon='check' if context_toggles and context_toggles.children[1].description == obs_desc and context_toggles.children[1].value else '', 
                     layout=Layout(width='12%'), style={"button_color": "Snow"})
    obs_btn.on_click(show_obs)
    row_items = [HTML(f"<center>{step}</center>", layout=Layout(width='4%', margin='0px')), obs_btn]
    for idx, f in enumerate(funcs):
        btn = Button(description=' '.join(arg_inputs[f].value.split()), tooltip=f'jump to step {step}', 
                     layout=Layout(width='21%', margin='0px'), button_style=styles[idx] if f == func else '', disabled=(f != func))
        if f == func:
            btn.on_click(show_step)
            btn.style = {"button_color": colors[f]}
        else:
            btn.style = {"button_color": "transparent"}
        row_items.append(btn)
    display(HBox(row_items, layout=Layout(margin='0px')))
    
    if func != 'Answer':
        func_impl = implementations[func]
        if func_impl == 'MDR':
            cmd = (func_impl, (question, None if arg == 'Q' else title2id[arg[4:]]))
        elif func_impl == 'LINK':
            cmd = (func_impl, arg.split(' → ')[-1])
        else:
            cmd = (func_impl, arg)
        obs_id = env.step(cmd, session_id, 
                          exclusion=[toggle.description_tooltip for toggle in context_toggles.children[1:] if toggle.value] if context_toggles else None)  # xxx observed
        
        step = len(is_tab.children)
        is_tab.children = is_tab.children + (create_page(),)
        is_tab.set_title(step, str(step))
        is_tab.selected_index = step

def show_step(b):
    global is_tab
    is_tab.selected_index = int(b.tooltip.split()[-1])

def show_obs(b):
    global is_tab
    is_tab.selected_index = int(b.tooltip.split()[-1])
    try:
        if is_tab.children[is_tab.selected_index].children[0].children[0].value:
            is_tab.children[is_tab.selected_index].children[0].children[2].selected_index = 0
    except:
        pass

def create_page():
    global step, question, obs_id, session_id, obs_html, context_accordion, context_toggles, action_radio, arg_inputs, consult_button, execute_button
    
    if obs_id is None:
        obs = question if step == 0 else ''
        context = []
    else:
        obs_para = env.get(obs_id)
        obs = unescape(obs_para['title'])
        title2id[obs] = obs_id
        context = [obs_para]
        if obs_id not in passage_htmls:
            segs = []
            offset = 0
            for span, tgt_title in sorted(obs_para['refs'], key=lambda x: x[0]):
                segs.append(obs_para['text'][offset:span[0]])
                segs.append(f'<a href="https://en.wikipedia.org/wiki/{tgt_title}" target="_blank">{obs_para["text"][span[0]:span[1]]}</a>')
                offset = span[1]
            segs.append(obs_para['text'][offset:])
            passage_htmls[obs_id] = HTML(value=''.join(segs))
    if session_id:
        context.extend(agent.memory(session_id, 'ordered_passages'))
    
    belief_widgets = []
    obs_html = HTML(value=obs, description='Observation:', description_tooltip='question' if step == 0 else 'passage title', style={"description_width": "initial"})
    belief_widgets.append(obs_html)
    if len(context) > 0:
        assert step > 0
        context_toggles = HBox([Label(value=r'Evidence:')] + [
            ToggleButton(
                value=(i > 0 or obs_id is None), description=unescape(p['title']), description_tooltip=p['para_id'], 
                tooltip=f"Click to {'add' if i == 0 and obs_id else 'remove'} {p['para_id']}",
                button_style='warning' if i == 0 and obs_id else 'success', icon='ban' if i == 0 and obs_id else 'check', layout=Layout(width='auto', max_width='30%')
            ) for i, p in enumerate(context)
        ], layout=Layout(display='flex', flex_flow='row wrap'))
        for x in context_toggles.children[1:]:
            x.observe(update_evidence, names='value')
        belief_widgets.append(context_toggles)
        context_accordion = Accordion(children=[passage_htmls[p['para_id']] for p in context], _titles={i: unescape(p['title']) for i, p in enumerate(context)})
        if obs_id is None:
            context_accordion.selected_index = None
        else:
            context_accordion.selected_index = 0
        belief_widgets.append(context_accordion)
    else:
        context_accordion, context_toggles = None, None
    belief_panel = VBox(belief_widgets, layout=Layout(width='50%'))
    
    action_radio = RadioButtons(options=funcs, value=None, description=r'Action:', disabled=(step == 0), 
                                style={"description_width": "initial"}, layout=Layout(width='max-content', min_width='max-content'))
    action_radio.observe(check_action, names='value')
    arg_inputs['Sparse'] = Combobox(value='', placeholder='Query for sparse', options=[question], ensure_option=False, disabled=(step == 0), layout=Layout(width='auto'))
    arg_inputs['Dense'] = Combobox(value='', placeholder='Query for dense', ensure_option=True, disabled=(step == 0), layout=Layout(width='auto'),
                                   options=['Q'] + [f"Q + {x.description}" for x in context_toggles.children[1:] if x.value] if context_toggles else ['Q'])
    arg_inputs['Link'] = Combobox(value='', placeholder='Anchor text', ensure_option=True, disabled=(step == 0), layout=Layout(width='auto'), 
                                  options=['None'] + [f"{unescape(p['title'])} ⤚ {p['text'][span[0]:span[1]]} → {tgt_title}"
                                           for p in context for span, tgt_title in sorted(obs_para['refs'], key=lambda x: x[0])])
    arg_inputs['Answer'] = Combobox(value='', placeholder='Answer span', options=['None', 'Yes', 'No'], ensure_option=False, disabled=(step == 0), layout=Layout(width='auto'))
    action_inputs = HBox([action_radio, VBox([arg_inputs[f] for f in funcs], layout=Layout(width='100%'))], layout=Layout(width='100%'))
    consult_button = Button(description='Consult', disabled=(step < 1), button_style='', tooltip='Seek suggestion', icon='lightbulb', layout=Layout(width='20%'))  # lightbulb question
    consult_button.on_click(act_step)
    execute_button = Button(description='Execute', disabled=True, button_style='', tooltip='Act step-by-step', icon='step-forward', layout=Layout(width='20%'))
    execute_button.on_click(execute)
    do_buttons = HBox([consult_button, execute_button], layout=Layout(display='flex', flex_flow='row', justify_content='center', width='100%'))
    policy_panel = VBox([action_inputs, do_buttons], layout=Layout(width='50%'))
    
    page = HBox([belief_panel, policy_panel], layout=Layout(width='100%'))
    
    return page

In [8]:
horitzon, step = -1, -1
question = ''
obs_id = None
session_id = None
passage_htmls.clear()
title2id.clear()

q_input = Combobox(value='', placeholder='Ask a question', options=question_examples, description=r'Q:', tooltip='question', ensure_option=False, disabled=False, 
                   style={"description_width": "initial"}, layout=Layout(width='60%'))
t_input = BoundedIntText(value=10, min=5, max=40, step=5, description=r'T:', tooltip='horitzon', disabled=False, style={"description_width": "initial"}, layout=Layout(width='10%'))
confirm_button = ToggleButton(value=False, description='Confirm', icon='user-check', button_style='', disabled=False, layout=Layout(width='10%'))
confirm_button.observe(confirm_game, names='value')
run_button = Button(description='Run', disabled=True, button_style='', tooltip='Act autonomously', icon='forward', layout=Layout(width='10%'))  # 'success', 'info', 'warning', 'danger' or ''
run_button.on_click(act_auto)
pause_button = Button(description='Pause', disabled=True, button_style='', tooltip='Pause autonomous execution', icon='pause', layout=Layout(width='10%'))
pause_button.on_click(pause)
game_setting = HBox([q_input, t_input, confirm_button, run_button, pause_button])

obs_html, context_accordion, context_toggles, action_radio, consult_button, execute_button = None, None, None, None, None, None
arg_inputs = {f: None for f in funcs}
is_tab = Tab([create_page()], _titles={0: '0'})

out.clear_output()

display(game_setting, is_tab, out)

HBox(children=(Combobox(value='', description='Q:', layout=Layout(width='60%'), options=('Were Scott Derrickso…

Tab(children=(HBox(children=(VBox(children=(HTML(value='', description='Observation:', description_tooltip='pa…

Output(layout=Layout(border='1px solid Silver'))