Skip to content

Commit

Permalink
workflow implementation and other updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jbkoh committed Jun 7, 2018
1 parent 954046e commit 0b863e0
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 23 deletions.
3 changes: 0 additions & 3 deletions plastering/rdflib_wrapper.py
Expand Up @@ -37,9 +37,6 @@
def adder(x, y):
return x + y

def adder(x, y):
return x + y


def init_graph(empty=False):
if empty:
Expand Down
35 changes: 15 additions & 20 deletions plastering/workflow.py
Expand Up @@ -92,15 +92,19 @@ def init_node(self, f_name, prev, f_graph_configs):
curr_node.nexts = nexts
return curr_node

def _merge_srcids(self, srcids_list, num):
# Dummy but reasonable solution for now.
return srcids_list[-1]

def select_informative_samples(self, sample_num):
params = {
'sample_num': sample_num
}
res_g = self._traverse_wrapper(self.f_head,
['select_informative_samples'],
[params])
pdb.set_trace()
#TODO: Post processing the colleted result
merged = self._merge_srcids(list(res_g.values()), sample_num)
return merged

def predict_proba(self, target_srcids=None):
params = {
Expand All @@ -109,7 +113,9 @@ def predict_proba(self, target_srcids=None):
res_g = self._traverse_wrapper(self.f_head, 'predict_proba', params)
# TODO: Post processing res_g to merge different results

def predict(self, target_srcids):
def predict(self, target_srcids=None):
if not target_srcids:
target_srcids = self.target_srcids
params = {
'target_srcids': target_srcids
}
Expand Down Expand Up @@ -141,17 +147,11 @@ def _traverse_wrapper(self, node, func_names, params, prev_attrs=[[]]):
param[attr] = getattr(node.prev.f, attr)
else:
param[attr] = None
try:
func = getattr(node.f, func_name)
except:
pdb.set_trace()
func = getattr(node.f, func_name)
try:
res_dict[(str(node), func_name)] = func(**param)
except EmptyTrainingSamples as e:
print(e.msg)
except Exception as e:
print(e)
pdb.set_trace()

for next_node in node.nexts:
res_dict.update(self._traverse_wrapper(next_node, func_names,
Expand All @@ -171,17 +171,12 @@ def update_model(self, srcids):
- srcids (list(str)): list of srcids to add.
"""
params = [
{
'new_srcids': srcids,
},
{
},
{
'target_srcids': self.target_srcids
}
{},
{'new_srcids': srcids},
{'target_srcids': self.target_srcids}
]
func_names = ['update_model', 'update_prior', 'predict']
prev_attrs = [[], ['pred_g'], []]
func_names = ['update_prior', 'update_model', 'predict']
prev_attrs = [['pred_g', 'pred_confidences'], [], []]
self._traverse_wrapper(self.f_head, func_names, params, prev_attrs)

def update_model_deprectaed(self, srcids):
Expand Down
2 changes: 2 additions & 0 deletions test_zodiac.py
Expand Up @@ -12,6 +12,8 @@

zodiac = ZodiacInterface(target_building=target_building,
target_srcids=target_srcids)
#zodiac.update_model([])
pred = zodiac.predict()
zodiac.learn_auto()
pred = zodiac.predict()
proba = zodiac.predict_proba()
48 changes: 48 additions & 0 deletions test_zodiac_scrabble.py
@@ -0,0 +1,48 @@
#from plastering.inferencers.quiver import DummyQuiver
from plastering.inferencers.zodiac_new import ZodiacInterface
from plastering.inferencers.scrabble_new import ScrabbleInterface
#from plastering.inferencers.zodiac_interface import ZodiacInterface
from plastering.metadata_interface import *
from plastering.workflow import *
from plastering.helper.common import *
import pdb

# construct a framework dict for referneces.
f_class_dict = {
'zodiac': ZodiacInterface,
'quiver': DummyQuiver
}

target_building = 'ebu3b'
source_buildings = ['ap_m', 'ebu3b']
building_sentence_dict, building_label_dict, building_tagsets_dict = \
data_loader(target_building, source_buildings)
target_srcids = list(building_tagsets_dict[target_building].keys())

base_config = {
'target_building': target_building,
'target_srcids': target_srcids,
'source_buildings': source_buildings
}
zodiac_config = deepcopy(base_config)
quiver_config = deepcopy(base_config)
building_ttl = 'groundtruth/{0}_brick.ttl'.format(target_building)
quiver_config['config'] = {
'ground_truth_ttl': building_ttl
}

f_graph = {
'zodiac': (zodiac_config, {
'scrabble': (scrabble_config, {
})
})
}

# init workflow
workflow = Workflow(target_srcids, f_class_dict, f_graph)
# random srcids to update
new_srcids = random.sample(target_srcids, 10)
for i in range(0, 20):
workflow.update_model(new_srcids)
workflow.predict(new_srcids)
new_srcids = workflow.select_informative_samples(5)

0 comments on commit 0b863e0

Please sign in to comment.