Skip to content

Commit 2319a26

Browse files
Updating with comparison from sklearn/lale pipeline in timing for cross validation, on MacBook Pro 8 core, scale out goes from 433s to 72s
1 parent 78ce87b commit 2319a26

File tree

1 file changed

+50
-39
lines changed

1 file changed

+50
-39
lines changed

notebooks/lale_cross_val_score.ipynb

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,7 @@
4242
"execution_count": 5,
4343
"metadata": {},
4444
"outputs": [],
45-
"source": [
46-
"# from lale.lib.sklearn import PCA, Nystroem, SelectKBest, RandomForestClassifier\n",
47-
"# from lale.lib.lale import ConcatFeatures\n",
48-
"\n",
49-
"# pipeline = (PCA() & Nystroem() & SelectKBest(k=3)) >> ConcatFeatures() >> RandomForestClassifier(n_estimators=200)\n",
50-
"# # pipeline.visualize()"
51-
]
45+
"source": []
5246
},
5347
{
5448
"cell_type": "code",
@@ -77,7 +71,7 @@
7771
"name": "stderr",
7872
"output_type": "stream",
7973
"text": [
80-
"2021-05-25 15:55:50,487\tINFO services.py:1269 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n"
74+
"2021-05-26 08:32:21,221\tINFO services.py:1269 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n"
8175
]
8276
},
8377
{
@@ -86,12 +80,12 @@
8680
"{'node_ip_address': '9.163.5.112',\n",
8781
" 'raylet_ip_address': '9.163.5.112',\n",
8882
" 'redis_address': '9.163.5.112:6379',\n",
89-
" 'object_store_address': '/tmp/ray/session_2021-05-25_15-55-48_881861_17264/sockets/plasma_store',\n",
90-
" 'raylet_socket_name': '/tmp/ray/session_2021-05-25_15-55-48_881861_17264/sockets/raylet',\n",
83+
" 'object_store_address': '/tmp/ray/session_2021-05-26_08-32-19_645025_30302/sockets/plasma_store',\n",
84+
" 'raylet_socket_name': '/tmp/ray/session_2021-05-26_08-32-19_645025_30302/sockets/raylet',\n",
9185
" 'webui_url': '127.0.0.1:8265',\n",
92-
" 'session_dir': '/tmp/ray/session_2021-05-25_15-55-48_881861_17264',\n",
93-
" 'metrics_export_port': 63066,\n",
94-
" 'node_id': 'd067798c6a3d62df2148c17631b1a63b58e53a141542f7697fe29297'}"
86+
" 'session_dir': '/tmp/ray/session_2021-05-26_08-32-19_645025_30302',\n",
87+
" 'metrics_export_port': 65535,\n",
88+
" 'node_id': '1eb8277b22f236079beacdbd19b26d419304b495f01a69a7c4dc3a26'}"
9589
]
9690
},
9791
"execution_count": 8,
@@ -219,8 +213,18 @@
219213
"cell_type": "code",
220214
"execution_count": 20,
221215
"metadata": {},
222-
"outputs": [],
216+
"outputs": [
217+
{
218+
"name": "stdout",
219+
"output_type": "stream",
220+
"text": [
221+
"CPU times: user 2.61 s, sys: 1.34 s, total: 3.96 s\n",
222+
"Wall time: 1min 12s\n"
223+
]
224+
}
225+
],
223226
"source": [
227+
"%%time\n",
224228
"scores = rt.cross_validate(kf, pipeline, pipeline_input)"
225229
]
226230
},
@@ -232,16 +236,16 @@
232236
{
233237
"data": {
234238
"text/plain": [
235-
"[0.8145188145188145,\n",
236-
" 0.8115218115218116,\n",
237-
" 0.8161838161838162,\n",
238-
" 0.8168498168498168,\n",
239-
" 0.8171828171828172,\n",
239+
"[0.8185148185148186,\n",
240240
" 0.8175158175158175,\n",
241-
" 0.8105228105228105,\n",
242-
" 0.8211788211788211,\n",
243-
" 0.8031312458361093,\n",
244-
" 0.8154563624250499]"
241+
" 0.8128538128538129,\n",
242+
" 0.8195138195138195,\n",
243+
" 0.8228438228438228,\n",
244+
" 0.8168498168498168,\n",
245+
" 0.8131868131868132,\n",
246+
" 0.8161838161838162,\n",
247+
" 0.7991339107261826,\n",
248+
" 0.8077948034643571]"
245249
]
246250
},
247251
"execution_count": 21,
@@ -262,46 +266,53 @@
262266
},
263267
{
264268
"cell_type": "code",
265-
"execution_count": null,
269+
"execution_count": 24,
266270
"metadata": {},
267271
"outputs": [],
268-
"source": []
272+
"source": [
273+
"from lale.lib.sklearn import PCA, Nystroem, SelectKBest, RandomForestClassifier\n",
274+
"from lale.lib.lale import ConcatFeatures\n",
275+
"\n",
276+
"pipeline = (PCA() & Nystroem() & SelectKBest(k=3)) >> ConcatFeatures() >> RandomForestClassifier(n_estimators=200)\n",
277+
"# pipeline.visualize()"
278+
]
269279
},
270280
{
271281
"cell_type": "code",
272-
"execution_count": 6,
282+
"execution_count": 25,
273283
"metadata": {},
274284
"outputs": [
275285
{
276286
"name": "stdout",
277287
"output_type": "stream",
278288
"text": [
279-
"CPU times: user 7min 45s, sys: 12.4 s, total: 7min 57s\n",
280-
"Wall time: 7min 30s\n"
289+
"CPU times: user 7min 39s, sys: 13.1 s, total: 7min 52s\n",
290+
"Wall time: 7min 13s\n"
281291
]
282292
},
283293
{
284294
"data": {
285295
"text/plain": [
286-
"[0.8168498168498168,\n",
287-
" 0.8215118215118216,\n",
288-
" 0.8138528138528138,\n",
296+
"[0.8185148185148186,\n",
297+
" 0.8101898101898102,\n",
298+
" 0.8148518148518149,\n",
299+
" 0.8201798201798202,\n",
300+
" 0.8145188145188145,\n",
289301
" 0.8185148185148186,\n",
290-
" 0.8178488178488178,\n",
291-
" 0.8215118215118216,\n",
292-
" 0.8121878121878122,\n",
293-
" 0.8191808191808192,\n",
294-
" 0.8021319120586275,\n",
295-
" 0.8074616922051966]"
302+
" 0.8168498168498168,\n",
303+
" 0.8125208125208125,\n",
304+
" 0.7978014656895404,\n",
305+
" 0.8114590273151232]"
296306
]
297307
},
298-
"execution_count": 6,
308+
"execution_count": 25,
299309
"metadata": {},
300310
"output_type": "execute_result"
301311
}
302312
],
303313
"source": [
304-
"%%time \n",
314+
"%%time\n",
315+
"from lale.helpers import cross_val_score\n",
305316
"cross_val_score(pipeline, X_train, y_train, cv=10)"
306317
]
307318
},

0 commit comments

Comments
 (0)