Skip to content

Commit

Permalink
Fix merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
rvlb committed Apr 29, 2021
2 parents ecdf736 + e853c46 commit 6cbde0f
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 35 deletions.
13 changes: 5 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Cache installed requirements
uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ matrix.os }}-${{ env.pythonLocation }}-${{ hashFiles('**/requirements.txt', '**/requirements-dev.txt') }}
- name: Upgrade pip
run: pip install -U pip
#- name: Cache installed requirements
# uses: actions/cache@v2
# with:
# path: ${{ env.pythonLocation }}
# key: ${{ matrix.os }}-${{ env.pythonLocation }}-${{ hashFiles('**/requirements.txt', '**/requirements-dev.txt') }}
- name: Install tox, tox-gh-actions and coveralls
run: pip install tox==3.21.3 tox-gh-actions==2.4.0 coveralls==3.0.0
- name: Lint using flake8
Expand All @@ -38,7 +36,6 @@ jobs:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python-version }}
PYTORCH: ${{ matrix.pytorch-version }}
TEST_DEVICE: "cpu"
run: tox
- name: Upload coverage.xml
if: ${{ matrix.os == 'ubuntu-latest' && matrix.python-version == '3.9' && matrix.pytorch-version == '1.8' }}
Expand Down
12 changes: 8 additions & 4 deletions docs/guide/cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ Deduplication Train
--ef_construction 150 \
--ef_search -1 \
--random_seed 42 \
--model_save_dir trained-models/er/
--model_save_dir trained-models/er/ \
--use_gpu 1
Deduplication Predict
~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -121,7 +122,8 @@ Deduplication Predict
--ef_construction 150 \
--ef_search -1 \
--random_seed 42 \
--output_json example-data/er-prediction.json
--output_json example-data/er-prediction.json \
--use_gpu 1
Record Linkage Train
~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -162,7 +164,8 @@ Record Linkage Train
--ef_construction 150 \
--ef_search -1 \
--random_seed 42 \
--model_save_dir trained-models/rl/
--model_save_dir trained-models/rl/ \
--use_gpu 1
Record Linkage Predict
~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -185,4 +188,5 @@ Record Linkage Predict
--ef_construction 150 \
--ef_search -1 \
--random_seed 42 \
--output_json example-data/rl-prediction.json
--output_json example-data/rl-prediction.json \
--use_gpu 1
5 changes: 4 additions & 1 deletion entity_embed/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def _fit_model(model, datamodule, kwargs):
model_save_verbose=True,
tb_save_dir=kwargs["tb_save_dir"],
tb_name=kwargs["tb_name"],
use_gpu=kwargs["use_gpu"],
)


Expand Down Expand Up @@ -318,6 +319,7 @@ def _fit_model(model, datamodule, kwargs):
help="Directory path where to save the best validation model checkpoint"
" using PyTorch Lightning",
)
@click.option("--use_gpu", type=bool, default=True, help="Use GPU for training")
def train(**kwargs):
"""
Transform entities like companies, products, etc. into vectors
Expand Down Expand Up @@ -366,7 +368,7 @@ def _load_model(kwargs):
model_cls = EntityEmbed

model = model_cls.load_from_checkpoint(kwargs["model_save_filepath"], datamodule=None)
if torch.cuda.is_available():
if kwargs["use_gpu"]:
model = model.to(torch.device("cuda"))
else:
model = model.to(torch.device("cpu"))
Expand Down Expand Up @@ -504,6 +506,7 @@ def _write_json(found_pairs, kwargs):
"Remember Entity Embed is focused on recall. "
"You must use some classifier to filter these and find the best matching pairs.",
)
@click.option("--use_gpu", type=bool, default=True, help="Use GPU for predicting pairs")
def predict(**kwargs):
_fix_workers_kwargs(kwargs)
_set_random_seeds(kwargs)
Expand Down
4 changes: 3 additions & 1 deletion entity_embed/entity_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def fit(
model_save_verbose=False,
tb_save_dir=None,
tb_name=None,
use_gpu=True,
):
if early_stop_mode is None:
if "pair_entity_ratio_at" in early_stop_monitor:
Expand All @@ -184,13 +185,14 @@ def fit(
verbose=model_save_verbose,
)
trainer_args = {
"gpus": 1,
"min_epochs": min_epochs,
"max_epochs": max_epochs,
"check_val_every_n_epoch": check_val_every_n_epoch,
"callbacks": [early_stop_callback, checkpoint_callback],
"reload_dataloaders_every_epoch": True, # for shuffling ClusterDataset every epoch
}
if use_gpu:
trainer_args["gpus"] = 1

if tb_name and tb_save_dir:
trainer_args["logger"] = TensorBoardLogger(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,8 @@
" --ef_construction 150 \\\n",
" --ef_search -1 \\\n",
" --random_seed 42 \\\n",
" --model_save_dir trained-models/er/\n",
" --model_save_dir trained-models/er/ \\\n",
" --use_gpu 1\n",
"```"
]
},
Expand All @@ -448,7 +449,8 @@
" --ef_construction 150 \\\n",
" --ef_search -1 \\\n",
" --random_seed 42 \\\n",
" --output_json example-data/er-prediction.json\n",
" --output_json example-data/er-prediction.json \\\n",
" --use_gpu 1\n",
"```"
]
},
Expand Down Expand Up @@ -485,9 +487,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.6 64-bit ('venv': venv)",
"display_name": "Python 3",
"language": "python",
"name": "python38664bitvenvvenvc78c7844fbe149cea84f2ec734d1a56b"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@
" --ef_construction 150 \\\n",
" --ef_search -1 \\\n",
" --random_seed 42 \\\n",
" --model_save_dir trained-models/rl/\n",
" --model_save_dir trained-models/rl/ \\\n",
" --use_gpu 1\n",
"```"
]
},
Expand All @@ -513,7 +514,8 @@
" --ef_construction 150 \\\n",
" --ef_search -1 \\\n",
" --random_seed 42 \\\n",
" --output_json example-data/rl-prediction.json\n",
" --output_json example-data/rl-prediction.json \\\n",
" --use_gpu 1\n",
"```"
]
},
Expand Down
11 changes: 0 additions & 11 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +0,0 @@
import os

# libgomp issue, must import n2 before torch. See: https://github.com/kakao/n2/issues/42
import n2 # noqa: F401
import torch

dtypes_from_environ = os.environ.get("TEST_DTYPES", "float32,float64").split(",")
device_from_environ = os.environ.get("TEST_DEVICE", "cuda")

TEST_DTYPES = [getattr(torch, x) for x in dtypes_from_environ]
TEST_DEVICE = torch.device(device_from_environ)
12 changes: 9 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def test_cli_train(
42,
"--model_save_dir",
"trained-models",
"--use_gpu",
False,
],
)

Expand Down Expand Up @@ -293,6 +295,7 @@ def test_cli_train(
"random_seed": 42,
"model_save_dir": "trained-models",
"n_threads": 16, # assigned
"use_gpu": False,
}
expected_field_config_name_dict = {
"key": "name",
Expand Down Expand Up @@ -351,6 +354,7 @@ def test_cli_train(
model_save_verbose=True,
tb_save_dir=expected_args_dict["tb_save_dir"],
tb_name=expected_args_dict["tb_name"],
use_gpu=expected_args_dict["use_gpu"],
)
datamodule = mock_model.return_value.fit.call_args[0][0]

Expand Down Expand Up @@ -396,9 +400,9 @@ def test_cli_train(
@mock.patch("torch.manual_seed")
@mock.patch("numpy.random.seed")
@mock.patch("random.seed")
@mock.patch("torch.cuda.is_available", return_value=False)
@mock.patch("torch.device")
def test_cli_predict(
mock_cuda_is_available,
mock_torch_device,
mock_random_seed,
mock_np_random_seed,
mock_torch_random_seed,
Expand Down Expand Up @@ -461,6 +465,8 @@ def test_cli_predict(
42,
"--output_json",
expected_output_json,
"--use_gpu",
False,
],
)

Expand Down Expand Up @@ -494,7 +500,7 @@ def test_cli_predict(
mock_torch_random_seed.assert_called_once_with(expected_args_dict["random_seed"])

# cuda asserts
mock_cuda_is_available.assert_called_once()
mock_torch_device.assert_called_once_with("cpu")

# predict_pairs asserts
expected_record_dict = {record["id"]: record for record in UNLABELED_RECORD_DICT_VALUES}
Expand Down
1 change: 0 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ deps = flake8
commands = flake8 entity_embed tests

[testenv]
passenv = TEST_DEVICE
setenv =
PYTHONPATH = {toxinidir}
deps =
Expand Down

0 comments on commit 6cbde0f

Please sign in to comment.