In [2]:
import torch
from matplotlib import pyplot as plt

import sys
sys.path.insert(0, '..')

from survae import SurVAE
from survae.data import Dataset
from survae.layer import BijectiveLayer, AbsoluteUnit, OrthonormalLayer, SortingLayer, MaxPoolingLayer, MaxTheLayer


In [3]:
from abc import ABC, abstractmethod

import torch.nn as nn
from torch import exp, tanh, log
import numpy as np


In [4]:


class Layer(nn.Module, ABC):
    """
    Abstract class used as the framework for all types of layers used in the SurVAE-Flow architecture.

    All layers are defined in the inference direction, i.e. the 'forward' method sends elements of the
    data space X to the latent space Z, whereas the 'backward' method goes from Z to X.
    """

    @abstractmethod
    def forward(self, X: torch.Tensor, condition: torch.Tensor | None = None, return_log_likelihood: bool = False):
        """
        Computes the forward pass of the layer and, optionally, the log likelihood contribution as a scalar (i.e. already summed).
        """
        pass

    @abstractmethod
    def backward(self, Z: torch.Tensor, condition: torch.Tensor | None = None):
        """
        Computes the backward pass of the layer.
        """
        pass

    @abstractmethod
    def in_size(self) -> int | None:
        pass

    @abstractmethod
    def out_size(self) -> int | None:
        pass

    def make_conditional(self, size: int):
        pass



In [17]:


class MaxPoolingLayer(Layer):
    # TODO now: implement ohne Überlappungen
    def __init__(self, size: int, stride: int, lam = 0.1):
        super().__init__()

        self.size = np.sqrt(size).astype(int) 

        assert self.size % stride == 0, "Stride must be a divisor of size!"
        self.stride = stride

        self.lam = lam

        self.index_probs = torch.tensor([1 / self.stride for _ in range(self.stride)])


    def forward(self, X: torch.Tensor, condition: torch.Tensor | None = None, return_log_likelihood: bool = False):

        X = X.view(self.size, self.size) # reshape to 2D
        
        l = []
        for i in range(self.stride):
            for j in range(self.stride):
                l.append(X[i::self.stride,j::self.stride])

        combined_tensor = torch.stack(l, dim=0)
        Z, _ = torch.max(combined_tensor, dim=0)
        return Z.view(-1)

    def backward(self, Z: torch.Tensor, condition: torch.Tensor | None = None):
        exp_distr = torch.distributions.exponential.Exponential(self.lam)
        Z = Z.view((self.out_size(), self.out_size()))

        # expand matrix containing local maxima 
        X_hat = Z.repeat_interleave(self.stride,dim=0).repeat_interleave(self.stride,dim=1)

        # sample values in (- infty, 0]) with exponential distribution
        exp_distr = torch.distributions.exponential.Exponential(self.lam)
        samples = -exp_distr.sample(X_hat.shape)


        # mask for the indices of the local maxima

        k = torch.distributions.categorical.Categorical(self.index_probs) 
        i_indices = k.sample((self.out_size()**2,))
        j_indices = k.sample((self.out_size()**2,))

        index_mask = torch.ones_like(X_hat)

        for I in range(self.out_size()):
            for J in range(self.out_size()):
                index_mask[I*self.stride + i_indices[I*self.out_size()+J], J*self.stride + j_indices[I*self.out_size()+J]] = 0

        X_hat = X_hat + samples * index_mask
        
        return X_hat.view(-1)

    def in_size(self) -> int | None:
        return self.size

    def out_size(self) -> int | None:
        return int(self.size / self.stride)

In [18]:
flattenedsize = 28*28
in_size = 28
size = np.sqrt(flattenedsize).astype(int)
stride = 4
out_size = int((in_size / stride))


MP = MaxPoolingLayer(flattenedsize, stride)
X = torch.arange(flattenedsize)

Z = MP.forward(X)

X_hat = MP.backward(Z)