<a href="https://colab.research.google.com/github/sean-halpin/diffusion_models/blob/main/dct_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Imports

In [None]:
!pip freeze | xargs -IXX pip uninstall -y XX

In [None]:
!pip install Theano

In [None]:
!pip install numpy

In [None]:
!pip install picklable-itertools==0.1.1

In [None]:
!pip install progressbar2==3.10.0

In [None]:
!pip install pyyaml==3.11

In [None]:
!pip install six==1.15.0

In [None]:
!pip install toolz==0.9.0

In [None]:
# !pip install git+https://github.com/Theano/Theano.git#egg=theano

In [None]:
!pip install git+https://github.com/sean-halpin/fuel.git

In [None]:
!pip install --verbose git+https://github.com/mila-iqia/blocks.git

In [None]:
import argparse
import numpy as np
import os
import warnings

import theano
import theano.tensor as T

In [None]:
from theano.tensor.shared_randomstreams import RandomStreams
from blocks.algorithms import (RMSProp, GradientDescent, CompositeRule,RemoveNotFinite)
from blocks.extensions import FinishAfter, Timing, Printing
from blocks.extensions.monitoring import (TrainingDataMonitoring,DataStreamMonitoring)
from blocks.extensions.saveload import Checkpoint
from blocks.extensions.training import SharedVariableModifier
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph, apply_dropout
from blocks.main_loop import MainLoop
import blocks.model
from blocks.roles import INPUT, PARAMETER

from fuel.streams import DataStream
from fuel.schemes import ShuffledScheme
from fuel.transformers import Flatten, ScaleAndShift


### Helpers

#### Viz

In [None]:
"""
Tools for plotting / visualization
"""

import matplotlib
matplotlib.use('Agg')  # no displayed figures -- need to call before loading pylab
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import warnings

def is_square(shp, n_colors=1):
    """
    Test whether entries in shp are square numbers, or are square numbers after divigind out the
    number of color channels.
    """
    is_sqr = (shp == np.round(np.sqrt(shp))**2)
    is_sqr_colors = (shp == n_colors*np.round(np.sqrt(np.array(shp)/float(n_colors)))**2)
    return is_sqr | is_sqr_colors

def show_receptive_fields(theta, P=None, n_colors=None, max_display=100, grid_wa=None):
    """
    Display receptive fields in a grid. Tries to intelligently guess whether to treat the rows,
    the columns, or the last two axes together as containing the receptive fields. It does this
    by checking which axes are square numbers -- so you can get some unexpected plots if the wrong
    axis is a square number, or if multiple axes are. It also tries to handle the last axis
    containing color channels correctly.
    """

    shp = np.array(theta.shape)
    if n_colors is None:
        n_colors = 1
        if shp[-1] == 3:
            n_colors = 3
    # multiply colors in as appropriate
    if shp[-1] == n_colors:
        shp[-2] *= n_colors
        theta = theta.reshape(shp[:-1])
        shp = np.array(theta.shape)
    if len(shp) > 2:
        # merge last two axes
        shp[-2] *= shp[-1]
        theta = theta.reshape(shp[:-1])
        shp = np.array(theta.shape)
    if len(shp) > 2:
        # merge leading axes
        theta = theta.reshape((-1,shp[-1]))
        shp = np.array(theta.shape)
    if len(shp) == 1:
        theta = theta.reshape((-1,1))
        shp = np.array(theta.shape)

    # figure out the right orientation, by looking for the axis with a square
    # number of entries, up to number of colors. transpose if required
    is_sqr = is_square(shp, n_colors=n_colors)
    if is_sqr[0] and is_sqr[1]:
        warnings.warn("Unsure of correct matrix orientation. "
            "Assuming receptive fields along first dimension.")
    elif is_sqr[1]:
        theta = theta.T
    elif not is_sqr[0] and not is_sqr[1]:
        # neither direction corresponds well to an image
        # NOTE if you delete this next line, the code will work. The rfs just won't look very
        # image like
        return False

    theta = theta[:,:max_display].copy()

    if P is None:
        img_w = int(np.ceil(np.sqrt(theta.shape[0]/float(n_colors))))
    else:
        img_w = int(np.ceil(np.sqrt(P.shape[0]/float(n_colors))))
    nf = theta.shape[1]
    if grid_wa is None:
        grid_wa = int(np.ceil(np.sqrt(float(nf))))
    grid_wb = int(np.ceil(nf / float(grid_wa)))

    if P is not None:
        theta = np.dot(P, theta)

    vmin = np.min(theta)
    vmax = np.max(theta)

    for jj in range(nf):
        plt.subplot(grid_wa, grid_wb, jj+1)
        ptch = np.zeros((n_colors*img_w**2,))
        ptch[:theta.shape[0]] = theta[:,jj]
        if n_colors==3:
            ptch = ptch.reshape((n_colors, img_w, img_w))
            ptch = ptch.transpose((1,2,0)) # move color channels to end
        else:
            ptch = ptch.reshape((img_w, img_w))
        ptch -= vmin
        ptch /= vmax-vmin
        plt.imshow(ptch, interpolation='nearest', cmap=cm.Greys_r )
        plt.axis('off')

    return True


def plot_parameter(theta_in, base_fname_part1, base_fname_part2="", title = '', n_colors=None):
    """
    Save both a raw and receptive field style plot of the contents of theta_in.
    base_fname_part1 provides the mandatory root of the filename.
    """

    theta = np.array(theta_in.copy()) # in case it was a scalar
    print("%s min %g median %g mean %g max %g shape"%(
        title, np.min(theta), np.median(theta), np.mean(theta), np.max(theta)), theta.shape)
    theta = np.squeeze(theta)
    if len(theta.shape) == 0:
        # it's a scalar -- make it a 1d array
        theta = np.array([theta])
    shp = theta.shape
    if len(shp) > 2:
        theta = theta.reshape((theta.shape[0], -1))
        shp = theta.shape

    ## display basic figure
    plt.figure(figsize=[8,8])
    if len(shp) == 1:
        plt.plot(theta, '.', alpha=0.5)
    elif len(shp) == 2:
        plt.imshow(theta, interpolation='nearest', aspect='auto', cmap=cm.Greys_r)
        plt.colorbar()

    plt.title(title)
    plt.savefig(base_fname_part1 + '_raw_' + base_fname_part2 + '.pdf')
    plt.close()

    ## also display it in basis function view if it's a matrix, or
    ## if it's a bias with a square number of entries
    if len(shp) >= 2 or is_square(shp[0]):
        if len(shp) == 1:
            theta = theta.reshape((-1,1))
        plt.figure(figsize=[8,8])
        if show_receptive_fields(theta, n_colors=n_colors):
            plt.suptitle(title + "receptive fields")
            plt.savefig(base_fname_part1 + '_rf_' + base_fname_part2 + '.pdf')
        plt.close()

def plot_images(X, fname):
    """
    Plot images in a grid.
    X is expected to be a 4d tensor of dimensions [# images]x[# colors]x[height]x[width]
    """
    ## plot
    # move color to end
    Xcol = X.reshape((X.shape[0],-1,)).T
    plt.figure(figsize=[8,8])
    if show_receptive_fields(Xcol, n_colors=X.shape[1]):
        plt.savefig(fname + '.pdf')
    else:
        warnings.warn('Images unexpected shape.')
    plt.close()

    ## save as a .npz file
    np.savez(fname + '.npz', X=X)


#### Sampler

In [None]:
import numpy as np

# import viz

def diffusion_step(Xmid, t, get_mu_sigma, denoise_sigma, mask, XT, rng):
    """
    Run a single reverse diffusion step
    """
    mu, sigma = get_mu_sigma(Xmid, np.array([[t]]))
    if denoise_sigma is not None:
        sigma_new = (sigma**-2 + denoise_sigma**-2)**-0.5
        mu_new = mu * sigma_new**2 * sigma**-2 + XT * sigma_new**2 * denoise_sigma**-2
        sigma = sigma_new
        mu = mu_new
    if mask is not None:
        mu.flat[mask] = XT.flat[mask]
        sigma.flat[mask] = 0.
    Xmid = mu + sigma*rng.normal(size=Xmid.shape)
    return Xmid


def generate_inpaint_mask(n_samples, n_colors, spatial_width):
    """
    The mask will be True where we keep the true image, and False where we're
    inpainting.
    """
    mask = np.zeros((n_samples, n_colors, spatial_width, spatial_width), dtype=bool)
    # simple mask -- just mask out half the image
    mask[:,:,:,spatial_width/2:] = True
    return mask.ravel()


def generate_samples(model, get_mu_sigma,
            n_samples=36, inpaint=False, denoise_sigma=None, X_true=None,
            base_fname_part1="samples", base_fname_part2='',
            num_intermediate_plots=4, seed=12345):
    """
    Run the reverse diffusion process (generative model).
    """
    # use the same noise in the samples every time, so they're easier to
    # compare across learning
    rng = np.random.RandomState(seed)

    spatial_width = model.spatial_width
    n_colors = model.n_colors

    # set the initial state X^T of the reverse trajectory
    XT = rng.normal(size=(n_samples,n_colors,spatial_width,spatial_width))
    if denoise_sigma is not None:
        XT = X_true + XT*denoise_sigma
        base_fname_part1 += '_denoise%g'%denoise_sigma
    if inpaint:
        mask = generate_inpaint_mask(n_samples, n_colors, spatial_width)
        XT.flat[mask] = X_true.flat[mask]
        base_fname_part1 += '_inpaint'
    else:
        mask = None

    if X_true is not None:
        plot_images(X_true, base_fname_part1 + '_true' + base_fname_part2)
    plot_images(XT, base_fname_part1 + '_t%04d'%model.trajectory_length + base_fname_part2)

    Xmid = XT.copy()
    for t in range(model.trajectory_length-1, 0, -1):
        Xmid = diffusion_step(Xmid, t, get_mu_sigma, denoise_sigma, mask, XT, rng)
        if np.mod(model.trajectory_length-t,
            int(np.ceil(model.trajectory_length/(num_intermediate_plots+2.)))) == 0:
            plot_images(Xmid, base_fname_part1 + '_t%04d'%t + base_fname_part2)

    X0 = Xmid
    plot_images(X0, base_fname_part1 + '_t%04d'%0 + base_fname_part2)


#### Util

In [None]:
import numpy as np
import os
import theano
import theano.tensor as T
import time

logit = lambda u: T.log(u / (1.-u))
logit_np = lambda u: np.log(u / (1.-u)).astype(theano.config.floatX)

def get_norms(model, gradients):
    """Compute norm of weights and their gradients divided by the number of elements"""
    norms = []
    grad_norms = []
    for param_name, param in model.params.items():
        norm = T.sqrt(T.sum(T.square(param))) / T.prod(param.shape.astype(theano.config.floatX))
        norm.name = 'norm_' + param_name
        norms.append(norm)
        grad = gradients[param]
        grad_norm = T.sqrt(T.sum(T.square(grad))) / T.prod(grad.shape.astype(theano.config.floatX))
        grad_norm.name = 'grad_norm_' + param_name
        grad_norms.append(grad_norm)
    return norms, grad_norms

def create_log_dir(args, model_id):
    model_id += args.suffix + time.strftime('-%y%m%dT%H%M%S')
    model_dir = os.path.join(os.path.expanduser(args.output_dir), model_id)
    os.makedirs(model_dir)
    return model_dir


#### Extensions

In [None]:
"""
Extensions called during training to generate samples and diagnostic plots and printouts.
"""

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import theano.tensor as T
import theano

from blocks.extensions import SimpleExtension

# import viz
# import sampler


class PlotSamples(SimpleExtension):
    def __init__(self, model, algorithm, X, path, n_samples=49, **kwargs):
        """
        Generate samples from the model. The do() function is called as an extension during training.
        Generates 3 types of samples:
        - Sample from generative model
        - Sample from image denoising posterior distribution (default signal to noise of 1)
        - Sample from image inpainting posterior distribution (inpaint left half of image)
        """

        super(PlotSamples, self).__init__(**kwargs)
        self.model = model
        self.path = path
        n_samples = np.min([n_samples, X.shape[0]])
        self.X = X[:n_samples].reshape(
            (n_samples, model.n_colors, model.spatial_width, model.spatial_width))
        self.n_samples = n_samples
        X_noisy = T.tensor4('X noisy samp', dtype=theano.config.floatX)
        t = T.matrix('t samp', dtype=theano.config.floatX)
        self.get_mu_sigma = theano.function([X_noisy, t], model.get_mu_sigma(X_noisy, t),
            allow_input_downcast=True)

    def do(self, callback_name, *args):

        import sys
        sys.setrecursionlimit(10000000)

        print("generating samples")
        base_fname_part1 = self.path + '/samples-'
        base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done']
        generate_samples(self.model, self.get_mu_sigma,
            n_samples=self.n_samples, inpaint=False, denoise_sigma=None, X_true=None,
            base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2)
        generate_samples(self.model, self.get_mu_sigma,
            n_samples=self.n_samples, inpaint=True, denoise_sigma=None, X_true=self.X,
            base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2)
        generate_samples(self.model, self.get_mu_sigma,
            n_samples=self.n_samples, inpaint=False, denoise_sigma=1, X_true=self.X,
            base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2)


class PlotParameters(SimpleExtension):
    def __init__(self, model, blocks_model, path, **kwargs):
        super(PlotParameters, self).__init__(**kwargs)
        self.path = path
        self.model = model
        self.blocks_model = blocks_model

    def do(self, callback_name, *args):

        import sys
        sys.setrecursionlimit(10000000)

        print("plotting parameters")
        for param in self.blocks_model.parameters:
            param_name = param.name
            filename_safe_name = '-'.join(param_name.split('/')[2:]).replace(' ', '_')
            base_fname_part1 = self.path + '/params-' + filename_safe_name
            base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done']
            viz.plot_parameter(param.get_value(), base_fname_part1, base_fname_part2,
                title=param_name, n_colors=self.model.n_colors)


class PlotGradients(SimpleExtension):
    def __init__(self, model, blocks_model, algorithm, X, path, **kwargs):
        super(PlotGradients, self).__init__(**kwargs)
        self.path = path
        self.X = X
        self.model = model
        self.blocks_model = blocks_model
        gradients = []
        for param_name in sorted(self.blocks_model.parameters.keys()):
            gradients.append(algorithm.gradients[self.blocks_model.parameters[param_name]])
        self.grad_f = theano.function(algorithm.inputs, gradients, allow_input_downcast=True)

    def do(self, callback_name, *args):
        print("plotting gradients")
        grad_vals = self.grad_f(self.X)
        keynames = sorted(self.blocks_model.parameters.keys())
        for ii in range(len(keynames)):
            param_name = keynames[ii]
            val = grad_vals[ii]
            filename_safe_name = '-'.join(param_name.split('/')[2:]).replace(' ', '_')
            base_fname_part1 = self.path + '/grads-' + filename_safe_name
            base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done']
            viz.plot_parameter(val, base_fname_part1, base_fname_part2,
                title="grad " + param_name, n_colors=self.model.n_colors)


class PlotInternalState(SimpleExtension):
    def __init__(self, model, blocks_model, state, features, X, path, **kwargs):
        super(PlotInternalState, self).__init__(**kwargs)
        self.path = path
        self.X = X
        self.model = model
        self.blocks_model = blocks_model
        self.internal_state_f = theano.function([features], state, allow_input_downcast=True)
        self.internal_state_names = []
        for var in state:
            self.internal_state_names.append(var.name)

    def do(self, callback_name, *args):
        print("plotting internal state of network")
        state = self.internal_state_f(self.X)
        for ii in range(len(state)):
            param_name = self.internal_state_names[ii]
            val = state[ii]
            filename_safe_name = param_name.replace(' ', '_').replace('/', '-')
            base_fname_part1 = self.path + '/state-' + filename_safe_name
            base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done']
            viz.plot_parameter(val, base_fname_part1, base_fname_part2,
                title="state " + param_name, n_colors=self.model.n_colors)


class PlotMonitors(SimpleExtension):
    def __init__(self, path, burn_in_iters=0, **kwargs):
        super(PlotMonitors, self).__init__(**kwargs)
        self.path = path
        self.burn_in_iters = burn_in_iters

    def do(self, callback_name, *args):
        print("plotting monitors")
        try:
            df = self.main_loop.log.to_dataframe()
        except AttributeError:
            # This starting breaking after a Blocks update.
            print("Failed to generate monitoring plots due to Blocks interface change.")
            return
        iter_number  = df.tail(1).index
        # Throw out the first burn_in values
        # as the objective is often much larger
        # in that period.
        if iter_number > self.burn_in_iters:
            df = df.loc[self.burn_in_iters:]
        cols = [col for col in df.columns if col.startswith(('cost', 'train', 'test'))]
        df = df[cols].interpolate(method='linear')

        # If we don't have any non-nan dataframes, don't plot
        if len(df) == 0:
            return
        try:
            axs = df.interpolate(method='linear').plot(
                subplots=True, legend=False, figsize=(5, len(cols)*2))
        except TypeError:
            # This starting breaking after a different Blocks update.
            print("Failed to generate monitoring plots due to Blocks interface change.")
            return

        for ax, cname in zip(axs, cols):
            ax.set_title(cname)
        fn = os.path.join(self.path,
            'monitors_subplots_batch%06d.png' % self.main_loop.status['iterations_done'])
        plt.savefig(fn, bbox_inches='tight')

        plt.clf()
        df.plot(subplots=False, figsize=(15,10))
        plt.gcf().tight_layout()
        fn = os.path.join(self.path,
            'monitors_batch%06d.png' % self.main_loop.status['iterations_done'])
        plt.savefig(fn, bbox_inches='tight')
        plt.close('all')


class LogLikelihood(SimpleExtension):
    def __init__(self, model, test_stream, rescale, num_eval_batches=10000, **kwargs):
        """
        Compute and print log likelihood lower bound on test dataset.
        The do() function is called as an extension during training.
        """
        super(LogLikelihood, self).__init__(**kwargs)
        self.model = model
        self.test_stream = test_stream
        self.rescale = rescale
        self.num_eval_batches = num_eval_batches

        features = T.matrix('features', dtype=theano.config.floatX)
        cost = self.model.cost(features)

        self.L_gap_func = theano.function([features,], cost,
            allow_input_downcast=True)

    def print_stats(self, L_gap):
        larr = np.array(L_gap)
        mn = np.mean(larr)
        sd = np.std(larr, ddof=1)
        stderr = sd / np.sqrt(len(L_gap))

        # The log likelihood lower bound, K, is reported for the data after Z-scoring it.
        # Z-score rescale is the multiplicative factor by which the data was rescaled, to
        # give it standard deviation 1.
        print("eval batch=%05d  (K-L_null)=%g bits/pix  standard error=%g bits/pix  Z-score rescale %g"%(
            len(L_gap), mn, stderr, self.rescale))

    def do(self, callback_name, *args):
        L_gap = []
        n_colors = self.model.n_colors

        Xiter = None
        for kk in range(self.num_eval_batches):
            try:
                X = next(Xiter)[0]
            except:
                Xiter = self.test_stream.get_epoch_iterator()
                X = next(Xiter)[0]

            lg = -self.L_gap_func(X)
            L_gap.append(lg)

            if np.mod(kk, 1000) == 999:
                self.print_stats(L_gap)
        self.print_stats(L_gap)


def decay_learning_rate(iteration, old_value):
    # TODO the numbers in this function should not be hard coded

    # this is called every epoch
    # reduce the learning rate by 10 every 1000 epochs
    min_value = 1e-4

    decay_rate = np.exp(np.log(0.1)/1000.)
    new_value = decay_rate*old_value
    if new_value < min_value:
        new_value = min_value
    print("learning rate %g"%new_value)
    return np.float32(new_value)


### Model

#### Regression

In [None]:
"""
Defines the function approximators
"""

import numpy as np
import theano.tensor as T

from blocks.bricks import Activation, MLP, Initializable, application, Identity
from blocks.bricks.conv import Convolutional
from blocks.initialization import IsotropicGaussian, Constant, Orthogonal

# TODO IsotropicGaussian init will be wrong scale for some layers

class LeakyRelu(Activation):
    @application(inputs=['input_'], outputs=['output'])
    def apply(self, input_):
        return T.switch(input_ > 0, input_, 0.05*input_)

dense_nonlinearity = LeakyRelu()
conv_nonlinearity = LeakyRelu()

class MultiScaleConvolution(Initializable):
    def __init__(self, num_channels, num_filters, spatial_width, num_scales, filter_size, downsample_method='meanout', name=""):
        """
        A brick implementing a single layer in a multi-scale convolutional network.
        """
        super(MultiScaleConvolution, self).__init__()

        self.num_scales = num_scales
        self.filter_size = filter_size
        self.num_filters = num_filters
        self.spatial_width = spatial_width
        self.downsample_method = downsample_method
        self.children = []

        print("adding MultiScaleConvolution layer")

        # for scale in range(self.num_scales-1, -1, -1):
        for scale in range(self.num_scales):
            print("scale %d"%scale)
            conv_layer = Convolutional(# TODO: Do I need to replace thisactivation=conv_nonlinearity.apply,
                filter_size=(filter_size,filter_size), num_filters=num_filters,
                num_channels=num_channels, image_size=(spatial_width/2**scale, spatial_width/2**scale),
                # assume images are spatially smooth -- in which case output magnitude scales with
                # # filter pixels rather than square root of # filter pixels, so initialize
                # accordingly.
                weights_init=IsotropicGaussian(std=np.sqrt(1./(num_filters))/filter_size**2),
                biases_init=Constant(0), border_mode='full', name=name+"scale%d"%scale)
            self.children.append(conv_layer)

    def downsample(self, imgs_in, scale):
        """
        Downsample an image by a factor of 2**scale
        """
        imgs = imgs_in.copy()

        if scale == 0:
            return imgs

        # if self.downsample_method == 'maxout':
        #     print "maxout",
        #     imgs_maxout = downsample.max_pool_2d(imgs.copy(), (2**scale, 2**scale), ignore_border=False)
        # else:
        #     print "meanout",
        #     imgs_maxout = self.downsample_mean_pool_2d(imgs.copy(), (2**scale, 2**scale))

        num_imgs = imgs.shape[0].astype('int16')
        num_layers = imgs.shape[1].astype('int16')
        nlx0 = imgs.shape[2].astype('int16')
        nlx1 = imgs.shape[3].astype('int16')

        scalepow = np.int16(2**scale)

        # downsample
        imgs = imgs.reshape((num_imgs, num_layers, nlx0/scalepow, scalepow, nlx1/scalepow, scalepow))
        imgs = T.mean(imgs, axis=5)
        imgs = T.mean(imgs, axis=3)
        return imgs

    @application
    def apply(self, X):

        print("MultiScaleConvolution apply")

        nsamp = X.shape[0].astype('int16')

        Z = 0
        # overshoot = (self.filter_size - 1)/2
        # TODO: Is this correct
        overshoot = int((self.filter_size - 1)/2)
        imgs_accum = 0 # accumulate the output image
        for scale in range(self.num_scales-1, -1, -1):
            # downsample image to appropriate scale
            imgs_down = self.downsample(X, scale)
            # do a convolutional transformation on it
            conv_layer = self.children[scale]
            # NOTE this is different than described in the paper, since each conv_layer
            # includes a nonlinearity -- it's not just one nonlinearity at the end
            imgs_down_conv = conv_layer.apply(imgs_down)

            # crop the edge so it's the same size as the input at that scale
            print("debug")
            print(overshoot)
            print(imgs_down_conv)
            imgs_down_conv_croppoed = imgs_down_conv[:,:,overshoot:-overshoot,overshoot:-overshoot]
            imgs_accum += imgs_down_conv_croppoed

            if scale > 0:
                # scale up by factor of 2
                layer_width = self.spatial_width/2**scale
                imgs_accum = imgs_accum.reshape((nsamp, self.num_filters, layer_width, 1, layer_width, 1))
                imgs_accum = T.concatenate((imgs_accum, imgs_accum), axis=5)
                imgs_accum = T.concatenate((imgs_accum, imgs_accum), axis=3)
                imgs_accum = imgs_accum.reshape((nsamp, self.num_filters, layer_width*2, layer_width*2))

        return imgs_accum/self.num_scales


class MultiLayerConvolution(Initializable):
    def __init__(self, n_layers, n_hidden, spatial_width, n_colors, n_scales, filter_size=3):
        """
        A brick implementing a multi-layer, multi-scale convolutional network.
        """
        super(MultiLayerConvolution, self).__init__()

        self.children = []
        num_channels = n_colors
        for ii in range(n_layers):
            conv_layer = MultiScaleConvolution(num_channels, n_hidden, spatial_width, n_scales, filter_size, name="layer%d_"%ii)
            self.children.append(conv_layer)
            num_channels = n_hidden

    @application
    def apply(self, X):
        Z = X
        for conv_layer in self.children:
            Z = conv_layer.apply(Z)
        return Z

class MLP_conv_dense(Initializable):
    def __init__(self, n_layers_conv, n_layers_dense_lower, n_layers_dense_upper,
        n_hidden_conv, n_hidden_dense_lower, n_hidden_dense_lower_output, n_hidden_dense_upper,
        spatial_width, n_colors, n_scales, n_temporal_basis):
        """
        The multilayer perceptron, that provides temporal weighting coefficients for mu and sigma
        images. This consists of a lower segment with a convolutional MLP, and optionally with a
        dense MLP in parallel. The upper segment then consists of a per-pixel dense MLP
        (convolutional MLP with 1x1 kernel).
        """
        super(MLP_conv_dense, self).__init__()

        self.n_colors = n_colors
        self.spatial_width = spatial_width
        self.n_hidden_dense_lower = n_hidden_dense_lower
        self.n_hidden_dense_lower_output = n_hidden_dense_lower_output
        self.n_hidden_conv = n_hidden_conv

        ## the lower layers
        self.mlp_conv = MultiLayerConvolution(n_layers_conv, n_hidden_conv, spatial_width, n_colors, n_scales)
        self.children = [self.mlp_conv]
        if n_hidden_dense_lower > 0 and n_layers_dense_lower > 0:
            n_input = n_colors*spatial_width**2
            n_output = n_hidden_dense_lower_output*spatial_width**2
            self.mlp_dense_lower = MLP([dense_nonlinearity] * n_layers_conv,
                [n_input] + [n_hidden_dense_lower] * (n_layers_conv-1) + [n_output],
                name='MLP dense lower', weights_init=Orthogonal(), biases_init=Constant(0))
            self.children.append(self.mlp_dense_lower)
        else:
            n_hidden_dense_lower_output = 0

        ## the upper layers (applied to each pixel independently)
        n_output = n_colors*n_temporal_basis*2 # "*2" for both mu and sigma
        self.mlp_dense_upper = MLP([dense_nonlinearity] * (n_layers_dense_upper-1) + [Identity()],
            [n_hidden_conv+n_hidden_dense_lower_output] +
            [n_hidden_dense_upper] * (n_layers_dense_upper-1) + [n_output],
            name='MLP dense upper', weights_init=Orthogonal(), biases_init=Constant(0))
        self.children.append(self.mlp_dense_upper)

    @application
    def apply(self, X):
        """
        Take in noisy input image and output temporal coefficients for mu and sigma.
        """
        print(X)
        Y = self.mlp_conv.apply(X)
        Y = Y.dimshuffle(0,2,3,1)
        if self.n_hidden_dense_lower > 0:
            n_images = X.shape[0].astype('int16')
            X = X.reshape((n_images, self.n_colors*self.spatial_width**2))
            Y_dense = self.mlp_dense_lower.apply(X)
            Y_dense = Y_dense.reshape((n_images, self.spatial_width, self.spatial_width,
                self.n_hidden_dense_lower_output))
            Y = T.concatenate([Y/T.sqrt(self.n_hidden_conv),
                Y_dense/T.sqrt(self.n_hidden_dense_lower_output)], axis=3)
        Z = self.mlp_dense_upper.apply(Y)
        return Z


#### Diffusion Model

In [None]:
"""
This is the heart of the algorithm. Implements the objective function and mu
and sigma estimators for a Gaussian diffusion probabilistic model
"""

import numpy as np
import theano
import theano.tensor as T

from blocks.bricks import application, Initializable, Random

# import regression
# import util

class DiffusionModel(Initializable):
    def __init__(self,
            spatial_width,
            n_colors,
            trajectory_length=1000,
            n_temporal_basis=10,
            n_hidden_dense_lower=500,
            n_hidden_dense_lower_output=2,
            n_hidden_dense_upper=20,
            n_hidden_conv=20,
            n_layers_conv=4,
            n_layers_dense_lower=4,
            n_layers_dense_upper=2,
            n_t_per_minibatch=1,
            n_scales=1,
            step1_beta=0.001,
            uniform_noise = 0,
            ):
        """
        Implements the objective function and mu and sigma estimators for a Gaussian diffusion
        probabilistic model, as described in the paper:
            Deep Unsupervised Learning using Nonequilibrium Thermodynamics
            Jascha Sohl-Dickstein, Eric A. Weiss, Niru Maheswaranathan, Surya Ganguli
            International Conference on Machine Learning, 2015

        Parameters are as follow:
        spatial_width - Spatial_width of training images
        n_colors - Number of color channels in training data.
        trajectory_length - The number of time steps in the trajectory.
        n_temporal_basis - The number of temporal basis functions to capture time-step
            dependence of model.
        n_hidden_dense_lower - The number of hidden units in each layer of the dense network
            in the lower half of the MLP. Set to 0 to make a convolutional-only lower half.
        n_hidden_dense_lower_output - The number of outputs *per pixel* from the dense network
            in the lower half of the MLP. Total outputs are
            n_hidden_dense_lower_output*spatial_width**2.
        n_hidden_dense_upper - The number of hidden units per pixel in the upper half of the MLP.
        n_hidden_conv - The number of feature layers in the convolutional layers in the lower half
            of the MLP.
        n_layers_conv - How many convolutional layers to use in the lower half of the MLP.
        n_layers_dense_lower - How many dense layers to use in the lower half of the MLP.
        n_layers_dense_upper - How many dense layers to use in the upper half of the MLP.
        n_t_per_minibatch - When computing objective, how many random time-steps t to evaluate
            each minibatch at.
        step1_beta - The lower bound on the noise variance of the first diffusion step. This is
            the minimum variance of the learned model.
        uniform_noise - Add uniform noise between [-uniform_noise/2, uniform_noise/2] to the input.
        """
        super(DiffusionModel, self).__init__()

        self.n_t_per_minibatch = n_t_per_minibatch
        self.spatial_width = np.int16(spatial_width)
        self.n_colors = np.int16(n_colors)
        self.n_temporal_basis = n_temporal_basis
        self.trajectory_length = trajectory_length
        self.uniform_noise = uniform_noise

        self.mlp = MLP_conv_dense(
            n_layers_conv, n_layers_dense_lower, n_layers_dense_upper,
            n_hidden_conv, n_hidden_dense_lower, n_hidden_dense_lower_output, n_hidden_dense_upper,
            spatial_width, n_colors, n_scales, n_temporal_basis)
        self.temporal_basis = self.generate_temporal_basis(trajectory_length, n_temporal_basis)
        self.beta_arr = self.generate_beta_arr(step1_beta)
        self.children = [self.mlp]


    def generate_beta_arr(self, step1_beta):
        """
        Generate the noise covariances, beta_t, for the forward trajectory.
        """
        # lower bound on beta
        min_beta_val = 1e-6
        min_beta_values = np.ones((self.trajectory_length,))*min_beta_val
        min_beta_values[0] += step1_beta
        min_beta = theano.shared(value=min_beta_values.astype(theano.config.floatX),
            name='min beta')
        # (potentially learned) function for how beta changes with timestep
        # TODO add beta_perturb_coefficients to the parameters to be learned
        beta_perturb_coefficients_values = np.zeros((self.n_temporal_basis,))
        beta_perturb_coefficients = theano.shared(
            value=beta_perturb_coefficients_values.astype(theano.config.floatX),
            name='beta perturb coefficients')
        beta_perturb = T.dot(self.temporal_basis.T, beta_perturb_coefficients)
        # baseline behavior of beta with time -- destroy a constant fraction
        # of the original data variance each time step
        # NOTE 2 below means a fraction ~1/T of the variance will be left at the end of the
        # trajectory
        beta_baseline = 1./np.linspace(self.trajectory_length, 2., self.trajectory_length)
        beta_baseline_offset = logit_np(beta_baseline).astype(theano.config.floatX)
        # and the actual beta_t, restricted to be between min_beta and 1-[small value]
        beta_arr = T.nnet.sigmoid(beta_perturb + beta_baseline_offset)
        beta_arr = min_beta + beta_arr * (1 - min_beta - 1e-5)
        beta_arr = beta_arr.reshape((self.trajectory_length, 1))
        return beta_arr


    def get_t_weights(self, t):
        """
        Generate vector of weights allowing selection of current timestep.
        (if t is not an integer, the weights will linearly interpolate)
        """
        n_seg = self.trajectory_length
        t_compare = T.arange(n_seg, dtype=theano.config.floatX).reshape((1,n_seg))
        diff = abs(T.addbroadcast(t,1) - T.addbroadcast(t_compare,0))
        t_weights = T.max(T.join(1, (-diff+1).reshape((n_seg,1)), T.zeros((n_seg,1))), axis=1)
        return t_weights.reshape((-1,1))


    def get_beta_forward(self, t):
        """
        Get the covariance of the forward diffusion process at timestep
        t.
        """
        t_weights = self.get_t_weights(t)
        return T.dot(t_weights.T, self.beta_arr)


    def get_mu_sigma(self, X_noisy, t):
        """
        Generate mu and sigma for one step in the reverse trajectory,
        starting from a minibatch of images X_noisy, and at timestep t.
        """
        Z = self.mlp.apply(X_noisy)
        mu_coeff, beta_coeff = self.temporal_readout(Z, t)
        # reverse variance is perturbation around forward variance
        beta_forward = self.get_beta_forward(t)
        # make impact of beta_coeff scaled appropriately with mu_coeff
        beta_coeff_scaled = beta_coeff / np.sqrt(self.trajectory_length).astype(theano.config.floatX)
        beta_reverse = T.nnet.sigmoid(beta_coeff_scaled + logit(beta_forward))
        # # reverse mean is decay towards mu_coeff
        # mu = (X_noisy - mu_coeff)*T.sqrt(1. - beta_reverse) + mu_coeff
        # reverse mean is a perturbation around the mean under forward
        # process


        # # DEBUG -- use these lines to test objective is 0 for isotropic Gaussian model
        # beta_reverse = beta_forward
        # mu_coeff = mu_coeff*0


        mu = X_noisy*T.sqrt(1. - beta_forward) + mu_coeff*T.sqrt(beta_forward)
        sigma = T.sqrt(beta_reverse)
        mu.name = 'mu p'
        sigma.name = 'sigma p'
        return mu, sigma


    def generate_forward_diffusion_sample(self, X_noiseless):
        """
        Corrupt a training image with t steps worth of Gaussian noise, and
        return the corrupted image, as well as the mean and covariance of the
        posterior q(x^{t-1}|x^t, x^0).
        """

        X_noiseless = X_noiseless.reshape(
            (-1, self.n_colors, self.spatial_width, self.spatial_width))

        n_images = X_noiseless.shape[0].astype('int16')
        rng = Random().theano_rng
        # choose a timestep in [1, self.trajectory_length-1].
        # note the reverse process is fixed for the very
        # first timestep, so we skip it.
        # TODO for some reason random_integer is missing from the Blocks
        # theano random number generator.
        t = T.floor(rng.uniform(size=(1,1), low=1, high=self.trajectory_length,
            dtype=theano.config.floatX))
        t_weights = self.get_t_weights(t)
        N = rng.normal(size=(n_images, self.n_colors, self.spatial_width, self.spatial_width),
            dtype=theano.config.floatX)

        # noise added this time step
        beta_forward = self.get_beta_forward(t)
        # decay in noise variance due to original signal this step
        alpha_forward = 1. - beta_forward
        # compute total decay in the fraction of the variance due to X_noiseless
        alpha_arr = 1. - self.beta_arr
        alpha_cum_forward_arr = T.extra_ops.cumprod(alpha_arr).reshape((self.trajectory_length,1))
        alpha_cum_forward = T.dot(t_weights.T, alpha_cum_forward_arr)
        # total fraction of the variance due to noise being mixed in
        beta_cumulative = 1. - alpha_cum_forward
        # total fraction of the variance due to noise being mixed in one step ago
        beta_cumulative_prior_step = 1. - alpha_cum_forward/alpha_forward

        # generate the corrupted training data
        X_uniformnoise = X_noiseless + (rng.uniform(size=(n_images, self.n_colors, self.spatial_width, self.spatial_width),
            dtype=theano.config.floatX)-T.constant(0.5,dtype=theano.config.floatX))*T.constant(self.uniform_noise,dtype=theano.config.floatX)
        X_noisy = X_uniformnoise*T.sqrt(alpha_cum_forward) + N*T.sqrt(1. - alpha_cum_forward)

        # compute the mean and covariance of the posterior distribution
        mu1_scl = T.sqrt(alpha_cum_forward / alpha_forward)
        mu2_scl = 1. / T.sqrt(alpha_forward)
        cov1 = 1. - alpha_cum_forward/alpha_forward
        cov2 = beta_forward / alpha_forward
        lam = 1./cov1 + 1./cov2
        mu = (
                X_uniformnoise * mu1_scl / cov1 +
                X_noisy * mu2_scl / cov2
            ) / lam
        sigma = T.sqrt(1./lam)
        sigma = sigma.reshape((1,1,1,1))

        mu.name = 'mu q posterior'
        sigma.name = 'sigma q posterior'
        X_noisy.name = 'X_noisy'
        t.name = 't'

        return X_noisy, t, mu, sigma


    def get_beta_full_trajectory(self):
        """
        Return the cumulative covariance from the entire forward trajectory.
        """
        alpha_arr = 1. - self.beta_arr
        beta_full_trajectory = 1. - T.exp(T.sum(T.log(alpha_arr)))
        return beta_full_trajectory


    def get_negL_bound(self, mu, sigma, mu_posterior, sigma_posterior):
        """
        Compute the lower bound on the log likelihood, as a function of mu and
        sigma from the reverse diffusion process, and the posterior mu and
        sigma from the forward diffusion process.

        Returns the difference between this bound and the log likelihood
        under a unit norm isotropic Gaussian. So this function returns how
        much better the diffusion model is than an isotropic Gaussian.
        """

        # the KL divergence between model transition and posterior from data
        KL = (  T.log(sigma) - T.log(sigma_posterior)
                + (sigma_posterior**2 + (mu_posterior-mu)**2)/(2*sigma**2)
                - 0.5)
        # conditional entropies H_q(x^T|x^0) and H_q(x^1|x^0)
        H_startpoint = (0.5*(1 + np.log(2.*np.pi))).astype(theano.config.floatX) + 0.5*T.log(self.beta_arr[0])
        H_endpoint = (0.5*(1 + np.log(2.*np.pi))).astype(theano.config.floatX) + 0.5*T.log(self.get_beta_full_trajectory())
        H_prior = (0.5*(1 + np.log(2.*np.pi))).astype(theano.config.floatX) + 0.5*T.log(1.)
        negL_bound = KL*self.trajectory_length + H_startpoint - H_endpoint + H_prior
        # the negL_bound if this was an isotropic Gaussian model of the data
        negL_gauss = (0.5*(1 + np.log(2.*np.pi))).astype(theano.config.floatX) + 0.5*T.log(1.)
        negL_diff = negL_bound - negL_gauss
        L_diff_bits = negL_diff / T.log(2.)
        L_diff_bits_avg = L_diff_bits.mean()*self.n_colors
        return L_diff_bits_avg


    def cost_single_t(self, X_noiseless):
        """
        Compute the lower bound on the log likelihood, given a training minibatch, for a single
        randomly chosen timestep.
        """
        X_noisy, t, mu_posterior, sigma_posterior = \
            self.generate_forward_diffusion_sample(X_noiseless)
        mu, sigma = self.get_mu_sigma(X_noisy, t)
        negL_bound = self.get_negL_bound(mu, sigma, mu_posterior, sigma_posterior)
        return negL_bound


    def internal_state(self, X_noiseless):
        """
        Return a bunch of the internal state, for monitoring purposes during optimization.
        """
        X_noisy, t, mu_posterior, sigma_posterior = \
            self.generate_forward_diffusion_sample(X_noiseless)
        mu, sigma = self.get_mu_sigma(X_noisy, t)
        mu_diff = mu-mu_posterior
        mu_diff.name = 'mu diff'
        logratio = T.log(sigma/sigma_posterior)
        logratio.name = 'log sigma ratio'
        return [mu_diff, logratio, mu, sigma, mu_posterior, sigma_posterior, X_noiseless, X_noisy]


    # \@application # TODO: Is this needed?
    def cost(self, X_noiseless):
        """
        Compute the lower bound on the log likelihood, given a training minibatch.
        This will draw a single timestep and compute the cost for that timestep only.
        """
        cost = 0.
        for ii in range(self.n_t_per_minibatch):
            cost += self.cost_single_t(X_noiseless)
        return cost/self.n_t_per_minibatch


    def temporal_readout(self, Z, t):
        """
        Go from the top layer of the multilayer perceptron to coefficients for
        mu and sigma for each pixel.
        Z contains coefficients for spatial basis functions for each pixel for
        both mu and sigma.
        """
        n_images = Z.shape[0].astype('int16')
        t_weights = self.get_t_weights(t)
        Z = Z.reshape((n_images, self.spatial_width, self.spatial_width,
            self.n_colors, 2, self.n_temporal_basis))
        coeff_weights = T.dot(self.temporal_basis, t_weights)
        concat_coeffs = T.dot(Z, coeff_weights)
        mu_coeff = concat_coeffs[:,:,:,:,0].dimshuffle(0,3,1,2)
        beta_coeff = concat_coeffs[:,:,:,:,1].dimshuffle(0,3,1,2)
        return mu_coeff, beta_coeff


    def generate_temporal_basis(self, trajectory_length, n_basis):
        """
        Generate the bump basis functions for temporal readout of mu and sigma.
        """
        temporal_basis = np.zeros((trajectory_length, n_basis))
        xx = np.linspace(-1, 1, trajectory_length)
        x_centers = np.linspace(-1, 1, n_basis)
        width = (x_centers[1] - x_centers[0])/2.
        for ii in range(n_basis):
            temporal_basis[:,ii] = np.exp(-(xx-x_centers[ii])**2 / (2*width**2))
        temporal_basis /= np.sum(temporal_basis, axis=1).reshape((-1,1))
        temporal_basis = temporal_basis.T

        temporal_basis_theano = theano.shared(value=temporal_basis.astype(theano.config.floatX),
            name="temporal basis")
        return temporal_basis_theano



### Run

#### Prep model run

In [None]:
!mkdir -p /content/mnist

In [None]:
!ls

In [None]:
!export FUEL_DATA_PATH="/content/mnist/"
!export data_path="/content/mnist/"
# !export FUEL_DATA_PATH="/content/mnist/:/second/path/to/my/data"
# ~/.fuelrc
!echo "data_path: \"/content/mnist/\"" > ~/.fuelrc
!echo "floatX: int16" >> ~/.fuelrc

In [None]:
!cat ~/.fuelrc

In [None]:
from fuel import config
config.data_path = "/content/mnist/"
config.floatX = 'int16' # https://github.com/mila-iqia/fuel/blob/1d6292dc25e3a115544237e392e61bff6631d23c/tests/transformers/test_transformers.py#L288

In [None]:
!fuel-download mnist -d /content/mnist/

In [None]:
!fuel-convert mnist -d /content/mnist/ -o /content/mnist/

In [None]:
!fuel-info /content/mnist/mnist.hdf5

#### Args

In [None]:
# Args to allow for easy convertion of python script to notebook
class Args():
    def __init__(self):
        self.batch_size = 512
        self.lr = 1e-3
        self.resume_file = None
        self.suffix = ''
        self.output_dir = './'
        self.ext_every_n = 25
        self.model_args = ''
        self.dropout_rate = 0
        self.dataset = 'MNIST'
        self.plot_before_training = False

    def __str__(self):
        return str(self.__class__) + ": " + str(self.__dict__)

args = Args()
print(args)

#### Model Run

In [None]:

# TODO batches_per_epoch should not be hard coded
batches_per_epoch = 500
import sys
sys.setrecursionlimit(10000000)

model_args = eval('dict(' + args.model_args + ')')
print(model_args)

if args.resume_file is not None:
    print("Resuming training from " + args.resume_file)
    from blocks.scripts import continue_training
    continue_training(args.resume_file)

## load the training data
if args.dataset == 'MNIST':
    from fuel.datasets import MNIST
    dataset_train = MNIST(['train'], sources=('features',))
    dataset_test = MNIST(['test'], sources=('features',))
    n_colors = 1
    spatial_width = 28
# # elif args.dataset == 'CIFAR10':
# #     from fuel.datasets import CIFAR10
# #     dataset_train = CIFAR10(['train'], sources=('features',))
# #     dataset_test = CIFAR10(['test'], sources=('features',))
# #     n_colors = 3
# #     spatial_width = 32
# # elif args.dataset == 'IMAGENET':
# #     from imagenet_data import IMAGENET
# #     spatial_width = 128
# #     dataset_train = IMAGENET(['train'], width=spatial_width)
# #     dataset_test = IMAGENET(['test'], width=spatial_width)
# #     n_colors = 3
else:
    raise ValueError("Unknown dataset %s."%args.dataset)

train_stream = Flatten(DataStream.default_stream(dataset_train,
                          iteration_scheme=ShuffledScheme(
                              examples=dataset_train.num_examples,
                              batch_size=args.batch_size)))
test_stream = Flatten(DataStream.default_stream(dataset_test,
                          iteration_scheme=ShuffledScheme(
                              examples=dataset_test.num_examples,
                              batch_size=args.batch_size))
                          )

shp = next(train_stream.get_epoch_iterator())[0].shape

# make the training data 0 mean and variance 1
# TODO compute mean and variance on full dataset, not minibatch
Xbatch = next(train_stream.get_epoch_iterator())[0]
scl = 1./np.sqrt(np.mean((Xbatch-np.mean(Xbatch))**2))
shft = -np.mean(Xbatch*scl)
# scale is applied before shift
train_stream = ScaleAndShift(train_stream, scl, shft)
test_stream = ScaleAndShift(test_stream, scl, shft)
baseline_uniform_noise = 1./255. # appropriate for MNIST and CIFAR10 Fuel datasets, which are scaled [0,1]
uniform_noise = baseline_uniform_noise/scl

## initialize the model
dpm = DiffusionModel(spatial_width, n_colors, uniform_noise=uniform_noise, **model_args)
dpm.initialize()

## set up optimization
features = T.matrix('features', dtype=theano.config.floatX)
cost = dpm.cost(features)



In [None]:
blocks_model = blocks.model.Model(cost)
cg_nodropout = ComputationGraph(cost)
if args.dropout_rate > 0:
    # DEBUG this triggers an error on my machine
    # apply dropout to all the input variables
    inputs = VariableFilter(roles=[INPUT])(cg_nodropout.variables)
    # dropconnect
    # inputs = VariableFilter(roles=[PARAMETER])(cg_nodropout.variables)
    cg = apply_dropout(cg_nodropout, inputs, args.dropout_rate)
else:
    cg = cg_nodropout
step_compute = RMSProp(learning_rate=args.lr, max_scaling=1e10)
algorithm = GradientDescent(step_rule=CompositeRule([RemoveNotFinite(),
    step_compute]),
    parameters=cg.parameters, cost=cost)
extension_list = []
extension_list.append(
    SharedVariableModifier(step_compute.learning_rate,
        decay_learning_rate,
        after_batch=False,
        every_n_batches=batches_per_epoch, ))
extension_list.append(FinishAfter(after_n_epochs=100001))

## logging of test set performance
extension_list.append(LogLikelihood(dpm, test_stream, scl,
    every_n_batches=args.ext_every_n*batches_per_epoch, before_training=False))


In [None]:
## set up logging
extension_list.extend([Timing(), Printing()])
model_dir = create_log_dir(args, dpm.name + '_' + args.dataset)
model_save_name = os.path.join(model_dir, 'model.pkl')
extension_list.append(
    Checkpoint(model_save_name, every_n_batches=args.ext_every_n*batches_per_epoch, save_separately=['log']))
# generate plots
extension_list.append(PlotMonitors(model_dir,
    every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
test_batch = next(test_stream.get_epoch_iterator())[0]
extension_list.append(PlotSamples(dpm, algorithm, test_batch, model_dir,
    every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
internal_state = dpm.internal_state(features)
train_batch = next(train_stream.get_epoch_iterator())[0]
# extension_list.append(
#     extensions.PlotInternalState(dpm, blocks_model, internal_state, features, train_batch, model_dir,
#         every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
extension_list.append(
    PlotParameters(dpm, blocks_model, model_dir,
        every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
# extension_list.append(
#     extensions.PlotGradients(dpm, blocks_model, algorithm, train_batch, model_dir,
#         every_n_batches=args.ext_every_n*batches_per_epoch, before_training=args.plot_before_training))
# # console monitors
# # DEBUG -- incorporating train_monitor or test_monitor triggers a large number of
# # float64 vs float32 GPU warnings, although monitoring still works. I think this is a Blocks
# # bug. Uncomment this code to have more information during debugging/development.
# train_monitor_vars = [cost]
# norms, grad_norms = util.get_norms(blocks_model, algorithm.gradients)
# train_monitor_vars.extend(norms + grad_norms)
# train_monitor = TrainingDataMonitoring(
#     train_monitor_vars, prefix='train', after_batch=True, before_training=True)
# extension_list.append(train_monitor)
# test_monitor_vars = [cost]
# test_monitor = DataStreamMonitoring(test_monitor_vars, test_stream, prefix='test', before_training=True)
# extension_list.append(test_monitor)

In [None]:
## train
sys.setrecursionlimit(10000000)
main_loop = MainLoop(model=blocks_model, algorithm=algorithm,
                     data_stream=train_stream,
                     extensions=extension_list)
main_loop.run()