Skip to content

Commit

Permalink
Merge pull request #17 from sankhaMukherjee/dev
Browse files Browse the repository at this point in the history
created the tutorial for qNetwork
  • Loading branch information
sankhaMukherjee committed May 30, 2019
2 parents d387efc + 85fdd74 commit a6fd561
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 4 deletions.
174 changes: 174 additions & 0 deletions notebooks/.ipynb_checkpoints/NN-qNetwork-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `qNetworkDiscrete`\n",
"\n",
"This network is a simple sequential network that takes a 1D vector and is able to learn a multi-valued function. This is useful and can act as a discrete Q-Network because, one can think of it as something that takes a 1D state, and returns a Q-value, one for each discrete action. So, lets see this in action:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/sankha/Documents/programs/ML/RLalgos/src\n"
]
}
],
"source": [
"cd ../src"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from lib.agents import qNetwork as qN\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"from tqdm import tqdm_notebook as tqdm\n",
"\n",
"if torch.cuda.is_available():\n",
" device = 'cuda:0'\n",
"else:\n",
" device = 'cpu'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let us create some dummy data, and see whether our network is able to detect it ..."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1000, 2) (1000, 2)\n"
]
}
],
"source": [
"X = np.random.rand(1000, 2) - 0.5\n",
"Y = np.array([\n",
" X[:,0]*2 + X[:,1]*3,\n",
" X[:,0]*5 + X[:,1]*6\n",
"]).T\n",
"\n",
"print(X.shape, Y.shape)\n",
"Xt = torch.as_tensor(X.astype(np.float32)).to(device)\n",
"Yt = torch.as_tensor(X.astype(np.float32)).to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us create a Q-network and see wheter we are able to represent this function."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"network = qN.qNetworkDiscrete(2, 2, layers=[10, 5], activations=[F.tanh, F.tanh], batchNormalization = False, lr=0.01).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "55788024a40a498799d6bee7b7d0a09f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEKCAYAAAAFJbKyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xt0nHd95/H3V9JIM7pfratl+YYdOzGJY3KFEi6BQHPhQLbgbhcoKTlpS4Htbrew7Tndnp427SlbCksWMGBSOCUpUKC5QWAhF8jVCSSOE8d32ZIsW5as+9WSfvvH80gey7LlsWb0zDzzeZ0zRzPPPJr5Pn6cfPx7fr/n9zPnHCIiIhcqJ+gCREQksyg4REQkIQoOERFJiIJDREQSouAQEZGEKDhERCQhCg4REUmIgkNERBKi4BARkYTkBV1AKlRXV7uWlpagyxARySgvvvhit3OuZqH9QhkcLS0tvPDCC0GXISKSUczs8IXsp0tVIiKSEAWHiIgkJFTBYWa3mNm2/v7+oEsREQmtUAWHc+5B59ydZWVlQZciIhJaoQoOERFJPQWHiIgkRMEhIiIJUXDE+cmuY3ztyYNBlyEiktYUHHEe39PFtl8qOEREzkfBEaeyKJ/e4Qmmp13QpYiIpC0FR5yq4gImpx0DY6eCLkVEJG2FKjgWewNgdXE+AN1DE8ksS0QkVEIVHIu9AbCqqACAk8MKDhGRcwlVcCxWZZHX4ugZGg+4EhGR9KXgiDN7qUotDhGRc1JwxKnwWxwn1cchInJOCo44kdwcymIReoZ1qUpE5FwUHHNUFefToxaHiMg5KTjmqC4qUItDROQ8FBxzqMUhInJ+Co45Kovy6dGoKhGRc1JwzFFVXEDvyARTmq9KRGReCo45qovzcQ56R9TqEBGZT9oHh5mtMrNvmNn3l+L7Tt89ruAQEZlPSoPDzLabWZeZ7Zqz/SYz22Nm+83sM+f7DOfcQefcHamsM97MfFWadkREZH55Kf78e4EvAd+a2WBmucA9wI1AO7DDzB4AcoG75/z+x5xzXSmu8Qwz046og1xEZH4pDQ7n3JNm1jJn81XAfufcQQAzux+4zTl3N3DzxX6Xmd0J3AnQ3Nx8sR9DdbHX4jgxqBaHiMh8gujjaATa4l63+9vmZWZVZvYV4Aoz++y59nPObXPObXHObampqbno4soLI+Tn5XB8YOyiP0NEJMxSfalq0ZxzPcBdS/V9ZkZdaZTOfgWHiMh8gmhxdADL4143+dsWbbErAM6oK4tyTC0OEZF5BREcO4C1ZrbSzPKBDwEPJOODF7sC4Iy60qguVYmInEOqh+PeBzwDrDOzdjO7wzk3CXwCeBTYDXzXOfdqkr4vaS2Ozv4xnNPd4yIic6V6VNXWc2x/BHgkBd/3IPDgli1bPr6Yz2muLGRicpqj/WM0lseSVJ2ISDik/Z3jQXhDbQkAe48PBlyJiEj6CVVwJOtS1RtqiwHYp+AQETlLqIIjWZ3j5YX51JYW8L0X2nnuYE+SqhMRCYdQBUcyvX9zE/u6hvjgtmd5ua0v6HJERNKGguMc/uxd6/j3P7wWM3ho59GgyxERSRuhCo5k9XEA5OQYV66oZMuKCp5v7U1CdSIi4RCq4EhWH0e8N7VU8mpHP6MTU0n7TBGRTBaq4EiFTU3lTE479miElYgIoOBY0MaGUgBeOzoQcCUiIukhVMGRzD6OGU0VMUoK8nitM3mfKSKSyUIVHKno4zAzLmkoZXenLlWJiEDIgiNVNtSXsrtzgOlpTXooIqLguAAb6ksZmZji8MmRoEsREQmcguMCbFAHuYjIrFAFRyo6xwHWLCsmL8fYdVQd5CIioQqOVHSOA0QjuWxeUcEvdncl9XNFRDJRqIIjlW7eVM+e44Ps6lCrQ0Sym4LjAt12eSMl0Tz+8dE9WlJWRLKaguMClcUi/OmNb+CJvSfY9uTBoMsREQmMgiMBH72uhfdcWsc/PrqH3Z0aYSUi2SlUwZGqUVVxn8/fv38TkdwcvvVMa0q+Q0Qk3YUqOFI1qipeWWGE915Wz0MvdzI5NZ2y7xERSVehCo6l8tZ1NQyOT/KaLleJSBZScFyEa1ZWAvDswZ6AKxERWXoKjouwrDRKU0WMl9t1T4eIZB8Fx0Xa2FCquatEJCspOC7SxoYyDnUPMzQ+GXQpIiJLSsFxkWaWlNX9HCKSbRQcF2ljgzfk91XNXSUiWSZUwZHqGwDj1ZYWUFWUz6vq5xCRLBOq4FiKGwBnmBkbGkp5/ZjWIheR7BKq4FhqK6uLaO0e1my5IpJVFByL0FJVxOD4JD3DE0GXIiKyZBQci7CypgiAQ93DAVciIrJ0FByLsLJKwSEi2UfBsQhNFTHycoxWBYeIZBEFxyLk5eawvLKQ1h4Fh4hkDwXHIjWWx+jsHwu6DBGRJaPgWKS6siidfQoOEckeCo5Fqi+L0jU4ptUARSRrKDgWqb4sxrSDE0PjQZciIrIkMiI4zOx9ZvY1M/s3M3tX0PXEqy+LAqifQ0SyRsqDw8y2m1mXme2as/0mM9tjZvvN7DPn+wzn3I+ccx8H7gI+mMp6E7WstACA4woOEckSeUvwHfcCXwK+NbPBzHKBe4AbgXZgh5k9AOQCd8/5/Y8557r853/p/17aqC72gkPTjohItkh5cDjnnjSzljmbrwL2O+cOApjZ/cBtzrm7gZvnfoaZGfD3wI+dc7+e73vM7E7gToDm5uak1b+QyqJ8AHqGFBwikh2C6uNoBNriXrf7287lT4B3Areb2V3z7eCc2+ac2+Kc21JTU5O8ShcQyc2hvDBCz7A6x0UkOyzFpapFc859Efhi0HWcS1VRPt0aVSUiWSKoFkcHsDzudZO/bVGWcgXAeFXFBXTrUpWIZImggmMHsNbMVppZPvAh4IHFfuhSrgAYr7o4nx61OEQkSyzFcNz7gGeAdWbWbmZ3OOcmgU8AjwK7ge86515NwncF0uKoLMrnpEZViUiWWIpRVVvPsf0R4JEkf9eDwINbtmz5eDI/dyFlsQgDY5M45/AGgImIhFdG3Dme7spiEaamHUPjk0GXIiKScqEKjqAuVZXFIgD0j55a0u8VEQlCqIIjqM5xBYeIZJNQBUdQymLe3eMKDhHJBgqOJJhpcQwoOEQkC4QqOALr4yj0gqNvRMEhIuEXquBQH4eISOqFKjiCUpSfS26OKThEJCsoOJLAzCiPRRQcIpIVQhUcQfVxgHe5SsEhItkgVMERVB8HQKmCQ0SyRKiCI0hqcYhItlBwJElpLMLgmOaqEpHwU3AkSUk0j8ExtThEJPxCFRxBdo6XRPMYUItDRLLAeYPDzH4v7vn1c977RKqKuliBdo5HI0xMTjM+ObXk3y0ispQWanH8adzz/zPnvY8luZaMVhL11sRSP4eIhN1CwWHneD7f66ym4BCRbLFQcLhzPJ/vdVYrKfDmq1IHuYiE3UJrjq83s514rYvV/nP816tSWlmGUYtDRLLFQsFxyZJUkSRmdgtwy5o1a5b8u0uiWpNDRLLDeS9VOecOxz+AIWAzUO2/TitBjqpSi0NEssVCw3EfMrNL/ef1wC680VTfNrNPL0F9GaN0psWhPg4RCbmFOsdXOud2+c9/H/iZc+4W4Go0HPcMxWpxiEiWWCg44v/5/A7gEQDn3CAwnaqiMlFujlGUn6vgEJHQW6hzvM3M/gRox+vb+AmAmcWASIpryzgl0YiG44pI6C3U4rgD2Ah8FPigc67P334N8M0U1pWRvIkO1eIQkXA7b4vDOdcF3DXP9seAx1JVVKYqieYxOK4Wh4iE23mDw8weON/7zrlbk1tOZiuJRugdmQi6DBGRlFqoj+NaoA24D3iONJ+fKsgbAMFrcRw5ORLId4uILJWF+jjqgP8JXAp8AbgR6HbOPeGceyLVxSUqyBsAQZ3jIpIdFrpzfMo59xPn3EfwOsT3A4+n41oc6aBUizmJSBZY6FIVZlYA/DawFWgBvgj8MLVlZaaSaN7sYk4FeblBlyMikhILdY5/C+8y1SPAX8fdRS7zmJnocHBskoJiBYeIhNNCfRy/B6wFPgU8bWYD/mPQzAZSX15miZ/o0DnHF3++j3ufOhRwVSIiybXQfRwLBYvEOd3iOMXuzkH+6Wd7AbhuTTVvqC0JsjQRkaRRMCRRaVyL49mDPbPbf/ba8aBKEhFJOgVHEsW3OPafGKK8MMLGhlKe2Hsi4MpERJJHwZFEM30cA2OTHO4ZpqWqiKtXVvFyWx9T01qiXUTCQcGRRKVxo6q6BsapLS1gfV0J45PTtOmOchEJCQVHEp1ezOkU3UPj1JQUsLa2GIA9xweDLE1EJGnSPjjM7BIz+4qZfd/M/jDoes5nZjGn3uEJekdOUV1cwFp/NNU+BYeIhERKg8PMtptZl5ntmrP9JjPbY2b7zewz5/sM59xu59xdwO8A16ey3mQoiUZo7fEuS1UXF1BckEd1cT4dfaMBVyYikhypbnHcC9wUv8HMcoF7gPcAG4CtZrbBzC4zs4fmPJb5v3Mr8DD+0rXprCSax6HuYcALDoD6shgdfWNBliUikjQLzlW1GM65J82sZc7mq4D9zrmDAGZ2P3Cbc+5u4OZzfM4DwANm9jDwnfn2MbM7gTsBmpubk1L/xSiJ5rHviLdQYk1JPgAN5VEOnhgOrCYRkWQKoo+jEW+Njxnt/rZ5mdkNZvZFM/sq52lxOOe2Oee2OOe21NTUJK/aBM3cywGnWxwN5TGO9o3inIbkikjmS2mLIxmcc48DjwdcxgWbuZcD4oKjLMbwxBQDY5OUxSLn+lURkYwQRIujA1ge97rJ37ZoZnaLmW3r7+9PxsddlJkWRyySS1GBFyJ1ZVEAjg+on0NEMl8QwbEDWGtmK80sH/gQcN61zS9U0CsAAlQVef0alf5POB0cx/oVHCKS+VI9HPc+4BlgnZm1m9kdzrlJ4BPAo8Bu4LvOuVeT9H2BtzhmQqIgcvqPtrZELQ4RCY9Uj6raeo7tj5CCobXOuQeBB7ds2fLxZH/2hWoo90JieUXh7LZlpV5fh4JDRMIg7TvHM801q6q4662ree9ldbPbopFcygsjHB8YD7AyEZHkCFVwmNktwC1r1qwJrIbC/Dw+8571Z22vLYmqxSEioZD2c1UlIh06x89lWWmBgkNEQiFUwZHO6kqjulQlIqGg4FgitaVRTgyNMzk1HXQpIiKLEqrgSIfhuOfSXFXI1LSjrVez5IpIZgtVcKRzH8c6f12OPce8dTmcc0xrOVkRyUChCo50NrMS4N7jg3QPjfPmf3iMt37uMY70aElZEcksCo4lUpifx8rqIna0nuSex/bT0TfK0b4x/vg7v1a/h4hklFAFRzr3cQDcsqmeX+7r5ptPtfKhNy3n8x+8nFc6+nlw59GgSxMRuWChCo507uMA2Hp1M5VF+ZQXRvjUO9dy82X1rKst4cuPH1B/h4hkjFAFR7qrL4vx+J/dwFN//nbqy2Lk5Bh33bCKvceH+PnrXUGXJyJyQRQcS6w0GpldpwPglk0NLK+M8aXH9qvVISIZQcERsLzcHD759rW83NbH3zz8GsPjk0GXJCJyXprkMA3cfmUTuzr6+eZTrdz/fBvXra7i9iubuOnSOsws6PJERM5gzoXv8siWLVvcCy+8EHQZCXvxcC8/+k0Hj+3por13lHesX8bfvf8yakujQZcmIlnAzF50zm1ZcD8FR/qZnJrm3qdb+dxP91CQl8tf37qR2y5vUOtDRFLqQoNDfRxpKC83hz94yyoe+eRbWF1TxKf/7SU+vP15XmrrC7o0EREFRzpbVVPM9+66jr/87UvY1dHP++55ij/4lx3s6kjPGxxFJDvoUlWGGBqf5N6nDrHtyYMMjE3y7o21fPqdb+CS+tKgSxORkMjKPo64UVUf37dvX9DlpMTA2Cm2/+oQ3/jlIQbHJ/nty+r59DvXstaffVdE5GJlZXDMCGOLY66+kQm+/stDfPOpQ4ycmuLWNzbwyXesZXVNcdCliUiGUnCEPDhmnByeYNuTB/mXp1sZn5zifVc08sm3r6Wluijo0kQkwyg4siQ4ZnQPjfOVxw/w7WcPMzntuH1zE594+xqWVxYGXZqIZAgFR5YFx4yugTH+7+MH+M5zR5h2jvdd0chdb13NmmW6hCUi56fgyNLgmNHZP8pXnzjI/TuOMD45zU0b6/ijG9ZwWVN6TjkvIsFTcGR5cMzoGRrnm0+18i/PtDI4Nslb1lbzx29bw9UrK3UnuoicQcGh4DjDwNgp/vXZI3zjVwfpHppgc3M5f3TDGt6+fhk5OQoQEVFwKDjOYezUFN97oY2vPHGQjr5RVtUU8fvXr+QDmxspzA/VZMkikqCsDI5suAEwWU5NTfPwzk62P3WIne39lEbz2Hp1Mx+5toWG8ljQ5YlIALIyOGaoxXHhnHP8+kgv23/Vyo93dWJm3HRpHR+7voXNzRXqBxHJIhcaHLo2keXMjCtXVHLlikrae0f49jOHue/5Izy8s5P1dSVsvaqZ913RSFksEnSpIpIm1OKQswyPT/LAy0f5znNHeKWjn2gkh5s3NbD1qmY2N5erFSISUrpUpeBIilfa+/nO80d44KUOhiemWF9Xwu1XNnHr5Q0sK9HKhCJhouBQcCTV0PgkD758lPueP8LO9n5yDN68tob3X9HIuzbWakSWSAgoOBQcKbO/a4gf/qadH/3mKB19oxTl5/LuS+u47fJGrltdRSRX64OJZCIFh4Ij5aanHTtaT/LD33Tw8CudDI5NUhaLcOOGWt5zaR1vXltNQV5u0GWKyAVScCg4ltTYqSl+ua+bH+/q5GevHWdwbJLigjzecckyP0RqKC7Q5SyRdKbhuLKkopFcbtxQy40bapmYnObpA938+JVj/PS1Y/zHS0eJ5BpXr6zibeuX8bZ1NazSglMiGUstDkmpyalpdrT28tieLh57vYt9XUMAtFQVcsO6Zbxt/TKuaqkklq9LWiJB06UqBUdaajs5wuN7uvjF6108faCH8clp8nNzuKK5nOtWV3P9mio2NZWTn6cOdpGlFqrgMLMi4AngfznnHlpofwVHZhg7NcWzB3t45kAPTx3o5tWjAzgHhfm5vKmlkutWV3Ht6iouqS/VSC2RJZAWfRxmth24Gehyzl0at/0m4AtALvB159zfL/BRfw58N2WFSiCikVxuWLeMG9YtA6BvZIJnD57k6QPdPH2gh7t//DoAsUguly8vZ0tLBVeuqGDzigpKo5oCRSQoKW1xmNlvAUPAt2aCw8xygb3AjUA7sAPYihcid8/5iI8BbwSqgCjQrRZH9jg+MMaO1pO80NrLC4dPsrtzkKlphxmsqy3hyhUVbGmpYFNTOSurirSuiMgipUWLwzn3pJm1zNl8FbDfOXcQwMzuB25zzt2N1zo5g5ndABQBG4BRM3vEOTedyrolPdSWRrl5UwM3b2oAvDm0Xmrrmw2S/3jpKP/63BEASgryuLSxjE3Ly3hjUzmXNZbRVBHTvFoiKRDEcNxGoC3udTtw9bl2ds79BYCZfRSvxTFvaJjZncCdAM3NzcmqVdJIUUEe16+p5vo11QBMTTv2dQ2ys72fne197GzvZ/uvDnFqymtFVxbls6mpjE2NZWxsLGNDfanCRCQJMuY+DufcvQu8vw3YBt6lqqWoSYKVm2OsrytlfV0pv7NlOQDjk1PsOTbIy+397GzzwuTJvSeY9v9GlBTksb6+hEvqS2cf62pLNBxYJAFBBEcHsDzudZO/bdHiVgBMxsdJBirIy2VTUzmbmsrhmhUAjExMsufYILs7B9ndOcDuzgF+8OsOhsYPA5Bj0FJdxIa4IFlbW8zyikL1m4jMI+XDcf0+jofiOsfz8DrH34EXGDuA33XOvZqs71TnuCxketrR3jvKa50DvOaHye7OAdp7R2f3iUZyWF1TzBtqS1izzPu5dlkxyysLyVWgSAilxX0cZnYfcANQDRwH/so59w0zey/wz3gjqbY75/42Sd+nNcdlUQbGTrG/a4h9xwfZe3yIff7zzv6x2X0K8mYCpZi1fqisrimmubJQNy5KRkuL4AiKWhySbINjp9jXNcT+40PsPT44GyhH4wIlN8dYXhFjVU0xq6qLWFlTxKrqYlbXFFFTUqBOeUl7aTEcVyQsSqIRNjdXsLm54oztg34L5VD3MAdPDHOwe4iDJ4Z5an8345OnBwAWF+SxqqaIldVemKyqKZp9rUWwJNOEqsWhS1WSLqanHUf7R70wOeEHix8uHX2jZ+xbXxZlRVUhLVVFNPs/V1QVsqKqSFPRy5LSpSpdqpI0NToxRWvP8BmhcvjkCId7hukemjhj3+rifFb4QRIfKC1VhZQX5gd0BBJWulQlkqZi+bmz95DMNTQ+yeGeYQ73jNDaM8zh7hEOnxzmmQM9/ODXZ45aL43m0VJdNBskzZWF3uvKQvWpSEopOETSSHFBHhsbytjYUHbWe2Onpmg7OUJrj9c6afUD5uW2Ph7eeXT2JkfwhhI3VXhhsrwixvLKQu9RUcjyyhglmiRSFiFUwaEbACXMopFc1taWsLa25Kz3Jian6egbnW2ttJ0c4cjJEdp6R3n+0EmGxifP2L+iMEJzZSFNfpg0V3qB0lxZSEN5TNPYy3mpj0Mk5Jxz9I2coq3XD5OTo7T1euHSdnKE9t5RJuOaKzkG9WUxllfG4kLFC5bllYXUFOsyWFipj0NEADAzKoryqSjK96ZimWNq2nFsYGy2ldLut1SOnBzhib0n6BocP2P/mctgTRUxGstjNFUU0lgRo6kiRlN5jOriAk3VEnIKDpEsl5tjNJZ7IXDNqqqz3h87NUV7r9dSOeK3Utp6R+joG+Wltj76Rk6dsX9+Xo4fKLHTPyv8gCmPUVsa1ZQtGS5UwaE+DpHki0ZyWbOshDXLzu5bAW8kWEfvKB193mWvjt5R2ntHae8bZffu42cNMc7LMerLo6dbK3HhsryikLqyqPpY0pz6OEQkpcZOTdHRNxoXKiNnvD4+OEb8/4ZyDOpKo2e0UmaCpbE8RkN5jGhE0+Cngvo4RCQtRCO5rK7xJoKcz8TkNJ39ccHS54eLPyLs2MAYU9Nn/gO3uriAxvIoDeUx6stiNPgtmAb/UV2crw78FFJwiEig8vNy/Lvji+Z9f3JqmmMDY7OXwDr6RunsH6Wjb4x9XUM8sfcEIxNTZ31mQ1l0NkgaymNzXkc1R9gi6E9ORNJaXu7MKK7CedeYds7RP3qKjr5RjvaN+aHiPT/aN8pT+7s5PjDGnEYLFYWR2SBp9MPkdMjEqCkpUCf+OYQqONQ5LpJ9zIzywnzKC/PnveMe4NTUNMcHxmbDZKbVcrTPG4b87MEeBsfOvEkyL8eo81spc4OlsTxGfVk0a+/AV+e4iAjeIl6dccFydObR72071j92xo2SACXRvLi+lehsf0t9mddqqSuLZtTiXuocFxFJQGk0QmldhHV18w87npp2nBgcPzNU+ry+ls7+UX59pPese1rA68j3wmROsPhhs6wk8+5rUXCIiFyAXP/SVV1ZlCtXVMy7z8jEJJ39Y17LpX+UTj9UjvaPceDEML/a183wnI783ByjtqSAev/yV4P/Mz5kqory0+pufAWHiEiSFObnnXfosXOOgbFJOv1QOTrn566Ofn762nEm4laPBMjPzaGuLHpmsPgjxWYCpiwWWbIhyAoOEZElYmaUxSKUxSKsrzt7PRbwwuXk8ASdft9KZ78XLEf7xujsO/e9LbFILvXlUf73f3ojVzTP3yJKllAFh0ZViUimMzOqiguoKi7g0sb5R4nN9LeccTnM/7kUK0NqVJWIiAAXPqoqc8aJiYhIWlBwiIhIQhQcIiKSEAWHiIgkRMEhIiIJUXCIiEhCFBwiIpKQUAWHmd1iZtv6+/uDLkVEJLRCeQOgmZ0ADl/kr1cD3UksJxPomLODjjk7LOaYVzjnahbaKZTBsRhm9sKF3DkZJjrm7KBjzg5LccyhulQlIiKpp+AQEZGEKDjOti3oAgKgY84OOubskPJjVh+HiIgkRC0OERFJiIIjjpndZGZ7zGy/mX0m6HqSwcyWm9ljZvaamb1qZp/yt1ea2c/MbJ//s8Lfbmb2Rf/PYKeZbQ72CC6emeWa2W/M7CH/9Uoze84/tn8zs3x/e4H/er//fkuQdV8sMys3s++b2etmttvMrg37eTaz/+r/vd5lZveZWTRs59nMtptZl5ntituW8Hk1s4/4++8zs48spiYFh8/McoF7gPcAG4CtZrYh2KqSYhL4b865DcA1wB/7x/UZ4OfOubXAz/3X4B3/Wv9xJ/DlpS85aT4F7I57/Q/A551za4Be4A5/+x1Ar7/98/5+megLwE+cc+uBN+Ide2jPs5k1Ap8EtjjnLgVygQ8RvvN8L3DTnG0JnVczqwT+CrgauAr4q5mwuSjOOT28fp5rgUfjXn8W+GzQdaXgOP8DuBHYA9T72+qBPf7zrwJb4/af3S+THkCT/x/U24GHAMO7KSpv7vkGHgWu9Z/n+ftZ0MeQ4PGWAYfm1h3m8ww0Am1ApX/eHgLeHcbzDLQAuy72vAJbga/GbT9jv0QfanGcNvOXcEa7vy00/Kb5FcBzQK1zrtN/6xhQ6z8Py5/DPwP/A5j2X1cBfc65Sf91/HHNHrP/fr+/fyZZCZwAvulfnvu6mRUR4vPsnOsAPgccATrxztuLhPs8z0j0vCb1fCs4soSZFQP/DnzaOTcQ/57z/gkSmuF1ZnYz0OWcezHoWpZQHrAZ+LJz7gpgmNOXL4BQnucK4Da80GwAijj7kk7oBXFeFRyndQDL4143+dsynplF8ELjX51zP/A3Hzezev/9eqDL3x6GP4frgVvNrBW4H+9y1ReAcjPL8/eJP67ZY/bfLwN6lrLgJGgH2p1zz/mvv48XJGE+z+8EDjnnTjjnTgE/wDv3YT7PMxI9r0k93wqO03YAa/0RGfl4nWwPBFzTopmZAd8Adjvn/inurQeAmZEVH8Hr+5jZ/mF/dMY1QH9ckzgjOOc+65xrcs614J3HXzjn/jPwGHC7v9vcY575s7jd3z+j/mXunDsGtJnZOn/TO4DXCPF5xrtEdY2ZFfp/z2eOObTnOU6i5/VR4F1mVuG31N7lb7tFRXYSAAACrUlEQVQ4QXf6pNMDeC+wFzgA/EXQ9STpmN6M14zdCbzkP96Ld23358A+4P8Blf7+hje67ADwCt6IlcCPYxHHfwPwkP98FfA8sB/4HlDgb4/6r/f7768Kuu6LPNbLgRf8c/0joCLs5xn4a+B1YBfwbaAgbOcZuA+vD+cUXsvyjos5r8DH/GPfD/z+YmrSneMiIpIQXaoSEZGEKDhERCQhCg4REUmIgkNERBKi4BARkYQoOESWmJk9bmZZtQ62hIuCQySDxN0RLRIYBYfIOZhZi7+uxdf8NR9+amax+BaDmVX7U5tgZh81sx/56yO0mtknzOxP/UkHn/Wntp7xX8zsJX8diav83y/y11543v+d2+I+9wEz+wXwczOrN7Mn437/LUv8RyNZTsEhcn5rgXuccxuBPuADC+x/KfB+4E3A3wIjzpt08Bngw3H7FTrnLgf+CNjub/sLvGkwrgLeBvyjP8MtePNO3e6ceyvwu3hThV+Ot+7GS4s8RpGEqNkrcn6HnHMz/2N+EW9dhPN5zDk3CAyaWT/woL/9FWBT3H73ATjnnjSzUjMrx5s/6FYz++/+PlGg2X/+M+fcSf/5DmC7P3nlj+LqE1kSanGInN943PMpvH9sTXL6v53oefafjns9zZn/UJs714/Dm2foA865y/1Hs3NuZgXD4dkdnXsS+C282U3vNbMPI7KEFBwiiWsFrvSf336e/c7ngwBm9ma8GUz78WYr/RN/plfM7Ir5ftHMVgDHnXNfA76OdxlLZMnoUpVI4j4HfNfM7gQevsjPGDOz3wARvFlLAf4Gb+XCnWaWg7cU7M3z/O4NwJ+Z2SlgiDP7TkRSTrPjiohIQnSpSkREEqLgEBGRhCg4REQkIQoOERFJiIJDREQSouAQEZGEKDhERCQhCg4REUnI/we26gEkuDFOKAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"errors = []\n",
"for i in tqdm(range(1000)):\n",
" y = network.forward( Xt )\n",
" network.step(Yt, y)\n",
" e = ((y - Yt)**2).mean()\n",
" errors.append(e.cpu().detach().numpy())\n",
" \n",
"errors = np.array(errors)\n",
"plt.plot(errors)\n",
"plt.yscale('log')\n",
"plt.xlabel('numbers')\n",
"plt.ylabel('MSE')\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit a6fd561

Please sign in to comment.