# T distillation

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

import matplotlib
matplotlib.rcParams.update(**{
    'axes.titlesize': 14,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
    'legend.title_fontsize': 12,
    'figure.figsize': (7, 5),
})

In [None]:
from cirq_qubitization.quantum_graph.composite_bloq import \
    get_soquets, CompositeBloqBuilder, FancyRegisters

In [None]:
from cirq_qubitization.quantum_graph.graphviz import PrettyGraphDrawer
from IPython.display import SVG

def showb(cbloq):
    display(SVG(PrettyGraphDrawer(cbloq).get_graph().create_svg()))

In [None]:
import numpy as np
import networkx as nx

import quimb
import quimb.tensor as qtn
COLORS=['+', '0', 'COPY', 'XOR', 'CNOT', 'Join']

In [None]:
from cirq_qubitization.quantum_graph.quimb_sim import *
from cirq_qubitization.quantum_graph.basic_gates import *

In [None]:
from cirq_qubitization.surface_code.gosc2 import *

In [None]:
#tfact = TFactory().decompose_bloq()
cbloq = TIdentity().decompose_bloq()
#print(cbloq.debug_text())
showb(cbloq)

In [None]:
from cirq_qubitization.quantum_graph.composite_bloq import _create_binst_graph, BloqInstance
bg = _create_binst_graph(cbloq.connections)
bg

# Score

In [None]:
score = binstgraph_to_musical_score(bg)
for s in score:
    print(s)
maxwidth= max(len(gen) for gen in score)
maxwidth

def pad_gen(gen):
    return gen + (None,) * (maxwidth - len(gen))
score = [pad_gen(gen) for gen in score]
score = np.array(score)

for line in score.T:
    items = [f'{x:3d}' if x is not None else ' '*3 for x in line]
    print(' '.join(items))

In [None]:
sg = nx.Graph()

prev_xi = {}
for ti, moment in enumerate(score):
    
    seen = {}
    for xi, cell in enumerate(moment):
        if cell is None:
            continue

        me = (ti,xi)
        sg.add_node(me, i=cell)
        
        if cell in seen:
            sg.add_edge(seen[cell], me, etype='binst')
            
        if xi in prev_xi:
            sg.add_edge(prev_xi[xi], me, etype='qubit')
        
        seen[cell] = me
        prev_xi[xi] = me
        
            
pos = {(t,x): (t,-x) for t,x in sg.nodes}            
nx.draw_networkx(sg, pos=pos, labels=dict(sg.nodes.data('i')))

# Quimb

In [None]:
import networkx as nx

pos = nx.nx_agraph.graphviz_layout(bg, 'dot')
nx.draw_networkx(bg, pos=pos, labels={n: n.bloq.pretty_name() if isinstance(n, BloqInstance) else 'dang' for n in bg.nodes})

In [None]:
tn, fix = cbloq_to_quimb(bg, pos=pos)
tn.draw( color=['COPY', '+', 'Mx', 'Z'], show_tags=False, fix=fix)

In [None]:
outs = list(set(tn.all_inds()) - set(tn.inner_inds()))
outs

In [None]:
left_soqs = blow_up_soquets(cbloq.registers.lefts(), LeftDangle)['qs']
left_soqs

In [None]:
from cirq_qubitization.quantum_graph.quantum_graph import LeftDangle, RightDangle

right_soqs = blow_up_soquets(cbloq.registers.rights(), RightDangle)['qs']
right_soqs

In [None]:
unitary = tn.to_dense(right_soqs, left_soqs)
print(np.round(unitary, 3))

In [None]:
np.where(np.abs(unitary) > 1e-4)

In [None]:
np.round(unitary[np.abs(unitary)>1e-8], 4)

In [None]:
import cirq
cirq.unitary(cirq.Rz(rads=np.pi/8))

In [None]:
Z = np.array([[1,0],[0,-1]], dtype=np.complex128)
from scipy.linalg import expm
cunitary = expm(-1.j * np.pi/(8*2) * Z)
print(np.round(cunitary, 3))

In [None]:
Z = np.array([[1,0],[0,-1]], dtype=np.complex128)
from scipy.linalg import expm
cunitary = expm(-1.j * np.pi/(8*2) * np.kron(Z,Z))
print(np.round(cunitary, 3))

In [None]:
zdata = np.array(list(itertools.product([1, -1], repeat=2)))
zdata = np.product(zdata, axis=1)
zdata = zdata.reshape((2,) * 2)
zdata = np.exp(-1.j * zdata * np.pi / (8 * 2))
zdata

## fine just use cirq

In [None]:
import cirq

In [None]:
qs = np.array(cirq.LineQubit.range(5))
c = cirq.Circuit()

Zg = cirq.MatrixGate(expm(-1.j * np.pi/(8*2) * Z), qid_shape=(2,)*1)
ZZZ = cirq.MatrixGate(expm(-1.j * np.pi/(8*2) * cirq.kron(Z,Z,Z)), qid_shape=(2,)*3)
ZZZZZ = cirq.MatrixGate(expm(-1.j * np.pi/(8*2) * cirq.kron(Z,Z,Z,Z,Z)))

for i in range(5):
    c += Zg.on(qs[i])
    

c +=ZZZ.on(*qs[[1, 2, 3]])
c +=ZZZ.on(*qs[[0, 1, 2]])
c +=ZZZ.on(*qs[[0, 1, 3]])
c +=ZZZ.on(*qs[[0, 2, 3]])
c +=ZZZ.on(*qs[[0, 3, 4]])
c +=ZZZ.on(*qs[[0, 1, 4]])
c +=ZZZ.on(*qs[[0, 2, 4]])
c +=ZZZZZ.on(*qs)  # 12
c +=ZZZ.on(*qs[[2, 3, 4]])
c +=ZZZ.on(*qs[[1, 3, 4]])
c +=ZZZ.on(*qs[[1, 2, 4]])

    
from cirq.contrib.svg import SVGCircuit
SVGCircuit(c)

In [None]:
cunitary = cirq.unitary(c)
np.where(np.abs(cunitary)>1e-3)

In [None]:
np.round(cunitary[np.where(np.abs(cunitary)>1e-3)], 8)

In [None]:
np.testing.assert_allclose(np.eye(2**5), cunitary)