In [1]:
# define a state class such that it can be used to measure some or all qubits, evolve a random brick work circuit
# compute expectation value of a given Pauli string
# apply a specific gate on it, could be supplied as a full unitary or 
# state attributes : norm, full array
# state.measure([Z Z Z])
# state.bwm_evolve(n_steps):
# state.apply(gate)
# state.reset()
# state.compute_ee( ) # entanglement entropy

In [133]:
import numpy as np
from itertools import product

X=np.array([[0,1],[1,0]],dtype=complex)
Y=np.array([[0,-1j],[1j,0]],dtype=complex)
Z=np.array([[1,0],[0,-1]],dtype=complex)

class qstate:
    def __init__(self,n=1,random=False,arr=[]):
        
        self.nqubits=n
        
        if(random==True):
            self.arr=np.random.rand(2**n)+1j*np.random.rand(2**n)
            self.arr=self.arr/np.linalg.norm(self.arr)
            
        elif(random==False and len(arr)==0):
            self.arr=np.zeros(2**n,dtype=complex)
            self.arr[0]=1
            
        elif(len(arr)!=0):
            self.arr=np.array(arr,dtype=complex)/np.linalg.norm(arr)
            self.nqubits=int(np.log2(len(arr)))
            
        self.norm=np.linalg.norm(self.arr)
    
    
    def apply_1gate(self,gate,i=1,inplace=True):

        #apply gate on qubit i
        
        Ntot=2**self.nqubits # total dimesnions of the Hilbert space

        arr=self.arr.copy()
    
        # say if m ranges from 0 to Ntot-1, mp ranges from 0 to Ntot/2 -1 since ith position bit string is fixed

        L=self.nqubits
        
        alpha,beta=0,1 # two states which can be generalized for qudits later
        
        for mp in range(Ntot//2):

            # suppose mp = m1'+m2', where m1'=b1x2**(L-2)+b2x2**(L-3)+... +b_{i-1}x2**(L-i), m2'=b_{i+1}2**(L-i-1)+...+b_{L-1}2**(1)+b_{L}q**(0)
            # then by definition m1' = (2**(L-i))*mp//2**(L-i)
            # and m2'=mp-m1'
            m1p=(2**(L-i))*(mp//(2**(L-i)))
            m2p=mp-m1p

            # now the two states which will be modified are ones which have alpha and beta at ith position

            alpha_ind=2*m1p+alpha*(2**(L-i))+m2p

            beta_ind=2*m1p+beta*(2**(L-i))+m2p

            # modifying the state at the corresponding indices

            arr[alpha_ind]=self.arr[alpha_ind]*gate[0,0]+gate[0,1]*self.arr[beta_ind]
            arr[beta_ind]=self.arr[beta_ind]*gate[1,1]+gate[1,0]*self.arr[alpha_ind]

        if(inplace==True):
            self.arr=arr
            self.norm=np.linalg.norm(self.arr)
            return None
        else:
            return arr
    
    #def compute_ee(self,)
        
    def measure(self,basis_list=['Z'],only_subsystem=False):
        

        if(basis_list==['I']*self.nqubits): #if no qubits are being measured
            return None
        
        
        # if non-standard, assume all qubits are being measured in Z basis
        if(len(basis_list)!=self.nqubits):
            basis_list=['Z']*self.nqubits
        
        
        # if basis element 'I', qubit is NOT measured
        meas_qubits=[]

        for i in range(self.nqubits):
            
            if(basis_list[i]=='X'):
                self.arr.apply_1gate(X,i)
                meas_qubits.append(i)
            elif(basis_list[i]=='Y'):
                self.arr.apply_1gate(Y,i)
                meas_qubits.append(i+1)
                
            elif(basis_list[i]=='Z'):
                meas_qubits.append(i)
                
        all_qubits=[i for i in range(self.nqubits)]
        
        
        n_meas=len(meas_qubits)

        traced_qubits=np.setdiff1d(np.arange(self.nqubits),meas_qubits)
        print("traced",traced_qubits)
        
        shape_arr=np.repeat(2,self.nqubits)
        psi=np.reshape(self.arr,shape_arr)

        if(n_meas!=self.nqubits):
            prob_arr=np.sum(np.abs(psi)**2,axis=tuple(traced_qubits)) #tracing out qubits that won't be measured
        
        else:
            prob_arr=np.abs(psi)**2

        # reshape to sample integer
        prob_arr=np.reshape(prob_arr,2**n_meas)

        meas_int=np.random.choice(range(2**n_meas),p=prob_arr) #choosing the computational basis state randomly after measurement based on 
        meas_str = format(meas_int,'0'+str(n_meas)+'b')

        
        # initialising collapsed state on the unmeasured system
        if(only_subsystem):
            psi_b=np.zeros(2**(self.nqubits-n_meas),dtype="complex")

        else:
            psi_b=np.zeros(2**self.nqubits,dtype="complex")
            

        # find the full collapsed state including the unmeasured qubits
        for i in range(2**(self.nqubits-n_meas)):
            
            # find the full binary string inserting the measured string
            bin_i=format(i,'0'+str(self.nqubits-n_meas)+'b')
            bin_full=['0']*self.nqubits
            
            for j,m in enumerate(meas_qubits):
                bin_full[m]=meas_str[j]#str(meas_ind[j])
                
            for j,m in enumerate(traced_qubits):
                bin_full[m]=str(bin_i[j])

            new_i=int("".join(bin_full),2)

            if(only_subsystem):
                psi_b[i]=self.arr[new_i]
                
            else:
                psi_b[new_i]=self.arr[new_i]
                
            
        
        # update number of qubits and state array
        if(only_subsystem):
            self.nqubits=self.nqubits-n_meas
                             
        self.arr=psi_b/np.linalg.norm(psi_b) # normalize                    
        
        return meas_str

In [136]:
state1=qstate(n=10,random=True)
print(state1.arr)
print("number of qubits",state1.nqubits)

X=np.array([[0,1],[1,0]])
#state1.apply_1gate(X,inplace=False)
#print('state after')
#print(state1.arr)
print("measured state from return ",state1.measure(['Z','Z','Z','Z','Z'],only_subsystem=False))
print("state after collapse",state1.arr)

[0.00140177+0.03231905j 0.0161872 +0.00603294j 0.01490798+0.01961563j ...
 0.03447461+0.00575293j 0.03425638+0.03502216j 0.02838612+0.0233663j ]
number of qubits 10
traced []
measured state from return  1110001011
state after collapse [0.+0.j 0.+0.j 0.+0.j ... 0.+0.j 0.+0.j 0.+0.j]
