In [146]:
import json
import pennylane as qml
import pennylane.numpy as np

def beam_splitter(r):
    """
    Returns the beam splitter matrix.

    Args:
        - r (float): The reflection coefficient of the beam splitter.
    Returns:
        - (np.array): 2 x 2 matrix that represents the beam
        splitter matrix.    
    """


    T = np.sqrt(1 - r**2)
    return np.array([[r, T], [T, -r]])


dev = qml.device('default.qubit',wires =1)

@qml.qnode(dev)
def mz_interferometer(r):
    """
    This QNode returns the probability that either A or C
    detect a photon, and the probability that D detects a photon.
    
    Args:
        - r (float): The reflection coefficient of the beam splitters.
    Returns: 
        - np.array(float): An array of shape (2,), where the first 
        element is the probability of detection at A or C,
        and the second element is the probability of detection at D.
    """

    qml.Hadamard(wires=0)  # Prepare photon in superposition

    qml.QubitUnitary(beam_splitter(r), wires=0)
    qml.Hadamard(wires=0)

    mA = qml.measure(wires = 0)
    
    def true_fn():
        qml.QubitStateVector([1, 0], wires=0)
        qml.Hadamard(wires=0)
        qml.QubitUnitary(beam_splitter(r), wires=0)
        qml.Hadamard(wires=0)

    
    def false_fn():
        pass
    
    qml.cond(mA, true_fn, false_fn)()

    return qml.probs([0])



# These functions are responsible for testing the solution.


def run(test_case_input: str) -> str:
    ins = json.loads(test_case_input)
    outs = mz_interferometer(ins).tolist()
    
    return str(outs)


def check(solution_output: str, expected_output: str) -> None:
    solution_output = json.loads(solution_output)
    print("🚀 ~ solution_output:", solution_output)
    print("🚀 ~ expected_output:", expected_output)
    expected_output = json.loads(expected_output)
    assert np.allclose(solution_output,expected_output), "Not the correct probabilities"


# These are the public test cases
test_cases = [
    ('0.5', '[0.8125, 0.1875]'),
    ('0.577350269', '[0.777778, 0.222222]')
]

# This will run the public test cases locally
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '0.5'...
🚀 ~ solution_output: [0.8124999999999997, 0.1874999999999999]
🚀 ~ expected_output: [0.8125, 0.1875]
Correct!
Running test case 1 with input '0.577350269'...
🚀 ~ solution_output: [0.7777777778507643, 0.222222222149235]
🚀 ~ expected_output: [0.777778, 0.222222]
Correct!


https://faculty.csbsju.edu/frioux/q-intro/MZI-QuantumCircuit.pdf