In [1]:
initial_path = 'peptide-QML'
# initial_path = '..'

In [2]:
import sys, time, pickle
sys.path.append(initial_path)

%load_ext autoreload
%autoreload 2
from my_code import helper_classes as c
from my_code import quantum_nodes as q
# from my_code import pytorch_model as m

In [3]:
def time_left(time_start, n_epochs_total, n_batches_total, n_epochs_done, current_batch):
    time_left = (time.time() - time_start) * (n_epochs_total*n_batches_total / (n_epochs_done*n_batches_total + current_batch) - 1)
    total_hours = int(time_left // 3600)
    total_minutes = int((time_left - total_hours * 3600) // 60)
    total_seconds = int(time_left - total_hours * 3600 - total_minutes * 60)

    # remaining time for the current epoch
    time_left_epoch = (time.time() - time_start) / (n_epochs_done*n_batches_total + current_batch) * (n_batches_total - current_batch)
    epoch_hours = int(time_left_epoch // 3600)
    epoch_minutes = int((time_left_epoch - epoch_hours * 3600) // 60)
    epoch_seconds = int(time_left_epoch - epoch_hours * 3600 - epoch_minutes * 60)

    return epoch_hours, epoch_minutes, epoch_seconds, total_hours, total_minutes, total_seconds

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Define the Score Predictor model
class ScorePredictor(nn.Module):
    def __init__(self, latent_dim:int):
        super(ScorePredictor, self).__init__()
        
        self.quantum_layer = q.circuit(
            n_qubits = int(q.np.ceil(q.np.log2(latent_dim))),
            device = "default.qubit.torch",
            device_options = {'shots': None},
            embedding = q.parts.AmplitudeEmbedding,
            # embedding_ansatz = sweep_point['ansatz'],
            block_ansatz = q.parts.Ansatz_11,
            final_ansatz = q.parts.Ansatz_11, 
            measurement = q.parts.Measurement('Z', 1),
            # embedding_n_layers = sweep_point['embedding_n_layers'],
            # different_inputs_per_layer = False,
            block_n_layers = 10,
            # wrapper_qlayer = pw.QLayerEmpty,
        )()
        # self.quantum_layer = nn.Linear(latent_dim, 1)
        self.post_quantum = nn.Linear(1, 1)

    @staticmethod
    def loss_function(predicted_socre, real_socre, reduction:str='mean'):
        score_loss = F.mse_loss(predicted_socre, real_socre.float(), reduction=reduction)
        return score_loss
    
    def forward(self, x):
        x = self.quantum_layer(x)
        x = self.post_quantum(x)
        x = x.squeeze(-1)
        return x

# helper function
def dense_sequence(in_dim, out_dim, n_dense_layers, activation):
    """
    dense_sequence(100, 53, 3, nn.ReLU) returns:
    [   Linear(in_features=100, out_features=153, bias=True), ReLU(),
        Linear(in_features=153, out_features=153, bias=True), ReLU(),
        Linear(in_features=153, out_features=53,  bias=True)             ]
    """

    steps = q.np.linspace(in_dim, out_dim, n_dense_layers+1, dtype=int) if in_dim != out_dim else [in_dim] * (n_dense_layers+1)
    steps = [(i,j) for i,j in zip(steps[:-1], steps[1:])]
    sequence = []
    for in_units, out_units in steps[:-1]:
        sequence += [nn.Linear(in_units, out_units), activation()]
    sequence += [nn.Linear(*steps[-1])]

    return nn.Sequential(*sequence)

def dense_sequence_XL(in_dim, out_dim, n_dense_layers, activation):
    """
    dense_sequence_XL(100, 53, 3, nn.ReLU) returns:
    [   Linear(in_features=100, out_features=84, bias=True), ReLU(),
        Linear(in_features=84,  out_features=68, bias=True), ReLU(),
        Linear(in_features=68,  out_features=53, bias=True)             ]
    """

    sequence = []
    for i, (s_dim, f_dim) in enumerate([(in_dim, in_dim + out_dim), (in_dim + out_dim, out_dim)]):
        
        steps = q.np.linspace(s_dim, f_dim, n_dense_layers//2+1, dtype=int) if in_dim != out_dim else [in_dim] * (n_dense_layers+1)
        steps = [(i,j) for i,j in zip(steps[:-1], steps[1:])]     

        for in_units, out_units in (steps[:-1] if i == 1 else steps):
            sequence += [nn.Linear(in_units, out_units), activation()]

        if i == 0:  sequence += [nn.Linear(out_units, out_units), activation()] if n_dense_layers % 2 != 0 else []
        else:       sequence += [nn.Linear(*steps[-1])]

    return nn.Sequential(*sequence)

# Define the VAE model
class VAE(nn.Module):

    N_EMB = 18
    RNN_types_dict = {'LSTM': nn.LSTM, 'GRU': nn.GRU, 'RNN': nn.RNN}

    def __init__(self, emb_dim:int, latent_dim:int, output_dim:int, n_dense_layers:int, RNN_type:str, RNN_units:list, bidirectional:bool=True, dropout:float=0, num_layers:int=2, one_hot:bool=False):
        super(VAE, self).__init__()

        # Define the hyperparameters
        self.emb_dim = emb_dim
        self.latent_dim = latent_dim
        self.output_dim = output_dim
        self.n_dense_layers = n_dense_layers
        self.RNN_type = RNN_type
        self.RNN_units = RNN_units if isinstance(RNN_units, list) else [RNN_units]
        self.bidirectional = bidirectional
        self.dropout = dropout
        self.num_layers = num_layers
        self.one_hot = one_hot
        # TODO: add temperature for softmax

        # Encoder and Decoder
        self.encoder = VAE.Encoder(emb_dim, latent_dim, n_dense_layers, RNN_type, RNN_units, bidirectional, dropout, num_layers)
        self.decoder = VAE.Decoder(latent_dim, output_dim, n_dense_layers, RNN_type, RNN_units[::-1], bidirectional, dropout, num_layers, one_hot)

    # Define the Encoder
    class Encoder(nn.Module):
        def __init__(self, emb_dim:int, latent_dim:int, n_dense_layers:int, RNN_type:str, RNN_units:list, bidirectional:bool, dropout:float, num_layers:int):
            super(VAE.Encoder, self).__init__()

            # Define the hyperparameters
            rnn_layer = VAE.RNN_types_dict[RNN_type]
            lstm_out_dim = RNN_units[-1] * 2 if bidirectional else RNN_units[-1]
            self.RNN_units = RNN_units

            # layers
            self.embedding = nn.Embedding(VAE.N_EMB, emb_dim)

            in_units = emb_dim
            for i, out_units in enumerate(RNN_units):
                setattr(self, f'lstm_{i}', rnn_layer(in_units, out_units, batch_first=True, bidirectional=bidirectional, dropout=dropout, num_layers=num_layers)) 
                in_units = out_units * 2 if bidirectional else out_units

            self.fc_post = nn.Sequential(nn.Linear(lstm_out_dim, lstm_out_dim), nn.ReLU())
            self.fc_mean = dense_sequence(lstm_out_dim, latent_dim, n_dense_layers, nn.ReLU)
            self.fc_log_var = dense_sequence(lstm_out_dim, latent_dim, n_dense_layers, nn.ReLU)
        
        def forward(self, x):
            x = self.embedding(x)
            for i in range(len(self.RNN_units)):
                x, _ = getattr(self, f'lstm_{i}')(x)
            x = x[:, -1, :]  # Take the last time step output
            x = self.fc_post(x)
            z_mean = self.fc_mean(x)
            z_log_var = self.fc_log_var(x)
            return z_mean, z_log_var
        
    # Define the Encoder
    class Encoder_conv(nn.Module):
        pass
        # TODO

    # Define the Decoder
    class Decoder(nn.Module):
        def __init__(self, latent_dim:int, output_dim:int, n_dense_layers:int, RNN_type:str, RNN_units:list, bidirectional:bool, dropout:float, num_layers:int, one_hot:bool):
            super(VAE.Decoder, self).__init__()

            # Define the hyperparameters
            self.output_dim = output_dim
            self.one_hot = one_hot
            rnn_layer = VAE.RNN_types_dict[RNN_type]
            lstm_out_dim = RNN_units[-1] * 2 if bidirectional else RNN_units[-1]
            self.RNN_units = RNN_units

            # layers
            self.fc_pre = dense_sequence(latent_dim, latent_dim, n_dense_layers, nn.ReLU)

            in_units = latent_dim
            for i, out_units in enumerate(RNN_units):
                setattr(self, f'lstm_{i}', rnn_layer(in_units, out_units, batch_first=True, bidirectional=bidirectional, dropout=dropout, num_layers=num_layers)) 
                in_units = out_units * 2 if bidirectional else out_units

            # output_dim = output_dim if not one_hot else VAE.N_EMB*output_dim
            output_dim = VAE.N_EMB if one_hot else output_dim
            self.fc_post = dense_sequence(lstm_out_dim, output_dim, n_dense_layers, nn.ReLU)

        def forward(self, x):
            x = self.fc_pre(x)
            x = x.unsqueeze(1).repeat(1, self.output_dim, 1)
            
            for i in range(len(self.RNN_units)):
                x, _ = getattr(self, f'lstm_{i}')(x)
                
            if self.one_hot:                
                out_reshape = x.contiguous().view(-1, x.size(-1))
                y0 = F.softmax(self.fc_post(out_reshape), dim=1)
                x = y0.contiguous().view(x.size(0), -1, y0.size(-1))
            else:
                x = self.fc_post(x[:, -1, :])
                x = F.relu(x)
            return x

    @staticmethod
    def reparameterize(z_mean, z_log_var):
        epsilon = 1e-2*torch.randn_like(z_mean)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon
    
    @staticmethod
    def loss_function(reconstructed, x, mu=None, logvar=None, one_hot:bool=False, reduction:str='sum', kl_weight:float=1.0):
        if not one_hot:
            reconstruction_loss = F.mse_loss(reconstructed, x.float(), reduction=reduction)
        else:
            reconstruction_loss = F.cross_entropy(reconstructed.view(-1, VAE.N_EMB), x.view(-1).long(), reduction=reduction)
            
        if mu is None or logvar is None or kl_weight == 0:
            return reconstruction_loss
        
        kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  
        return reconstruction_loss + kl_divergence*kl_weight
    
    @staticmethod
    def process_decoded_outputs(reconstructed, one_hot:bool=False):
        if not one_hot:
            return torch.round(reconstructed).int()
        else:
            return torch.argmax(reconstructed, dim=-1).int()
    
    def forward(self, x):
        z_mean, z_log_var = self.encoder(x)
        z = self.reparameterize(z_mean, z_log_var)
        decoded_sequence = self.decoder(z)
        return decoded_sequence, z_mean, z_log_var


In [5]:
sweep_points = [
    {
        'emb_dim': 200,
        'latent_dim': 16,
        'n_dense_layers': 1,
        'RNN_type': 'LSTM',
        'RNN_units': [500],
        'bidirectional': True,
        'dropout': 0,
        'num_layers': 2,
        'one_hot': False,
        'n_epochs': 500,
        'lr': 3e-4,
        'save_name': 'vae_2_1.pickle',
        'score_predictor': True,
    },
    {
        'emb_dim': 200,
        'latent_dim': 64,
        'n_dense_layers': 1,
        'RNN_type': 'LSTM',
        'RNN_units': [1000],
        'bidirectional': True,
        'dropout': 0,
        'num_layers': 2,
        'one_hot': False,
        'n_epochs': 500,
        'lr': 3e-4,
        'save_name': 'vae_2_2.pickle',
        'score_predictor': True,
    },
    {
        'emb_dim': 200,
        'latent_dim': 64,
        'n_dense_layers': 2,
        'RNN_type': 'LSTM',
        'RNN_units': [1000],
        'bidirectional': True,
        'dropout': 0,
        'num_layers': 2,
        'one_hot': False,
        'n_epochs': 500,
        'lr': 3e-4,
        'save_name': 'vae_2_3.pickle',
        'score_predictor': True,
    },
    {
        'emb_dim': 200,
        'latent_dim': 16,
        'n_dense_layers': 1,
        'RNN_type': 'GRU',
        'RNN_units': [500],
        'bidirectional': True,
        'dropout': 0,
        'num_layers': 3,
        'one_hot': False,
        'n_epochs': 500,
        'lr': 3e-4,
        'save_name': 'vae_2_4.pickle',
        'score_predictor': True,
    },
]

In [6]:
for sweep_point in sweep_points:

    print('\n\n\n ## ----- NEW SWEEP POINT ----- ## \n') 
    print(sweep_point, '\n')   

    # Create the VAE model and get data
    device = "cuda:1"
    # device = "cpu"

    data = c.Data.load(initial_path=initial_path, file_name='PET_SCORES_12')
    data.set_test_ptc(0.1)
    data.to(device)
    data_loader = data.get_loader(batch_size=64, shuffle=True)

    # Define models
    vae_model = VAE(
        emb_dim = sweep_point['emb_dim'],
        latent_dim = sweep_point['latent_dim'],
        output_dim = len(data.x_train[0]), #12
        n_dense_layers = sweep_point['n_dense_layers'],
        RNN_type = sweep_point['RNN_type'],
        RNN_units = sweep_point['RNN_units'],
        bidirectional = sweep_point['bidirectional'],
        dropout = sweep_point['dropout'],
        num_layers = sweep_point['num_layers'],
        one_hot = sweep_point['one_hot'],
    ).to(device)
    score_predictor = ScorePredictor(vae_model.latent_dim).to(device)

    # Define your optimizer
    optimizer_vae = optim.Adam(vae_model.parameters(), lr=sweep_point['lr'])
    optimizer_score = optim.Adam(score_predictor.parameters(), lr=sweep_point['lr'])

    #training
    vae_model.losses = {
        'batch': [],
        'epoch': [],
        'test': [],
        'score': [],
        'accuracy': [],
    }
    time_start, n_epochs, n_batches = time.time(), sweep_point['n_epochs'], len(data_loader)
    for epoch in range(n_epochs):
        
        # train score predictor?
        if sweep_point['score_predictor']:
            train_score_predictor = torch.rand(1).item() < 1/4 if epoch > 2 else True
            print(f'Epoch {epoch+1}/{n_epochs}, \t train score predictor = {train_score_predictor}')
        else:
            train_score_predictor = False

        vae_model.train()
        for i, (x, y) in enumerate(data_loader):

            #train
            optimizer_vae.zero_grad()
            x_reconstructed, mu, logvar = vae_model(x)
            loss_vae = VAE.loss_function(x_reconstructed, x, mu, logvar, one_hot=vae_model.one_hot, reduction='sum')

            if train_score_predictor:
                optimizer_score.zero_grad()
                pred_score = score_predictor(VAE.reparameterize(mu, logvar)) 
                loss_score = ScorePredictor.loss_function(pred_score, y, reduction='sum')
                loss = loss_vae + loss_score
            else:
                loss = loss_vae

            loss.backward()
            optimizer_vae.step()
            if train_score_predictor: optimizer_score.step()

            #time and print
            h, m, s, th, tm, ts = time_left(time_start, n_epochs, n_batches, epoch, i+1)

            #loss
            loss = loss.item()
            vae_model.losses['batch'].append(loss)
            
            #print
            print(f'Epoch {epoch+1}/{n_epochs}, batch {i+1}/{n_batches}, loss={loss/data_loader.batch_size:.4f}, \t total time left = {th}h {tm}m {ts}s, \t epoch time left = {h}h {m}m {s}s                         ', end='\r')

        #validation
        vae_model.eval()
        with torch.no_grad():
            x_reconstructed, mu, logvar = vae_model(data.x_test_ptc)
            loss_test = VAE.loss_function(x_reconstructed, data.x_test_ptc, mu, logvar, one_hot=vae_model.one_hot, reduction='sum')
            vae_model.losses['test'].append(loss_test.item() / len(data.x_test_ptc))

            prediction = VAE.process_decoded_outputs(x_reconstructed, one_hot=vae_model.one_hot)
            accuracy = (prediction == data.x_test_ptc).float().mean().item()
            vae_model.losses['accuracy'].append(accuracy)

            if train_score_predictor or (epoch>0 and vae_model.losses['score'][-1]>0):
                loss_score = ScorePredictor.loss_function(score_predictor(mu), data.y_test_ptc, reduction='mean')
                vae_model.losses['score'].append(loss_score.item())
            else:
                vae_model.losses['score'].append(0)

        # print loss on test set
        vae_model.losses['epoch'].append(sum(vae_model.losses['batch'][-n_batches:]) / (n_batches*data_loader.batch_size))
        print(f"Epoch {epoch+1}/{n_epochs}, \t loss={vae_model.losses['epoch'][-1]:.6f}, \t loss test={vae_model.losses['test'][-1]:.6f}, \t accuracy test={vae_model.losses['accuracy'][-1]:.6f}, \t loss score={vae_model.losses['score'][-1]:.6f},                                                       ", end='\n')
        
        #save
        pickle.dump(vae_model, open(initial_path+'/saved/Pickle/VAE-'+sweep_point['save_name'], 'wb'))
        pickle.dump(score_predictor.state_dict(), open(initial_path+'/saved/Pickle/SP-'+sweep_point['save_name'], 'wb'))

        # early stopping if loss(epoch) doesn't improve for 10 epochs
        patience, min_delta = 5, 0.0001
        if epoch > patience:
            loss_difference = q.np.mean(vae_model.losses['batch'][-patience:]) - vae_model.losses['batch'][-1]
            if loss_difference < min_delta:
                print('Early stopping - epoch')
                break

        # cuda empty cache
        torch.cuda.synchronize()
        torch.cuda.empty_cache()




 ## ----- NEW SWEEP POINT ----- ## 

{'emb_dim': 200, 'latent_dim': 16, 'n_dense_layers': 1, 'RNN_type': 'LSTM', 'RNN_units': [500], 'bidirectional': True, 'dropout': 0, 'num_layers': 2, 'one_hot': False, 'n_epochs': 500, 'lr': 0.0003, 'save_name': 'vae_2_1.pickle', 'score_predictor': True} 





Epoch 1/500, 	 train score predictor = True
Epoch 1/500, 	 loss=1330.148796, 	 loss test=358.598327, 	 accuracy test=0.102362, 	 loss score=784.720886,                                                       


Epoch 2/500, 	 train score predictor = True
Epoch 2/500, 	 loss=1024.833877, 	 loss test=322.155840, 	 accuracy test=0.111702, 	 loss score=643.532593,                                                       


Epoch 3/500, 	 train score predictor = True
Epoch 3/500, 	 loss=873.446588, 	 loss test=297.394291, 	 accuracy test=0.118854, 	 loss score=522.468201,                                                       


Epoch 4/500, 	 train score predictor = False
Epoch 4/500, 	 loss=290.746878, 	 loss test=291.310269, 	 accuracy test=0.119554, 	 loss score=523.027405,                                                       


Epoch 5/500, 	 train score predictor = True
Epoch 5/500, 	 loss=749.059685, 	 loss test=292.430184, 	 accuracy test=0.118613, 	 loss score=421.453064,             

In [None]:
data.x_test_ptc[:10], data.y_test_ptc[:10]

In [None]:
vae_model(data.x_test_ptc[:10])[0][0]

In [None]:
vae_model.eval()
vae_model
vae_model.process_decoded_outputs(vae_model(data.x_test_ptc[:10])[0], one_hot=True)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10,10))
plt.plot(vae_model.loss_list_epoch, label='train')
plt.plot(vae_model.loss_list_epoch_test, label='test')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
vae_model.to('cpu')
data.to('cpu')

In [None]:
mu, logvar = vae_model.encoder(data.x_test)

In [None]:
x_reconstructed = vae_model.decoder(mu)
x_reconstructed = VAE.process_decoded_outputs(x_reconstructed, one_hot=vae_model.one_hot)
x_reconstructed_norm = torch.norm(x_reconstructed.float(), dim=-1)
x_norm = torch.norm(data.x_test.float(), dim=-1)

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(x_norm, x_reconstructed_norm, alpha=0.5, s=10, color='orange')
plt.plot([q.np.max(x_norm.numpy()), q.np.min(x_norm.numpy())], [q.np.max(x_norm.numpy()), q.np.min(x_norm.numpy())], color='black', linestyle='--')
plt.xlabel('Norm of original vector')
plt.ylabel('Norm of reconstructed vector')
plt.title('Norm of original vector vs norm of reconstructed vector (no noise)')
plt.show()

In [None]:
x_reconstructed_noise = vae_model.decoder(vae_model.reparameterize(mu, logvar))
x_reconstructed_noise = VAE.process_decoded_outputs(x_reconstructed_noise, one_hot=vae_model.one_hot)
x_reconstructed_noise_norm = torch.norm(x_reconstructed_noise.float(), dim=-1)

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(x_norm, x_reconstructed_noise_norm, alpha=0.5, s=10, color='orange')
plt.plot([q.np.max(x_norm.numpy()), q.np.min(x_norm.numpy())], [q.np.max(x_norm.numpy()), q.np.min(x_norm.numpy())], color='black', linestyle='--')
plt.xlabel('Norm of original vector')
plt.ylabel('Norm of reconstructed vector (with noise)')
plt.title('Norm of original vector vs norm of reconstructed vector (with noise)')
plt.show()

In [None]:
score_predicted = vae_model.score_predictor(mu)

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(data.y_test.tolist(), score_predicted.tolist(), alpha=0.5, s=10)
plt.plot([q.np.max(data.y_test.numpy()), q.np.min(data.y_test.numpy())], [q.np.max((data.y_test).numpy()), q.np.min((data.y_test).numpy())], color='black', linestyle='--')
plt.xlabel('Real score')
plt.ylabel('Predicted score')
plt.title('Predicted score vs real score')
plt.show()

In [None]:
score_predicted = vae_model.score_predictor(vae_model.reparameterize(mu, logvar))

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(data.y_test.tolist(), score_predicted.tolist(), alpha=0.5, s=10)
plt.plot([q.np.max(data.y_test.numpy()), q.np.min(data.y_test.numpy())], [q.np.max((data.y_test).numpy()), q.np.min((data.y_test).numpy())], color='black', linestyle='--')
plt.xlabel('Real score')
plt.ylabel('Predicted score')
plt.title('Predicted score vs real score (with noise)')
plt.show()

In [None]:
# find vector in laten space that gives the lowest score
vector = torch.randn(1, 16)
vector.requires_grad = True
optimizer = optim.Adam([vector], lr=0.01)

for i in range(1000):

    optimizer.zero_grad()
    score_predicted = vae_model.score_predictor(vector)
    score_predicted.backward()
    optimizer.step()

    print(f'Epoch {i+1}/1000, score={score_predicted.item():.6f}                                                       ', end='\r')

    # early stopping
    patience = 20
    min_delta = 0.01
    if i > patience:
        score_difference = q.np.mean(score_predicted.item() - score_predicted.item()) - score_predicted.item()
        if score_difference < min_delta:
            print('Early stopping')
            break
print(vector, score_predicted.item())

In [None]:
vector.requires_grad = False
vae_model.eval()

new_sequence = vae_model.decoder(vector)
new_sequence = VAE.process_decoded_outputs(new_sequence, one_hot=vae_model.one_hot)

score_new_sequence = vae_model.score_predictor(vae_model.encoder(new_sequence)[0])

print(f'New sequence: {new_sequence.tolist()}, score={score_new_sequence.item():.6f}')

In [None]:
# def check_if_tensor_in_data(data, tensor):
#     in_test = torch.any(torch.all(data.x_test == tensor, dim=-1))
#     in_train = torch.any(torch.all(data.x_train == tensor, dim=-1))
#     return in_test or in_train

# def get_random_tensor(data, length=12, max_int=17):
#     tensor = torch.randint(0, max_int+1, (length,))
#     while check_if_tensor_in_data(data, tensor):
#         tensor = torch.randint(0, max_int+1, (length,))
#     return tensor

# def get_random_tensors(data, n_tensors, length=12, max_int=17):
#     tensors = []
#     for i in range(n_tensors):
#         print(i+1, '', end='\r')
#         tensors.append(get_random_tensor(data, length, max_int))
#     return torch.stack(tensors)

# data.to('cpu')
# random_x = get_random_tensors(data, n_tensors=100000, length=12, max_int=17)
# zeros_y = torch.zeros(len(random_x))

# #change type of tensors of random_x and zeros_y to old_x.dtype and old_y.dtype
# random_x = random_x.to(data.x_train.dtype)
# zeros_y = zeros_y.to(data.y_train.dtype)

# new_x = torch.cat((data.x_train.to('cpu'), random_x.to('cpu')), dim=0).to(device)
# new_y = torch.cat((data.y_train.to('cpu'), zeros_y.to('cpu')), dim=0).to(device)

# from torch.utils.data import DataLoader, TensorDataset
# new_data_loader = DataLoader(TensorDataset(new_x, new_y), batch_size=data_loader.batch_size, shuffle=True)

# data.to(device)

In [None]:
data.x_test[:100]

In [None]:
x_reconstructed_test, mu_test, logvar_test = vae_model(data.x_test[:100])
x_reconstructed_test

In [None]:
loss = VAE.loss_function(x_reconstructed_test, data.x_test[:100], mu_test, logvar_test, one_hot=vae_model.one_hot, reduction='mean')
print(f'Loss on test set: {loss.item():.6f}')

In [None]:
pred = VAE.process_decoded_outputs(x_reconstructed_test, one_hot=vae_model.one_hot)
pred

In [None]:
accuracy = (pred == data.x_test[:100]).float().mean().item()
accuracy

In [None]:
for x, p in zip(data.x_test, pred):
    print(x.tolist(), p.tolist())

In [None]:
for x, p in zip(data.x_test, x_reconstructed_test):
    print(x.tolist(), p.tolist())