# NOBODY TOUCH THIS

In [1]:
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from survae import SurVAE
from survae.data import *
from survae.layer import *



In [24]:
class MaxPoolingLayerWithHop(Layer):
    '''
    MaxPoolingLayer: Layer that performs max pooling on the input data.

    size: input size, i.e. flattened picture
    width: input width, i.e. the width of the picture, sqrt of size
    hop: defines the distance between the blocks of size stride for which the maxima is taken
        we require that hop >= 1

    out_width: width of the output picture, sqrt of out_size
    out_size: size of flattened output picture

    stride: the stride of the max pooling operation 
    index_probs: probability distribution for the indices of the local maxima in the "flattened" stride square
        we assume that the indices are uniformly distributed

    Distribution choices: 
        - standard half-normal distribution (default)
        - exponential distribution 

    '''
    def __init__(self, size: int, stride: int, hop: int, exponential_distribution: bool = False, learn_distribution_parameter: bool = False):
        super().__init__()

        self.size = size

        self.width = np.sqrt(size).astype(int) 

        assert stride <= self.width, "Stride must be smaller than the width of the picture!"
        assert 0 < hop and hop <= stride, "Hop must be smaller than the stride!"
        assert (self.width - stride) % hop == 0, "Stride and hop must be chosen such that the picture is fully covered!"

        self.stride = stride
        self.hop = hop

        self.out_width = (self.width - stride) // hop + 1 # = the number of blocks considered (possible overlap)

        
        if exponential_distribution:
            self.distribution = "exponential"
            lam = torch.tensor([0.1])
            if learn_distribution_parameter:
                lam = nn.Parameter(lam)
            self.lam = lam
        else:
            self.distribution = "half-normal"
            sigma = 1
            if learn_distribution_parameter:
                sigma = nn.Parameter(sigma)
            self.sigma = sigma


        self.index_probs = torch.tensor([1 / self.stride**2 for _ in range(self.stride**2)])


    def forward(self, X: torch.Tensor, condition: torch.Tensor | None = None, return_log_likelihood: bool = False):

        X = X.view(-1, self.width, self.width) # reshape to 2D
        
        l = []
        for i in range(self.out_width):
            for j in range(self.out_width):
                l.append(X[:, i * self.hop:i * self.hop + self.stride:,j * self.hop:j * self.hop +self.stride])
        print(l)

        combined_tensor = torch.stack(l, dim=0)
        Z, _ = torch.max(combined_tensor, dim=0)

        return Z.flatten(start_dim=1)

    def backward(self, Z: torch.Tensor, condition: torch.Tensor | None = None):
        Z = Z.view(-1, self.out_width, self.out_width)
        
        # expand matrix containing local maxima (by repeating local max)
        X_hat = Z.repeat_interleave(self.stride,dim=2).repeat_interleave(self.stride,dim=1)

        # mask for the indices of the local maxima
        k = torch.distributions.categorical.Categorical(self.index_probs) 
        indices = k.sample(Z.shape)

        indices_repeated = indices.repeat_interleave(self.stride, dim=2).repeat_interleave(self.stride, dim=1)
        index_places = torch.arange(self.stride**2).reshape(self.stride, self.stride).repeat(self.out_size(), self.out_size())

        index_mask = (index_places == indices_repeated)

        # sample values in (- infty, 0]) with respective distribution
        if self.distribution == "half-normal":
            distr = torch.distributions.half_normal.HalfNormal(self.sigma)
        else:
            distr = torch.distributions.exponential.Exponential(self.lam)
        samples = -distr.sample(X_hat.shape)

        X_hat = X_hat + samples * ~index_mask
        
        return X_hat.flatten(start_dim=1)

    def in_size(self) -> int | None:
        return self.size

    def out_size(self) -> int | None:
        return int(self.out_width ** 2)

In [28]:
stride = 3
hop = 2
width = 11
size = 11 ** 2
X = torch.arange((size * 2)).float().reshape(2, size)
X = X.view(-1, width, width)
X


tensor([[[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.],
         [ 11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.],
         [ 22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.],
         [ 33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.],
         [ 44.,  45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.],
         [ 55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.],
         [ 66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.],
         [ 77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.],
         [ 88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.],
         [ 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109.],
         [110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.]],

        [[121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131.],
         [132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.

In [26]:
M = MaxPoolingLayerWithHop(size, stride, hop)
Z = M.forward(X)
Z

tensor([[ 96.,  97.,  98., 107., 108., 109., 118., 119., 120.],
        [217., 218., 219., 228., 229., 230., 239., 240., 241.]],
       dtype=torch.float32)

In [23]:
out_width

9

In [35]:
out_width = (width - stride) // hop + 1
l = []
for i in range(stride):
    for j in range(stride):
        l.append(X[:, i:i+width-stride + 2:hop,j:j+width-stride + 2:hop])

combined_tensor = torch.stack(l, dim=0)
Z, _ = torch.max(combined_tensor, dim=0)
Z


torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([2, 5, 5])


tensor([[[ 24.,  26.,  28.,  30.,  32.],
         [ 46.,  48.,  50.,  52.,  54.],
         [ 68.,  70.,  72.,  74.,  76.],
         [ 90.,  92.,  94.,  96.,  98.],
         [112., 114., 116., 118., 120.]],

        [[145., 147., 149., 151., 153.],
         [167., 169., 171., 173., 175.],
         [189., 191., 193., 195., 197.],
         [211., 213., 215., 217., 219.],
         [233., 235., 237., 239., 241.]]], dtype=torch.float32)