# NOBODY TOUCH THIS

In [1]:
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from survae import SurVAE
from survae.data import *
from survae.layer import *



In [112]:
class MaxPoolingLayerWithHop(Layer):
    '''
    MaxPoolingLayer: Layer that performs max pooling on the input data.

    size: input size, i.e. flattened picture
    width: input width, i.e. the width of the picture, sqrt of size
    hop: defines the distance between the blocks of size stride for which the maxima is taken
        we require that hop >= 1

    out_width: width of the output picture, sqrt of out_size
    out_size: size of flattened output picture

    stride: the stride of the max pooling operation 
    index_probs: probability distribution for the indices of the local maxima in the "flattened" stride square
        we assume that the indices are uniformly distributed

    Distribution choices: 
        - standard half-normal distribution (default)
        - exponential distribution 

    '''
    def __init__(self, size: int, stride: int, hop: int, exponential_distribution: bool = False, learn_distribution_parameter: bool = False):
        super().__init__()

        self.size = size

        self.width = np.sqrt(size).astype(int) 

        assert stride <= self.width, "Stride must be smaller than the width of the picture!"
        assert 0 < hop and hop <= stride, "Hop must be smaller than the stride!"
        assert (self.width - stride) % hop == 0, "Stride and hop must be chosen such that the picture is fully covered!"

        self.stride = stride
        self.hop = hop

        self.out_width = (self.width - stride) // hop + 1 # = the number of blocks considered (possible overlap)

        
        if exponential_distribution:
            self.distribution = "exponential"
            lam = torch.tensor([0.1])
            if learn_distribution_parameter:
                lam = nn.Parameter(lam)
            self.lam = lam
        else:
            self.distribution = "half-normal"
            sigma = 1
            if learn_distribution_parameter:
                sigma = nn.Parameter(sigma)
            self.sigma = sigma


        self.index_probs = torch.tensor([1 / self.stride**2 for _ in range(self.stride**2)])


    def forward(self, X: torch.Tensor, condition: torch.Tensor | None = None, return_log_likelihood: bool = False):

        X = X.view(-1, self.width, self.width) # reshape to 2D
        
        l = []
        for i in range(self.stride):
            for j in range(self.stride):
                l.append(X[:, i:i+self.width-self.stride + 1:self.hop,j:j+self.width-self.stride + 1:self.hop])
                print(l[-1].shape)

        combined_tensor = torch.stack(l, dim=0)
        Z, _ = torch.max(combined_tensor, dim=0)

        return Z.flatten(start_dim=1)
    
    def backward(self, Z: torch.Tensor, condition: torch.Tensor | None = None):
        Z = Z.view(-1, self.out_width, self.out_width)

        _max = Z.max() + 1
        max_max = torch.full((len(Z), self.out_width * self.out_width, self.width, self.width), _max)

        batch_indices = torch.arange(len(Z))
        for i in range(self.out_width):
            for j in range(self.out_width):
                max_max[batch_indices, j + i * (self.out_width), (self.hop * i):(self.hop * i + self.stride), (self.hop * j):(self.hop * j + self.stride)] = Z[batch_indices, i, j].unsqueeze(1).unsqueeze(2)

        # This is the best possible name. I will not elaborate.
        min_max = max_max.min(dim=1)[0]

        block_mask = torch.isclose(min_max.unsqueeze(1), max_max)
        _rand = torch.rand(block_mask.shape)
        _rand[~block_mask] = -1
        arg_max = _rand.flatten(start_dim=2).argmax(dim=2)

        noise_mask = torch.ones((len(Z), self.size), dtype=torch.bool)
        for idx in arg_max.T:
            noise_mask[batch_indices, idx] = False

        # sample values in (- infty, 0]) with respective distribution
        if self.distribution == "half-normal":
            distr = torch.distributions.half_normal.HalfNormal(self.sigma)
        else:
            distr = torch.distributions.exponential.Exponential(self.lam)
        samples = -distr.sample(noise_mask.shape)

        X_hat = min_max.flatten(start_dim=1) + (samples * noise_mask)

        return X_hat

    def in_size(self) -> int | None:
        return self.size

    def out_size(self) -> int | None:
        return int(self.out_width ** 2)

In [116]:
stride = 3
hop = 2
width = 11
size = 11 ** 2

In [117]:
# X = torch.arange((size * 2)).double().reshape(2, size)
X = torch.randn(2, size)
X.view(-1, width, width)

tensor([[[-1.6169e-01, -4.9837e-01, -3.6823e-01, -1.0678e+00,  8.5087e-01,
          -1.4756e+00,  1.3415e+00,  9.5072e-01,  7.5384e-01,  4.8031e-01,
          -1.6446e+00],
         [-1.1954e+00, -2.8456e-01,  9.8753e-01, -2.6220e-02,  4.9732e-01,
           1.6501e+00,  7.6097e-01,  6.2538e-01,  1.7338e+00,  1.3753e-01,
          -7.4808e-02],
         [ 6.1283e-01, -1.3058e+00,  2.9694e-01,  6.0140e-01, -1.3080e+00,
           7.7402e-01,  6.8075e-01,  2.1717e+00, -1.6834e-01,  4.5260e-01,
          -1.1169e-01],
         [ 8.3983e-01, -7.6345e-01,  1.0826e+00, -3.1517e-01,  1.1299e+00,
           2.7441e+00, -1.0826e+00,  1.2842e-01, -2.7102e-01, -8.2161e-01,
           1.0545e+00],
         [ 4.0962e-01,  1.1643e-01, -9.8105e-01, -2.3844e+00, -1.2941e+00,
          -7.7962e-02,  9.7388e-02,  1.7634e+00, -2.6357e-01,  1.0084e+00,
          -1.7644e-01],
         [-9.6454e-01,  1.4662e-01,  1.2819e+00,  7.0117e-02,  1.3592e-01,
          -2.2104e-03,  2.5133e+00, -7.7248e-01, -1.443

In [118]:
M = MaxPoolingLayerWithHop(size, stride, hop)
Z = M.forward(X)
Z

torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])


tensor([[0.9875, 0.9875, 1.6501, 2.1717, 1.7338, 1.0826, 1.1299, 2.7441, 2.1717,
         1.0545, 1.2819, 1.2819, 2.5133, 2.5133, 1.0084, 1.1158, 2.0937, 1.5041,
         0.8947, 0.6577, 2.5549, 2.0937, 2.4980, 2.4980, 0.8659],
        [0.4215, 1.7125, 1.7125, 2.5798, 1.3273, 0.9471, 1.0000, 0.5154, 1.6228,
         1.6228, 1.9646, 0.8585, 1.1243, 1.6228, 1.6228, 0.8585, 0.8585, 0.9233,
         1.3971, 1.6894, 1.6660, 0.6067, 0.9233, 1.0012, 1.6894]],
       device='cuda:0')

In [119]:
X[1]

tensor([ 0.4215, -0.9531,  0.4025, -0.6514, -1.4438, -1.4014,  0.5517,  0.2779,
        -0.1973,  1.3273,  0.1701, -0.4731, -0.0702, -0.0190, -1.1951,  1.7125,
        -1.6587,  0.2956,  2.5798, -0.8385,  1.1867,  0.7996, -0.4012, -1.0554,
         0.1291,  1.0000, -1.3602,  0.2076, -1.2534, -1.1306, -0.5187,  1.2214,
        -0.4591, -0.9941, -0.1706, -1.4173, -1.7470,  0.5154, -1.1679, -1.1675,
        -0.5192, -0.0352, -0.7464,  0.3482, -0.3792,  0.9471, -0.8959, -0.9584,
        -0.4296, -0.8550, -0.3929,  0.7639,  1.6228,  0.2506, -0.7315,  0.1325,
         1.9646,  0.2501,  0.2555, -0.2589,  0.8957,  1.1243, -0.1054, -0.1848,
         0.4857, -2.8138,  0.7752,  0.1215,  0.8585,  0.3487,  0.3943,  0.0864,
        -0.6624, -0.0227,  1.3971, -0.3416,  0.0183, -0.8368, -0.4540,  0.1137,
        -0.6751, -0.2101, -0.9242, -0.4391, -0.7199, -0.9788,  0.2221, -0.5247,
         0.3683,  0.8214,  0.3009,  0.1045, -0.8271,  0.9233, -0.2710, -0.4134,
        -1.2654, -1.2939,  1.6894,  0.21

In [120]:
X_hat = M.backward(Z)
X_hat.view(2, 11, 11)[1]

tensor([[ 0.4215, -0.6940, -2.2752, -0.0152,  0.4294,  0.6356,  1.7125,  2.5464,
          0.0298,  1.2753,  0.8869],
        [ 0.2652,  0.0117,  0.1767,  1.4352,  1.7125, -0.1051,  0.7302,  2.5798,
          0.3721,  1.1857,  1.3273],
        [-0.0507, -0.8287, -0.1616, -0.2671, -0.0553, -0.1914,  0.5102,  1.2566,
          1.0282, -0.4295, -0.1043],
        [-0.7745, -0.1292,  0.9471,  1.0000, -0.5373, -1.0983,  0.5154,  0.2941,
          0.3589,  1.0686,  1.6228],
        [-0.3145,  0.8919,  0.2807,  0.7780, -0.9890, -0.3129, -1.8920,  1.6228,
          0.4688, -0.0866,  0.2354],
        [ 0.3986,  1.9646,  0.3941,  0.8585,  0.4519, -0.2497,  1.1243,  0.8579,
          0.9146,  1.6154,  1.6047],
        [ 0.6650,  0.8316,  0.6711, -0.0119,  0.4888,  0.4752,  0.6339,  1.3462,
          1.3805,  1.6228,  1.1591],
        [ 0.8585,  0.0646, -1.3792,  0.3981,  0.8585,  0.9233,  0.5131,  1.3971,
          1.0098,  1.0437,  0.1427],
        [ 0.7574,  0.2537, -0.1676, -0.9752,  0.2323,  0

In [121]:
Z_hat = M.forward(X_hat)
Z_hat

torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])


tensor([[0.9875, 0.9875, 1.6501, 2.1717, 1.7338, 1.0826, 1.1299, 2.7441, 2.1717,
         1.0545, 1.2819, 1.2819, 2.5133, 2.5133, 1.0084, 1.1158, 2.0937, 1.5041,
         0.8947, 0.6577, 2.5549, 2.0937, 2.4980, 2.4980, 0.8659],
        [0.4215, 1.7125, 1.7125, 2.5798, 1.3273, 0.9471, 1.0000, 0.5154, 1.6228,
         1.6228, 1.9646, 0.8585, 1.1243, 1.6228, 1.6228, 0.8585, 0.8585, 0.9233,
         1.3971, 1.6894, 1.6660, 0.6067, 0.9233, 1.0012, 1.6894]],
       device='cuda:0')

In [122]:
Z

tensor([[0.9875, 0.9875, 1.6501, 2.1717, 1.7338, 1.0826, 1.1299, 2.7441, 2.1717,
         1.0545, 1.2819, 1.2819, 2.5133, 2.5133, 1.0084, 1.1158, 2.0937, 1.5041,
         0.8947, 0.6577, 2.5549, 2.0937, 2.4980, 2.4980, 0.8659],
        [0.4215, 1.7125, 1.7125, 2.5798, 1.3273, 0.9471, 1.0000, 0.5154, 1.6228,
         1.6228, 1.9646, 0.8585, 1.1243, 1.6228, 1.6228, 0.8585, 0.8585, 0.9233,
         1.3971, 1.6894, 1.6660, 0.6067, 0.9233, 1.0012, 1.6894]],
       device='cuda:0')

In [105]:
X_hat_hat = M.backward(Z_hat)
X_hat_hat.view(2, 11, 11)[0]

tensor([[ 24.0000,  23.5045,  23.5058,  26.0000,  24.3403,  27.9868,  27.8526,
          28.7478,  29.9449,  31.4097,  31.2622],
        [ 23.6969,  23.9655,  22.4380,  25.7051,  25.5659,  28.0000,  26.5468,
          28.8559,  28.7481,  31.6039,  31.9620],
        [ 23.5923,  23.9763,  23.4521,  24.7231,  25.4507,  25.9755,  27.6326,
          29.3905,  30.0000,  31.7688,  32.0000],
        [ 45.0696,  45.8001,  45.6782,  48.0000,  47.9187,  50.0000,  48.7045,
          51.2083,  52.0000,  51.6875,  53.5254],
        [ 46.0000,  45.7557,  45.9206,  47.8226,  47.9991,  49.0587,  48.3122,
          50.9986,  51.3033,  53.6216,  54.0000],
        [ 67.9560,  67.0166,  67.4944,  69.5665,  69.8349,  71.9635,  69.9864,
          73.0992,  72.8426,  73.5723,  75.9999],
        [ 68.0000,  66.8994,  66.9654,  70.0000,  68.1620,  70.6817,  72.0000,
          73.2720,  74.0000,  76.0000,  75.0762],
        [ 87.2095,  89.9965,  89.9415,  92.0000,  91.2262,  94.0000,  93.2638,
          95.4544,

In [75]:
_max = Z.max() + 1

In [24]:
N = MaxPoolingLayer(18*18, 3)
Z = torch.rand(300, 6*6) * 3

In [25]:
X = N.backward(Z)
Z_tilde = N.forward(X)

In [26]:
torch.allclose(Z, Z_tilde)

True

In [27]:
X_tilde = Z.view(-1, 6, 6).repeat_interleave(3, dim=2).repeat_interleave(3, dim=1).flatten()

In [28]:
Z_tilde_tilde = N.forward(X_tilde)
Z_tilde_tilde

tensor([[2.0187, 1.5905, 1.4425,  ..., 0.1042, 0.5838, 0.9130],
        [1.8340, 0.8886, 2.1207,  ..., 0.9819, 1.3652, 1.3764],
        [2.1103, 2.0345, 1.8622,  ..., 1.8676, 0.1135, 1.0842],
        ...,
        [0.8945, 0.2138, 0.4180,  ..., 1.0716, 2.7263, 1.3650],
        [0.4388, 0.1131, 1.9491,  ..., 1.5017, 1.8399, 0.9055],
        [1.2034, 0.8644, 0.8090,  ..., 0.7842, 2.2281, 1.2368]],
       device='cuda:0')

In [29]:
torch.allclose(Z, Z_tilde_tilde)

True

In [77]:
stride = 3
hop = 2
width = 11
size = 11 ** 2
out_width = (width - stride) // hop + 1

In [78]:
Z = Z.view(-1, out_width, out_width)
Z

tensor([[[ 24.,  26.,  28.,  30.,  32.],
         [ 46.,  48.,  50.,  52.,  54.],
         [ 68.,  70.,  72.,  74.,  76.],
         [ 90.,  92.,  94.,  96.,  98.],
         [112., 114., 116., 118., 120.]],

        [[145., 147., 149., 151., 153.],
         [167., 169., 171., 173., 175.],
         [189., 191., 193., 195., 197.],
         [211., 213., 215., 217., 219.],
         [233., 235., 237., 239., 241.]]], device='cuda:0')

In [82]:
max_masks = torch.full((2, out_width * out_width, width, width), _max)

batch_indices = torch.arange(len(Z))
for i in range(out_width):
    for j in range(out_width):
        max_masks[batch_indices, j + i * (out_width), (hop * i):(hop * i + stride), (hop * j):(hop * j + stride)] = Z[batch_indices, i, j].unsqueeze(1).unsqueeze(2)

In [87]:
max_masks[0, 6]

tensor([[242., 242., 242., 242., 242., 242., 242., 242., 242., 242., 242.],
        [242., 242., 242., 242., 242., 242., 242., 242., 242., 242., 242.],
        [242., 242.,  48.,  48.,  48., 242., 242., 242., 242., 242., 242.],
        [242., 242.,  48.,  48.,  48., 242., 242., 242., 242., 242., 242.],
        [242., 242.,  48.,  48.,  48., 242., 242., 242., 242., 242., 242.],
        [242., 242., 242., 242., 242., 242., 242., 242., 242., 242., 242.],
        [242., 242., 242., 242., 242., 242., 242., 242., 242., 242., 242.],
        [242., 242., 242., 242., 242., 242., 242., 242., 242., 242., 242.],
        [242., 242., 242., 242., 242., 242., 242., 242., 242., 242., 242.],
        [242., 242., 242., 242., 242., 242., 242., 242., 242., 242., 242.],
        [242., 242., 242., 242., 242., 242., 242., 242., 242., 242., 242.]],
       device='cuda:0')

In [53]:
min_matrix = max_masks.min(dim=1)[0]
min_matrix

tensor([[[ 24.,  24.,  24.,  26.,  26.,  28.,  28.,  30.,  30.,  32.,  32.],
         [ 24.,  24.,  24.,  26.,  26.,  28.,  28.,  30.,  30.,  32.,  32.],
         [ 24.,  24.,  24.,  26.,  26.,  28.,  28.,  30.,  30.,  32.,  32.],
         [ 46.,  46.,  46.,  48.,  48.,  50.,  50.,  52.,  52.,  54.,  54.],
         [ 46.,  46.,  46.,  48.,  48.,  50.,  50.,  52.,  52.,  54.,  54.],
         [ 68.,  68.,  68.,  70.,  70.,  72.,  72.,  74.,  74.,  76.,  76.],
         [ 68.,  68.,  68.,  70.,  70.,  72.,  72.,  74.,  74.,  76.,  76.],
         [ 90.,  90.,  90.,  92.,  92.,  94.,  94.,  96.,  96.,  98.,  98.],
         [ 90.,  90.,  90.,  92.,  92.,  94.,  94.,  96.,  96.,  98.,  98.],
         [112., 112., 112., 114., 114., 116., 116., 118., 118., 120., 120.],
         [112., 112., 112., 114., 114., 116., 116., 118., 118., 120., 120.]],

        [[145., 145., 145., 147., 147., 149., 149., 151., 151., 153., 153.],
         [145., 145., 145., 147., 147., 149., 149., 151., 151., 153., 153.

In [12]:
max_masks.shape

torch.Size([2, 25, 11, 11])

In [13]:
min_matrix.shape

torch.Size([2, 11, 11])

In [14]:
block_mask = torch.isclose(min_matrix.unsqueeze(1), max_masks)

In [15]:
_rand = torch.rand(block_mask.shape)

In [16]:
_rand[~block_mask] = -1

In [17]:
_argmax = _rand.flatten(start_dim=2).argmax(dim=2)
_argmax

tensor([[ 12,  25,   6,  19,   9,  33,  37,  38,  40,  54,  68,  69,  71,  74,
          65,  78,  80,  82,  96,  86, 110, 103, 105, 107, 109],
        [  2,   3,  28,   7,  20,  46,  48,  39,  52,  53,  55,  70,  61,  63,
          76,  90,  80,  94,  95,  87, 101, 103, 115, 118, 108]],
       device='cuda:0')

In [42]:
noise_mask = torch.ones((2, 11 * 11), dtype=torch.bool)

In [44]:
batch_indices = torch.arange(len(Z))
for idx in _argmax.T:
    noise_mask[batch_indices, idx] = False

In [47]:
noise_mask.reshape(2, 11, 11)

tensor([[[ True,  True,  True,  True,  True,  True, False,  True,  True, False,
           True],
         [ True, False,  True,  True,  True,  True,  True,  True, False,  True,
           True],
         [ True,  True,  True, False,  True,  True,  True,  True,  True,  True,
           True],
         [False,  True,  True,  True, False, False,  True, False,  True,  True,
           True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          False],
         [ True,  True, False, False,  True, False,  True,  True, False,  True,
           True],
         [ True, False,  True, False,  True, False,  True,  True,  True, False,
           True],
         [ True,  True,  True,  True,  True,  True,  True,  True, False,  True,
           True],
         [ True,  True,  True,  True, False,  True, False,  True, False,  True,
          False],
         [False,  Tr

In [48]:
distr = torch.distributions.half_normal.HalfNormal(1)

In [51]:
noise = -distr.sample(noise_mask.shape)

In [58]:
min_matrix.flatten(start_dim=1)

tensor([[ 24.,  24.,  24.,  26.,  26.,  28.,  28.,  30.,  30.,  32.,  32.,  24.,
          24.,  24.,  26.,  26.,  28.,  28.,  30.,  30.,  32.,  32.,  24.,  24.,
          24.,  26.,  26.,  28.,  28.,  30.,  30.,  32.,  32.,  46.,  46.,  46.,
          48.,  48.,  50.,  50.,  52.,  52.,  54.,  54.,  46.,  46.,  46.,  48.,
          48.,  50.,  50.,  52.,  52.,  54.,  54.,  68.,  68.,  68.,  70.,  70.,
          72.,  72.,  74.,  74.,  76.,  76.,  68.,  68.,  68.,  70.,  70.,  72.,
          72.,  74.,  74.,  76.,  76.,  90.,  90.,  90.,  92.,  92.,  94.,  94.,
          96.,  96.,  98.,  98.,  90.,  90.,  90.,  92.,  92.,  94.,  94.,  96.,
          96.,  98.,  98., 112., 112., 112., 114., 114., 116., 116., 118., 118.,
         120., 120., 112., 112., 112., 114., 114., 116., 116., 118., 118., 120.,
         120.],
        [145., 145., 145., 147., 147., 149., 149., 151., 151., 153., 153., 145.,
         145., 145., 147., 147., 149., 149., 151., 151., 153., 153., 145., 145.,
         145

In [64]:
output:torch.Tensor = min_matrix.flatten(start_dim=1) + (noise * noise_mask)

In [67]:
output.view(2, 11, 11)


tensor([[[ 22.1889,  23.2967,  23.7507,  24.5962,  24.6887,  27.6577,  28.0000,
           29.2626,  28.4096,  32.0000,  30.7142],
         [ 23.2467,  24.0000,  23.9909,  25.7106,  24.5697,  26.4020,  27.4834,
           28.7218,  30.0000,  31.3557,  31.2244],
         [ 23.0809,  23.9219,  23.2633,  26.0000,  25.0351,  26.8057,  26.6938,
           29.5222,  28.4734,  31.9631,  31.4082],
         [ 46.0000,  45.9254,  45.8251,  46.0863,  48.0000,  50.0000,  49.7254,
           52.0000,  51.5926,  53.7272,  53.6193],
         [ 45.0906,  45.4238,  45.5665,  47.6253,  47.3361,  49.4593,  49.1894,
           50.3918,  50.9083,  53.8746,  54.0000],
         [ 66.4170,  67.4385,  65.6814,  69.5357,  68.5218,  71.3216,  70.1659,
           73.3019,  73.8194,  74.3874,  76.0000],
         [ 67.9787,  65.7458,  68.0000,  70.0000,  69.8597,  72.0000,  70.0402,
           73.2468,  74.0000,  75.2969,  75.7988],
         [ 89.3915,  90.0000,  89.3454,  92.0000,  91.2445,  94.0000,  93.5268,
   