In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from crism_classifier.VAE_classifier_248 import EncoderX, EncoderY, DecoderX, Classifier, Prior

In [2]:
class VAEClassifier(nn.Module):
    def __init__(
        self, n_blocks=3, n_conv_layers=1, zx_dim=16, zy_dim=16, n_classes=38
    ):
        """Hybrid VAE and Classifier network.

        Parameters
        ----------
        n_blocks : int
            Number of convolutional up and downsampling blocks
            in the encoder and decoder.
            Must be at least 1, and less than 6.
            Default = 3.
        n_conv_layers : int
            Number of convolutional blocks between each downsampling block.
            Default is 1.
        zx_dim : int
            Dimension of the noisy/entangled latent space.
            Default is 16.
        zy_dim : int
            Dimension of the clean/disentangled latent space.
            Default is 16.
        n_classes : int
            Number of classes to predict.
            Default is 37.
        """
        super(VAEClassifier, self).__init__()
        self.n_blocks = n_blocks
        self.n_conv_layers = n_conv_layers
        self.zx_dim = zx_dim
        self.zy_dim = zy_dim
        self.n_classes = n_classes

        self.encoder_x = EncoderX(
            self.n_blocks, self.n_conv_layers, self.zx_dim
        )
        self.encoder_y = EncoderY(
            self.n_blocks, self.n_conv_layers, self.zy_dim
        )
        self.decoder = DecoderX(
            self.n_blocks, self.n_conv_layers, self.zx_dim + self.zy_dim
        )
        self.classifier = Classifier(self.n_classes, self.zy_dim)
        self.prior_x = Prior(self.zx_dim)
        self.prior_y = Prior(self.zy_dim)

    def forward(self, x):
        """Run a full forward pass on input x.

        Parameters
        ----------
        x : torch.Tensor
            Input spectra to run through the network.

        Returns
        -------
        x_recon : torch.Tensor
            Reconstructed input spectra.
        y_pred : torch.Tensor
            Predicted mineral class probabilities.
        """
        mu_e_x, log_var_e_x = self.encoder_x(x)
        mu_e_y, log_var_e_y = self.encoder_y(x)
        zx = mu_e_x + torch.exp(0.5 * log_var_e_x) * torch.randn_like(mu_e_x)
        zy = mu_e_y + torch.exp(0.5 * log_var_e_y) * torch.randn_like(mu_e_y)
        x_recon = self.decoder(zx, zy)
        y_pred = self.classifier(mu_e_y)
        y_pred = F.softmax(y_pred, dim=1)
        return x_recon, y_pred

In [3]:
rand_x = torch.randn(1024, 1, 248)

model = VAEClassifier(n_blocks=1, n_conv_layers=1, zx_dim=16, zy_dim=16, n_classes=38)
model.load_state_dict(torch.load("/home/rob_platt/CRISM_classifier_application/data/v3_248_1_1_zx16_zy16_beta20_epoch100_weights.pth", map_location=torch.device('cpu')))
model.eval()

x_recon, y_pred = model(rand_x)

In [4]:
onnx_program = torch.onnx.export(model, rand_x, "vae_classifier_1024.onnx")

verbose: False, log level: Level.ERROR

