Skip to content

Commit 2f2bed8

Browse files
committed
updated with codeflare pipeline invocation
1 parent 9e5fe26 commit 2f2bed8

File tree

1 file changed

+98
-26
lines changed

1 file changed

+98
-26
lines changed

notebooks/plot_semi_supervised_newsgroups.ipynb

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,51 @@
169169
},
170170
{
171171
"cell_type": "code",
172-
"execution_count": null,
172+
"execution_count": 5,
173173
"metadata": {},
174-
"outputs": [],
174+
"outputs": [
175+
{
176+
"name": "stdout",
177+
"output_type": "stream",
178+
"text": [
179+
"11314 documents\n",
180+
"20 categories\n",
181+
"\n"
182+
]
183+
},
184+
{
185+
"name": "stderr",
186+
"output_type": "stream",
187+
"text": [
188+
"2021-06-01 10:55:34,397\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
189+
]
190+
},
191+
{
192+
"name": "stdout",
193+
"output_type": "stream",
194+
"text": [
195+
"Supervised SGDClassifier on 100% of the data:\n",
196+
"Number of training samples: 8485\n",
197+
"Unlabeled samples in training set: 0\n"
198+
]
199+
},
200+
{
201+
"ename": "RayTaskError(ValueError)",
202+
"evalue": "\u001b[36mray::execute_or_node_remote()\u001b[39m (pid=5002, ip=192.168.1.230)\n File \"python/ray/_raylet.pyx\", line 505, in ray._raylet.execute_task\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/codeflare_pipelines-1.0.0-py3.8.egg/codeflare/pipelines/Runtime.py\", line 43, in execute_or_node_remote\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/base.py\", line 702, in fit_transform\n return self.fit(X, y, **fit_params).transform(X)\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/feature_extraction/text.py\", line 1477, in transform\n X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES, copy=copy)\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\", line 63, in inner_f\n return f(*args, **kwargs)\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\", line 593, in check_array\n array = _ensure_sparse_format(array, accept_sparse=accept_sparse,\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\", line 381, in _ensure_sparse_format\n spmatrix = spmatrix.astype(dtype)\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/scipy/sparse/data.py\", line 72, in astype\n self._deduped_data().astype(dtype, casting=casting, copy=copy),\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/scipy/sparse/data.py\", line 32, in _deduped_data\n self.sum_duplicates()\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/scipy/sparse/compressed.py\", line 1098, in sum_duplicates\n self.sort_indices()\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/scipy/sparse/compressed.py\", line 1144, in sort_indices\n _sparsetools.csr_sort_indices(len(self.indptr) - 1, self.indptr,\nValueError: WRITEBACKIFCOPY base is read-only",
203+
"output_type": "error",
204+
"traceback": [
205+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
206+
"\u001b[0;31mRayTaskError(ValueError)\u001b[0m Traceback (most recent call last)",
207+
"\u001b[0;32m<ipython-input-5-3afb1332e047>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Supervised SGDClassifier on 100% of the data:\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0meval_and_print_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode_clf1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;31m# select a mask of 20% of the train dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
208+
"\u001b[0;32m<ipython-input-5-3afb1332e047>\u001b[0m in \u001b[0;36meval_and_print_metrics\u001b[0;34m(clf, X_train, y_train, X_test, y_test)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;31m# execute FIT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0mpipeline_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_pipeline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpipeline\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mExecutionType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFIT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpipeline_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;31m# select a pipeline referenced by clf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
209+
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/codeflare_pipelines-1.0.0-py3.8.egg/codeflare/pipelines/Runtime.py\u001b[0m in \u001b[0;36mexecute_pipeline\u001b[0;34m(pipeline, mode, pipeline_input)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0mpost_edges\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpipeline\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_post_edges\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_node_input_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mdm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNodeInputType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOR\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 151\u001b[0;31m \u001b[0mexecute_or_node\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpre_edges\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpost_edges\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 152\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_node_input_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mdm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNodeInputType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAND\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0mexecute_and_node\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpre_edges\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpost_edges\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
210+
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/codeflare_pipelines-1.0.0-py3.8.egg/codeflare/pipelines/Runtime.py\u001b[0m in \u001b[0;36mexecute_or_node\u001b[0;34m(node, pre_edges, edge_args, post_edges, mode)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0mexec_xyrefs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mxy_ref_ptr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mXyref_ptrs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 75\u001b[0;31m \u001b[0mxy_ref\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxy_ref_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 76\u001b[0m \u001b[0minner_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexecute_or_node_remote\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mremote\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxy_ref\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0mexec_xyrefs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minner_result\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
211+
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/ray/_private/client_mode_hook.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mclient_mode_should_convert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
212+
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/ray/worker.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(object_refs, timeout)\u001b[0m\n\u001b[1;32m 1479\u001b[0m \u001b[0mworker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcore_worker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump_object_store_memory_usage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1480\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRayTaskError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1481\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_instanceof_cause\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1482\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1483\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
213+
"\u001b[0;31mRayTaskError(ValueError)\u001b[0m: \u001b[36mray::execute_or_node_remote()\u001b[39m (pid=5002, ip=192.168.1.230)\n File \"python/ray/_raylet.pyx\", line 505, in ray._raylet.execute_task\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/codeflare_pipelines-1.0.0-py3.8.egg/codeflare/pipelines/Runtime.py\", line 43, in execute_or_node_remote\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/base.py\", line 702, in fit_transform\n return self.fit(X, y, **fit_params).transform(X)\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/feature_extraction/text.py\", line 1477, in transform\n X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES, copy=copy)\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\", line 63, in inner_f\n return f(*args, **kwargs)\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\", line 593, in check_array\n array = _ensure_sparse_format(array, accept_sparse=accept_sparse,\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\", line 381, in _ensure_sparse_format\n spmatrix = spmatrix.astype(dtype)\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/scipy/sparse/data.py\", line 72, in astype\n self._deduped_data().astype(dtype, casting=casting, copy=copy),\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/scipy/sparse/data.py\", line 32, in _deduped_data\n self.sum_duplicates()\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/scipy/sparse/compressed.py\", line 1098, in sum_duplicates\n self.sort_indices()\n File \"/Users/yuanchi/anaconda3/lib/python3.8/site-packages/scipy/sparse/compressed.py\", line 1144, in sort_indices\n _sparsetools.csr_sort_indices(len(self.indptr) - 1, self.indptr,\nValueError: WRITEBACKIFCOPY base is read-only"
214+
]
215+
}
216+
],
175217
"source": [
176218
"import ray\n",
177219
"import codeflare.pipelines.Datamodel as dm\n",
@@ -205,33 +247,54 @@
205247
"vectorizer_params = dict(ngram_range=(1, 2), min_df=5, max_df=0.8)\n",
206248
"\n",
207249
"# Supervised Pipeline\n",
208-
"pipeline = Pipeline([\n",
209-
" ('vect', CountVectorizer(**vectorizer_params)),\n",
210-
" ('tfidf', TfidfTransformer()),\n",
211-
" ('clf', SGDClassifier(**sdg_params)),\n",
212-
"])\n",
250+
"pipeline = dm.Pipeline()\n",
251+
"\n",
252+
"vect = CountVectorizer(**vectorizer_params)\n",
253+
"tfidf = TfidfTransformer()\n",
254+
"clf1 = SGDClassifier(**sdg_params)\n",
255+
"clf2 = SelfTrainingClassifier(SGDClassifier(**sdg_params), verbose=True)\n",
256+
"todense = FunctionTransformer(lambda x: x.todense())\n",
257+
"clf3 = LabelSpreading() \n",
258+
"\n",
259+
"node_vect = dm.EstimatorNode('vect', CountVectorizer(**vectorizer_params))\n",
260+
"node_tfidf = dm.EstimatorNode('tfidf', TfidfTransformer())\n",
261+
"node_clf1 = dm.EstimatorNode('clf1', SGDClassifier(**sdg_params))\n",
262+
"node_clf2 = dm.EstimatorNode('clf2', SelfTrainingClassifier(SGDClassifier(**sdg_params), verbose=True))\n",
263+
"node_todense = dm.EstimatorNode('todense', FunctionTransformer(lambda x: x.todense()))\n",
264+
"node_clf3 = dm.EstimatorNode('clf3', LabelSpreading())\n",
265+
"\n",
266+
"pipeline.add_edge(node_vect, node_tfidf)\n",
267+
"# Supervised Pipeline\n",
268+
"pipeline.add_edge(node_tfidf, node_clf1)\n",
213269
"# SelfTraining Pipeline\n",
214-
"st_pipeline = Pipeline([\n",
215-
" ('vect', CountVectorizer(**vectorizer_params)),\n",
216-
" ('tfidf', TfidfTransformer()),\n",
217-
" ('clf', SelfTrainingClassifier(SGDClassifier(**sdg_params), verbose=True)),\n",
218-
"])\n",
270+
"pipeline.add_edge(node_tfidf, node_clf2)\n",
219271
"# LabelSpreading Pipeline\n",
220-
"ls_pipeline = Pipeline([\n",
221-
" ('vect', CountVectorizer(**vectorizer_params)),\n",
222-
" ('tfidf', TfidfTransformer()),\n",
223-
" # LabelSpreading does not support dense matrices\n",
224-
" ('todense', FunctionTransformer(lambda x: x.todense())),\n",
225-
" ('clf', LabelSpreading()),\n",
226-
"])\n",
272+
"pipeline.add_edge(node_tfidf, node_todense)\n",
273+
"pipeline.add_edge(node_todense, node_clf3)\n",
227274
"\n",
228275
"\n",
229276
"def eval_and_print_metrics(clf, X_train, y_train, X_test, y_test):\n",
230277
" print(\"Number of training samples:\", len(X_train))\n",
231278
" print(\"Unlabeled samples in training set:\",\n",
232279
" sum(1 for x in y_train if x == -1))\n",
233-
" clf.fit(X_train, y_train)\n",
234-
" y_pred = clf.predict(X_test)\n",
280+
" \n",
281+
" pipeline_input = dm.PipelineInput()\n",
282+
" pipeline_input.add_xy_arg(node_vect, dm.Xy(X_train, y_train))\n",
283+
" \n",
284+
" # execute FIT\n",
285+
" pipeline_output = rt.execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input)\n",
286+
" \n",
287+
" # select a pipeline referenced by clf\n",
288+
" selected_pipeline = rt.select_pipeline(pipeline_output, clf[0])\n",
289+
" \n",
290+
" # execute PREDICT\n",
291+
" pipeline_input = dm.PipelineInput()\n",
292+
" pipeline_input.add_xy_arg(node_vect, dm.Xy(X_test, y_test))\n",
293+
" predict_output = rt.execute_pipeline(selected_pipeline, ExecutionType.PREDICT, pipeline_input)\n",
294+
" \n",
295+
" predict_clf_output = predict_output.get_xyrefs(clf)\n",
296+
" y_pred = ray.get(predict_clf_output[0].get_yref())\n",
297+
" \n",
235298
" print(\"Micro-averaged F1 score on test set: \"\n",
236299
" \"%0.3f\" % f1_score(y_test, y_pred, average='micro'))\n",
237300
" print(\"-\" * 10)\n",
@@ -245,9 +308,9 @@
245308
" \n",
246309
" X, y = data.data, data.target\n",
247310
" X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
248-
"\n",
311+
" \n",
249312
" print(\"Supervised SGDClassifier on 100% of the data:\")\n",
250-
" eval_and_print_metrics(pipeline, X_train, y_train, X_test, y_test)\n",
313+
" eval_and_print_metrics(node_clf1, X_train, y_train, X_test, y_test)\n",
251314
"\n",
252315
" # select a mask of 20% of the train dataset\n",
253316
" y_mask = np.random.rand(len(y_train)) < 0.2\n",
@@ -256,21 +319,30 @@
256319
" X_20, y_20 = map(list, zip(*((x, y)\n",
257320
" for x, y, m in zip(X_train, y_train, y_mask) if m)))\n",
258321
" print(\"Supervised SGDClassifier on 20% of the training data:\")\n",
259-
" eval_and_print_metrics(pipeline, X_20, y_20, X_test, y_test)\n",
322+
" eval_and_print_metrics(node_clf1, X_20, y_20, X_test, y_test)\n",
260323
"\n",
261324
" # set the non-masked subset to be unlabeled\n",
262325
" y_train[~y_mask] = -1\n",
263326
" print(\"SelfTrainingClassifier on 20% of the training data (rest \"\n",
264327
" \"is unlabeled):\")\n",
265-
" eval_and_print_metrics(st_pipeline, X_train, y_train, X_test, y_test)\n",
328+
" eval_and_print_metrics(node_clf2, X_train, y_train, X_test, y_test)\n",
266329
"\n",
267330
" if 'CI' not in os.environ:\n",
268331
" # LabelSpreading takes too long to run in the online documentation\n",
269332
" print(\"LabelSpreading on 20% of the data (rest is unlabeled):\")\n",
270-
" eval_and_print_metrics(ls_pipeline, X_train, y_train, X_test, y_test)\n",
333+
" eval_and_print_metrics(node_clf3, X_train, y_train, X_test, y_test)\n",
271334
" \n",
272335
" ray.shutdown()"
273336
]
337+
},
338+
{
339+
"cell_type": "code",
340+
"execution_count": null,
341+
"metadata": {},
342+
"outputs": [],
343+
"source": [
344+
"\n"
345+
]
274346
}
275347
],
276348
"metadata": {

0 commit comments

Comments
 (0)