In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


# Gumbel Softmax


* Pre-print, published in ICLR 2017 https://arxiv.org/pdf/1611.01144.pdf


## Experiments
### 1 Stochastic Gradient Estimation
* **Dataset**: We use the MNIST dataset with fixed binarization for training and evaluation
* **Tricks**: We also found that variance normalization was necessary
* **Network**: We used sigmoid activation functions for binary (Bernoulli) neural networks and softmax activations for categorical variables.
* **Training**: Models were trained using stochastic gradient descent with momentum 0.9.
* **Learning rates**:  are chosen from {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3}; we select the best learning rate for each estimator using the MNIST validation set, and report performance on the test set.

### 2) Structured output prediction with stochastic binary networks

### 3) Generative modelling with variational Autoencoders

### 4) Generative semi supervised classification

# Requirements

In [1]:
import tensorflow as tf
tf.__version__

'2.1.0'

In [2]:
import pathlib
import os
import matplotlib.pyplot as plt
import numpy as np

np.set_printoptions(precision=4)

In [3]:
import pandas as pd

# Load and preprocess data

In [5]:
import tensorflow_datasets as tfds
mnist_data = tfds.load("binarized_mnist")
mnist_train, mnist_test = mnist_data["train"], mnist_data["test"]
assert isinstance(mnist_train, tf.data.Dataset)

DatasetNotFoundError: Dataset BinarizedMNIST not found. Available datasets:
	- abstract_reasoning
	- aeslc
	- aflw2k3d
	- amazon_us_reviews
	- arc
	- bair_robot_pushing_small
	- beans
	- big_patent
	- bigearthnet
	- billsum
	- binarized_mnist
	- binary_alpha_digits
	- blimp
	- c4
	- caltech101
	- caltech_birds2010
	- caltech_birds2011
	- cars196
	- cassava
	- cats_vs_dogs
	- celeb_a
	- celeb_a_hq
	- cfq
	- chexpert
	- cifar10
	- cifar100
	- cifar10_1
	- cifar10_corrupted
	- citrus_leaves
	- cityscapes
	- civil_comments
	- clevr
	- cmaterdb
	- cnn_dailymail
	- coco
	- coil100
	- colorectal_histology
	- colorectal_histology_large
	- common_voice
	- cos_e
	- crema_d
	- curated_breast_imaging_ddsm
	- cycle_gan
	- deep_weeds
	- definite_pronoun_resolution
	- dementiabank
	- diabetic_retinopathy_detection
	- div2k
	- dmlab
	- downsampled_imagenet
	- dsprites
	- dtd
	- duke_ultrasound
	- emnist
	- eraser_multi_rc
	- esnli
	- eurosat
	- fashion_mnist
	- flic
	- flores
	- food101
	- forest_fires
	- gap
	- geirhos_conflict_stimuli
	- german_credit_numeric
	- gigaword
	- glue
	- groove
	- higgs
	- horses_or_humans
	- i_naturalist2017
	- image_label_folder
	- imagenet2012
	- imagenet2012_corrupted
	- imagenet2012_subset
	- imagenet_resized
	- imagenette
	- imagewang
	- imdb_reviews
	- iris
	- kitti
	- kmnist
	- lfw
	- librispeech
	- librispeech_lm
	- libritts
	- ljspeech
	- lm1b
	- lost_and_found
	- lsun
	- malaria
	- math_dataset
	- mnist
	- mnist_corrupted
	- movie_rationales
	- moving_mnist
	- multi_news
	- multi_nli
	- multi_nli_mismatch
	- natural_questions
	- newsroom
	- nsynth
	- omniglot
	- open_images_challenge2019_detection
	- open_images_v4
	- opinosis
	- oxford_flowers102
	- oxford_iiit_pet
	- para_crawl
	- patch_camelyon
	- pet_finder
	- places365_small
	- plant_leaves
	- plant_village
	- plantae_k
	- qa4mre
	- quickdraw_bitmap
	- reddit
	- reddit_tifu
	- resisc45
	- robonet
	- rock_paper_scissors
	- rock_you
	- savee
	- scan
	- scene_parse150
	- scicite
	- scientific_papers
	- shapes3d
	- smallnorb
	- snli
	- so2sat
	- speech_commands
	- squad
	- stanford_dogs
	- stanford_online_products
	- starcraft_video
	- stl10
	- sun397
	- super_glue
	- svhn_cropped
	- ted_hrlr_translate
	- ted_multi_translate
	- tedlium
	- tf_flowers
	- the300w_lp
	- tiny_shakespeare
	- titanic
	- trivia_qa
	- uc_merced
	- ucf101
	- vgg_face2
	- visual_domain_decathlon
	- voc
	- voxceleb
	- waymo_open_dataset
	- web_questions
	- wider_face
	- wiki40b
	- wikihow
	- wikipedia
	- wmt14_translate
	- wmt15_translate
	- wmt16_translate
	- wmt17_translate
	- wmt18_translate
	- wmt19_translate
	- wmt_t2t_translate
	- wmt_translate
	- xnli
	- xsum
	- yelp_polarity_reviews
Check that:
    - if dataset was added recently, it may only be available
      in `tfds-nightly`
    - the dataset name is spelled correctly
    - dataset class defines all base class abstract methods
    - dataset class is not in development, i.e. if IN_DEVELOPMENT=True
    - the module defining the dataset class is imported
