Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
1653 lines (1573 sloc) 75.2 KB
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Neural Networks gone wild! They can sample from discrete distributions now!</title>
<link href="http://anotherdatum.com/feeds/all.atom.xml" type="application/atom+xml" rel="alternate" title="Another Datum Full Atom Feed" />
<!-- Bootstrap Core CSS -->
<link href="http://anotherdatum.com/theme/css/bootstrap.min.css" rel="stylesheet">
<!-- Custom CSS -->
<link href="http://anotherdatum.com/theme/css/clean-blog.min.css" rel="stylesheet">
<!-- Code highlight color scheme -->
<link href="http://anotherdatum.com/theme/css/code_blocks/tomorrow.css" rel="stylesheet">
<!-- CSS specified by the user -->
<link href="http://anotherdatum.com/css/overrides.css" rel="stylesheet">
<!-- Custom Fonts -->
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet" type="text/css">
<link href='https://fonts.googleapis.com/css?family=Lora:400,700,400italic,700italic' rel='stylesheet' type='text/css'>
<link href='https://fonts.googleapis.com/css?family=Open+Sans:300italic,400italic,600italic,700italic,800italic,400,300,600,700,800' rel='stylesheet' type='text/css'>
<!-- HTML5 Shim and Respond.js IE8 support of HTML5 elements and media queries -->
<!-- WARNING: Respond.js doesn't work if you view the page via file:// -->
<!--[if lt IE 9]>
<script src="https://oss.maxcdn.com/libs/html5shiv/3.7.0/html5shiv.js"></script>
<script src="https://oss.maxcdn.com/libs/respond.js/1.4.2/respond.min.js"></script>
<![endif]-->
<meta name="description" content="Learn how to use Gumbel distribution to form a NN containing a discrete random component.">
<meta name="author" content="Yoel Zeldes">
<meta name="tags" content="deep-learning">
<meta name="tags" content="GAN">
<meta property="og:locale" content="en">
<meta property="og:site_name" content="Another Datum">
<meta property="og:type" content="article">
<meta property="article:author" content="http://anotherdatum.com/author/yoel-zeldes.html">
<meta property="og:url" content="http://anotherdatum.com/gumbel-gan.html">
<meta property="og:title" content="Neural Networks gone wild! They can sample from discrete distributions now!">
<meta property="article:published_time" content="2018-07-16 23:00:00+03:00">
<meta property="og:description" content="Learn how to use Gumbel distribution to form a NN containing a discrete random component.">
<meta property="og:image" content="http://anotherdatum.com/images/gumbel-gan/cover.jpg">
<meta name="twitter:card" content="summary_large_image">
<meta name="twitter:site" content="@YZeldes">
<meta name="twitter:title" content="Neural Networks gone wild! They can sample from discrete distributions now!">
<meta name="twitter:image" content="http://anotherdatum.com/images/gumbel-gan/cover.jpg">
<meta name="twitter:description" content="Learn how to use Gumbel distribution to form a NN containing a discrete random component.">
</head>
<body>
<!-- Navigation -->
<nav class="navbar navbar-default navbar-custom navbar-fixed-top">
<div class="container-fluid">
<!-- Brand and toggle get grouped for better mobile display -->
<div class="navbar-header page-scroll">
<button type="button" class="navbar-toggle" data-toggle="collapse" data-target="#bs-example-navbar-collapse-1">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a class="navbar-brand" href="http://anotherdatum.com/">Another Datum</a>
</div>
<!-- Collect the nav links, forms, and other content for toggling -->
<div class="collapse navbar-collapse" id="bs-example-navbar-collapse-1">
<ul class="nav navbar-nav navbar-right">
<li><a href="http://anotherdatum.com">Posts</a></li>
<li><a href="http://anotherdatum.com/pages/about.html">about me</a></li>
<li><a href="http://anotherdatum.com/pages/resources.html">Resources</a></li>
</ul>
</div>
<!-- /.navbar-collapse -->
</div>
<!-- /.container -->
</nav>
<!-- Page Header -->
<header class="intro-header" style="background-image: url('images/gumbel-gan/cover.jpg')">
<div class="container">
<div class="row">
<div class="col-lg-8 col-lg-offset-2 col-md-10 col-md-offset-1">
<div class="post-heading">
<h1>Neural Networks gone wild! They can sample from discrete distributions now!</h1>
<span class="meta">Posted on 16 July 2018</span>
</div>
</div>
</div>
</div>
</header>
<!-- Main Content -->
<div class="container">
<div class="row">
<div class="col-lg-8 col-lg-offset-2 col-md-10 col-md-offset-1">
<!-- Post Content -->
<article>
<style type="text/css">/*!
*
* IPython notebook
*
*/
/* CSS font colors for translated ANSI colors. */
.ansibold {
font-weight: bold;
}
/* use dark versions for foreground, to improve visibility */
.ansiblack {
color: black;
}
.ansired {
color: darkred;
}
.ansigreen {
color: darkgreen;
}
.ansiyellow {
color: #c4a000;
}
.ansiblue {
color: darkblue;
}
.ansipurple {
color: darkviolet;
}
.ansicyan {
color: steelblue;
}
.ansigray {
color: gray;
}
/* and light for background, for the same reason */
.ansibgblack {
background-color: black;
}
.ansibgred {
background-color: red;
}
.ansibggreen {
background-color: green;
}
.ansibgyellow {
background-color: yellow;
}
.ansibgblue {
background-color: blue;
}
.ansibgpurple {
background-color: magenta;
}
.ansibgcyan {
background-color: cyan;
}
.ansibggray {
background-color: gray;
}
div.cell {
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: vertical;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: vertical;
-moz-box-align: stretch;
display: box;
box-orient: vertical;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: column;
align-items: stretch;
border-radius: 2px;
box-sizing: border-box;
-moz-box-sizing: border-box;
-webkit-box-sizing: border-box;
border-width: 1px;
border-style: solid;
border-color: transparent;
width: 100%;
padding: 5px;
/* This acts as a spacer between cells, that is outside the border */
margin: 0px;
outline: none;
border-left-width: 1px;
padding-left: 5px;
background: linear-gradient(to right, transparent -40px, transparent 1px, transparent 1px, transparent 100%);
}
div.cell.jupyter-soft-selected {
border-left-color: #90CAF9;
border-left-color: #E3F2FD;
border-left-width: 1px;
padding-left: 5px;
border-right-color: #E3F2FD;
border-right-width: 1px;
background: #E3F2FD;
}
@media print {
div.cell.jupyter-soft-selected {
border-color: transparent;
}
}
div.cell.selected {
border-color: #ababab;
border-left-width: 0px;
padding-left: 6px;
background: linear-gradient(to right, #42A5F5 -40px, #42A5F5 5px, transparent 5px, transparent 100%);
}
@media print {
div.cell.selected {
border-color: transparent;
}
}
div.cell.selected.jupyter-soft-selected {
border-left-width: 0;
padding-left: 6px;
background: linear-gradient(to right, #42A5F5 -40px, #42A5F5 7px, #E3F2FD 7px, #E3F2FD 100%);
}
.edit_mode div.cell.selected {
border-color: #66BB6A;
border-left-width: 0px;
padding-left: 6px;
background: linear-gradient(to right, #66BB6A -40px, #66BB6A 5px, transparent 5px, transparent 100%);
}
@media print {
.edit_mode div.cell.selected {
border-color: transparent;
}
}
.prompt {
/* This needs to be wide enough for 3 digit prompt numbers: In[100]: */
min-width: 14ex;
/* This padding is tuned to match the padding on the CodeMirror editor. */
padding: 0.4em;
margin: 0px;
font-family: monospace;
text-align: right;
/* This has to match that of the the CodeMirror class line-height below */
line-height: 1.21429em;
/* Don't highlight prompt number selection */
-webkit-touch-callout: none;
-webkit-user-select: none;
-khtml-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
/* Use default cursor */
cursor: default;
}
@media (max-width: 540px) {
.prompt {
text-align: left;
}
}
div.inner_cell {
min-width: 0;
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: vertical;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: vertical;
-moz-box-align: stretch;
display: box;
box-orient: vertical;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: column;
align-items: stretch;
/* Old browsers */
-webkit-box-flex: 1;
-moz-box-flex: 1;
box-flex: 1;
/* Modern browsers */
flex: 1;
}
/* input_area and input_prompt must match in top border and margin for alignment */
div.input_area {
border: 1px solid #cfcfcf;
border-radius: 2px;
background: #f7f7f7;
line-height: 1.21429em;
}
/* This is needed so that empty prompt areas can collapse to zero height when there
is no content in the output_subarea and the prompt. The main purpose of this is
to make sure that empty JavaScript output_subareas have no height. */
div.prompt:empty {
padding-top: 0;
padding-bottom: 0;
}
div.unrecognized_cell {
padding: 5px 5px 5px 0px;
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: horizontal;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: horizontal;
-moz-box-align: stretch;
display: box;
box-orient: horizontal;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: row;
align-items: stretch;
}
div.unrecognized_cell .inner_cell {
border-radius: 2px;
padding: 5px;
font-weight: bold;
color: red;
border: 1px solid #cfcfcf;
background: #eaeaea;
}
div.unrecognized_cell .inner_cell a {
color: inherit;
text-decoration: none;
}
div.unrecognized_cell .inner_cell a:hover {
color: inherit;
text-decoration: none;
}
@media (max-width: 540px) {
div.unrecognized_cell > div.prompt {
display: none;
}
}
div.code_cell {
/* avoid page breaking on code cells when printing */
}
@media print {
div.code_cell {
page-break-inside: avoid;
}
}
/* any special styling for code cells that are currently running goes here */
div.input {
page-break-inside: avoid;
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: horizontal;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: horizontal;
-moz-box-align: stretch;
display: box;
box-orient: horizontal;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: row;
align-items: stretch;
}
@media (max-width: 540px) {
div.input {
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: vertical;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: vertical;
-moz-box-align: stretch;
display: box;
box-orient: vertical;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: column;
align-items: stretch;
}
}
/* input_area and input_prompt must match in top border and margin for alignment */
div.input_prompt {
color: #303F9F;
border-top: 1px solid transparent;
}
div.input_area > div.highlight {
margin: 0.4em;
border: none;
padding: 0px;
background-color: transparent;
}
div.input_area > div.highlight > pre {
margin: 0px;
border: none;
padding: 0px;
background-color: transparent;
}
/* The following gets added to the <head> if it is detected that the user has a
* monospace font with inconsistent normal/bold/italic height. See
* notebookmain.js. Such fonts will have keywords vertically offset with
* respect to the rest of the text. The user should select a better font.
* See: https://github.com/ipython/ipython/issues/1503
*
* .CodeMirror span {
* vertical-align: bottom;
* }
*/
.CodeMirror {
line-height: 1.21429em;
/* Changed from 1em to our global default */
font-size: 14px;
height: auto;
/* Changed to auto to autogrow */
background: none;
/* Changed from white to allow our bg to show through */
}
.CodeMirror-scroll {
/* The CodeMirror docs are a bit fuzzy on if overflow-y should be hidden or visible.*/
/* We have found that if it is visible, vertical scrollbars appear with font size changes.*/
overflow-y: hidden;
overflow-x: auto;
}
.CodeMirror-lines {
/* In CM2, this used to be 0.4em, but in CM3 it went to 4px. We need the em value because */
/* we have set a different line-height and want this to scale with that. */
padding: 0.4em;
}
.CodeMirror-linenumber {
padding: 0 8px 0 4px;
}
.CodeMirror-gutters {
border-bottom-left-radius: 2px;
border-top-left-radius: 2px;
}
.CodeMirror pre {
/* In CM3 this went to 4px from 0 in CM2. We need the 0 value because of how we size */
/* .CodeMirror-lines */
padding: 0;
border: 0;
border-radius: 0;
}
/*
Original style from softwaremaniacs.org (c) Ivan Sagalaev <Maniac@SoftwareManiacs.Org>
Adapted from GitHub theme
*/
.highlight-base {
color: #000;
}
.highlight-variable {
color: #000;
}
.highlight-variable-2 {
color: #1a1a1a;
}
.highlight-variable-3 {
color: #333333;
}
.highlight-string {
color: #BA2121;
}
.highlight-comment {
color: #408080;
font-style: italic;
}
.highlight-number {
color: #080;
}
.highlight-atom {
color: #88F;
}
.highlight-keyword {
color: #008000;
font-weight: bold;
}
.highlight-builtin {
color: #008000;
}
.highlight-error {
color: #f00;
}
.highlight-operator {
color: #AA22FF;
font-weight: bold;
}
.highlight-meta {
color: #AA22FF;
}
/* previously not defined, copying from default codemirror */
.highlight-def {
color: #00f;
}
.highlight-string-2 {
color: #f50;
}
.highlight-qualifier {
color: #555;
}
.highlight-bracket {
color: #997;
}
.highlight-tag {
color: #170;
}
.highlight-attribute {
color: #00c;
}
.highlight-header {
color: blue;
}
.highlight-quote {
color: #090;
}
.highlight-link {
color: #00c;
}
/* apply the same style to codemirror */
.cm-s-ipython span.cm-keyword {
color: #008000;
font-weight: bold;
}
.cm-s-ipython span.cm-atom {
color: #88F;
}
.cm-s-ipython span.cm-number {
color: #080;
}
.cm-s-ipython span.cm-def {
color: #00f;
}
.cm-s-ipython span.cm-variable {
color: #000;
}
.cm-s-ipython span.cm-operator {
color: #AA22FF;
font-weight: bold;
}
.cm-s-ipython span.cm-variable-2 {
color: #1a1a1a;
}
.cm-s-ipython span.cm-variable-3 {
color: #333333;
}
.cm-s-ipython span.cm-comment {
color: #408080;
font-style: italic;
}
.cm-s-ipython span.cm-string {
color: #BA2121;
}
.cm-s-ipython span.cm-string-2 {
color: #f50;
}
.cm-s-ipython span.cm-meta {
color: #AA22FF;
}
.cm-s-ipython span.cm-qualifier {
color: #555;
}
.cm-s-ipython span.cm-builtin {
color: #008000;
}
.cm-s-ipython span.cm-bracket {
color: #997;
}
.cm-s-ipython span.cm-tag {
color: #170;
}
.cm-s-ipython span.cm-attribute {
color: #00c;
}
.cm-s-ipython span.cm-header {
color: blue;
}
.cm-s-ipython span.cm-quote {
color: #090;
}
.cm-s-ipython span.cm-link {
color: #00c;
}
.cm-s-ipython span.cm-error {
color: #f00;
}
.cm-s-ipython span.cm-tab {
background: url();
background-position: right;
background-repeat: no-repeat;
}
div.output_wrapper {
/* this position must be relative to enable descendents to be absolute within it */
position: relative;
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: vertical;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: vertical;
-moz-box-align: stretch;
display: box;
box-orient: vertical;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: column;
align-items: stretch;
z-index: 1;
}
/* class for the output area when it should be height-limited */
div.output_scroll {
/* ideally, this would be max-height, but FF barfs all over that */
height: 24em;
/* FF needs this *and the wrapper* to specify full width, or it will shrinkwrap */
width: 100%;
overflow: auto;
border-radius: 2px;
-webkit-box-shadow: inset 0 2px 8px rgba(0, 0, 0, 0.8);
box-shadow: inset 0 2px 8px rgba(0, 0, 0, 0.8);
display: block;
}
/* output div while it is collapsed */
div.output_collapsed {
margin: 0px;
padding: 0px;
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: vertical;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: vertical;
-moz-box-align: stretch;
display: box;
box-orient: vertical;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: column;
align-items: stretch;
}
div.out_prompt_overlay {
height: 100%;
padding: 0px 0.4em;
position: absolute;
border-radius: 2px;
}
div.out_prompt_overlay:hover {
/* use inner shadow to get border that is computed the same on WebKit/FF */
-webkit-box-shadow: inset 0 0 1px #000;
box-shadow: inset 0 0 1px #000;
background: rgba(240, 240, 240, 0.5);
}
div.output_prompt {
color: #D84315;
}
/* This class is the outer container of all output sections. */
div.output_area {
padding: 0px;
page-break-inside: avoid;
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: horizontal;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: horizontal;
-moz-box-align: stretch;
display: box;
box-orient: horizontal;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: row;
align-items: stretch;
}
div.output_area .MathJax_Display {
text-align: left !important;
}
div.output_area
div.output_area
div.output_area img,
div.output_area svg {
max-width: 100%;
height: auto;
}
div.output_area img.unconfined,
div.output_area svg.unconfined {
max-width: none;
}
/* This is needed to protect the pre formating from global settings such
as that of bootstrap */
.output {
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: vertical;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: vertical;
-moz-box-align: stretch;
display: box;
box-orient: vertical;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: column;
align-items: stretch;
}
@media (max-width: 540px) {
div.output_area {
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: vertical;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: vertical;
-moz-box-align: stretch;
display: box;
box-orient: vertical;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: column;
align-items: stretch;
}
}
div.output_area pre {
margin: 0;
padding: 0;
border: 0;
vertical-align: baseline;
color: black;
background-color: transparent;
border-radius: 0;
}
/* This class is for the output subarea inside the output_area and after
the prompt div. */
div.output_subarea {
overflow-x: auto;
padding: 0.4em;
/* Old browsers */
-webkit-box-flex: 1;
-moz-box-flex: 1;
box-flex: 1;
/* Modern browsers */
flex: 1;
max-width: calc(100% - 14ex);
}
div.output_scroll div.output_subarea {
overflow-x: visible;
}
/* The rest of the output_* classes are for special styling of the different
output types */
/* all text output has this class: */
div.output_text {
text-align: left;
color: #000;
/* This has to match that of the the CodeMirror class line-height below */
line-height: 1.21429em;
}
/* stdout/stderr are 'text' as well as 'stream', but execute_result/error are *not* streams */
div.output_stderr {
background: #fdd;
/* very light red background for stderr */
}
div.output_latex {
text-align: left;
}
/* Empty output_javascript divs should have no height */
div.output_javascript:empty {
padding: 0;
}
.js-error {
color: darkred;
}
/* raw_input styles */
div.raw_input_container {
line-height: 1.21429em;
padding-top: 5px;
}
pre.raw_input_prompt {
/* nothing needed here. */
}
input.raw_input {
font-family: monospace;
font-size: inherit;
color: inherit;
width: auto;
/* make sure input baseline aligns with prompt */
vertical-align: baseline;
/* padding + margin = 0.5em between prompt and cursor */
padding: 0em 0.25em;
margin: 0em 0.25em;
}
input.raw_input:focus {
box-shadow: none;
}
p.p-space {
margin-bottom: 10px;
}
div.output_unrecognized {
padding: 5px;
font-weight: bold;
color: red;
}
div.output_unrecognized a {
color: inherit;
text-decoration: none;
}
div.output_unrecognized a:hover {
color: inherit;
text-decoration: none;
}
.rendered_html {
color: #000;
/* any extras will just be numbers: */
}
.rendered_html :link {
text-decoration: underline;
}
.rendered_html :visited {
text-decoration: underline;
}
.rendered_html h1:first-child {
margin-top: 0.538em;
}
.rendered_html h2:first-child {
margin-top: 0.636em;
}
.rendered_html h3:first-child {
margin-top: 0.777em;
}
.rendered_html h4:first-child {
margin-top: 1em;
}
.rendered_html h5:first-child {
margin-top: 1em;
}
.rendered_html h6:first-child {
margin-top: 1em;
}
.rendered_html * + ul {
margin-top: 1em;
}
.rendered_html * + ol {
margin-top: 1em;
}
.rendered_html pre,
.rendered_html tr,
.rendered_html th,
.rendered_html td,
.rendered_html * + table {
margin-top: 1em;
}
.rendered_html * + p {
margin-top: 1em;
}
.rendered_html * + img {
margin-top: 1em;
}
.rendered_html img,
.rendered_html img.unconfined,
div.text_cell {
/* Old browsers */
display: -webkit-box;
-webkit-box-orient: horizontal;
-webkit-box-align: stretch;
display: -moz-box;
-moz-box-orient: horizontal;
-moz-box-align: stretch;
display: box;
box-orient: horizontal;
box-align: stretch;
/* Modern browsers */
display: flex;
flex-direction: row;
align-items: stretch;
}
@media (max-width: 540px) {
div.text_cell > div.prompt {
display: none;
}
}
div.text_cell_render {
/*font-family: "Helvetica Neue", Arial, Helvetica, Geneva, sans-serif;*/
outline: none;
resize: none;
width: inherit;
border-style: none;
padding: 0.5em 0.5em 0.5em 0.4em;
color: #000;
box-sizing: border-box;
-moz-box-sizing: border-box;
-webkit-box-sizing: border-box;
}
a.anchor-link:link {
text-decoration: none;
padding: 0px 20px;
visibility: hidden;
}
h1:hover .anchor-link,
h2:hover .anchor-link,
h3:hover .anchor-link,
h4:hover .anchor-link,
h5:hover .anchor-link,
h6:hover .anchor-link {
visibility: visible;
}
.text_cell.rendered .input_area {
display: none;
}
.text_cell.rendered
.text_cell.unrendered .text_cell_render {
display: none;
}
.cm-header-1,
.cm-header-2,
.cm-header-3,
.cm-header-4,
.cm-header-5,
.cm-header-6 {
font-weight: bold;
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
}
.cm-header-1 {
font-size: 185.7%;
}
.cm-header-2 {
font-size: 157.1%;
}
.cm-header-3 {
font-size: 128.6%;
}
.cm-header-4 {
font-size: 110%;
}
.cm-header-5 {
font-size: 100%;
font-style: italic;
}
.cm-header-6 {
font-size: 100%;
font-style: italic;
}
</style>
<style type="text/css">.highlight .hll { background-color: #ffffcc }
.highlight { background: #f8f8f8; }
.highlight .c { color: #408080; font-style: italic } /* Comment */
.highlight .err { border: 1px solid #FF0000 } /* Error */
.highlight .k { color: #008000; font-weight: bold } /* Keyword */
.highlight .o { color: #666666 } /* Operator */
.highlight .ch { color: #408080; font-style: italic } /* Comment.Hashbang */
.highlight .cm { color: #408080; font-style: italic } /* Comment.Multiline */
.highlight .cp { color: #BC7A00 } /* Comment.Preproc */
.highlight .cpf { color: #408080; font-style: italic } /* Comment.PreprocFile */
.highlight .c1 { color: #408080; font-style: italic } /* Comment.Single */
.highlight .cs { color: #408080; font-style: italic } /* Comment.Special */
.highlight .gd { color: #A00000 } /* Generic.Deleted */
.highlight .ge { font-style: italic } /* Generic.Emph */
.highlight .gr { color: #FF0000 } /* Generic.Error */
.highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */
.highlight .gi { color: #00A000 } /* Generic.Inserted */
.highlight .go { color: #888888 } /* Generic.Output */
.highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */
.highlight .gs { font-weight: bold } /* Generic.Strong */
.highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */
.highlight .gt { color: #0044DD } /* Generic.Traceback */
.highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */
.highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */
.highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */
.highlight .kp { color: #008000 } /* Keyword.Pseudo */
.highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */
.highlight .kt { color: #B00040 } /* Keyword.Type */
.highlight .m { color: #666666 } /* Literal.Number */
.highlight .s { color: #BA2121 } /* Literal.String */
.highlight .na { color: #7D9029 } /* Name.Attribute */
.highlight .nb { color: #008000 } /* Name.Builtin */
.highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */
.highlight .no { color: #880000 } /* Name.Constant */
.highlight .nd { color: #AA22FF } /* Name.Decorator */
.highlight .ni { color: #999999; font-weight: bold } /* Name.Entity */
.highlight .ne { color: #D2413A; font-weight: bold } /* Name.Exception */
.highlight .nf { color: #0000FF } /* Name.Function */
.highlight .nl { color: #A0A000 } /* Name.Label */
.highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */
.highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */
.highlight .nv { color: #19177C } /* Name.Variable */
.highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */
.highlight .w { color: #bbbbbb } /* Text.Whitespace */
.highlight .mb { color: #666666 } /* Literal.Number.Bin */
.highlight .mf { color: #666666 } /* Literal.Number.Float */
.highlight .mh { color: #666666 } /* Literal.Number.Hex */
.highlight .mi { color: #666666 } /* Literal.Number.Integer */
.highlight .mo { color: #666666 } /* Literal.Number.Oct */
.highlight .sa { color: #BA2121 } /* Literal.String.Affix */
.highlight .sb { color: #BA2121 } /* Literal.String.Backtick */
.highlight .sc { color: #BA2121 } /* Literal.String.Char */
.highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */
.highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */
.highlight .s2 { color: #BA2121 } /* Literal.String.Double */
.highlight .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */
.highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */
.highlight .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */
.highlight .sx { color: #008000 } /* Literal.String.Other */
.highlight .sr { color: #BB6688 } /* Literal.String.Regex */
.highlight .s1 { color: #BA2121 } /* Literal.String.Single */
.highlight .ss { color: #19177C } /* Literal.String.Symbol */
.highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */
.highlight .fm { color: #0000FF } /* Name.Function.Magic */
.highlight .vc { color: #19177C } /* Name.Variable.Class */
.highlight .vg { color: #19177C } /* Name.Variable.Global */
.highlight .vi { color: #19177C } /* Name.Variable.Instance */
.highlight .vm { color: #19177C } /* Name.Variable.Magic */
.highlight .il { color: #666666 } /* Literal.Number.Integer.Long */</style>
<style type="text/css">
/* Temporary definitions which will become obsolete with Notebook release 5.0 */
.ansi-black-fg { color: #3E424D; }
.ansi-black-bg { background-color: #3E424D; }
.ansi-black-intense-fg { color: #282C36; }
.ansi-black-intense-bg { background-color: #282C36; }
.ansi-red-fg { color: #E75C58; }
.ansi-red-bg { background-color: #E75C58; }
.ansi-red-intense-fg { color: #B22B31; }
.ansi-red-intense-bg { background-color: #B22B31; }
.ansi-green-fg { color: #00A250; }
.ansi-green-bg { background-color: #00A250; }
.ansi-green-intense-fg { color: #007427; }
.ansi-green-intense-bg { background-color: #007427; }
.ansi-yellow-fg { color: #DDB62B; }
.ansi-yellow-bg { background-color: #DDB62B; }
.ansi-yellow-intense-fg { color: #B27D12; }
.ansi-yellow-intense-bg { background-color: #B27D12; }
.ansi-blue-fg { color: #208FFB; }
.ansi-blue-bg { background-color: #208FFB; }
.ansi-blue-intense-fg { color: #0065CA; }
.ansi-blue-intense-bg { background-color: #0065CA; }
.ansi-magenta-fg { color: #D160C4; }
.ansi-magenta-bg { background-color: #D160C4; }
.ansi-magenta-intense-fg { color: #A03196; }
.ansi-magenta-intense-bg { background-color: #A03196; }
.ansi-cyan-fg { color: #60C6C8; }
.ansi-cyan-bg { background-color: #60C6C8; }
.ansi-cyan-intense-fg { color: #258F8F; }
.ansi-cyan-intense-bg { background-color: #258F8F; }
.ansi-white-fg { color: #C5C1B4; }
.ansi-white-bg { background-color: #C5C1B4; }
.ansi-white-intense-fg { color: #A1A6B2; }
.ansi-white-intense-bg { background-color: #A1A6B2; }
.ansi-bold { font-weight: bold; }
</style>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h1 id="Neural-Networks-gone-wild!-They-can-sample-from-discrete-distributions-now!">Neural Networks gone wild! They can sample from discrete distributions now!<a class="anchor-link" href="#Neural-Networks-gone-wild!-They-can-sample-from-discrete-distributions-now!">¶</a></h1><p>Training deep neural networks usually boils down to defining your model's architecture and a loss function, and watching the gradients propagate.</p>
<p>However, sometimes it's not that simple: some architectures incorporate a random component. The forward pass is no longer a deterministic function of the input and weights. The random component introduces stochasticity, by means of sampling from it.</p>
<p>When would that happen, you ask? Whenever we want to approximate an intractable sum or integral. Then, we can form a Monte Carlo estimate. A good example is <a href="https://arxiv.org/abs/1312.6114">the variational autoencoder</a>. Basically, it's an autoencoder on steroids: the encoder's job is to learn a <em>distribution</em> over the latent space. The loss function contains an intractable expectation over that distribution, so we sample from it.</p>
<p>As with any architecture, the gradients need to propagate to the weights of the model. Some of the weights are responsible for transforming the input into the parameters of the distribution from which we sample. Here we face a problem: the gradients can't propagate through random nodes! Hence, these weights won't be updated.</p>
<p>One solution to the problem is the reparameterization trick: you substitute the sampled random variable with a deterministic parameterized transformation of a parameterless random variable.</p>
<p>If you don't know this trick I highly encourage you to <a href="http://blog.shakirm.com/2015/10/machine-learning-trick-of-the-day-4-reparameterisation-tricks/">read about it</a>. I'll demonstrate it with the Gaussian case:</p>
<p>Let $Z \sim \mathcal{N}(\mu(X), \sigma^2(X))$. The parameters of the Gaussian are a function of the input $X$ - e.g. the output of stacked dense layers. When sampling realizations of $Z$, gradients won't be able to propagate to the weights of the dense layers. We can substitute $Z$ with a different random variable $Z' = \mu(X) + \sigma(X) \cdot \mathcal{E}$ where $\mathcal{E} \sim \mathcal{N}(0, 1)$. Now the sampling will be from $\mathcal{E}$, so the gradients won't propagate through this path - which we don't care about. However, through $\mu(X)$ and $\sigma(X)$ they will, since it's a deterministic path.</p>
<p>For many types of continuous distributions you can do the reparameterization trick. But what do you do if you need the distribution to be over a discrete set of values?</p>
<p>In the following sections you'll learn:</p>
<ul>
<li>what the Gumbel distribution is</li>
<li>how it is used for sampling from a discrete distribution</li>
<li>how the weights that affect the distribution's parameters can be trained</li>
<li>how to use all of that in a toy example (with code)</li>
</ul>
<h1 id="The-Gumbel-distribution">The Gumbel distribution<a class="anchor-link" href="#The-Gumbel-distribution">¶</a></h1><p>The Gumbel distribution has two parameters - $\mu$ and $\beta$. The standard Gumbel distribution, where $\mu$ and $\beta$ are 0 and 1 respectively, has PDF of $e^{-(x + e^{-x})}$.</p>
<p>Why should you care about this distribution? Consider the setting where you have a discrete random variable whose logits are $\{\alpha_i\}_{i=1}^k$. The logits are a function of the input and weights that need to be trained.</p>
<p>What I'm going to describe next is called the Gumbel-max trick. Using this trick, you can sample from the discrete distribution. The process is as follows:</p>
<ol>
<li>Sample i.i.d samples $\{z_i\}_{i=1}^k$ from the standard Gumbel distribution.</li>
<li>Add the samples to the logits: $\{\alpha_i + z_i\}_{i=1}^k$.</li>
<li>Take the index of the maximal value: $\text{argmax}_{i=1}^k\alpha_i + z_i$.</li>
</ol>
<p>The result will be a random sample of your original distribution. You can read the proof <a href="https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/">here</a>.</p>
<p>Great! So we were able to substitute our distribution with a deterministic transformation of a parameterless distribution! So if we plug it into our model, the gradients will be able to propagate to the weights of the logits, right?</p>
<p>Well, not so fast! Gradients can't propagate through argmax...</p>
<h1 id="Gumbel-softmax-trick-to-the-rescue!">Gumbel-softmax trick to the rescue!<a class="anchor-link" href="#Gumbel-softmax-trick-to-the-rescue!">¶</a></h1><p><img alt="Photo by Radu Florin on Unsplash" src="images/gumbel-gan/gumbel-softmax.jpg"/></p>
<p>Using argmax is equivalent to using one hot vector where the entry corresponding to the maximal value is 1.</p>
<p>So instead of using a hard one hot vector, we can approximate it using a soft one - softmax.</p>
<p>The process is the same as the process described above, except now you apply softmax instead of argmax.</p>
<p>And voila! Gradients can propagate to the weights of the logits.</p>
<p>There's one hyperparameter I didn't tell you about (yet) - the temperature:</p>
<p>$\frac{\text{exp}((\log(\alpha_i)+z_i) \cdot \tau^{-1})}{\sum_{j=1}^k \text{exp}((\log(\alpha_j)+z_j) \cdot \tau^{-1})}$</p>
<p>By dividing by a temperature $\tau > 0$, we can control how close the approximation will be to argmax. When $\tau \to 0$ the entry corresponding to the maximal value will tend to 1, and the other entries will tend to 0. When $\tau \to \infty$ the result will tend to uniform. The smaller $\tau$ is the better the approximation gets. The problem with setting $\tau$ to small values is that the variance of the gradients will be too high. This will make it difficult for the training. A good practice is to start with big temperature and then anneal it towards small values.</p>
<p>You can read more about the Gumbel-softmax trick <a href="https://arxiv.org/abs/1611.01144">here</a> and <a href="https://arxiv.org/abs/1611.00712">here</a>.</p>
<h1 id="Enough-theory---it's-code-time!">Enough theory - it's code time!<a class="anchor-link" href="#Enough-theory---it's-code-time!">¶</a></h1><p><img alt="Photo by Blake Connally on Unsplash" src="images/gumbel-gan/code.jpg"/></p>
<p>To show that the theory works in real life, I'll use a toy problem. The data is a stream of numbers in the range 0 to 4. Each number has a different probability to come up in the stream. Your mission, should you choose to accept it, is to find out what the distribution over the 5 numbers is.</p>
<p>A simple solution would be to count, but we're going to do something much cooler (and ridiculously too complicated for this task): we'll train a <a href="https://arxiv.org/abs/1406.2661">GAN</a>. The generator will generate numbers from a distribution which should converge to the real one.</p>
<p>Here's an intuition of why it should work: let's say the true probability associated with the number 0 is 0. The discriminator will learn that 0's never come with the label REAL. Therefore, the generator will incur a big loss whenever it generates 0's. This will encourage the generator to stop generating 0's.</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="prompt input_prompt">In [1]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython2"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="kn">as</span> <span class="nn">tf</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="kn">as</span> <span class="nn">plt</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">tf</span><span class="o">.</span><span class="n">set_random_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="o">%</span><span class="k">matplotlib</span> inline
</pre></div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>These are all the hyperparameters we're going to use:</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="prompt input_prompt">In [2]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython2"><pre><span></span><span class="n">BATCHS_IN_EPOCH</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">200</span> <span class="c1"># the stream is infinite so one epoch will be defined as BATCHS_IN_EPOCH * BATCH_SIZE</span>
<span class="n">GENERATOR_TRAINING_FACTOR</span> <span class="o">=</span> <span class="mi">10</span> <span class="c1"># for every training of the disctiminator we'll train the generator 10 times</span>
<span class="n">LEARNING_RATE</span> <span class="o">=</span> <span class="mf">0.0007</span>
<span class="n">TEMPERATURE</span> <span class="o">=</span> <span class="mf">0.001</span> <span class="c1"># we use a constant, but for harder problems we should anneal it</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>We'll define the stream of data using <code>tf.data.Dataset.from_generator</code>. The underlying generator continuously samples numbers according to a predefined distribution:</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="prompt input_prompt">In [3]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython2"><pre><span></span><span class="n">number_to_prob</span> <span class="o">=</span> <span class="p">{</span>
<span class="mi">0</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="mi">1</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="mi">2</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="mi">3</span><span class="p">:</span> <span class="mf">0.3</span><span class="p">,</span>
<span class="mi">4</span><span class="p">:</span> <span class="mf">0.6</span>
<span class="p">}</span>
<span class="k">def</span> <span class="nf">generate_text</span><span class="p">():</span>
<span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">number_to_prob</span><span class="o">.</span><span class="n">keys</span><span class="p">(),</span> <span class="n">p</span><span class="o">=</span><span class="n">number_to_prob</span><span class="o">.</span><span class="n">values</span><span class="p">(),</span> <span class="n">size</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_generator</span><span class="p">(</span><span class="n">generate_text</span><span class="p">,</span>
<span class="n">output_types</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">output_shapes</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">make_one_shot_iterator</span><span class="p">()</span><span class="o">.</span><span class="n">get_next</span><span class="p">()</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">number_to_prob</span><span class="p">))</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h1 id="The-model-architecture">The model architecture<a class="anchor-link" href="#The-model-architecture">¶</a></h1><p><img alt="Photo by Dmitri Popov on Unsplash" src="images/gumbel-gan/architecture.jpg"/></p>
<p>The core of the generator is the Gumbel distribution. Tensorflow already has an implementation for it - <code>RelaxedOneHotCategorical</code>. We'll follow the same process described above: sample from the standard Gumbel distribution, add the result to the logits, and apply softmax. The logits are what the generator has to learn.</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="prompt input_prompt">In [4]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython2"><pre><span></span><span class="k">def</span> <span class="nf">generator</span><span class="p">():</span>
<span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">variable_scope</span><span class="p">(</span><span class="s1">'generator'</span><span class="p">):</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">get_variable</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">number_to_prob</span><span class="p">)]))</span>
<span class="n">gumbel_dist</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">RelaxedOneHotCategorical</span><span class="p">(</span><span class="n">TEMPERATURE</span><span class="p">,</span> <span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">)</span>
<span class="n">probs</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
<span class="n">generated</span> <span class="o">=</span> <span class="n">gumbel_dist</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span>
<span class="k">return</span> <span class="n">generated</span><span class="p">,</span> <span class="n">probs</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>The discriminator is a logistic regression model. Being simple, it allows us to focus in this demonstration on the Gumbel-softmax trick. If the discriminator were any smarter, we would have faced some challenges, such as:</p>
<ul>
<li>choosing a good schedule for the temperature annealing.</li>
<li>smoothing the real data from the stream, or otherwise the discriminator would be able to learn that smooth vectors mean FAKE, and hard vectors mean REAL.</li>
</ul>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="prompt input_prompt">In [5]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython2"><pre><span></span><span class="k">def</span> <span class="nf">discriminator</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">variable_scope</span><span class="p">(</span><span class="s1">'discriminator'</span><span class="p">,</span> <span class="n">reuse</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">AUTO_REUSE</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">fully_connected</span><span class="p">(</span><span class="n">x</span><span class="p">,</span>
<span class="n">num_outputs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">activation_fn</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>Next, we're going to define the usual GAN loss and training ops - one for the discriminator and one for the generator.</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="prompt input_prompt">In [6]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython2"><pre><span></span><span class="n">generated_outputs</span><span class="p">,</span> <span class="n">generated_probs</span> <span class="o">=</span> <span class="n">generator</span><span class="p">()</span>
<span class="n">discriminated_real</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
<span class="n">discriminated_generated</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">generated_outputs</span><span class="p">)</span>
<span class="n">d_loss_real</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span>
<span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">sigmoid_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">discriminated_real</span><span class="p">,</span>
<span class="n">labels</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">discriminated_real</span><span class="p">)))</span>
<span class="n">d_loss_fake</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span>
<span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">sigmoid_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">discriminated_generated</span><span class="p">,</span>
<span class="n">labels</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">discriminated_generated</span><span class="p">)))</span>
<span class="n">d_loss</span> <span class="o">=</span> <span class="n">d_loss_real</span> <span class="o">+</span> <span class="n">d_loss_fake</span>
<span class="n">g_loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span>
<span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">sigmoid_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">discriminated_generated</span><span class="p">,</span>
<span class="n">labels</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">discriminated_generated</span><span class="p">)))</span>
<span class="n">all_vars</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">()</span>
<span class="n">g_vars</span> <span class="o">=</span> <span class="p">[</span><span class="n">var</span> <span class="k">for</span> <span class="n">var</span> <span class="ow">in</span> <span class="n">all_vars</span> <span class="k">if</span> <span class="n">var</span><span class="o">.</span><span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'generator'</span><span class="p">)]</span>
<span class="n">d_vars</span> <span class="o">=</span> <span class="p">[</span><span class="n">var</span> <span class="k">for</span> <span class="n">var</span> <span class="ow">in</span> <span class="n">all_vars</span> <span class="k">if</span> <span class="n">var</span><span class="o">.</span><span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'discriminator'</span><span class="p">)]</span>
<span class="n">d_train_opt</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">AdamOptimizer</span><span class="p">(</span><span class="n">LEARNING_RATE</span><span class="p">)</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">d_loss</span><span class="p">,</span> <span class="n">var_list</span><span class="o">=</span><span class="n">d_vars</span><span class="p">)</span>
<span class="n">g_train_opt</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">AdamOptimizer</span><span class="p">(</span><span class="n">LEARNING_RATE</span><span class="p">)</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">g_loss</span><span class="p">,</span> <span class="n">var_list</span><span class="o">=</span><span class="n">g_vars</span><span class="p">)</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h1 id="Train,-baby,-train!">Train, baby, train!<a class="anchor-link" href="#Train,-baby,-train!">¶</a></h1><p><img alt="Photo by Steven Lelham on Unsplash" src="images/gumbel-gan/training.jpg"/></p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="prompt input_prompt">In [7]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython2"><pre><span></span><span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">Session</span><span class="p">()</span> <span class="k">as</span> <span class="n">sess</span><span class="p">:</span>
<span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">global_variables_initializer</span><span class="p">())</span>
<span class="n">learned_probs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">EPOCHS</span><span class="p">):</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">BATCHS_IN_EPOCH</span><span class="p">):</span>
<span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">d_train_opt</span><span class="p">)</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">GENERATOR_TRAINING_FACTOR</span><span class="p">):</span>
<span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">g_train_opt</span><span class="p">)</span>
<span class="n">learned_probs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">generated_probs</span><span class="p">))</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>After each epoch we evaluate <code>generated_probs</code>. At the beginning it's random, and as training progresses it should (hopefully) converge towards the real distribution defined by <code>number_to_prob</code>.</p>
<p>To see if it happens, we can plot the difference between the two distributions:</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="prompt input_prompt">In [8]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython2"><pre><span></span><span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
<span class="n">prob_errors</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">learned_prob</span><span class="p">)</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">number_to_prob</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
<span class="k">for</span> <span class="n">learned_prob</span> <span class="ow">in</span> <span class="n">learned_probs</span><span class="p">]</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">prob_errors</span><span class="p">),</span>
<span class="n">cmap</span><span class="o">=</span><span class="s1">'bwr'</span><span class="p">,</span>
<span class="n">aspect</span><span class="o">=</span><span class="s1">'auto'</span><span class="p">,</span>
<span class="n">vmin</span><span class="o">=-</span><span class="mi">2</span><span class="p">,</span>
<span class="n">vmax</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">'epoch'</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">'number'</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">colorbar</span><span class="p">(</span><span class="n">aspect</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">ticks</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]);</span>
</pre></div>
</div>
</div>
</div>
<div class="output_wrapper">
<div class="output">
<div class="output_area">
<div class="prompt"></div>
<div class="output_png output_subarea ">
<img src="
AAALEgAACxIB0t1+/AAAFDRJREFUeJzt3X2QZFd53/Hvb0YSECQjEApRdiWvMLIrNgFJWUQqAls2
kSMcjIiDE4EtY+PKxhUwqOyUjSAJFH/ZxhAnFRKziZTIsSwZGylWURgkHCGicgTaVTaWV0LxIi/F
qhSEeNELRBKz8+SPvrPb29Pd0zPTPbe75/upmuru0+ee+5x7u2eeOee+pKqQJElq00LbAUiSJJmQ
SJKk1pmQSJKk1pmQSJKk1pmQSJKk1pmQSJKk1pmQSJKkdUlydpLbk9yX5GCSd266Ta9DIkmS1iPJ
WcBZVXVPktOA/cAbquq+jbbpCIkkSVqXqnq4qu5pnj8B3A/s2EybJiSSJGnDkuwCLgA+t5l2ThpH
MOPywjPOqF07dx4vSFZX6lc2yHrqttHetKxLktSKw4cP8+ijj27ZL/zLknp0xLr74SDwVFfR3qra
210nyanAx4CrqurxzcQ2VQnJrp072XfrrccLFhZWP1/oGdTpfb3We4PqD2tn3HVGMa52tsIsxSpJ
U2T3K16xpet7FNh30mh/+rO09FRV7R74fnIynWTk+qq6abOxTVVCIkmSJiiBERMSlpaGNJMA1wD3
V9WHxhGaCYkkSdvFehKS4S4GrgTuTXKgKXt3VX1iow2akEiStF0kcMopm26mqu4ExnrsiwmJJEnb
xfhGSMZuOqOSJEnjN8UJyURPj0hyWZIHkhxK8q5JrkuSJK1hZcpmlJ8tNrE0Kcki8GHgUuAIcHeS
WzZzWVlJkrQJUzxCMsmoLgIOVdWDAEluBC4HTEgkSWrDmA5qnYRJJiQ7gC93vT4CvHKC65MkScNs
0xGSkSTZA+wBOKf7svGSJGm8tmlC8hBwdtfrnU3ZCZrr4u8F2P3yl9cE45EkaXvbplM2dwPnJTmX
TiJyBfDmCa5PkiQNsx1HSKpqKcnbgU8Bi8C1VXVwUuuTJEkj2G4JCUBzTfsNX9dekiSN0cLCtpyy
kSRJ02Y7jpBIkqQpsh2PIZEkSVPGKRtJkjQVHCGRJEmtcspGkiS1bpteGG1jlpaOP19YGFyv33u9
ZeOus7y8dt2VOsNi72dQO5u13jg2YjOxbkV8kqQOR0gkSVLrTEgkSVLrnLKRJEmtc4REkiS1zhES
SZLUOkdIJElS67ZjQpLkWuB1wCNV9dJJrUeSJI1oiqdsJnkRiP8CXDbB9iVJ0nqsjJCM8rPFJrbG
qvpskl2Tal+SJK3TdpyykSRJU2aKp2xaT0iS7AH2AJyzY0fL0UiSNMemeISk9RuJVNXeqtpdVbvP
POOMtsORJGm+bbdjSCRJ0pRZWJjaKZuJjZAkuQH4n8D3JTmS5OcntS5JkjSCWT7LJski8Omq+uH1
NFxVb9pwVJIkaTKm9BiSNaOqqqNJlpM8r6oe24qgJEnSBEzxlM2oadKTwL1JbgO+tVJYVe+YSFSS
JGkyZnWEpHFT8yNJkmbVrI+QVNV1SZ4DnFNVD0w4JkmSNClTOkIy0lk2SX4cOAB8snl9fpJbJhmY
JEkas1k+y6bxPuAi4DMAVXUgyYsnFJMkSZqEObh0/Heq6rEk3WXLY49mYQFOPbVrDV2rWFo6se5y
n9X3lvWrs7Aw/DUMzgz71R1Up7vuKOvsUWTNOr36dXcrlx93O+sxyq6ZVDvrXSbU+leyYj0bd1jd
cW2wEdrbyGe5n1G6vpFubWp/jGiUbbCeXTuonyP1pXtFAxra7D4bFsdG2t7sPtpMf0b5Gq03vnF9
JzZkii8dP2pUB5O8GVhMch7wDuBPJxeWJEkauylOSEb9f+IXgR8AngZuAB4HrppUUJIkaQJWpmxG
+dlio55l823gPUl+vfOynphsWJIkaeymeIRkpKiSvAK4Fjitef0Y8Naq2j/B2CRJ0jjNekICXAP8
s6r6HwBJXgX8Z+BlkwpMkiSN2RycZXN0JRkBqKo7kywNW0CSJE2ZKR4hGXpQa5ILk1wI3JHkI0ku
SfJDSf49zTVJhix7dpLbk9yX5GCSd44xbkmStF5jvDBaksuSPJDkUJJ3bTa0tdb4wZ7X7+16vtaJ
10vAL1fVPUlOA/Ynua2q7ltvkJIkaQzGNGWTZBH4MHApcAS4O8ktm/kbPzQhqaof3mjDVfUw8HDz
/Ikk9wM7ABMSSZLaML4pm4uAQ1X1YKfZ3Ahczib+xo96ls3pwM8Au7qXqap3jLj8LuAC4HN93tsD
7AE45+yzR2lOkiRtxPoSkhcm2df1em9V7W2e7wC+3PXeEeCVmwlt1Kg+AdwF3Ms6Lxmf5FTgY8BV
VfV47/tN5/YC7L7ggslfw1mSpG3sKIujVn20qnZPMpZuoyYkz66qX1pv40lOppOMXF9VN613eUmS
ND5Vq28Nt0EPAd3TGjubsg0bNSH5r0n+CfBxOpePB6Cqvj5ogXTuxHcNcH9VfWgzQUqSpM2rgmee
GUtTdwPnJTmXTiJyBfDmzTQ4akLyDPAB4D0cP7umgBcPWeZi4Erg3iQHmrJ3V9UnNhKoJEnanHGN
kFTVUpK3A58CFoFrq+rgZtocNSH5ZeAlVfXoqA1X1Z3Q5j2WJUlStzFO2dAMMIxtkGHUhOQQ8O1x
rVSSJG29MU7ZjN2oCcm3gANJbufEY0hGOu1XkiS1b5wjJOM2akLy35ofSZI0o2Y+Iamq6yYdiCRJ
mqyZn7JJ8pf0uXdNVQ07y0aSJE2RmR8hAbqv1PZs4CeBF4w7mKMs8hjPO/b6hEvCNpEurxQOuU/x
8nquJduv7hRkjwtD+jfove7y3jppzndaXFxdZ9AjQFby0JWN2m/jDntvUPl6OtgvsOZ5NSdyrRXC
eq20073KfmWDZNi9J9cKcpTt1fV60DZYXl59Ncb19GFYCL27pG9/m5WtOtVuIysHFje22IbUGicI
rvdztmo79WlgZW/VwuCraA5a77Dt3/tdoav93s/DSjtZPrq6vd52hsR1bF/1CThD2llVd8j3aJTl
B7Yz7Iuw8tf62OXVs+qt1V/H43VW1jUsvqG/HyZs5hOSqvpaT9FvJdkP/KvxhyRJkiZhHqZsLux6
uUBnxGQstwuUJElbY+ZHSIAPcvwYkiXgMJ1pG0mSNCPmISF5LfAPgV1dy1wBvH8CMUmSpAmY+Skb
Otcg+SZwD/DU5MKRJEmTMg8jJDur6rKJRiJJkiZu1hOSP03yN6vq3lEbTvJs4LPAs5r1/GFVvXcD
MUqSpDFYXp79KZtXAT/bXCDtaTonZldVvWzIMk8DP1JVTyY5GbgzyR9X1V2bC1mSJG3UrI+QvHa9
DVdVAU82L09uftq7GowkSdvczI+QVNWXNtJ4kkVgP/AS4MNV9bmNtCNJkjZvmg9qnejFmKvqaFWd
D+wELkry0t46SfYk2Zdk39e+9tVJhiNJ0ra3tDTaz1bbkrtDVNU3gduBVWfqVNXeqtpdVbvPOOPM
rQhHkqRtaWXKZpSfrTaxhCTJmUlOb54/B7gU+MKk1idJktY2rSMkk7wfzVnAdc1xJAvAR6vq4xNc
nyRJGmKajyGZWEJSVX8GXDCp9iVJ0vrMw6XjJUnSjNuWIySSJGm6mJBIkqTWOWUjSZJa5wiJJElq
nQmJJElqnVM2I3r6aTh0CBb6XK5tpWzlMVldZ3Gx83jSSf2X6X6+Urd3mWHLn3LK6nWe1LMFFxea
+wcuLx8v7H4Ox9PT3vL16t1Q3e3124i95b0dbB6L4xt35fkyiyc8nrDaY+31X2V3L3vDGhRmv/eO
bVs41tc0jytRLQ7b7sNWdrzSCe11B7/YU6d3mW7HtlufXby8vHobnlh3cVXZsO02YDeSfveyHGWb
DPpcdtVd3b9+/T2xn8cWH+FjerzVte/H2f15HUXvNh32Neytc+z3x8LG7hPa+33qt64sHx0Y2OKx
gE5cqN/nbeVzturT2uezs+Lo8sq2XB3fYrMvju2TrpX11q6mZGklhu51No/92ukNbL37ttfKOla1
s9D/OwiwvHDyCXEuLn/n2HuLyytDCytftpNOfN1vXX1stl+b4QiJJElqnSMkkiSpdY6QSJKk1pmQ
SJKk1jllI0mSWucIiSRJat22TkiSLAL7gIeq6nWTXp8kSepvu0/ZvBO4H/iuLViXJEkaYJpHSEa5
UtSGJdkJ/H3gP01yPZIkaTRLS6P9bLVJj5D8FvArwGkTXo8kSVrD8vL0TtlMbIQkyeuAR6pq/xr1
9iTZl2TfN77x1UmFI0mSmN4RkklO2VwMvD7JYeBG4EeS/G5vparaW1W7q2r3859/5gTDkSRpe1s5
hmQaE5KJTdlU1dXA1QBJLgH+eVX99KTWJ0mShtvuZ9lIkqQpMM1n2WxJQlJVnwE+sxXrkiRJ/W37
hESSJLXPKRtJktQ6R0gkSVLrtmqEJMkHgB8HngG+CPxcVX1z2DITvVKrJEmaHlt42u9twEur6mXA
/6E563YYR0gkSdomtmrKpqpu7Xp5F/DGtZYxIZEkaZto6aDWtwK/v1alVNUWxDKaJF8FvgU82nYs
E/JC5rdvYP9mnf2bXfPcN5jv/n13VW3ZZcqTfJLO9hzFs4Gnul7vraq9XW19GvhrfZZ7T1X9UVPn
PcBu4CdqjYRjqhISgCT7qmp323FMwjz3DezfrLN/s2ue+wbz3795leRngX8KvKaqvr1WfadsJEnS
WCW5DPgV4IdGSUbAs2wkSdL4/TvgNOC2JAeS/PZaC0zjCMnetavMrHnuG9i/WWf/Ztc89w3mv39z
p6pest5lpu4YEkmStP04ZSNJklo3NQlJksuSPJDkUJJ3tR3PZiU5O8ntSe5LcjDJO5vy9yV5qJlT
O5Dkx9qOdaOSHE5yb9OPfU3ZC5LcluQvmsfntx3nRiT5vq59dCDJ40mumuX9l+TaJI8k+fOusr77
Kx3/tvk+/lmSC9uLfG0D+vaBJF9o4r85yelN+a4k/69rH645t922Af0b+FlMcnWz7x5I8vfaiXp0
A/r3+119O5zkQFM+c/tPo5mKKZski3QuLXspcAS4G3hTVd3XamCbkOQs4KyquifJacB+4A3APwKe
rKrfbDXAMUhyGNhdVY92lf0G8PWq+rUmsXx+Vf1qWzGOQ/P5fAh4JfBzzOj+S/KDwJPA71TVS5uy
vvur+eP2i8CP0en3v6mqV7YV+1oG9O1Hgf9eVUtJfh2g6dsu4OMr9WbBgP69jz6fxSTfD9wAXAT8
deDTwPdW1dEtDXod+vWv5/0PAo9V1ftncf9pNNMyQnIRcKiqHqyqZ4AbgctbjmlTqurhqrqnef4E
cD+wo92otsTlwHXN8+voJGGz7jXAF6vqS20HshlV9Vng6z3Fg/bX5XT+OFRV3QWc3iTZU6lf36rq
1qpauUj2XcDOLQ9sTAbsu0EuB26sqqer6i+BQ3R+x06tYf1LEjr/yN2wpUFpy01LQrID+HLX6yPM
0R/vJqO/APhcU/T2Zhj52lmd0mgUcGuS/Un2NGUvqqqHm+f/F3hRO6GN1RWc+MtwXvYfDN5f8/ad
fCvwx12vz03yv5LckeTVbQU1Bv0+i/O2714NfKWq/qKrbF72n7pMS0Iyt5KcCnwMuKqqHgf+A/A9
wPnAw8AHWwxvs15VVRcCrwXe1gy7HtNcJrj9OcFNSHIK8HrgD5qiedp/J5iH/dVPOpeuXgKub4oe
Bs6pqguAXwJ+L8l3tRXfJsztZ7HHmzjxH4J52X/qMS0JyUPA2V2vdzZlMy3JyXSSkeur6iaAqvpK
VR2tqmXgPzLlQ6nDVNVDzeMjwM10+vKVlaH95vGR9iIci9cC91TVV2C+9l9j0P6ai+9kOpeufh3w
Uyv30WimMr7WPN8PfBH43taC3KAhn8W52HcASU4CfoKuG7PNy/7TatOSkNwNnJfk3OY/0iuAW1qO
aVOaec9rgPur6kNd5d3z8P8A+PPeZWdBkuc2B+uS5LnAj9Lpyy3AW5pqbwH+qJ0Ix+aE/87mZf91
GbS/bgF+pjnb5m/TOaDw4X4NTKscv3T167svXZ3kzOZAZZK8GDgPeLCdKDduyGfxFuCKJM9Kci6d
/n1+q+Mbk78LfKGqjqwUzMv+02pTcaXW5ij4twOfAhaBa6vqYMthbdbFwJXAvSunqwHvBt6U5Hw6
Q+OH6dx4aBa9CLi5k3dxEvB7VfXJJHcDH03y88CX6ByMNpOaROtSTtxHvzGr+y/JDcAlwAuTHAHe
C/wa/ffXJ+icYXMI+Dads4um1oC+XQ08i86lqwHuqqpfAH4QeH+S7wDLwC9U1agHjLZiQP8u6fdZ
rKqDST4K3Ednqupt03yGDfTvX1Vdw+rjt2AG959GMxWn/UqSpO1tWqZsJEnSNmZCIkmSWmdCIkmS
WmdCIkmSWmdCIkmSWmdCImlNSS5J8vG245A0v0xIJElS60xIpDmS5KeTfD7JgSQfSbKY5Mkk/zrJ
wSR/kuTMpu75Se5qbs5288rN2ZK8JMmnk/zvJPck+Z6m+VOT/GGSLyS5vrkasSSNhQmJNCeS/A3g
HwMXV9X5wFHgp4DnAvuq6geAO+hc5RPgd4BfraqXAfd2lV8PfLiqXg78HTo3M4POHauvAr4feDGd
qxFL0lhMxaXjJY3Fa4C/BdzdDF48h87N8pY5fnOy3wVuSvI84PSquqMpvw74g+b+RDuq6maAqnoK
oGnv8yv3FGluh7ALuHPy3ZK0HZiQSPMjwHVVdfUJhcm/7Km30ftFPN31/Cj+/pA0Rk7ZSPPjT4A3
JvmrAElekOS76XzP39jUeTNwZ1U9Bnwjyaub8iuBO6rqCeBIkjc0bTwryV/Z0l5I2pb8D0eaE1V1
X5J/AdyaZAH4DvA24FvARc17j9A5zgTgLcBvNwnHgxy/o++VwEeSvL9p4ye3sBuStinv9ivNuSRP
VtWpbcchScM4ZSNJklrnCIkkSWqdIySSJKl1JiSSJKl1JiSSJKl1JiSSJKl1JiSSJKl1JiSSJKl1
/x+nMDxo9lsN6QAAAABJRU5ErkJggg==
"/>
</div>
</div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>Red means the generator erroneously output too big probability, and blue means too small. White means it converged.</p>
<p>We can see that by the end of the training all the entries have pretty much converged. Sweet!</p>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="prompt input_prompt">
</div>
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h1 id="Conclusion">Conclusion<a class="anchor-link" href="#Conclusion">¶</a></h1><p>In this post you learned what the Gumbel-softmax trick is.</p>
<p>Using this trick, you can sample from a discrete distribution and let the gradients propagate to the weights that affect the distribution's parameters.</p>
<p>This trick opens doors to many interesting applications. For start, you can find an example of text generation in the paper <a href="https://arxiv.org/abs/1611.04051">GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution</a>.</p>
<p>Feel free to drop a line in the comments if you find more interesting use cases!</p>
</div>
</div>
</div>
<script type="text/javascript">if (!document.getElementById('mathjaxscript_pelican_#%@#$@#')) {
var mathjaxscript = document.createElement('script');
mathjaxscript.id = 'mathjaxscript_pelican_#%@#$@#';
mathjaxscript.type = 'text/javascript';
mathjaxscript.src = '//cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML';
mathjaxscript[(window.opera ? "innerHTML" : "text")] =
"MathJax.Hub.Config({" +
" config: ['MMLorHTML.js']," +
" TeX: { extensions: ['AMSmath.js','AMSsymbols.js','noErrors.js','noUndefined.js'], equationNumbers: { autoNumber: 'AMS' } }," +
" jax: ['input/TeX','input/MathML','output/HTML-CSS']," +
" extensions: ['tex2jax.js','mml2jax.js','MathMenu.js','MathZoom.js']," +
" displayAlign: 'center'," +
" displayIndent: '0em'," +
" showMathMenu: true," +
" tex2jax: { " +
" inlineMath: [ ['$','$'] ], " +
" displayMath: [ ['$$','$$'] ]," +
" processEscapes: true," +
" preview: 'TeX'," +
" }, " +
" 'HTML-CSS': { " +
" styles: { '.MathJax_Display, .MathJax .mo, .MathJax .mi, .MathJax .mn': {color: 'black ! important'} }" +
" } " +
"}); ";
(document.body || document.getElementsByTagName('head')[0]).appendChild(mathjaxscript);
}
</script>
</article>
<div class="tags">
<p>tags: <a href="http://anotherdatum.com/tag/deep-learning.html">deep-learning</a>, <a href="http://anotherdatum.com/tag/gan.html">GAN</a></p>
</div>
<hr>
<!-- Begin MailChimp Signup Form -->
<link href="//cdn-images.mailchimp.com/embedcode/classic-10_7.css" rel="stylesheet" type="text/css">
<style type="text/css">
#mc_embed_signup{background:#fff; clear:left; font:14px Helvetica,Arial,sans-serif; width:300px;}
#mc_embed_signup form{padding: 0;}
/* Add your own MailChimp form style overrides in your site stylesheet or in this style block.
We recommend moving this block and the preceding CSS link to the HEAD of your HTML file. */
</style>
<div id="mc_embed_signup">
<form action="https://anotherdatum.us14.list-manage.com/subscribe/post?u=6894d7badcfb253606fa3fb54&amp;id=c6f34ad6b7" method="post" id="mc-embedded-subscribe-form" name="mc-embedded-subscribe-form" class="validate" target="_blank" novalidate>
<div id="mc_embed_signup_scroll">
<h2>Get updated of new posts</h2>
<div class="mc-field-group">
<label for="mce-EMAIL">Email Address </label>
<input type="email" value="" name="EMAIL" class="required email" id="mce-EMAIL">
</div>
<div id="mce-responses" class="clear">
<div class="response" id="mce-error-response" style="display:none"></div>
<div class="response" id="mce-success-response" style="display:none"></div>
</div> <!-- real people should not fill this in and expect good things - do not remove this or risk form bot signups-->
<div style="position: absolute; left: -5000px;" aria-hidden="true"><input type="text" name="b_6894d7badcfb253606fa3fb54_c6f34ad6b7" tabindex="-1" value=""></div>
<div class="clear"><input type="submit" value="Subscribe" name="subscribe" id="mc-embedded-subscribe" class="button"></div>
</div>
</form>
</div>
<script type='text/javascript' src='//s3.amazonaws.com/downloads.mailchimp.com/js/mc-validate.js'></script><script type='text/javascript'>(function($) {window.fnames = new Array(); window.ftypes = new Array();fnames[0]='EMAIL';ftypes[0]='email';fnames[1]='FNAME';ftypes[1]='text';fnames[2]='LNAME';ftypes[2]='text';}(jQuery));var $mcj = jQuery.noConflict(true);</script>
<!--End mc_embed_signup-->
<hr />
<div class="comments">
<h2>Comments !</h2>
<div id="disqus_thread"></div>
<script type="text/javascript">
var disqus_shortname = 'anotherdatum';
var disqus_identifier = 'gumbel-gan.html';
var disqus_url = 'http://anotherdatum.com/gumbel-gan.html';
(function() {
var dsq = document.createElement('script'); dsq.type = 'text/javascript'; dsq.async = true;
dsq.src = '//anotherdatum.disqus.com/embed.js';
(document.getElementsByTagName('head')[0] || document.getElementsByTagName('body')[0]).appendChild(dsq);
})();
</script>
<noscript>Please enable JavaScript to view the comments.</noscript>
</div>
</div>
</div>
</div>
<hr>
<!-- Footer -->
<footer>
<div class="container">
<div class="row">
<div class="col-lg-8 col-lg-offset-2 col-md-10 col-md-offset-1">
<ul class="list-inline text-center">
<li>
<a href="https://il.linkedin.com/in/yoelzeldes">
<span class="fa-stack fa-lg">
<i class="fa fa-circle fa-stack-2x"></i>
<i class="fa fa-linkedin fa-stack-1x fa-inverse"></i>
</span>
</a>
</li>
<li>
<a href="https://github.com/yoel-zeldes">
<span class="fa-stack fa-lg">
<i class="fa fa-circle fa-stack-2x"></i>
<i class="fa fa-github fa-stack-1x fa-inverse"></i>
</span>
</a>
</li>
<li>
<a href="https://www.facebook.com/yoel.zeldes">
<span class="fa-stack fa-lg">
<i class="fa fa-circle fa-stack-2x"></i>
<i class="fa fa-facebook fa-stack-1x fa-inverse"></i>
</span>
</a>
</li>
<li>
<a href="https://twitter.com/YZeldes">
<span class="fa-stack fa-lg">
<i class="fa fa-circle fa-stack-2x"></i>
<i class="fa fa-twitter fa-stack-1x fa-inverse"></i>
</span>
</a>
</li>
</ul>
<p class="copyright text-muted">
Blog powered by <a href="http://getpelican.com">Pelican</a>,
which takes great advantage of <a href="http://python.org">Python</a>.
<br />
Blog sources can be found <a href="https://github.com/yoel-zeldes/yoel-zeldes.github.io">here</a>.
</p> </div>
</div>
</div>
</footer>
<!-- jQuery -->
<script src="http://anotherdatum.com/theme/js/jquery.js"></script>
<!-- Bootstrap Core JavaScript -->
<script src="http://anotherdatum.com/theme/js/bootstrap.min.js"></script>
<!-- Custom Theme JavaScript -->
<script src="http://anotherdatum.com/theme/js/clean-blog.min.js"></script>
<script type="text/javascript">
var _gaq = _gaq || [];
_gaq.push(['_setAccount', 'UA-83684090-1']);
_gaq.push(['_trackPageview']);
(function() {
var ga = document.createElement('script'); ga.type = 'text/javascript'; ga.async = true;
ga.src = ('https:' == document.location.protocol ? 'https://ssl' : 'http://www') + '.google-analytics.com/ga.js';
var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(ga, s);
})();
</script>
<script type="text/javascript">
var disqus_shortname = 'anotherdatum';
(function () {
var s = document.createElement('script'); s.async = true;
s.type = 'text/javascript';
s.src = '//' + disqus_shortname + '.disqus.com/count.js';
(document.getElementsByTagName('HEAD')[0] || document.getElementsByTagName('BODY')[0]).appendChild(s);
}());
</script>
</body>
</html>