In [1]:
import os
print(os.listdir())

['.git', 'DiscreteVariationalParameterizationsDeepV2.py', 'GibbsSampling.py', 'README.md', 'Symmetric_Exclusion_Process _Simulator.ipynb', 'DiscreteVariationalParameterizationsDeepV3.py', 'Quantum_Transformer.ipynb', 'Quantum_Brickworks_Circuit_Generator.ipynb', 'QuantumSimulatorDataset.py', 'dense_small.param', 'Mutual_Information_Maximizing_Model.ipynb', 'Mutual_Information_Transformer.ipynb', 'quantum_experiments', '.ipynb_checkpoints']


In [2]:
!pip install torch
!pip install qiskit-aer
!pip install qiskit
!pip install pylatexenc

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [None]:
from re import X
from math import e
import torch
import torch.nn as nn
import torch.nn.functional as F
import DiscreteVariationalParameterizationsDeepV3 as DVP
from torch.autograd.functional import vjp
from torch.autograd.function import Function
from QuantumSimulatorDataset import QuantumSimulationDatasetFast, generate_circuit_params
from GibbsSampling import BatchedConditionalGibbsSampler, BatchedConditionalDoubleGibbsSampler

# uncomment if mounting google drive
directory_path = ''

class EmbeddingMI3(nn.Module):
    def __init__(self, batch_size, in_dim, out_dim, num_ones):
        super().__init__()
        self.encoder = DVP.BoltzmannBasedEncoder(in_dim=in_dim, out_dim=out_dim)
        self.decoder = DVP.EnergyBasedDecoder(in_dim=out_dim, out_dim=in_dim, num_ones=num_ones)
        self.num_ones = num_ones
        self.embedding_dynamics = DVP.EnergyBasedModelEmbeddingDynamics(dim=out_dim)
        self.loss_func = MutualInformationLossV3.apply
        self.embedding_sampler = BatchedConditionalGibbsSampler(batch_size=batch_size, num_samples=256, # needs to be tuned
                                                                mixing_time=5, # seems like this can be low and still work
                                                                joint_distribution=self.embedding_dynamics)
        self.decoder_sampler = BatchedConditionalDoubleGibbsSampler(batch_size=batch_size, num_samples=256, # needs to be tuned
                                                                mixing_time=24, # seems like this can be low and still work
                                                                joint_distribution=self.decoder, dim=in_dim, num_ones=self.num_ones)
    def test_objective_function(self, x, y):
        w = self.encoder.encoder_sample(x).detach()
        z = self.encoder.encoder_sample(y).detach()
        # print(w, z, x, y)

        w_tilde = self.embedding_sampler.run_batched_gibbs(z).detach()
        x_tilde = self.decoder_sampler.run_batched_gibbs(w).detach()

        return -self.loss_func(self.num_ones, *(z, y, w, x, w_tilde, x_tilde), *self.encoder.params(), *self.decoder.params(), *self.embedding_dynamics.params())

class MutualInformationLossV3(Function):
    @staticmethod
    def forward(ctx, *inputs):
        num_ones = inputs[0]
        zywx_w_tilde_ins = inputs[1:7]
        encoder_params = inputs[7:11]
        decoder_params = inputs[11:19]
        embedding_params = inputs[19:27]

        z, y, w, x, _, _ = zywx_w_tilde_ins

        #print(x[0], '|' ,y[0])
        #print(w[0], '|' ,z[0])

        p_x_w_estimate = DVP.EnergyBasedDecoder.estimated_conditional_log_probability_a_given_b(x, w, num_ones, *decoder_params)
        p_w_x = DVP.BoltzmannBasedEncoder.conditional_log_probability_a_given_b_params(w, x, *encoder_params)
        r_w_z_estimate = DVP.EnergyBasedModelEmbeddingDynamics.estimated_normalized_log_probabilities_w_given_z_params(z, w, *embedding_params)

        out = p_x_w_estimate - p_w_x + r_w_z_estimate
        ctx.num_ones = num_ones
        ctx.save_for_backward(*zywx_w_tilde_ins, *encoder_params, *decoder_params, *embedding_params, r_w_z_estimate, out)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        num_ones = ctx.num_ones
        z, y, w, x, w_tilde, x_tilde = ctx.saved_tensors[0:6]
        encoder_params = ctx.saved_tensors[6:10]
        decoder_params = ctx.saved_tensors[10:18]
        embedding_params = ctx.saved_tensors[18:26]
        r_w_z = ctx.saved_tensors[26]
        MI = ctx.saved_tensors[27]

        decoder_unnormalized_probs = lambda x, w, *params: DVP.EnergyBasedDecoder.unnormalized_log_probs_a_given_b_params(num_ones, x, w, *params)
        decoder_expected_unnormalized_probs = lambda x_tilde, w, *params: DVP.EnergyBasedDecoder.expected_unnormalized_log_probs_a_given_b(num_ones, x_tilde, w, *params)

        _, decoder_grad_1 = vjp(decoder_unnormalized_probs, (x, w, *decoder_params), grad_output, create_graph=False)
        _, decoder_grad_2 = vjp(decoder_expected_unnormalized_probs, (x_tilde, w.expand(x_tilde.shape[0], -1, -1), *decoder_params), grad_output, create_graph=False)

        decoder_grad = tuple(map(lambda x, y: x - y, decoder_grad_1[2:], decoder_grad_2[2:]))

        #_, encoder_grad_term_1 = vjp(DVP.BoltzmannBasedEncoder.conditional_log_probability_a_given_b_params, (w, x, encoder_params[0], encoder_params[1]), grad_output * (MI - 1), create_graph=False)
        _, encoder_grad_term_1 = vjp(DVP.BoltzmannBasedEncoder.conditional_log_probability_a_given_b_params, (w, x, *encoder_params), grad_output * (MI - 1), create_graph=False)
        encoder_grad_term_1 = encoder_grad_term_1[2:]

        #_, encoder_grad_term_2 = vjp(DVP.BoltzmannBasedEncoder.conditional_log_probability_a_given_b_params, (z, y, encoder_params[0], encoder_params[1]), grad_output * r_w_z, create_graph=False)
        _, encoder_grad_term_2 = vjp(DVP.BoltzmannBasedEncoder.conditional_log_probability_a_given_b_params, (z, y, *encoder_params), grad_output * r_w_z, create_graph=False)
        encoder_grad_term_2 = encoder_grad_term_2[2:]

        encoder_grad = tuple(map(lambda x, y: x + y, encoder_grad_term_1, encoder_grad_term_2))

        _, embedding_grad_1 = vjp(DVP.EnergyBasedModelEmbeddingDynamics.unnormalized_log_probs_w_given_z_params, (z, w, *embedding_params), grad_output, create_graph=False)
        _, embedding_grad_2 = vjp(DVP.EnergyBasedModelEmbeddingDynamics.expected_unnormalized_log_probs_w_given_z, (z.expand(w_tilde.shape[0], -1, -1), w_tilde, *embedding_params), grad_output, create_graph=False)

        embedding_grad = tuple(map(lambda x, y: x - y, embedding_grad_1[2:], embedding_grad_2[2:]))

        #print(len(encoder_grad), len(decoder_grad), len(embedding_grad))

        return None, None, None, None, None, None, None, *encoder_grad, *decoder_grad, *embedding_grad

def run_dim_red_process(device, state_space, embedding_space_size, batch_size=256, num_steps=10000):

    model = EmbeddingMI3(batch_size, state_space, embedding_space_size, num_ones=4)

    # Path to the state dictionary file
    state_dict_path = 'quantum_experiments/initializer.model'
    specific_dict_path = f'quantum_experiments/experiment_{state_space}_{embedding_space_size}.model'

    # Check if the state dictionary file exists
    # if os.path.exists(directory_path + specific_dict_path):
    #     # Load the state dictionary
    #     state_dict = torch.load(specific_dict_path, map_location=device)
    #     # Get the current state dictionary of the model
    #     old_state_dict = model.state_dict()
    #     # Modify the state dictionary to match the embedding_space_size
    #     old_state_dict['encoder.b'] = state_dict['encoder.b'][:, :embedding_space_size]
    #     old_state_dict['encoder.W'] = state_dict['encoder.W'][:, :embedding_space_size, :]
    # 
    #     # Load the adjusted state dictionary into the model
    #     model.load_state_dict(old_state_dict, strict=False)
    #     print("| Successfully loaded initializer model", specific_dict_path)
    #else:
        #print(f"State dictionary file '{state_dict_path}' does not exist. Continuing without loading pre-trained weights.")

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    #params = generate_circuit_params(12,12)
    params = generate_circuit_params(file_name = directory_path + 'dense_small.param')
    dataset = QuantumSimulationDatasetFast(params, batch_size, 4, device, inverse_density=3)

    for i, (final_state, initial_state) in enumerate(dataset):
        optimizer.zero_grad()
        loss, actual_loss = model.test_objective_function(initial_state, final_state)
        loss = loss.mean()
        actual_loss = actual_loss.mean()
        loss.backward()
        if i % 10 == 0:
            print('| Iteration', i, 'I(W,Z) > ', f"{-loss.detach().cpu().item():,.5f}")
        optimizer.step()
        if i > num_steps:
            print('Training Terminated')
            break
        if i % 1000 == 999:
            torch.save(model.state_dict(), f'quantum_experiments/experiment_{state_space}_{embedding_space_size}_{i}.model')

    # Save the model at the end of training
    final_save_path = f'quantum_experiments/experiment_{state_space}_{embedding_space_size}_final.model'
    torch.save(model.state_dict(), final_save_path)
    print(f"Final model saved to {final_save_path}")

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('Device Running: ', device)

    # experiments = [
    #     (12, 2, 1024), (12, 3, 1024), (12, 4, 1024), (12, 5, 1024),
    #     (12, 6, 1024), (12, 7, 1024), (12, 8, 1024), (12, 9, 1024),
    #     (12, 10, 1024), (12, 11, 1024), (12, 12, 1024)
    # ]
    
    experiments = [
        (12, 12, 1024), (12, 2, 1024), (12, 4, 1024), (12, 8, 1024)
    ]

    for i, params in enumerate(experiments):
        print('Running Experiment', i, params)
        print('State Space of', params[0], 'to Embedding Space of', params[1])
        run_dim_red_process(device, *params)


Device Running:  cuda
Running Experiment 0 (12, 2, 1024)
State Space of 12 to Embedding Space of 2
| Successfully loaded initializer model quantum_experiments/experiment_12_2.model
| Iteration 0 I(W,Z) >  -6.21998  I(X,Y) >  -0.000000
| Iteration 10 I(W,Z) >  -6.22121  I(X,Y) >  -0.000000
| Iteration 20 I(W,Z) >  -6.20680  I(X,Y) >  -0.000000
| Iteration 30 I(W,Z) >  -6.21895  I(X,Y) >  -0.000000
| Iteration 40 I(W,Z) >  -6.20875  I(X,Y) >  -0.000000
| Iteration 50 I(W,Z) >  -6.20711  I(X,Y) >  -0.000000
| Iteration 60 I(W,Z) >  -6.20879  I(X,Y) >  -0.000000
| Iteration 70 I(W,Z) >  -6.20923  I(X,Y) >  -0.000000
| Iteration 80 I(W,Z) >  -6.20776  I(X,Y) >  -0.000000
| Iteration 90 I(W,Z) >  -6.20781  I(X,Y) >  -0.000000
| Iteration 100 I(W,Z) >  -6.20542  I(X,Y) >  -0.000000
| Iteration 110 I(W,Z) >  -6.20969  I(X,Y) >  -0.000000
| Iteration 120 I(W,Z) >  -6.20448  I(X,Y) >  -0.000000
| Iteration 130 I(W,Z) >  -6.20383  I(X,Y) >  -0.000000
| Iteration 140 I(W,Z) >  -6.20569  I(X,Y) >  