Skip to content

Commit

Permalink
Fix caffe2 predict (facebookresearch#1103)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#1103

Currently it doesn't correctly handle str and non-str input. This commit fix the issue

Differential Revision: D18286625

fbshipit-source-id: 91a4e33d535f854c679e913bc4946098f1f7e4c5
  • Loading branch information
seayoung1112 authored and facebook-github-bot committed Nov 5, 2019
1 parent 9d64d10 commit db0e0db
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 18 deletions.
10 changes: 8 additions & 2 deletions pytext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ def _predict(workspace_id, predict_net, model, tensorizers, input):
}
model_inputs = model.arrange_model_inputs(tensor_dict)
model_input_names = model.get_export_input_names(tensorizers)
vocab_to_export = model.vocab_to_export(tensorizers)
for blob_name, model_input in zip(model_input_names, model_inputs):
converted_blob_name = convert_caffe2_blob_name(blob_name)
workspace.blobs[converted_blob_name] = np.array([model_input], dtype=str)
converted_blob_name = blob_name
dtype = np.float32
if blob_name in vocab_to_export:
converted_blob_name = convert_caffe2_blob_name(blob_name)
dtype = str

workspace.blobs[converted_blob_name] = np.array([model_input], dtype=dtype)
workspace.RunNet(predict_net)
return {
str(blob): workspace.blobs[blob][0] for blob in predict_net.external_outputs
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_data_tiny.tsv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
alarm/set_alarm 11:17:datetime reactivate weekly alarm
alarm/set_alarm 23:38:datetime Set alarm to ring only on the weekdays
alarm/time_left_on_alarm When will alarm go off
reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call the groomer
alarm/set_alarm 11:17:datetime reactivate weekly alarm [1.0]
alarm/set_alarm 23:38:datetime Set alarm to ring only on the weekdays [1.0]
alarm/time_left_on_alarm When will alarm go off [1.0]
reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call the groomer [1.0]
20 changes: 10 additions & 10 deletions tests/data/train_data_tiny.tsv
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier
alarm/set_alarm Turn on all my alarms
alarm/set_alarm 12:27:datetime sound alarm every 8 minutes
alarm/set_alarm 7:17:datetime repeat yesterdays alarm
alarm/snooze_alarm continue my alarm
alarm/time_left_on_alarm Do I have anymore time on the alarm
reminder/set_reminder 10:22:datetime,28:56:reminder/todo remind me Monday night that get doe with work 12 Tuesday
reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders
weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow
weather/find 13:17:weather/attribute When will it snow
alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier [1.0]
alarm/set_alarm Turn on all my alarms [1.0]
alarm/set_alarm 12:27:datetime sound alarm every 8 minutes [1.0]
alarm/set_alarm 7:17:datetime repeat yesterdays alarm [1.0]
alarm/snooze_alarm continue my alarm [1.0]
alarm/time_left_on_alarm Do I have anymore time on the alarm [1.0]
reminder/set_reminder 10:22:datetime,28:56:reminder/todo remind me Monday night that get doe with work 12 Tuesday [1.0]
reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders [1.0]
weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow [1.0]
weather/find 13:17:weather/attribute When will it snow [1.0]
29 changes: 27 additions & 2 deletions tests/predictor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
import tempfile
import unittest

import numpy as np
from pytext import batch_predict_caffe2_model
from pytext.config import LATEST_VERSION, PyTextConfig
from pytext.data import Data
from pytext.data.sources import TSVDataSource
from pytext.data.tensorizers import (
FloatListTensorizer,
LabelTensorizer,
TokenTensorizer,
)
from pytext.models.doc_model import DocModel
from pytext.task import create_task
from pytext.task.serialize import save
from pytext.task.tasks import DocumentClassificationTask
Expand All @@ -25,14 +32,23 @@ def test_batch_predict_caffe2_model(self):
eval_data = tests_module.test_file("test_data_tiny.tsv")
config = PyTextConfig(
task=DocumentClassificationTask.Config(
model=DocModel.Config(
inputs=DocModel.Config.ModelInput(
tokens=TokenTensorizer.Config(),
dense=FloatListTensorizer.Config(
column="dense", dim=1, error_check=True
),
labels=LabelTensorizer.Config(),
)
),
data=Data.Config(
source=TSVDataSource.Config(
train_filename=train_data,
eval_filename=eval_data,
test_filename=eval_data,
field_names=["label", "slots", "text"],
field_names=["label", "slots", "text", "dense"],
)
)
),
),
version=LATEST_VERSION,
save_snapshot_path=snapshot_file.name,
Expand All @@ -47,3 +63,12 @@ def test_batch_predict_caffe2_model(self):
snapshot_file.name, caffe2_model_file.name
)
self.assertEqual(4, len(results))

pt_results = task.predict(task.data.data_source.test)

for pt_res, res in zip(pt_results, results):
print(pt_res["score"])
print(res)
np.testing.assert_array_almost_equal(
pt_res["score"].tolist()[0], [score[0] for score in res.values()]
)

0 comments on commit db0e0db

Please sign in to comment.