In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "STAT 453: Deep Learning (Spring 2021)  \n",
    "Instructor: Sebastian Raschka (sraschka@wisc.edu)  \n",
    "\n",
    "Course website: http://pages.stat.wisc.edu/~sraschka/teaching/stat453-ss2021/  \n",
    "GitHub repository: https://github.com/rasbt/stat453-deep-learning-ss21"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Author: Sebastian Raschka\n",
      "\n",
      "Python implementation: CPython\n",
      "Python version       : 3.9.2\n",
      "IPython version      : 7.20.0\n",
      "\n",
      "torch: 1.9.0a0+d819a21\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Understanding Onehot Encoding and Cross Entropy in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Onehot Encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "one-hot encoding:\n",
      " tensor([[1., 0., 0.],\n",
      "        [0., 1., 0.],\n",
      "        [0., 0., 1.],\n",
      "        [0., 0., 1.]])\n"
     ]
    }
   ],
   "source": [
    "def to_onehot(y, num_classes):\n",
    "    y_onehot = torch.zeros(y.size(0), num_classes)\n",
    "    y_onehot.scatter_(1, y.view(-1, 1).long(), 1).float()\n",
    "    return y_onehot\n",
    "\n",
    "y = torch.tensor([0, 1, 2, 2])\n",
    "\n",
    "y_enc = to_onehot(y, 3)\n",
    "\n",
    "print('one-hot encoding:\\n', y_enc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Softmax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Suppose we have some net inputs Z, where each row is one training example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.3000, -0.5000, -0.5000],\n",
       "        [-0.4000, -0.1000, -0.5000],\n",
       "        [-0.3000, -0.9400, -0.5000],\n",
       "        [-0.9900, -0.8800, -0.5000]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z = torch.tensor( [[-0.3,  -0.5, -0.5],\n",
    "                   [-0.4,  -0.1, -0.5],\n",
    "                   [-0.3,  -0.94, -0.5],\n",
    "                   [-0.99, -0.88, -0.5]])\n",
    "\n",
    "Z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we convert them to \"probabilities\" via softmax:\n",
    "\n",
    "$$P(y=j \\mid z^{(i)}) = \\sigma_{\\text{softmax}}(z^{(i)}) = \\frac{e^{z^{(i)}}}{\\sum_{j=0}^{k} e^{z_{k}^{(i)}}}.$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "softmax:\n",
      " tensor([[0.3792, 0.3104, 0.3104],\n",
      "        [0.3072, 0.4147, 0.2780],\n",
      "        [0.4263, 0.2248, 0.3490],\n",
      "        [0.2668, 0.2978, 0.4354]])\n"
     ]
    }
   ],
   "source": [
    "def softmax(z):\n",
    "    return (torch.exp(z.t()) / torch.sum(torch.exp(z), dim=1)).t()\n",
    "\n",
    "smax = softmax(Z)\n",
    "print('softmax:\\n', smax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The probabilties can then be converted back to class labels based on the largest probability in each row:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predicted class labels:  tensor([0, 1, 0, 2])\n",
      "true class labels:  tensor([0, 1, 2, 2])\n"
     ]
    }
   ],
   "source": [
    "def to_classlabel(z):\n",
    "    return torch.argmax(z, dim=1)\n",
    "\n",
    "print('predicted class labels: ', to_classlabel(smax))\n",
    "print('true class labels: ', to_classlabel(y_enc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cross Entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we compute the cross entropy for each training example:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\\mathcal{L}(\\mathbf{W}; \\mathbf{b}) = \\frac{1}{n} \\sum_{i=1}^{n} H(T_i, O_i),$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$H(T_i, O_i) = -\\sum_m T_i \\cdot log(O_i).$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cross Entropy: tensor([0.9698, 0.8801, 1.0527, 0.8314])\n"
     ]
    }
   ],
   "source": [
    "def cross_entropy(softmax, y_target):\n",
    "    return - torch.sum(torch.log(softmax) * (y_target), dim=1)\n",
    "\n",
    "xent = cross_entropy(smax, y_enc)\n",
    "print('Cross Entropy:', xent)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## In PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that `nll_loss` takes log(softmax) as input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9698, 0.8801, 1.0527, 0.8314])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.nll_loss(torch.log(smax), y, reduction='none')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that `cross_entropy` takes logits as input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9698, 0.8801, 1.0527, 0.8314])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.cross_entropy(Z, y, reduction='none')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defaults"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, nll_loss & cross_entropy are already returning the average over training examples, which is useful for stability during optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.9335)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.cross_entropy(Z, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.9335)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.mean(cross_entropy(smax, y_enc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}