This repository was archived by the owner on Jun 30, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdataset_client_test.py
152 lines (133 loc) · 5.56 KB
/
dataset_client_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
'''
DatasetClient tests.
'''
import os
import unittest
import pandas as pd
import numpy as np
from mljar.client.project import ProjectClient
from mljar.client.dataset import DatasetClient
from .project_based_test import ProjectBasedTest, get_postfix
class DatasetClientTest(ProjectBasedTest):
def setUp(self):
proj_title = 'Test project-01'+get_postfix()
proj_task = 'bin_class'
# setup project
self.project_client = ProjectClient()
self.project = self.project_client.create_project(title = proj_title, task = proj_task)
# load data
df = pd.read_csv('tests/data/test_1.csv')
cols = ['sepal length', 'sepal width', 'petal length', 'petal width']
target = 'class'
self.X = df.loc[:,cols]
self.y = df[target]
def tearDown(self):
# clean
self.project_client.delete_project(self.project.hid)
def test_get_datasests(self):
#Get empty list of datasets in project.
# get datasets
datasets = DatasetClient(self.project.hid).get_datasets()
self.assertEqual(datasets, [])
def test_prepare_data(self):
#Test _prepare_data method on numpy array data
dc = DatasetClient(self.project.hid)
samples = 100
columns = 10
X = np.random.rand(samples, columns)
y = np.random.choice([0,1], samples, replace = True)
data, data_hash = dc._prepare_data(X, y)
self.assertTrue(data is not None)
self.assertTrue(data_hash is not None)
self.assertTrue(isinstance(data_hash, str))
self.assertEqual(11, len(data.columns))
self.assertTrue('target' in data.columns)
self.assertTrue('attribute_1' in data.columns)
self.assertTrue('attribute_10' in data.columns)
def test_get_dataset_for_wrong_hid(self):
#Get dataset for wrong hid should return None
dc = DatasetClient(self.project.hid)
dataset = dc.get_dataset('some-wrong-hid')
self.assertTrue(dataset is None)
def test_add_dataset_for_training(self):
# setup dataset client
dc = DatasetClient(self.project.hid)
self.assertNotEqual(dc, None)
# get datasets, there should be none
datasets = dc.get_datasets()
self.assertEqual(len(datasets), 0)
# add dataset
my_dataset = dc.add_dataset_if_not_exists(self.X, self.y)
self.assertNotEqual(my_dataset, None)
# get datasets
datasets = dc.get_datasets()
self.assertEqual(len(datasets), 1)
my_dataset_2 = dc.get_dataset(my_dataset.hid)
self.assertEqual(my_dataset.hid, my_dataset_2.hid)
self.assertEqual(my_dataset.title, my_dataset_2.title)
# test __str__ method
self.assertTrue('id' in str(my_dataset_2))
self.assertTrue('title' in str(my_dataset_2))
self.assertTrue('file' in str(my_dataset_2))
def test_add_dataset_for_prediction(self):
# setup dataset client
dc = DatasetClient(self.project.hid)
self.assertNotEqual(dc, None)
# get datasets, there should be none
datasets = dc.get_datasets()
self.assertEqual(len(datasets), 0)
# add dataset
my_dataset = dc.add_dataset_if_not_exists(self.X, None)
self.assertNotEqual(my_dataset, None)
# get datasets
datasets = dc.get_datasets()
self.assertEqual(len(datasets), 1)
my_dataset_2 = dc.get_dataset(my_dataset.hid)
self.assertEqual(my_dataset.hid, my_dataset_2.hid)
self.assertEqual(my_dataset.title, my_dataset_2.title)
def test_add_existing_dataset(self):
# setup dataset client
dc = DatasetClient(self.project.hid)
self.assertNotEqual(dc, None)
# get initial number of datasets
init_datasets_cnt = len(dc.get_datasets())
# add dataset
dc.add_dataset_if_not_exists(self.X, self.y)
# get datasets
datasets = dc.get_datasets()
self.assertEqual(len(datasets), init_datasets_cnt+1)
# add the same dataset
# it shouldn't be added
dc.add_dataset_if_not_exists(self.X, self.y)
# number of all datasets in project should be 1
datasets = dc.get_datasets()
self.assertEqual(len(datasets), init_datasets_cnt+1)
def test_prepare_data_two_sources(self):
dc = DatasetClient(self.project.hid)
data_1, data_hash_1 = dc._prepare_data(self.X, self.y)
data_2, data_hash_2 = dc._prepare_data(self.X, None)
self.assertNotEqual(data_hash_1, data_hash_2)
def test_prepare_data_two_sources_numpy(self):
dc = DatasetClient(self.project.hid)
data_1, data_hash_1 = dc._prepare_data(np.array(self.X), np.array(self.y))
data_2, data_hash_2 = dc._prepare_data(np.array(self.X), None)
self.assertNotEqual(data_hash_1, data_hash_2)
def test_create_and_delete(self):
# setup dataset client
dc = DatasetClient(self.project.hid)
self.assertNotEqual(dc, None)
# get initial number of datasets
init_datasets_cnt = len(dc.get_datasets())
# add dataset
my_dataset_1 = dc.add_dataset_if_not_exists(self.X, self.y)
my_dataset_2 = dc.add_dataset_if_not_exists(self.X, y = None)
# get datasets
datasets = dc.get_datasets()
self.assertEqual(len(datasets), init_datasets_cnt+2)
# delete added dataset
dc.delete_dataset(my_dataset_1.hid)
# check number of datasets
datasets = dc.get_datasets()
self.assertEqual(len(datasets), init_datasets_cnt+1)
if __name__ == "__main__":
unittest.main()