Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit acf7b5d

Browse files
authored
tests: high level: Added tests for load and save
Fixes: #573
1 parent 9da89f3 commit acf7b5d

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- Doctestable examples to `db` operations.
1515
- Source for parsing `.ini` file formats
1616
- Tests for noasync high level API.
17+
- Tests for load and save functions in high level API.
1718
### Changed
1819
- `Edit on Github` button now hidden for plugins.
1920
- Doctests now run via unittests

tests/test_high_level.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import importlib
55

66
from dffml.record import Record
7-
from dffml import train, accuracy, predict
7+
from dffml import train, accuracy, predict, save, load
88
from dffml.source.csv import CSVSource
99
from dffml.feature.feature import Features, DefFeature
1010
from dffml.util.asynctestcase import IntegrationCLITestCase
@@ -23,6 +23,8 @@ async def populate_source(self, source_cls, *records, **kwargs):
2323

2424
async def setUp(self):
2525
await super().setUp()
26+
save_and_load_file = self.mktempfile() + ".csv"
27+
setattr(self, "save_and_load", save_and_load_file)
2628
self.train_data = [
2729
[0, 1, 0.2, 10],
2830
[1, 3, 0.4, 20],
@@ -43,6 +45,61 @@ async def setUp(self):
4345
setattr(self, f"{use}_filename", filename)
4446
await self.populate_source(CSVSource, *records, filename=filename)
4547

48+
async def test_save_and_load(self):
49+
source = CSVSource(
50+
filename=self.save_and_load, allowempty=True, readwrite=True
51+
)
52+
await save(
53+
source,
54+
Record(
55+
"1",
56+
data={
57+
"features": {"A": 0, "B": 1},
58+
"prediction": {"C": {"value": 1, "confidence": 1.0}},
59+
},
60+
),
61+
Record(
62+
"2",
63+
data={
64+
"features": {"A": 3, "B": 4},
65+
"prediction": {"C": {"value": 2, "confidence": 1.0}},
66+
},
67+
),
68+
)
69+
# All records in source
70+
results = [record.export() async for record in load(source)]
71+
self.assertEqual(
72+
results,
73+
[
74+
{
75+
"key": "1",
76+
"features": {"A": 0, "B": 1},
77+
"prediction": {"C": {"confidence": 1.0, "value": "1"}},
78+
"extra": {},
79+
},
80+
{
81+
"key": "2",
82+
"features": {"A": 3, "B": 4},
83+
"prediction": {"C": {"confidence": 1.0, "value": "2"}},
84+
"extra": {},
85+
},
86+
],
87+
)
88+
89+
# For specific records in a source
90+
results = [record.export() async for record in load(source, "1")]
91+
self.assertEqual(
92+
results,
93+
[
94+
{
95+
"key": "1",
96+
"features": {"A": 0, "B": 1},
97+
"prediction": {"C": {"confidence": 1.0, "value": "1"}},
98+
"extra": {},
99+
}
100+
],
101+
)
102+
46103
async def test_predict(self):
47104
self.required_plugins("dffml-model-scikit")
48105
# Import SciKit modules

0 commit comments

Comments
 (0)