In [None]:
import sys
import json
import copy
import logging
import numpy as np
import matplotlib.pyplot as plt

from qiskit import Aer, QuantumCircuit, QuantumRegister, transpile
from qiskit.ignis.mitigation.measurement import complete_meas_cal, MeasurementFilter
from qiskit.providers.ibmq.runtime.utils import RuntimeEncoder, RuntimeDecoder
#from qiskit.providers.ibmq.runtime import UserMessenger

#logging.basicConfig(level=logging.WARNING, format='%(asctime)s: %(message)s')
#logging.getLogger('schwinger_rqd').setLevel(logging.INFO)

sys.path.append('..')
from main import make_step_circuits, run_forward_circuits, rqd_step, main
from pnp_ansatze import make_pnp_ansatz
from observables import plot_counts_with_curve
from trotter import trotter_step_circuits

## Runtime job input

In [None]:
inputs = {
    'num_sites': 4,
    'aJ': 1.,
    'am': 0.5,
    'omegadt': 0.2,
    'num_tsteps': 4,
    'tsteps_per_rqd': 2,
    'error_matrix': np.eye(16, dtype='f8'),
    'physical_qubits': None,
    'minimizer_shots_per_job': 1000,
    'minimizer_jobs': 2,
    'forward_shots': 2 * 8192,
    'max_sweeps': 100
}

logging.getLogger('four_qubit_schwinger').setLevel(logging.INFO)

## Testing just the forward steps

In [None]:
backend = Aer.get_backend('statevector_simulator')

num_sites = inputs['num_sites']
aJ = inputs['aJ']
am = inputs['am']
omegadt = inputs['omegadt']
num_tsteps = inputs['num_tsteps']
physical_qubits = inputs['physical_qubits']
error_matrix = inputs['error_matrix']
forward_shots = inputs['forward_shots']

_, state_labels = complete_meas_cal(qubit_list=list(range(num_sites)), qr=QuantumRegister(num_sites), circlabel='mcal')
error_mitigation_filter = MeasurementFilter(error_matrix, state_labels)

forward_circuits = make_step_circuits(num_sites, aJ, am, omegadt, backend, physical_qubits)

target_circuits = trotter_step_circuits(num_tsteps, forward_circuits, initial_state=None, measure=False)

counts_list = run_forward_circuits(target_circuits, backend, initial_layout=physical_qubits, shots=forward_shots, error_mitigation_filter=error_mitigation_filter)

plot_counts_with_curve(counts_list, num_sites, aJ, am, omegadt, num_tsteps, initial_state=None, num_toys=0)

## A UserMessenger class that saves all published results

In [None]:
class UserMessenger:
    """Base class for handling communication with program users.

    This class can be used when writing a new Qiskit Runtime program.
    """
    
    def __init__(self):
        self.results = []

    def publish(
            self,
            message,
            encoder=None,
            final=False
    ) -> None:
        self.results.append(message)

## Testing one RQD step

In [None]:
backend = Aer.get_backend('qasm_simulator')

num_sites = inputs['num_sites']
aJ = inputs['aJ']
am = inputs['am']
omegadt = inputs['omegadt']
tsteps_per_rqd = inputs['tsteps_per_rqd']
physical_qubits = inputs['physical_qubits']
error_matrix = inputs['error_matrix']

_, state_labels = complete_meas_cal(qubit_list=list(range(num_sites)), qr=QuantumRegister(num_sites), circlabel='mcal')
error_mitigation_filter = MeasurementFilter(error_matrix, state_labels)

forward_step_circuits = make_step_circuits(num_sites, aJ, am, omegadt, backend, physical_qubits)
    
if num_sites == 2:
    approximator = make_pnp_ansatz(
        num_qubits=num_sites,
        num_layers=num_sites // 2,
        initial_x_positions=[0])
elif num_sites == 4:
    approximator = make_pnp_ansatz(
        num_qubits=num_sites,
        num_layers=num_sites // 2,
        initial_x_positions=[1, 2],
        structure=[(1, 2), (0, 1), (2, 3)],
        first_layer_structure=[(0, 1), (2, 3)])
    
user_messenger = UserMessenger()

optimal_params = rqd_step(
    0,
    inputs,
    backend,
    forward_step_circuits,
    approximator,
    error_mitigation_filter=error_mitigation_filter,
    user_messenger=user_messenger
)

In [None]:
plt.plot(user_messenger.results[-1]['shots_values'], user_messenger.results[-1]['cost_values'])

In [None]:
plot_counts_with_curve(user_messenger.results[0]['forward_counts'], num_sites, aJ, am, omegadt, tsteps_per_rqd, initial_state=None, num_toys=0)

In [None]:
user_messenger_2 = UserMessenger()

optimal_params = rqd_step(
    1,
    inputs,
    backend,
    forward_step_circuits,
    approximator,
    optimal_params=optimal_params,
    error_mitigation_filter=error_mitigation_filter,
    user_messenger=user_messenger_2
)

In [None]:
plt.plot(user_messenger_2.results[-1]['shots_values'], user_messenger_2.results[-1]['cost_values'])

In [None]:
from hamiltonian import schwinger_model, diagonalized_evolution

vacuum_state = np.zeros(2 ** num_sites, dtype=np.complex128)
vacuum_state_index = 0
for j in range(0, num_sites, 2):
    vacuum_state_index += (1 << j)
vacuum_state[vacuum_state_index] = 1.

hamiltonian = schwinger_model(num_sites, aJ, am)
_, statevectors = diagonalized_evolution(hamiltonian, vacuum_state, omegadt * tsteps_per_rqd)

plot_counts_with_curve(user_messenger_2.results[0]['forward_counts'], num_sites, aJ, am, omegadt, tsteps_per_rqd, initial_state=statevectors[:, -1], num_toys=0)

In [None]:
plot_counts_with_curve(user_messenger.results[0]['forward_counts'] + user_messenger_2.results[0]['forward_counts'], num_sites, aJ, am, omegadt, tsteps_per_rqd * 2, initial_state=None, num_toys=0)

## Testing the main function

In [None]:
class UserMessengerForwardOnly:
    """Base class for handling communication with program users.

    This class can be used when writing a new Qiskit Runtime program.
    """
    
    def __init__(self):
        self.results = []

    def publish(
            self,
            message,
            encoder=None,
            final=False
    ) -> None:
        if 'forward_counts' not in message:
            return
        
        self.results.append(message)

In [None]:
backend = Aer.get_backend('qasm_simulator')

serialized_inputs = json.dumps(inputs, cls=RuntimeEncoder)
deserialized_inputs = json.loads(serialized_inputs, cls=RuntimeDecoder)

user_messenger_fw = UserMessengerForwardOnly()

main(backend, user_messenger_fw, **deserialized_inputs)

In [None]:
forward_counts = []
for res in user_messenger_fw.results:
    forward_counts += res['forward_counts']
    
plot_counts_with_curve(forward_counts, inputs['num_sites'], inputs['aJ'], inputs['am'], inputs['omegadt'], inputs['num_tsteps'], initial_state=None, num_toys=0)

In [None]:
interim_result = user_messenger.results[-3]
print(interim_result)

In [None]:
inputs['resume_from'] = {
    'rqd_step': interim_result['rqd_step'],
    'state': interim_result['state'],
    'minimizer_state': interim_result['minimizer_state']
}

serialized_inputs = json.dumps(inputs, cls=RuntimeEncoder)
deserialized_inputs = json.loads(serialized_inputs, cls=RuntimeDecoder)

user_messenger_fw_2 = UserMessengerForwardOnly()

main(backend, user_messenger_fw_2, **deserialized_inputs)

In [None]:
forward_counts = copy.deepcopy(user_messenger_fw.results[0]['forward_counts'])
for res in user_messenger_fw_2.results:
    forward_counts += res['forward_counts']

plot_counts_with_curve(forward_counts, inputs['num_sites'], inputs['aJ'], inputs['am'], inputs['omegadt'], inputs['num_tsteps'], initial_state=None, num_toys=0)