First define two classes for the Variational Autoencoder models, `trmVAE1` and `trmVAE2`. `trmVAE1` uses BERT as the encoder, while `trmVAE2` uses a Transformer as both the encoder and decoder. Both models include a reparameterization step in the forward pass, which is a key component of Variational Autoencoders.

Next, define a function `train_model` to train the model. This function takes in the model, dataloader, optimizer, scheduler, and number of epochs as arguments. It performs the forward pass, computes the loss, performs the backward pass and optimization, steps the learning rate scheduler, and increases the weight of the KL divergence term.

Finally, the `main` function is used to load and preprocess the data, create a DataLoader, initialize the model, create an optimizer, define the learning rate scheduler, and train the model.

# Step 3: Build the Model
## Solution 3
* Transformer encoder + VAE model

In [None]:
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, TensorDataset
from transformers import BertModel, BertConfig, get_linear_schedule_with_warmup
import torch.nn.functional as F

In [None]:
class MyDataset(Dataset):
    """Custom Dataset for loading the 5-dimensional points."""

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

class trmVAE1(nn.Module):
    """Variational Autoencoder with BERT as encoder."""

    def __init__(self, latent_dim, output_dim=5):
        super(trmVAE1, self).__init__()
        # Initialize the BERT model
        self.encoder = BertModel(BertConfig())
        # Define the linear layers for the mean and log variance
        self.fc_mu = nn.Linear(self.encoder.config.hidden_size, latent_dim)
        self.fc_var = nn.Linear(self.encoder.config.hidden_size, latent_dim)
        # Define the decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        mu = self.fc_mu(encoded.last_hidden_state)
        log_var = self.fc_var(encoded.last_hidden_state)
        z = self.reparameterize(mu, log_var)
        decoded = self.decoder(z)
        return decoded[:, -1, :], mu, log_var  # Only return the last output

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std) # Returns a tensor with the same size as input that is filled with random numbers from a uniform distribution on the interval 0,1
        return mu + eps*std

class trmVAE2(nn.Module):
    """Variational Autoencoder with Transformer as encoder and decoder."""

    def __init__(self, input_dim=5, latent_dim=5):
        super(trmVAE2, self).__init__()

        # Define the encoder
        encoder_layers = TransformerEncoderLayer(input_dim, 1,dim_feedforward=32) # ninp, nhead, nhid
        self.encoder = TransformerEncoder(encoder_layers, 2) # nlayers

        # Define the decoder
        decoder_layers = TransformerEncoderLayer(latent_dim, 2,dim_feedforward=128,activation='gelu')
        self.decoder = TransformerEncoder(decoder_layers, 6)

        self.fc_mu = nn.Linear(input_dim, latent_dim)
        self.fc_var = nn.Linear(input_dim, latent_dim)

        self.fc3 = nn.Linear(latent_dim, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        mu = mu.to(torch.float32)
        logvar = self.fc_var(x)
        logvar = logvar.to(torch.float32)
        z = self.reparameterize(mu, logvar)
        return self.fc3(self.decoder(z)), mu, logvar # If different latent space


# Step 4: Train the Model

In [None]:
def train_model(model, dataloader, optimizer, scheduler, epochs=100):
    """Train the model.

    Args:
        model (nn.Module): The model to be trained.
        dataloader (DataLoader): The data loader.
        optimizer (Optimizer): The optimizer.
        scheduler (lr_scheduler): The learning rate scheduler.
        epochs (int, optional): The number of epochs. Defaults to 100.
    """
    # Initialize the weight of the KL divergence term
    kld_weight = 4.0
    # Define the rate at which the weight is increased
    annealing_rate = 0.001

    # Train the model
    for epoch in range(epochs):
        for batch in dataloader:
            # Forward pass
            recon_batch, mu, log_var = model(batch.to("cuda"))
            recon_batch = recon_batch.to(torch.float32)
            batch = batch.to(torch.float32).to("cuda")
            # Compute the loss
            recon_loss = F.mse_loss(recon_batch, batch)
            kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            loss = recon_loss + kld_loss*kld_weight
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # Step the learning rate scheduler
            scheduler.step()
            # Increase the weight of the KL divergence term
            kld_weight = min(5, kld_weight + annealing_rate)
        print(f'Epoch {epoch}, Loss {loss.item()}')

In [None]:
def main():
    """Main function to build and train the model."""
    # Load and preprocess the data
    data = load_data() # Step 1&2
    data = normalize_data(data)

    # Create a DataLoader
    dataloader = DataLoader(MyDataset(data), batch_size=32)

    # Initialize the model
    # model = trmVAE1()
    model = trmVAE2()

    # Create an optimizer
    optimizer = AdamW(model.parameters(), lr=1e-3)

    # Define the learning rate scheduler
    epochs = 100
    total_steps = len(dataloader.dataset) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(total_steps/10), num_training_steps=total_steps)

    # Train the model
    train_model(model, dataloader, optimizer, scheduler, epochs)

if __name__ == "__main__":
    main()