In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

from survae import SurVAE, DEVICE
from survae.data import MNIST_784
from survae.layer import *
from survae.calibrate import *

In [2]:
mnist = MNIST_784()

KeyboardInterrupt: 

In [4]:
X = mnist.sample(5)

NameError: name 'mnist' is not defined

In [5]:
X.shape

torch.Size([5, 784])

In [7]:
sv_maxpool(X).shape

torch.Size([5, 49])

In [5]:
Y = torch.arange(28*28*2, dtype=torch.double).reshape(2, 28*28)
# Y = torch.arange(5*2*2, dtype=torch.double).reshape(5, 2, 2)

In [20]:
class MaxPoolingLayer(Layer):
    '''
    MaxPoolingLayer: Layer that performs max pooling on the input data.

    size: input size, i.e. flattened picture
    width: input width, i.e. the width of the picture, sqrt of size

    out_width: width of the output picture, sqrt of out_size
    out_size: size of flattened output picture

    stride: the stride of the max pooling operation 

    Distribution choices: 
        - standard half-normal distribution (default)
        - exponential distribution 

    '''
    def __init__(self, size: int, stride: int, exponential_distribution: bool = False, learn_distribution_parameter: bool = False):
        super().__init__()

        self.width = np.sqrt(size).astype(int) 

        assert self.width % stride == 0, "Stride must be a divisor of size!"
        self.stride = stride

        self.out_width = int(self.width / self.stride)

        
        if exponential_distribution:
            self.distribution = "exponential"
            lam = torch.tensor([0.1])
            if learn_distribution_parameter:
                lam = nn.Parameter(lam)
            self.lam = lam
        else:
            self.distribution = "half-normal"
            sigma = 1
            if learn_distribution_parameter:
                sigma = nn.Parameter(sigma)
            self.sigma = sigma


        self.index_probs = torch.tensor([1 / self.stride**2 for _ in range(self.stride**2)])


    def forward(self, X: torch.Tensor, condition: torch.Tensor | None = None, return_log_likelihood: bool = False):

        X = X.view(-1, self.width, self.width) # reshape to 2D
        
        l = []
        for i in range(self.stride):
            for j in range(self.stride):
                l.append(X[:, i::self.stride,j::self.stride])

        combined_tensor = torch.stack(l, dim=0)
        Z, _ = torch.max(combined_tensor, dim=0)
        # return Z.view(-1)
        return Z.flatten(start_dim=1)

    def backward(self, Z: torch.Tensor, condition: torch.Tensor | None = None):
        Z = Z.view(-1, self.out_width, self.out_width)
        
        # expand matrix containing local maxima (by repeating local max)
        X_hat = Z.repeat_interleave(self.stride,dim=2).repeat_interleave(self.stride,dim=1)

        # mask for the indices of the local maxima
        k = torch.distributions.categorical.Categorical(self.index_probs) 
        indices = k.sample(Z.shape)

        indices_repeated = indices.repeat_interleave(self.stride, dim=2).repeat_interleave(self.stride, dim=1)
        index_places = torch.arange(self.stride**2).reshape(self.stride, self.stride).repeat(self.out_width, self.out_width)

        index_mask = (index_places == indices_repeated)

        # sample values in (- infty, 0]) with respective distribution
        if self.distribution == "half-normal":
            distr = torch.distributions.half_normal.HalfNormal(self.sigma)
        else:
            distr = torch.distributions.exponential.Exponential(self.lam)
        samples = -distr.sample(X_hat.shape)

        X_hat = X_hat + samples * ~index_mask
        
        return X_hat.flatten(start_dim=1)

    def in_size(self) -> int | None:
        return self.size

    def out_size(self) -> int | None:
        return int(self.out_width ** 2)

In [21]:
sv_maxpool = SurVAE(
    [DequantizationLayer()] + 
    [
        [BijectiveLayer(784, [200, 200]), OrthonormalLayer(784)]
        for _ in range(4)
    ] +
    [MaxPoolingLayer(784, 2)] + 
    [
        [BijectiveLayer(196, [200, 200]), OrthonormalLayer(196)]
        for _ in range(4)
    ] +
    [MaxPoolingLayer(196, 2)] +
    [
        [BijectiveLayer(49, [200, 200]), OrthonormalLayer(49)]
        for _ in range(4)
    ]
)

In [24]:
sv_maxpool.backward(sv_maxpool(Y))[0].reshape(28, 28)

tensor([[ 5.4760e+03,  1.5419e+04, -3.3234e+04,  2.4778e+04, -4.7010e+03,
          8.0860e+03, -2.7783e+04,  1.0110e+04, -2.4530e+04,  2.1184e+04,
          2.4026e+04,  6.5950e+03,  1.3030e+03, -6.2290e+03, -1.9951e+04,
         -1.4337e+04, -1.4210e+03,  3.6335e+04, -5.9600e+03,  3.4960e+03,
          2.1983e+04, -1.6537e+04,  1.7545e+04,  1.2338e+04, -4.1140e+03,
         -3.6878e+04,  1.5547e+04,  6.0040e+03],
        [-2.5227e+04, -1.8820e+03,  1.0649e+04,  2.0700e+03,  1.6230e+04,
          1.8970e+04, -3.8660e+03,  1.1942e+04,  2.3059e+04,  2.1246e+04,
         -6.0300e+02,  1.5166e+04,  2.1356e+04,  1.0899e+04, -2.8766e+04,
          6.8800e+02, -3.1908e+04, -2.9285e+04, -2.5784e+04, -2.9797e+04,
         -1.8754e+04,  5.3270e+03,  3.0624e+04, -2.0924e+04,  1.1736e+04,
         -1.2643e+04,  1.0156e+04, -6.1740e+03],
        [ 5.0310e+03,  2.4070e+04,  1.2751e+04, -6.0871e+04, -1.5415e+04,
          1.9605e+04,  8.5150e+03, -4.1980e+03, -1.8255e+04,  2.4185e+04,
         -4.14

In [41]:
stride = 3
k = torch.distributions.categorical.Categorical(torch.tensor([1/stride**2] * stride**2))

In [39]:
X = torch.rand((5, 2, 2))

In [40]:
X

tensor([[[0.0113, 0.1925],
         [0.9855, 0.0136]],

        [[0.1857, 0.6647],
         [0.8401, 0.3900]],

        [[0.7625, 0.2858],
         [0.3856, 0.3626]],

        [[0.1868, 0.5256],
         [0.1865, 0.8363]],

        [[0.6340, 0.4767],
         [0.7391, 0.4692]]], device='cuda:0')

In [42]:
i_x = k.sample((5, 2, 2))
i_x

tensor([[[8, 7],
         [8, 8]],

        [[2, 1],
         [7, 3]],

        [[5, 1],
         [1, 1]],

        [[7, 1],
         [6, 8]],

        [[0, 1],
         [6, 4]]], device='cuda:0')

In [43]:
j = i_x.repeat_interleave(3, dim=2).repeat_interleave(3, dim=1)
j

tensor([[[8, 8, 8, 7, 7, 7],
         [8, 8, 8, 7, 7, 7],
         [8, 8, 8, 7, 7, 7],
         [8, 8, 8, 8, 8, 8],
         [8, 8, 8, 8, 8, 8],
         [8, 8, 8, 8, 8, 8]],

        [[2, 2, 2, 1, 1, 1],
         [2, 2, 2, 1, 1, 1],
         [2, 2, 2, 1, 1, 1],
         [7, 7, 7, 3, 3, 3],
         [7, 7, 7, 3, 3, 3],
         [7, 7, 7, 3, 3, 3]],

        [[5, 5, 5, 1, 1, 1],
         [5, 5, 5, 1, 1, 1],
         [5, 5, 5, 1, 1, 1],
         [1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1]],

        [[7, 7, 7, 1, 1, 1],
         [7, 7, 7, 1, 1, 1],
         [7, 7, 7, 1, 1, 1],
         [6, 6, 6, 8, 8, 8],
         [6, 6, 6, 8, 8, 8],
         [6, 6, 6, 8, 8, 8]],

        [[0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [6, 6, 6, 4, 4, 4],
         [6, 6, 6, 4, 4, 4],
         [6, 6, 6, 4, 4, 4]]], device='cuda:0')

In [70]:
m = torch.arange(stride**2).reshape(stride, stride).repeat(2, 2)
m

tensor([[0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5],
        [6, 7, 8, 6, 7, 8],
        [0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5],
        [6, 7, 8, 6, 7, 8]], device='cuda:0')

In [71]:
mask = (m == j)

In [72]:
mask

tensor([[[False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False,  True, False,  True, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False,  True, False, False,  True]],

        [[False, False,  True, False,  True, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False,  True, False, False],
         [False,  True, False, False, False, False]],

        [[False, False, False, False,  True, False],
         [False, False,  True, False, False, False],
         [False, False, False, False, False, False],
         [False,  True, False, False,  True, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False]],

        [[False, False, False, False,  T

In [49]:
distr = torch.distributions.half_normal.HalfNormal(1)

In [53]:
noise = distr.sample((5, 6, 6))

In [54]:
noise[mask] = 0

In [64]:
Y

tensor([[[ 0.,  1.],
         [ 2.,  3.]],

        [[ 4.,  5.],
         [ 6.,  7.]],

        [[ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.]],

        [[16., 17.],
         [18., 19.]]], device='cuda:0')

In [65]:
i_x

tensor([[[8, 7],
         [8, 8]],

        [[2, 1],
         [7, 3]],

        [[5, 1],
         [1, 1]],

        [[7, 1],
         [6, 8]],

        [[0, 1],
         [6, 4]]], device='cuda:0')

In [68]:
Y.repeat_interleave(stride, dim=1).repeat_interleave(stride, dim=2)# - noise

tensor([[[ 0.,  0.,  0.,  1.,  1.,  1.],
         [ 0.,  0.,  0.,  1.,  1.,  1.],
         [ 0.,  0.,  0.,  1.,  1.,  1.],
         [ 2.,  2.,  2.,  3.,  3.,  3.],
         [ 2.,  2.,  2.,  3.,  3.,  3.],
         [ 2.,  2.,  2.,  3.,  3.,  3.]],

        [[ 4.,  4.,  4.,  5.,  5.,  5.],
         [ 4.,  4.,  4.,  5.,  5.,  5.],
         [ 4.,  4.,  4.,  5.,  5.,  5.],
         [ 6.,  6.,  6.,  7.,  7.,  7.],
         [ 6.,  6.,  6.,  7.,  7.,  7.],
         [ 6.,  6.,  6.,  7.,  7.,  7.]],

        [[ 8.,  8.,  8.,  9.,  9.,  9.],
         [ 8.,  8.,  8.,  9.,  9.,  9.],
         [ 8.,  8.,  8.,  9.,  9.,  9.],
         [10., 10., 10., 11., 11., 11.],
         [10., 10., 10., 11., 11., 11.],
         [10., 10., 10., 11., 11., 11.]],

        [[12., 12., 12., 13., 13., 13.],
         [12., 12., 12., 13., 13., 13.],
         [12., 12., 12., 13., 13., 13.],
         [14., 14., 14., 15., 15., 15.],
         [14., 14., 14., 15., 15., 15.],
         [14., 14., 14., 15., 15., 15.]],

        

In [56]:
noise

tensor([[[1.3792e-01, 4.3330e-01, 1.5246e+00, 2.7716e-02, 5.4281e-01,
          8.7602e-01],
         [6.9744e-01, 6.2943e-01, 8.0855e-01, 1.0384e+00, 9.8370e-01,
          7.1186e-01],
         [8.5196e-02, 1.7355e+00, 0.0000e+00, 4.0134e-01, 0.0000e+00,
          8.3930e-02],
         [1.0733e+00, 5.2208e-01, 1.1773e+00, 7.3479e-01, 1.1327e+00,
          1.5764e+00],
         [1.2669e+00, 5.9731e-02, 1.9881e+00, 5.4116e-01, 1.2755e+00,
          1.0582e+00],
         [4.7391e-01, 4.9395e-01, 0.0000e+00, 1.2666e+00, 2.1236e-01,
          0.0000e+00]],

        [[1.2257e+00, 3.3878e-03, 0.0000e+00, 9.3340e-02, 0.0000e+00,
          9.5253e-01],
         [3.2024e-01, 2.0815e-01, 6.5682e-01, 2.6940e-01, 5.8574e-01,
          1.2978e+00],
         [7.0754e-01, 1.7845e-01, 7.8741e-01, 7.6911e-01, 5.8098e-01,
          1.2552e+00],
         [2.2400e-01, 4.8800e-02, 6.4940e-01, 8.8261e-01, 8.9923e-02,
          3.3365e-01],
         [5.1991e-01, 7.1833e-01, 6.3866e-01, 0.0000e+00, 1.6480e+00