Skip to content

Commit f0b6dd7

Browse files
committed
plot_digits_pipe unwind gridsearchCV
1 parent dabbd33 commit f0b6dd7

File tree

1 file changed

+75
-7
lines changed

1 file changed

+75
-7
lines changed

notebooks/plot_digits_pipe.ipynb

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,27 +153,68 @@
153153
"from sklearn.pipeline import Pipeline\n",
154154
"from sklearn.model_selection import GridSearchCV\n",
155155
"\n",
156-
"\n",
157156
"# Define a pipeline to search for the best combination of PCA truncation\n",
158157
"# and classifier regularization.\n",
159-
"pca = PCA()\n",
158+
"pca_5 = PCA(n_components=5)\n",
159+
"pca_15 = PCA(n_components=15)\n",
160+
"pca_30 = PCA(n_components=30)\n",
161+
"pca_45 = PCA(n_components=45)\n",
162+
"pca_64 = PCA(n_components=64)\n",
163+
"\n",
160164
"# set the tolerance to a large value to make the example faster\n",
161-
"logistic = LogisticRegression(max_iter=10000, tol=0.1)\n",
165+
"logistic_1 = LogisticRegression(max_iter=10000, tol=0.1, C=1.00000000e-04)\n",
166+
"logistic_2 = LogisticRegression(max_iter=10000, tol=0.1, C=4.64158883e-02)\n",
167+
"logistic_3 = LogisticRegression(max_iter=10000, tol=0.1, C=2.15443469e+01)\n",
168+
"logistic_4 = LogisticRegression(max_iter=10000, tol=0.1, C=1.00000000e+04)\n",
162169
"\n",
163170
"## initialize codeflare pipeline by first creating the nodes\n",
164171
"pipeline = dm.Pipeline()\n",
165-
"node_pca = dm.EstimatorNode('pca', pca)\n",
166-
"node_logistic = dm.EstimatorNode('logistic', logistic)\n",
172+
"node_pca_5 = dm.EstimatorNode('pca_5', pca_5)\n",
173+
"node_pca_15 = dm.EstimatorNode('pca_15', pca_15)\n",
174+
"node_pca_30 = dm.EstimatorNode('pca_30', pca_30)\n",
175+
"node_pca_45 = dm.EstimatorNode('pca_45', pca_45)\n",
176+
"node_pca_64 = dm.EstimatorNode('pca_64', pca_64)\n",
177+
"\n",
178+
"node_logistic_1 = dm.EstimatorNode('logistic_1', logistic_1)\n",
179+
"node_logistic_2 = dm.EstimatorNode('logistic_2', logistic_2)\n",
180+
"node_logistic_3 = dm.EstimatorNode('logistic_3', logistic_3)\n",
181+
"node_logistic_4 = dm.EstimatorNode('logistic_4', logistic_4)\n",
167182
"\n",
168183
"## codeflare nodes are then connected by edges\n",
169-
"pipeline.add_edge(node_pca, node_logistic)\n",
184+
"pipeline.add_edge(node_pca_5, node_logistic_1)\n",
185+
"pipeline.add_edge(node_pca_15, node_logistic_1)\n",
186+
"pipeline.add_edge(node_pca_30, node_logistic_1)\n",
187+
"pipeline.add_edge(node_pca_45, node_logistic_1)\n",
188+
"pipeline.add_edge(node_pca_64, node_logistic_1)\n",
189+
"\n",
190+
"pipeline.add_edge(node_pca_5, node_logistic_2)\n",
191+
"pipeline.add_edge(node_pca_15, node_logistic_2)\n",
192+
"pipeline.add_edge(node_pca_30, node_logistic_2)\n",
193+
"pipeline.add_edge(node_pca_45, node_logistic_2)\n",
194+
"pipeline.add_edge(node_pca_64, node_logistic_2)\n",
195+
"\n",
196+
"pipeline.add_edge(node_pca_5, node_logistic_3)\n",
197+
"pipeline.add_edge(node_pca_15, node_logistic_3)\n",
198+
"pipeline.add_edge(node_pca_30, node_logistic_3)\n",
199+
"pipeline.add_edge(node_pca_45, node_logistic_3)\n",
200+
"pipeline.add_edge(node_pca_64, node_logistic_3)\n",
201+
"\n",
202+
"pipeline.add_edge(node_pca_5, node_logistic_4)\n",
203+
"pipeline.add_edge(node_pca_15, node_logistic_4)\n",
204+
"pipeline.add_edge(node_pca_30, node_logistic_4)\n",
205+
"pipeline.add_edge(node_pca_45, node_logistic_4)\n",
206+
"pipeline.add_edge(node_pca_64, node_logistic_4)\n",
170207
"\n",
171208
"X_digits, y_digits = datasets.load_digits(return_X_y=True)\n",
172209
"\n",
173210
"# execute FIT\n",
174211
"pipeline_input = dm.PipelineInput()\n",
175212
"xy = dm.Xy(X_digits, y_digits)\n",
176-
"pipeline_input.add_xy_arg(node_pca, xy)\n",
213+
"pipeline_input.add_xy_arg(node_pca_5, xy)\n",
214+
"pipeline_input.add_xy_arg(node_pca_15, xy)\n",
215+
"pipeline_input.add_xy_arg(node_pca_30, xy)\n",
216+
"pipeline_input.add_xy_arg(node_pca_45, xy)\n",
217+
"pipeline_input.add_xy_arg(node_pca_64, xy)\n",
177218
"\n",
178219
"# Parameters of pipelines can be set using ‘__’ separated parameter names:\n",
179220
"param_grid = {\n",
@@ -215,6 +256,33 @@
215256
"plt.tight_layout()\n",
216257
"plt.show()"
217258
]
259+
},
260+
{
261+
"cell_type": "code",
262+
"execution_count": 3,
263+
"metadata": {},
264+
"outputs": [
265+
{
266+
"data": {
267+
"text/plain": [
268+
"array([1.00000000e-04, 4.64158883e-02, 2.15443469e+01, 1.00000000e+04])"
269+
]
270+
},
271+
"execution_count": 3,
272+
"metadata": {},
273+
"output_type": "execute_result"
274+
}
275+
],
276+
"source": [
277+
"np.logspace(-4, 4, 4)"
278+
]
279+
},
280+
{
281+
"cell_type": "code",
282+
"execution_count": null,
283+
"metadata": {},
284+
"outputs": [],
285+
"source": []
218286
}
219287
],
220288
"metadata": {

0 commit comments

Comments
 (0)