### HOMEWORK 6 - DENSITY MATRICES ###

# Density Matrices

Consider a quantum system composed of $ N $ subsystems (spins, atoms, particles, etc.) each described by a wave function $ \psi_i \in \mathcal{H}_D $, where $ \mathcal{H}_D $ is a $ D $-dimensional Hilbert space. How do you write the total wave function of the system $ \Psi(\psi_1, \psi_2, \ldots, \psi_N) $?

## Tasks

1. **Write Code**  
   (a) Write a code (in Fortran or Python) to describe the composite system in the case of an $ N $-body non-interacting, separable pure state.  
   (b) Write a code for the case of a general $ N $-body pure wave function $ \Psi \in \mathcal{H}_{D^N} $.  

2. **Efficiency**  
   (c) Comment on and compare the efficiency of the implementations for parts (a) and (b).  

3. **Density Matrix**  
   (d) Given $ N = 2 $, write the density matrix of a general pure state $ \Psi $:  
   $$
   \rho = |\Psi\rangle\langle\Psi|
   $$  

4. **Reduced Density Matrix**  
   (e) Given a generic density matrix of dimension $ D^N \times D^N $, compute the reduced density matrix of either the left or the right system, e.g.,  
   $$
   \rho_1 = \text{Tr}_2 \rho
   $$  

5. **Testing**  
   (f) Test the functions described in parts (a)–(e) (and all necessary functions) on a two-spin one-half system (qubits) with different states.


In [4]:
## IMPORTS
import aux 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

composite_state = np.array([np.tensordot(i, j, axes=0) for i in input_state for j in input_state if not np.array_equal(i, j) 
                                and np.where(input_state == i) < np.where(input_state == j)])

In [None]:
# a) non-interacting separable states

def separable_state(N:int, D:int, verb: int=0, input_state: np.ndarray=None):
    """This function takes a list of wavefunctions of subsystems and returns the separable state of the composite system.

    Args:
        N (int): number of subsytems.
        D (int): dimension of each subsytem (they all have the same dimension).
        verb (int, optional): verbosity. Default at 0.
        input_state (np.ndarray of np.ndarray, optional): list of wavefunctions of subsystems.

    Raises:

        TypeError: raises a type error if N is not an int.
        TypeError: raises a type error if D is not an int.
        TypeError: raises a type error if verb is not an int.
        TypeError: raises a type error if the input is not an ndarray.
        ValueError: raises a value error if the input is not an int.

    Returns:
        composite_state (np.ndarray of np.ndarray): wavefunction of the separable composite state.
    """
    if not isinstance(N, int):
        raise TypeError(f'N must be an int and not {type(N)}.')
    
    if not isinstance(D, int):
        raise TypeError(f'D must be an int and not {type(D)}.')
    
    if not isinstance(verb, int):
        raise TypeError(f'verb should be an int, not a {type(verb)}.')
    
    if not isinstance(input_state, np.ndarray):
        raise TypeError(f'input_state be an ndarray, not a {type(input_state)}.')
    
    if verb != 0 and verb != 1:
        raise ValueError(f'verb values supported are only 0 and 1, but {verb} was given.')
    
    states = np.zeros((N,D))
    
    if input_state is not None:
        states = input_state / np.linalg.norm(input_state)
        states = states.astype(complex)
        
        composite_state = states
    
    else:
        for i,_ in enumerate(states):
            states[i] = np.random.rand(D) + 1j * np.random.rand(D)
            
        states /= np.linalg.norm(states)
        
        for i,_ in enumerate(states):
            composite_state = np.kron(states[i], states[i+1])
            
    # check norm
    if not np.isclose(np.linalg.norm(composite_state),1.0,1.e-6):
        aux.checkpoint(True, verb=3, msg='The wavefunction is not normalized', var=np.linalg.norm(composite_state)) 
    
    if verb == 1:
        print('--------------------')
        print('COMPLEXITY ANALYSIS')
        print('---------------------')
        print(f'D = {D}\n\
                N = {N}\n\
                COMPLEXITY = {N*(2*D-2)}')
    
    return composite_state

# B) general state of dimensions (N,D)

def general_state(N: int, D: int, verb:int=0, input_state: np.ndarray=None):
    """This function generates a random state of a general composite system of N subsystems of dimension D.

    Args:
        N (int): number of subsytems.
        D (int): dimension of each subsytem (they all have the same dimension).
        verb (int, optional): verbosity. Default at 0.
        input_state (np.ndarray of np.ndarray, optional): list of wavefunctions of subsystems.

    Raises:
        TypeError: raises a type error if N is not an int.
        TypeError: raises a type error if D is not an int.
        TypeError: raises a type error if verb is not an int.
        TypeError: raises a type error if the input is not an ndarray.
        ValueError: raises a value error if the input is not an int.
        
    Returns:
        psi_normalized: wavefunction of the general composite state.
    """
    
    if not isinstance(N, int):
        raise TypeError(f'N must be an int and not {type(N)}.')
    
    if not isinstance(D, int):
        raise TypeError(f'D must be an int and not {type(D)}.')
    
    if not isinstance(input_state, np.ndarray) and input_state != None:
        raise TypeError(f'state must be an ndarray and not {type(input_state)}.')
    
    if N < 1:
        raise ValueError(f'N must be strictly positive.')
    
    if D < 1:
        raise ValueError(f'D must be strictly positive.')
    
    total_dim = D**N
    
    # generate random states 
    if input_state is not None:
        psi = input_state 
            
    else:  
        psi = np.random.rand(total_dim) + 1j * np.random.rand(total_dim)
        
    psi_normalized = psi / np.linalg.norm(psi)
            
    # check norm
    if not np.isclose(np.linalg.norm(psi_normalized),1.0,1.e-6):
        aux.checkpoint(True, verb=3, msg='The wavefunction is not normalized', var=np.linalg.norm(psi_normalized))
        
    if verb == 1:
        print('--------------------')
        print('COMPLEXITY ANALYSIS')
        print('---------------------')
        print(f'D = {D}\n\
                N = {N}\n\
                COMPLEXITY = {2*(D**N-2)}') 

    return psi_normalized