Skip to content

Commit

Permalink
typos fixed, plt.show() added
Browse files Browse the repository at this point in the history
  • Loading branch information
zotroneneis committed Mar 26, 2018
1 parent 8a9edc9 commit df7a8a4
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions k_nearest_neighbour.ipynb
Expand Up @@ -6,7 +6,7 @@
"source": [
"## k-nearest-neighbor algorithm in plain Python\n",
"\n",
"The k-nn algorithm is a simple **supervised** machine learning algorithm that can be used both for classification and regregression. It's an **instance-based** algorithm. So instead of estimating a model, it stores all training examples in memory and makes predictions using a similarity measure. \n",
"The k-nn algorithm is a simple **supervised** machine learning algorithm that can be used both for classification and regression. It's an **instance-based** algorithm. So instead of estimating a model, it stores all training examples in memory and makes predictions using a similarity measure. \n",
"\n",
"Given an input example, the k-nn algorithm retrieves the k most similar instances from memory. Similarity is defined in terms of distance, that is, the training examples with the smallest (euclidean) distance to the input example are considered to be most similar.\n",
"\n",
Expand All @@ -25,11 +25,11 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2018-03-10T11:38:40.624018Z",
"start_time": "2018-03-10T11:38:39.809127Z"
"end_time": "2018-03-26T14:32:41.915819Z",
"start_time": "2018-03-26T14:32:41.094749Z"
}
},
"outputs": [],
Expand All @@ -52,11 +52,11 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2018-03-10T11:38:42.987256Z",
"start_time": "2018-03-10T11:38:41.882424Z"
"end_time": "2018-03-26T14:33:11.784085Z",
"start_time": "2018-03-26T14:33:10.849626Z"
}
},
"outputs": [
Expand Down Expand Up @@ -98,7 +98,9 @@
"fig = plt.figure(figsize=(10,8))\n",
"for i in range(10):\n",
" ax = fig.add_subplot(2, 5, i+1)\n",
" plt.imshow(X[i].reshape((8,8)), cmap='gray')"
" plt.imshow(X[i].reshape((8,8)), cmap='gray')\n",
" \n",
"plt.show()"
]
},
{
Expand All @@ -110,11 +112,11 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2018-03-10T11:39:08.798800Z",
"start_time": "2018-03-10T11:39:08.704275Z"
"end_time": "2018-03-26T14:33:23.932644Z",
"start_time": "2018-03-26T14:33:23.803735Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -181,11 +183,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2018-03-10T11:42:35.794443Z",
"start_time": "2018-03-10T11:42:35.750277Z"
"end_time": "2018-03-26T14:33:34.324040Z",
"start_time": "2018-03-26T14:33:34.282266Z"
}
},
"outputs": [
Expand All @@ -194,20 +196,20 @@
"output_type": "stream",
"text": [
"Testing one datapoint, k=1\n",
"Predicted label: 3\n",
"True label: 3\n",
"Predicted label: 8\n",
"True label: 8\n",
"\n",
"Testing one datapoint, k=5\n",
"Predicted label: 9\n",
"True label: 9\n",
"Predicted label: 3\n",
"True label: 3\n",
"\n",
"Testing 10 datapoint, k=1\n",
"Predicted labels: [[3 1 0 7 4 0 0 5 1 6]]\n",
"True labels: [3 1 0 7 4 0 0 5 1 6]\n",
"Predicted labels: [[5 4 5 5 6 6 1 0 8 8]]\n",
"True labels: [5 4 5 5 6 6 1 0 8 8]\n",
"\n",
"Testing 10 datapoint, k=4\n",
"Predicted labels: [3, 1, 0, 7, 4, 0, 0, 5, 1, 6]\n",
"True labels: [3 1 0 7 4 0 0 5 1 6]\n",
"Predicted labels: [5, 4, 5, 5, 6, 6, 1, 0, 8, 8]\n",
"True labels: [5 4 5 5 6 6 1 0 8 8]\n",
"\n"
]
}
Expand Down Expand Up @@ -243,20 +245,20 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2018-03-10T11:44:40.376235Z",
"start_time": "2018-03-10T11:44:40.093108Z"
"end_time": "2018-03-26T14:33:36.781872Z",
"start_time": "2018-03-26T14:33:36.495726Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy with k = 1: 97.77777777777777\n",
"Test accuracy with k = 8: 97.55555555555556\n"
"Test accuracy with k = 1: 99.11111111111111\n",
"Test accuracy with k = 5: 98.66666666666667\n"
]
}
],
Expand All @@ -266,10 +268,17 @@
"test_acc1= np.sum(y_p_test1[0] == y_test)/len(y_p_test1[0]) * 100\n",
"print(f\"Test accuracy with k = 1: {format(test_acc1)}\")\n",
"\n",
"y_p_test8 = knn.predict(X_test, k=5)\n",
"test_acc8= np.sum(y_p_test8 == y_test)/len(y_p_test8) * 100\n",
"print(f\"Test accuracy with k = 8: {format(test_acc8)}\")"
"y_p_test5 = knn.predict(X_test, k=5)\n",
"test_acc5= np.sum(y_p_test5 == y_test)/len(y_p_test5) * 100\n",
"print(f\"Test accuracy with k = 5: {format(test_acc5)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit df7a8a4

Please sign in to comment.