-
Notifications
You must be signed in to change notification settings - Fork 18
/
experimental.py
104 lines (84 loc) · 3.68 KB
/
experimental.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Experimental recipes whose function signatures might change significantly in the future. Use with caution.
"""
from bokeh.layouts import row, column
from bokeh.models import Button, Slider
from .subroutine import (
standard_annotator,
standard_finder,
standard_snorkel,
standard_softlabel,
)
from hover.utils.bokeh_helper import servable
from wasabi import msg as logger
import pandas as pd
@servable(title="Snorkel Crosscheck")
def snorkel_crosscheck(dataset, lf_list, height=600, width=600):
"""
Use the dev set to check labeling functions; use the labeling functions to hint at potential annotation.
Layout:
sidebar | [inspect LFs here] | [annotate here]
"""
# building-block subroutines
snorkel = standard_snorkel(dataset, height=height, width=width)
annotator = standard_annotator(dataset, height=height, width=width)
# plot labeling functions
for _lf in lf_list:
snorkel.plot_lf(_lf)
snorkel.figure.legend.click_policy = "hide"
# link coordinates and selections
snorkel.link_xy_range(annotator)
snorkel.link_selection("raw", annotator, "raw")
sidebar = dataset.view()
layout = row(sidebar, snorkel.view(), annotator.view())
return layout
@servable(title="Active Learning")
def active_learning(dataset, vectorizer, vecnet_callback, height=600, width=600):
"""
Place a VectorNet in the loop.
Layout:
sidebar | [inspect soft labels here] | [annotate here] | [search here]
"""
# building-block subroutines
softlabel = standard_softlabel(dataset, height=height, width=width)
annotator = standard_annotator(dataset, height=height, width=width)
finder = standard_finder(dataset, height=height, width=width)
# link coordinates and selections
softlabel.link_xy_range(annotator)
softlabel.link_xy_range(finder)
softlabel.link_selection("raw", annotator, "raw")
softlabel.link_selection("raw", finder, "raw")
# recipe-specific widget
def setup_model_retrainer():
model_retrainer = Button(label="Train model", button_type="primary")
epochs_slider = Slider(start=1, end=20, value=1, step=1, title="# epochs")
def retrain_model():
"""
Callback function.
"""
model_retrainer.disabled = True
logger.info("Start training... button will be disabled temporarily.")
dataset.setup_label_coding()
model = vecnet_callback(dataset, vectorizer)
train_loader = dataset.loader("train", vectorizer, smoothing_coeff=0.2)
dev_loader = dataset.loader("dev", vectorizer)
_ = model.train(train_loader, dev_loader, epochs=epochs_slider.value)
logger.good("-- 1/2: retrained model")
for _key in ["raw", "train", "dev"]:
_probs = model.predict_proba(dataset.dfs[_key]["text"].tolist())
_labels = [
dataset.label_decoder[_val] for _val in _probs.argmax(axis=-1)
]
_scores = _probs.max(axis=-1).tolist()
dataset.dfs[_key]["pred_label"] = pd.Series(_labels)
dataset.dfs[_key]["pred_score"] = pd.Series(_scores)
softlabel._update_sources()
softlabel.plot()
model_retrainer.disabled = False
logger.good("-- 2/2: updated predictions. Training button is re-enabled.")
model_retrainer.on_click(retrain_model)
return model_retrainer, epochs_slider
model_retrainer, epochs_slider = setup_model_retrainer()
sidebar = column(model_retrainer, epochs_slider, dataset.view())
layout = row(sidebar, *[_plot.view() for _plot in [softlabel, annotator, finder]])
return layout