In [None]:
from __future__ import annotations

import pathlib
from typing import Callable, ClassVar, Type

import pytorch_lightning as lightning
import torch
import torch.utils.data
from torch import nn

# The src.vak.models prefix has to be removed in the actual implementation
from src.vak.models.registry import model_family
from src.vak.models import base
from src.vak.models.definition import ModelDefinition

In [2]:
# vak.nn.loss.vae
def vae_loss(
    x: torch.Tensor,
    z: torch.Tensor,
    x_rec: torch.Tensor,
    latent_dist: torch.Tensor,
    model_precision: float,
    z_dim: int
):

    x_dim = x.shape
    elbo = -0.5 * ( torch.sum( torch.pow(z, 2) ) + z_dim * np.log( 2 * np.pi ))
    # E_{q(z|x)} p(x|z)
    pxz_term = -0.5 * x_dim * (np.log(2 * np.pi / model_precision))
    l2s = torch.sum( torch.pow( x.view( x.shape[0], -1 ) - x_rec, 2), dim=1)
    pxz_term = pxz_term - 0.5 * model_precision * torch.sum(l2s)
    elbo = elbo + pxz_term
    # H[q(z|x)]
    elbo = elbo + torch.sum(latent_dist.entropy())
    return elbo

class VaeLoss(torch.nn.Module):
    """"""

    def __init__(
        self,
        return_latent_rec: bool = False,
        model_precision: float = 10.0,
        z_dim: int = 32
    ):
        super().__init__()
        self.return_latent_rec = return_latent_rec
        self.model_precision = model_precision
        self.z_dim = z_dim

    def forward(
        self,
        x: torch.Tensor,
        z: torch.Tensor,
        x_rec: torch.Tensor,
        latent_dist: torch.Tensor,
    ):
        x_shape = x.shape
        elbo = vae_loss(x=x, z=z, x_rec=x_rec, latent_dist=latent_dist, model_precision=self.model_precision, z_dim=self.z_dim)
        if self.return_latent_rec:
            return -elbo, z.detach().cpu().numpy(), \
                x_rec.view(-1, x_shape[0], x_shape[1]).detach().cpu().numpy()
        return -elbo


In [1]:
# vak.models.vae_model.VAEModel
@model_family
class VAEModel(base.Model):
    definition: ClassVar[ModelDefinition]
    def __init__(
        self,
        network: dict | None = None,
        loss: torch.nn.Module | Callable | None = None,
        optimizer: torch.optim.Optimizer | None = None,
        metrics: dict[str:Type] | None = None,
    ):
        super().__init__(
            network=network, loss=loss, optimizer=optimizer, metrics=metrics
        )
        self.encoder = network['encode']
        self.decoder = network['decode']

    def forward(self, x):
        return self.network(x)

    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, x):
        return self.decoder(x)

    def configure_optimizers(self):
        return self.optimizer

    @classmethod
    def from_config(
        cls, config: dict
    ):
        network, loss, optimizer, metrics = cls.attributes_from_config(config)
        return cls(
            network=network,
            optimizer=optimizer,
            loss=loss,
            metrics=metrics,
        )


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# nets.Ava
class Ava(nn.Module):
    """
    """
    def __init__(
        self,
        hidden_dims: List[int] = [8, 8, 16, 16, 24, 24, 32]
		fc_dims: List[int] = [1024, 256, 64, 32]
		in_channels: int = 1,
		in_fc: int = 8192,
		x_shape = tuple = (128, 128)
		
    ):
        """
        """
        super().__init__()
		self.in_fc = in_fc
		self.in_channels = in_channels
		self.x_shape = x_shape 
		self.x_dim = np.prod(x_shape)
		modules = []
		for h_dim in hidden_dims:
			stride = 2 if h_dim == in_channels else 1
            modules.append(
                nn.Sequential(
					nn.BatchNorm2d(in_channels),
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=3, stride=stride, padding=1),
                    nn.ReLU())
            )
            in_channels = h_dim
		
		self.encoder = nn.Sequential(*modules)
		
		modules = []
		for fc_dim in fc_dims[:-2]:
            modules.append(
                nn.Sequential(
					nn.Linear(in_fc, fc_dim),
                    nn.ReLU())
            )
            in_fc = fc_dim
		self.encoder_bottleneck = nn.Sequential(*modules)

		self.mu_layer = nn.Sequential(
			nn.Linear(fc_dims[-3], fc_dims[-2]),
            nn.ReLU(),
			nn.Linear(fc_dims[-2], fc_dims[-1]))
		
		self.u_layer = nn.Sequential(
			nn.Linear(fc_dims[-3], fc_dims[-2]),
            nn.ReLU(),
			nn.Linear(fc_dims[-2], fc_dims[-1]))
		
		self.d_layer = nn.Sequential(
			nn.Linear(fc_dims[-3], fc_dims[-2]),
            nn.ReLU(),
			nn.Linear(fc_dims[-2], fc_dims[-1]))

		fc_dims.reverse()
		modules = []
		for i in range(len(fc_dims)):
			out = self.fc_in if i == len(fc_dims) else fc_dims[i+1]
            modules.append(
                nn.Sequential(
					nn.Linear(fc_dims[i], out),
                    nn.ReLU())
            )
		self.decoder_bottleneck = nn.Sequential(*modules)
        
		hidden_dims.reverse()
		modules = []
		for i, h_dim in enumerate(hidden_dims):
			stride = 2 if h_dim == in_channels else 1
			output_padding = 1 if h_dim == in_channels else 0
            modules.append(
                nn.Sequential(
					nn.BatchNorm2d(in_channels),
                    nn.ConvTranspose2d(in_channels, out_channels=h_dim,
                              kernel_size=3, stride=stride, padding=1, output_padding=output_padding),
                    nn.ReLU() if i != len(hidden_dims))
            )
            in_channels = h_dim

		self.decoder = nn.Sequential(*modules)

	def encode(self, x):
		"""
		"""
		x = self.encoder(x.unsqueeze(self.in_channels)).view(-1, self.in_fc)
		x = self.encoder_bottleneck(x)
		mu = self.mu_layer(x)
		u = self.u_layer(x).unsqueeze(-1)
		d = torch.exp(self.d_layer(x))
		return mu, u, d


	def decode(self, z):
		"""
		"""
		z = self.decoder_bottleneck(z).view(-1,32,16,16)
		z = self.decoder(z).view(-1, x_dim)
		return z

    def reparametrize(self, mu, u, d):
        latent_dist = LowRankMultivariateNormal(mu, u, d)
		z = latent_dist.rsample()
        return z, latent_dist


	def forward(self, x, return_latent_rec=False):
		mu, u, d = self.encode(x)
		z, latent_dist = self.reparametrize(mu, u, d)
		x_rec = self.decode(z)
		return x_rec, {'z': z, 'mu': mu, 'latent_dist': latent_dist, 'u': u, 'd': d }

In [None]:
@model(family=VAEModel)
class AvaNet: # this will be renamed to Ava in implementation, just to avoid naming conflicts.
    """
    """
    network = Ava
    loss = VaeLoss
    optimizer = torch.optim.Adam
    metrics = {
        "loss": VaeLoss,
    }
    default_config = {"optimizer": {"lr": 0.003}}