Skip to content

Commit f3e3e54

Browse files
Kun-Lung WuKun-Lung Wu
authored andcommitted
example of a PREDICT after FIT
1 parent 2f2bed8 commit f3e3e54

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"%matplotlib inline"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"\n",
17+
"# Pipeline ANOVA SVM\n",
18+
"\n",
19+
"This example shows how a feature selection can be easily integrated within\n",
20+
"a machine learning pipeline.\n",
21+
"\n",
22+
"We also show that you can easily introspect part of the pipeline.\n"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 9,
28+
"metadata": {},
29+
"outputs": [
30+
{
31+
"name": "stdout",
32+
"output_type": "stream",
33+
"text": [
34+
"Automatically created module for IPython interactive environment\n",
35+
" precision recall f1-score support\n",
36+
"\n",
37+
" 0 0.92 0.80 0.86 15\n",
38+
" 1 0.75 0.90 0.82 10\n",
39+
"\n",
40+
" accuracy 0.84 25\n",
41+
" macro avg 0.84 0.85 0.84 25\n",
42+
"weighted avg 0.85 0.84 0.84 25\n",
43+
"\n"
44+
]
45+
},
46+
{
47+
"data": {
48+
"text/plain": [
49+
"array([[0. , 0. , 0.75791043, 0. , 0. ,\n",
50+
" 0. , 0. , 0. , 0. , 0.27158921,\n",
51+
" 0. , 0. , 0. , 0. , 0. ,\n",
52+
" 0. , 0. , 0. , 0. , 0.26109702]])"
53+
]
54+
},
55+
"execution_count": 9,
56+
"metadata": {},
57+
"output_type": "execute_result"
58+
}
59+
],
60+
"source": [
61+
"print(__doc__)\n",
62+
"\n",
63+
"from sklearn import set_config\n",
64+
"set_config(display='diagram')\n",
65+
"from sklearn.datasets import make_classification\n",
66+
"from sklearn.model_selection import train_test_split\n",
67+
"\n",
68+
"X, y = make_classification(\n",
69+
" n_features=20, n_informative=3, n_redundant=0, n_classes=2,\n",
70+
" n_clusters_per_class=2, random_state=42)\n",
71+
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
72+
"\n",
73+
"from sklearn.feature_selection import SelectKBest, f_classif\n",
74+
"from sklearn.pipeline import make_pipeline\n",
75+
"from sklearn.svm import LinearSVC\n",
76+
"\n",
77+
"anova_filter = SelectKBest(f_classif, k=3)\n",
78+
"clf = LinearSVC()\n",
79+
"anova_svm = make_pipeline(anova_filter, clf)\n",
80+
"anova_svm.fit(X_train, y_train)\n",
81+
"\n",
82+
"from sklearn.metrics import classification_report\n",
83+
"\n",
84+
"y_pred = anova_svm.predict(X_test)\n",
85+
"print(classification_report(y_test, y_pred))\n",
86+
"\n",
87+
"anova_svm[-1].coef_\n",
88+
"\n",
89+
"anova_svm[:-1].inverse_transform(anova_svm[-1].coef_)\n"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 11,
95+
"metadata": {},
96+
"outputs": [
97+
{
98+
"name": "stderr",
99+
"output_type": "stream",
100+
"text": [
101+
"2021-06-02 09:20:27,751\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
102+
]
103+
},
104+
{
105+
"ename": "RayTaskError(ValueError)",
106+
"evalue": "\u001b[36mray::execute_or_node_remote()\u001b[39m (pid=29747, ip=192.168.1.5)\n File \"python/ray/_raylet.pyx\", line 505, in ray._raylet.execute_task\n File \"/opt/anaconda3/lib/python3.8/site-packages/codeflare_pipelines-1.0.0-py3.8.egg/codeflare/pipelines/Runtime.py\", line 23, in execute_or_node_remote\n File \"/opt/anaconda3/lib/python3.8/site-packages/ray/_private/client_mode_hook.py\", line 47, in wrapper\n return func(*args, **kwargs)\nValueError: 'object_refs' must either be an object ref or a list of object refs.",
107+
"output_type": "error",
108+
"traceback": [
109+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
110+
"\u001b[0;31mRayTaskError(ValueError)\u001b[0m Traceback (most recent call last)",
111+
"\u001b[0;32m<ipython-input-11-5c12286d6a83>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mpredict_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_pipeline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mselected_pipeline\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mExecutionType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPREDICT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpipeline_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mpredict_clf_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_xyrefs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode_clf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredict_clf_output\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_yref\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
112+
"\u001b[0;32m/opt/anaconda3/lib/python3.8/site-packages/codeflare_pipelines-1.0.0-py3.8.egg/codeflare/pipelines/Datamodel.py\u001b[0m in \u001b[0;36mget_xyrefs\u001b[0;34m(self, node)\u001b[0m\n\u001b[1;32m 397\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPipelineNodeNotFoundException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Node \"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\" not found\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 398\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 399\u001b[0;31m \u001b[0mxyrefs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxyrefs_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 400\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mxyrefs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
113+
"\u001b[0;32m/opt/anaconda3/lib/python3.8/site-packages/ray/_private/client_mode_hook.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mclient_mode_should_convert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
114+
"\u001b[0;32m/opt/anaconda3/lib/python3.8/site-packages/ray/worker.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(object_refs, timeout)\u001b[0m\n\u001b[1;32m 1479\u001b[0m \u001b[0mworker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcore_worker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump_object_store_memory_usage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1480\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRayTaskError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1481\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_instanceof_cause\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1482\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1483\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
115+
"\u001b[0;31mRayTaskError(ValueError)\u001b[0m: \u001b[36mray::execute_or_node_remote()\u001b[39m (pid=29747, ip=192.168.1.5)\n File \"python/ray/_raylet.pyx\", line 505, in ray._raylet.execute_task\n File \"/opt/anaconda3/lib/python3.8/site-packages/codeflare_pipelines-1.0.0-py3.8.egg/codeflare/pipelines/Runtime.py\", line 23, in execute_or_node_remote\n File \"/opt/anaconda3/lib/python3.8/site-packages/ray/_private/client_mode_hook.py\", line 47, in wrapper\n return func(*args, **kwargs)\nValueError: 'object_refs' must either be an object ref or a list of object refs."
116+
]
117+
}
118+
],
119+
"source": [
120+
"import ray\n",
121+
"import codeflare.pipelines.Datamodel as dm\n",
122+
"import codeflare.pipelines.Runtime as rt\n",
123+
"from codeflare.pipelines.Datamodel import Xy\n",
124+
"from codeflare.pipelines.Datamodel import XYRef\n",
125+
"from codeflare.pipelines.Runtime import ExecutionType\n",
126+
"\n",
127+
"ray.shutdown()\n",
128+
"ray.init()\n",
129+
"\n",
130+
"from sklearn import set_config\n",
131+
"set_config(display='diagram')\n",
132+
"from sklearn.datasets import make_classification\n",
133+
"from sklearn.model_selection import train_test_split\n",
134+
"\n",
135+
"X, y = make_classification(\n",
136+
" n_features=20, n_informative=3, n_redundant=0, n_classes=2,\n",
137+
" n_clusters_per_class=2, random_state=42)\n",
138+
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
139+
"\n",
140+
"from sklearn.feature_selection import SelectKBest, f_classif\n",
141+
"from sklearn.pipeline import make_pipeline\n",
142+
"from sklearn.svm import LinearSVC\n",
143+
"\n",
144+
"anova_filter = SelectKBest(f_classif, k=3)\n",
145+
"clf = LinearSVC()\n",
146+
"\n",
147+
"pipeline = dm.Pipeline()\n",
148+
"node_anova_filter = dm.EstimatorNode('anova_filter', anova_filter)\n",
149+
"node_clf = dm.EstimatorNode('clf', clf)\n",
150+
"pipeline.add_edge(node_anova_filter, node_clf)\n",
151+
"\n",
152+
"pipeline_input = dm.PipelineInput()\n",
153+
"xy = dm.Xy(X_train, y_train)\n",
154+
"\n",
155+
"pipeline_input.add_xy_arg(node_anova_filter, xy)\n",
156+
"\n",
157+
"pipeline_output = rt.execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input)\n",
158+
"\n",
159+
"node_clf_output = pipeline_output.get_xyrefs(node_clf)\n",
160+
"\n",
161+
"Xout = ray.get(node_clf_output[0].get_Xref())\n",
162+
"yout = ray.get(node_clf_output[0].get_yref())\n",
163+
"\n",
164+
"selected_pipeline = rt.select_pipeline(pipeline_output, node_clf_output[0])\n",
165+
"\n",
166+
"pipeline_input = dm.PipelineInput()\n",
167+
"pipeline_input.add_xy_arg(node_anova_filter, dm.Xy(X_test, y_test))\n",
168+
"\n",
169+
"predict_output = rt.execute_pipeline(selected_pipeline, ExecutionType.PREDICT, pipeline_input)\n",
170+
"\n",
171+
"predict_clf_output = predict_output.get_xyrefs(node_clf)\n",
172+
"y_pred = ray.get(predict_clf_output[0].get_yref())\n",
173+
"\n",
174+
"from sklearn.metrics import classification_report\n",
175+
"\n",
176+
"#y_pred = anova_svm.predict(X_test)\n",
177+
"print(classification_report(y_test, y_pred))\n",
178+
"\n",
179+
"#anova_svm[-1].coef_\n",
180+
"\n",
181+
"#anova_svm[:-1].inverse_transform(anova_svm[-1].coef_)\n"
182+
]
183+
},
184+
{
185+
"cell_type": "code",
186+
"execution_count": null,
187+
"metadata": {},
188+
"outputs": [],
189+
"source": []
190+
}
191+
],
192+
"metadata": {
193+
"kernelspec": {
194+
"display_name": "Python 3",
195+
"language": "python",
196+
"name": "python3"
197+
},
198+
"language_info": {
199+
"codemirror_mode": {
200+
"name": "ipython",
201+
"version": 3
202+
},
203+
"file_extension": ".py",
204+
"mimetype": "text/x-python",
205+
"name": "python",
206+
"nbconvert_exporter": "python",
207+
"pygments_lexer": "ipython3",
208+
"version": "3.8.8"
209+
}
210+
},
211+
"nbformat": 4,
212+
"nbformat_minor": 1
213+
}

0 commit comments

Comments
 (0)