|
153 | 153 | "from sklearn.pipeline import Pipeline\n", |
154 | 154 | "from sklearn.model_selection import GridSearchCV\n", |
155 | 155 | "\n", |
156 | | - "\n", |
157 | 156 | "# Define a pipeline to search for the best combination of PCA truncation\n", |
158 | 157 | "# and classifier regularization.\n", |
159 | | - "pca = PCA()\n", |
| 158 | + "pca_5 = PCA(n_components=5)\n", |
| 159 | + "pca_15 = PCA(n_components=15)\n", |
| 160 | + "pca_30 = PCA(n_components=30)\n", |
| 161 | + "pca_45 = PCA(n_components=45)\n", |
| 162 | + "pca_64 = PCA(n_components=64)\n", |
| 163 | + "\n", |
160 | 164 | "# set the tolerance to a large value to make the example faster\n", |
161 | | - "logistic = LogisticRegression(max_iter=10000, tol=0.1)\n", |
| 165 | + "logistic_1 = LogisticRegression(max_iter=10000, tol=0.1, C=1.00000000e-04)\n", |
| 166 | + "logistic_2 = LogisticRegression(max_iter=10000, tol=0.1, C=4.64158883e-02)\n", |
| 167 | + "logistic_3 = LogisticRegression(max_iter=10000, tol=0.1, C=2.15443469e+01)\n", |
| 168 | + "logistic_4 = LogisticRegression(max_iter=10000, tol=0.1, C=1.00000000e+04)\n", |
162 | 169 | "\n", |
163 | 170 | "## initialize codeflare pipeline by first creating the nodes\n", |
164 | 171 | "pipeline = dm.Pipeline()\n", |
165 | | - "node_pca = dm.EstimatorNode('pca', pca)\n", |
166 | | - "node_logistic = dm.EstimatorNode('logistic', logistic)\n", |
| 172 | + "node_pca_5 = dm.EstimatorNode('pca_5', pca_5)\n", |
| 173 | + "node_pca_15 = dm.EstimatorNode('pca_15', pca_15)\n", |
| 174 | + "node_pca_30 = dm.EstimatorNode('pca_30', pca_30)\n", |
| 175 | + "node_pca_45 = dm.EstimatorNode('pca_45', pca_45)\n", |
| 176 | + "node_pca_64 = dm.EstimatorNode('pca_64', pca_64)\n", |
| 177 | + "\n", |
| 178 | + "node_logistic_1 = dm.EstimatorNode('logistic_1', logistic_1)\n", |
| 179 | + "node_logistic_2 = dm.EstimatorNode('logistic_2', logistic_2)\n", |
| 180 | + "node_logistic_3 = dm.EstimatorNode('logistic_3', logistic_3)\n", |
| 181 | + "node_logistic_4 = dm.EstimatorNode('logistic_4', logistic_4)\n", |
167 | 182 | "\n", |
168 | 183 | "## codeflare nodes are then connected by edges\n", |
169 | | - "pipeline.add_edge(node_pca, node_logistic)\n", |
| 184 | + "pipeline.add_edge(node_pca_5, node_logistic_1)\n", |
| 185 | + "pipeline.add_edge(node_pca_15, node_logistic_1)\n", |
| 186 | + "pipeline.add_edge(node_pca_30, node_logistic_1)\n", |
| 187 | + "pipeline.add_edge(node_pca_45, node_logistic_1)\n", |
| 188 | + "pipeline.add_edge(node_pca_64, node_logistic_1)\n", |
| 189 | + "\n", |
| 190 | + "pipeline.add_edge(node_pca_5, node_logistic_2)\n", |
| 191 | + "pipeline.add_edge(node_pca_15, node_logistic_2)\n", |
| 192 | + "pipeline.add_edge(node_pca_30, node_logistic_2)\n", |
| 193 | + "pipeline.add_edge(node_pca_45, node_logistic_2)\n", |
| 194 | + "pipeline.add_edge(node_pca_64, node_logistic_2)\n", |
| 195 | + "\n", |
| 196 | + "pipeline.add_edge(node_pca_5, node_logistic_3)\n", |
| 197 | + "pipeline.add_edge(node_pca_15, node_logistic_3)\n", |
| 198 | + "pipeline.add_edge(node_pca_30, node_logistic_3)\n", |
| 199 | + "pipeline.add_edge(node_pca_45, node_logistic_3)\n", |
| 200 | + "pipeline.add_edge(node_pca_64, node_logistic_3)\n", |
| 201 | + "\n", |
| 202 | + "pipeline.add_edge(node_pca_5, node_logistic_4)\n", |
| 203 | + "pipeline.add_edge(node_pca_15, node_logistic_4)\n", |
| 204 | + "pipeline.add_edge(node_pca_30, node_logistic_4)\n", |
| 205 | + "pipeline.add_edge(node_pca_45, node_logistic_4)\n", |
| 206 | + "pipeline.add_edge(node_pca_64, node_logistic_4)\n", |
170 | 207 | "\n", |
171 | 208 | "X_digits, y_digits = datasets.load_digits(return_X_y=True)\n", |
172 | 209 | "\n", |
173 | 210 | "# execute FIT\n", |
174 | 211 | "pipeline_input = dm.PipelineInput()\n", |
175 | 212 | "xy = dm.Xy(X_digits, y_digits)\n", |
176 | | - "pipeline_input.add_xy_arg(node_pca, xy)\n", |
| 213 | + "pipeline_input.add_xy_arg(node_pca_5, xy)\n", |
| 214 | + "pipeline_input.add_xy_arg(node_pca_15, xy)\n", |
| 215 | + "pipeline_input.add_xy_arg(node_pca_30, xy)\n", |
| 216 | + "pipeline_input.add_xy_arg(node_pca_45, xy)\n", |
| 217 | + "pipeline_input.add_xy_arg(node_pca_64, xy)\n", |
177 | 218 | "\n", |
178 | 219 | "# Parameters of pipelines can be set using ‘__’ separated parameter names:\n", |
179 | 220 | "param_grid = {\n", |
|
215 | 256 | "plt.tight_layout()\n", |
216 | 257 | "plt.show()" |
217 | 258 | ] |
| 259 | + }, |
| 260 | + { |
| 261 | + "cell_type": "code", |
| 262 | + "execution_count": 3, |
| 263 | + "metadata": {}, |
| 264 | + "outputs": [ |
| 265 | + { |
| 266 | + "data": { |
| 267 | + "text/plain": [ |
| 268 | + "array([1.00000000e-04, 4.64158883e-02, 2.15443469e+01, 1.00000000e+04])" |
| 269 | + ] |
| 270 | + }, |
| 271 | + "execution_count": 3, |
| 272 | + "metadata": {}, |
| 273 | + "output_type": "execute_result" |
| 274 | + } |
| 275 | + ], |
| 276 | + "source": [ |
| 277 | + "np.logspace(-4, 4, 4)" |
| 278 | + ] |
| 279 | + }, |
| 280 | + { |
| 281 | + "cell_type": "code", |
| 282 | + "execution_count": null, |
| 283 | + "metadata": {}, |
| 284 | + "outputs": [], |
| 285 | + "source": [] |
218 | 286 | } |
219 | 287 | ], |
220 | 288 | "metadata": { |
|
0 commit comments