In [22]:
import torch
import torch.nn.functional as F
import torch.nn as nn


class CNNAutoencoder(nn.Module):
    """LSTMAutoencoder

    Ensure that the model is always batched.

    Args:
        nn (_type_): _description_
    """

    def __init__(
        self,
        input_size,  # 2 for IQ
        latent_dim,
    ):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv1d(input_size, 16, kernel_size=3, stride=2, padding=1),  # (batch_size, 16, sequence_length/2)
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1),  # (batch_size, 32, sequence_length/4)
            nn.ReLU(),
            nn.Conv1d(32, latent_dim, kernel_size=3, stride=2, padding=1),  # (batch_size, latent_dim, sequence_length/8)
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(latent_dim, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch_size, 32, sequence_length/4)
            nn.ReLU(),
            nn.ConvTranspose1d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch_size, 16, sequence_length/2)
            nn.ReLU(),
            nn.ConvTranspose1d(16, input_size, kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch_size, input_size, sequence_length)
            nn.Sigmoid()
        )

    def forward(self, x):
        latent = self.encoder(x)
        output = self.decoder(latent)

        # return output, latent
        return output


In [24]:
from torchsummary import summary

model = CNNAutoencoder(2, 64).to("cuda")

summary(model, (2,64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1               [-1, 16, 32]             112
              ReLU-2               [-1, 16, 32]               0
            Conv1d-3               [-1, 32, 16]           1,568
              ReLU-4               [-1, 32, 16]               0
            Conv1d-5                [-1, 64, 8]           6,208
              ReLU-6                [-1, 64, 8]               0
   ConvTranspose1d-7               [-1, 32, 16]           6,176
              ReLU-8               [-1, 32, 16]               0
   ConvTranspose1d-9               [-1, 16, 32]           1,552
             ReLU-10               [-1, 16, 32]               0
  ConvTranspose1d-11                [-1, 2, 64]              98
          Sigmoid-12                [-1, 2, 64]               0
Total params: 15,714
Trainable params: 15,714
Non-trainable params: 0
---------------------------------