import sys
import os

os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['GOTO_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

import pygsti
from pygsti.processors import QubitProcessorSpec
from pygsti.protocols import StandardGST, ProtocolData
from pygsti.models.modelconstruction import create_explicit_model
import scipy
from pygsti.objectivefns import objectivefns as _objfns
import numpy as np
from pygsti.modelpacks import smq2Q_XY as std
import time
from pygsti.forwardsims import MapForwardSimulator, MatrixForwardSimulator
from pygsti.baseobjs import ResourceAllocation

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
print(f'My Rank is: {rank}', flush=True)
if rank == 0 or comm is None:
    print(f'Number of Ranks is {size}')
    target_model = std.target_model('GLND')
    exp_design = std.create_gst_experiment_design(max_max_length=1)
    #datagen_model = target_model.depolarize(spam_noise=.01)
    
    
    datagen_model = target_model.depolarize(spam_noise=.01)

    ds  = pygsti.data.simulate_data(datagen_model, exp_design.all_circuits_needing_data, 
          num_samples=10000, seed=20230217)
    data = ProtocolData(exp_design, ds)
       
    prob_clip = 1e-8
    tol = 1e-10
    maxiter = 1
    verbosity = 4

    chi2_builder = pygsti.objectivefns.Chi2Function.builder('chi2', regularization={'min_prob_clip_for_weighting': prob_clip}, penalties={'cptp_penalty_factor': 0.0})
    mle_builder = pygsti.objectivefns.PoissonPicDeltaLogLFunction.builder('logl', regularization={'min_prob_clip': prob_clip, 'radius': prob_clip})
    iteration_builders = [chi2_builder] 
    final_builders = [mle_builder]
    builders = pygsti.protocols.GSTObjFnBuilders(iteration_builders, final_builders)
    
    target_model.from_vector(np.zeros(target_model.num_params))
    #target_model.sim = 'matrix'
    
    print('Fit Log Single Rank')
    print('-----------------------------')
    optimizer = pygsti.optimize.simplerlm.SimplerLMOptimizer(maxiter=maxiter, tol={'f':tol})
    protoOpt = pygsti.protocols.GateSetTomography(target_model, verbosity=verbosity, optimizer=optimizer, gaugeopt_suite=None, objfn_builders=builders)
    results_single_rank = protoOpt.run(data, disable_checkpointing=True)
    
    target_model.from_vector(np.zeros(target_model.num_params))
    #parallelize across parameters
    #target_model.sim = MapForwardSimulator(processor_grid=(1, size))
    #parallelize across circuits
    #target_model.sim = MapForwardSimulator(processor_grid=(size, 1))
    #parallelize across both
    target_model.sim = MapForwardSimulator(processor_grid=(int(size/2), int(size/2)))
    optimizer = pygsti.optimize.simplerlm.SimplerLMOptimizer(maxiter=maxiter, tol={'f':tol})
    protoOpt = pygsti.protocols.GateSetTomography(target_model, verbosity=verbosity, optimizer=optimizer, gaugeopt_suite=None, objfn_builders=builders)
    
else:
    data = None
    protoOpt = None

    
if comm is not None:
    protoOpt = comm.bcast(protoOpt, root=0)
    data  = comm.bcast(data, root=0)    
        
if rank ==0 or comm is None:
    print('Fit Log Four Ranks')
    print('-----------------------------')
results = protoOpt.run(data, comm=comm, disable_checkpointing=True)

