In [None]:
import os
import sys
from datetime import datetime

import numpy as np
import sympy as sp
from typing import Union
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from numba import jit

%matplotlib inline

## Parameter Setup

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

RANDOM_SEED = 20210701
LEARNING_RATE = 0.001
BATCH_SIZE = 32
N_EPOCHS = 5

IMG_SIZE = 32
N_CLASSES = 10

## Model Architecture

In [None]:
class LeNet5(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
            nn.Tanh()
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=n_classes),
        )


    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probs = F.softmax(logits, dim=1)
        return logits, probs

## Helper Functions

In [None]:
# define transform
transform = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor()])

# download and create datasets
download_flag = False if os.path.exists('data/MNIST') else False

train_dataset = datasets.MNIST(root='data/MNIST', 
                               train=True, 
                               transform=transform,
                               download=True)

valid_dataset = datasets.MNIST(root='data/MNIST', 
                               train=False, 
                               transform=transform)

# define the data loaders
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)

valid_loader = DataLoader(dataset=valid_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=False)

In [None]:
def get_accuracy(model, data_loader, device):
    '''
    Function for computing the accuracy of the predictions over the entire data_loader
    '''
    
    correct_pred = 0 
    n = 0
    
    with torch.no_grad():
        model.eval()
        for X, y_true in data_loader:

            X = X.to(device)
            y_true = y_true.to(device)

            _, y_prob = model(X)
            _, predicted_labels = torch.max(y_prob, 1)

            n += y_true.size(0)
            correct_pred += (predicted_labels == y_true).sum()

    return correct_pred.float() / n

def plot_losses(train_losses, valid_losses):
    '''
    Function for plotting training and validation losses
    '''
    
    # temporarily change the style of the plots to seaborn 
    # plt.style.use('seaborn')
    plt.style.use('ggplot')

    train_losses = np.array(train_losses) 
    valid_losses = np.array(valid_losses)

    fig, ax = plt.subplots(figsize = (8, 4.5))

    ax.plot(train_losses, color='blue', label='Training loss') 
    ax.plot(valid_losses, color='red', label='Validation loss')
    ax.set(title="Loss over epochs", 
            xlabel='Epoch',
            ylabel='Loss') 
    ax.legend()
    fig.show()
    
    # change the plot style to default
    plt.style.use('default')

In [None]:
def train(train_loader, model, criterion, optimizer, device):
    '''
    Function for the training step of the training loop
    '''

    model.train()
    running_loss = 0
    
    for X, y_true in train_loader:

        optimizer.zero_grad()
        
        X = X.to(device)
        y_true = y_true.to(device)
    
        # Forward pass
        y_hat, _ = model(X) 
        loss = criterion(y_hat, y_true) 
        running_loss += loss.item() * X.size(0)

        # Backward pass
        loss.backward()
        optimizer.step()
        
    epoch_loss = running_loss / len(train_loader.dataset)
    return model, optimizer, epoch_loss

In [None]:
def validate(valid_loader, model, criterion, device):
    '''
    Function for the validation step of the training loop
    '''
   
    model.eval()
    running_loss = 0
    
    for X, y_true in valid_loader:
    
        X = X.to(device)
        y_true = y_true.to(device)

        # Forward pass and record loss
        y_hat, _ = model(X) 
        loss = criterion(y_hat, y_true) 
        running_loss += loss.item() * X.size(0)

    epoch_loss = running_loss / len(valid_loader.dataset)
        
    return model, epoch_loss

In [None]:
def training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every=1):
    '''
    Function defining the entire training loop
    '''
    
    # set objects for storing metrics
    best_loss = 1e10
    train_losses = []
    valid_losses = []
 
    # Train model
    for epoch in range(0, epochs):

        # training
        model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device)
        train_losses.append(train_loss)

        # validation
        with torch.no_grad():
            model, valid_loss = validate(valid_loader, model, criterion, device)
            valid_losses.append(valid_loss)

        if epoch % print_every == (print_every - 1):
            
            train_acc = get_accuracy(model, train_loader, device=device)
            valid_acc = get_accuracy(model, valid_loader, device=device)
                
            print(f'{datetime.now().time().replace(microsecond=0)} --- '
                  f'Epoch: {epoch}\t'
                  f'Train loss: {train_loss:.4f}\t'
                  f'Valid loss: {valid_loss:.4f}\t'
                  f'Train accuracy: {100 * train_acc:.2f}\t'
                  f'Valid accuracy: {100 * valid_acc:.2f}')

    plot_losses(train_losses, valid_losses)
    
    return model, optimizer, (train_losses, valid_losses)

## Training

In [None]:
torch.manual_seed(RANDOM_SEED)

model = LeNet5(N_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [None]:
model, optimizer, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, N_EPOCHS, DEVICE)

# Unitary Decomposition

## Helper Functions

In [None]:
@jit(nopython=True)
def atan2f(y, x, tolerance=1e-6, to_degree=False):
    zero_y = np.abs(y) <= tolerance
    zero_x = np.abs(x) <= tolerance
    if zero_x and zero_y:
        rad = 0
    elif zero_x and (not zero_y):
        rad = np.pi/2 if y > tolerance else -np.pi/2
    elif (not zero_x) and zero_y:
        rad = 0 if x > tolerance else np.pi
    else:
        rad = np.arctan2(y, x)
    if to_degree:
        return np.rad2deg(rad)
    else:
        return rad
    

@jit(nopython=True)
def angle_diff(comp_src, comp_dst, offset=0, tolerance=1e-6, wrap=True, to_degree=False):
    zero_src = np.abs(comp_src) <= tolerance
    zero_dst = np.abs(comp_dst) <= tolerance
    if zero_src and zero_dst:
        rad = 0
    elif zero_src and (not zero_dst):
        rad = np.angle(comp_dst)
    elif (not zero_src) and zero_dst:
        rad = -np.angle(comp_src)
    else:
        rad = np.angle(comp_dst) - np.angle(comp_src)
    rad += offset
    if wrap:
        rad = np.mod(rad, 2 * np.pi)
    if to_degree:
        return np.rad2deg(rad)
    else:
        return rad

## Types of Unitary Blocks

In [None]:
def U2BS(dim, m, n, phi, theta, use_sym=False, Lp=1, Lc=1):
    assert m < n < dim
    if use_sym:
        mat = sp.eye(dim)
        mat[m, m] = sp.sqrt(Lp) * sp.exp(sp.I * phi) * sp.cos(theta)
        mat[m, n] = sp.sqrt(Lc) * sp.I * sp.sin(theta)
        mat[n, m] = sp.sqrt(Lc) * sp.I * sp.exp(sp.I * phi) * sp.sin(theta)
        mat[n, n] = sp.sqrt(Lp) * sp.cos(theta)
    else:
        mat = np.eye(dim, dtype=np.complex128)
        mat[m, m] = np.sqrt(Lp) * np.exp(1j * phi) * np.cos(theta)
        mat[m, n] = np.sqrt(Lc) * 1j * np.sin(theta)
        mat[n, m] = np.sqrt(Lc) * 1j * np.exp(1j * phi) * np.sin(theta)
        mat[n, n] = np.sqrt(Lp) * np.cos(theta)
    return mat


def U2MZI(dim, m, n, phi, theta, use_sym=False, Lp=1, Lc=1):
    assert m < n < dim
    if use_sym:
        mat = sp.eye(dim)
        mat[m, m] = sp.sqrt(Lp) * sp.I * sp.exp(sp.I * phi) * sp.sin(theta)
        mat[m, n] = sp.sqrt(Lc) * sp.I * sp.cos(theta)
        mat[n, m] = sp.sqrt(Lc) * sp.I * sp.exp(sp.I * phi) * sp.cos(theta)
        mat[n, n] = -sp.sqrt(Lp) * sp.I * sp.sin(theta)
    else:
        mat = np.eye(dim, dtype=np.complex128)
        mat[m, m] = np.sqrt(Lp) * 1j *np.exp(1j * phi) * np.sin(theta)
        mat[m, n] = np.sqrt(Lc) * 1j * np.cos(theta)
        mat[n, m] = np.sqrt(Lc) * 1j * np.exp(1j * phi) * np.cos(theta)
        mat[n, n] = -np.sqrt(Lp) * 1j * np.sin(theta)
    return mat

## Decomposition Methods

### Reck's Encoding

In [None]:
def decompose_reck(u, block='bs'):
    assert isinstance(u, np.ndarray)
    assert isinstance(block, str) and block.strip().lower() in ['bs', 'mzi']
    if len(u.shape) != 2:
        raise ValueError("U(N) should be 2-dimension matrix.")
        
    if u.shape[0] != u.shape[1]:
        raise ValueError("U(N) should be a square matrix.")
        
    mat = u.copy().astype(np.complex128)
    dim = mat.shape[0]
    num = int(dim * (dim - 1) / 2)
    phis = np.zeros(num)
    thetas = np.zeros(num)
    alphas = np.zeros(dim)
    index = 0
    for p in range(1, dim):
        x = dim - p
        for q in range(dim-p, 0, -1):
            y = q - 1
            if block == 'bs':
                thetas[index] = atan2f(np.abs(mat[x,y]), np.abs(mat[x,x]))
                phis[index] = angle_diff(mat[x,x], mat[x,y], offset=-np.pi/2)
                U2block = U2BS
            elif block == 'mzi':
                thetas[index] = np.pi/2 - atan2f(np.abs(mat[x,y]), np.abs(mat[x,x]))
                phis[index] = angle_diff(mat[x,x], mat[x,y], offset=np.pi)
                U2block = U2MZI
            mat = mat @ U2block(dim, y, x, phis[index], thetas[index]).conj().T
            index += 1
    for i in range(dim):
        alphas[i] = np.angle(mat[i, i])
    return phis, thetas, alphas


def reconstruct_reck(phis, thetas, alphas, block='bs', Lp_dB=0, Lc_dB=0):
    assert len(phis.squeeze().shape) == 1
    assert len(thetas.squeeze().shape) == 1
    assert len(alphas.squeeze().shape) == 1
    assert phis.squeeze().shape[0] == thetas.squeeze().shape[0]
    assert isinstance(block, str) and block.strip().lower() in ['bs', 'mzi']
    if block == 'bs':
        U2block = U2BS
    elif block == 'mzi':
        U2block = U2MZI
    
    num = thetas.squeeze().shape[0]
    dim = int((1 + np.sqrt(1 + 8 * num))/ 2)
    assert alphas.squeeze().shape[0] == dim
    
    Lp = 10 ** (Lp_dB / 10)
    Lc = 10 ** (Lc_dB / 10)
    
    mat = np.diag(np.exp(1j * alphas))
    index = num
    for p in range(1, dim):
        for q in range(p):
            index -= 1
            mat = mat @ U2block(dim, q, p, phis[index], thetas[index], Lp=Lp, Lc=Lc)
    return mat

### Clements' Encoding

In [None]:
def decompose_clements(u, block='bs'):
    assert isinstance(u, np.ndarray)
    assert isinstance(block, str) and block.strip().lower() in ['bs', 'mzi']
    if len(u.shape) != 2:
        raise ValueError("U(N) should be 2-dimension matrix.")
        
    if u.shape[0] != u.shape[1]:
        raise ValueError("U(N) should be a square matrix.")
        
    mat = u.copy().astype(np.complex128)
    dim = mat.shape[0]
    
    row = dim - 1
    col = int(np.ceil(dim / 2))
    
    cnt_fore = np.zeros(row, dtype=int)
    cnt_back = np.ones(row, dtype=int) * (col - 1)
    if dim % 2 == 1:
        cnt_back[1::2] = col - 2
    
    phis = np.zeros((row, col))
    thetas = np.zeros((row, col))
    alphas = np.zeros(dim)
    
    for p in range(dim-1):
        for q in range(p+1):
            if p % 2 == 0:
                x = dim - 1 - q
                y = p - q
                if block == 'bs':
                    theta = atan2f(np.abs(mat[x,y]), np.abs(mat[x,y+1]))
                    phi = angle_diff(mat[x,y+1], mat[x,y], offset=-np.pi/2)
                    U2block = U2BS
                elif block == 'mzi':
                    theta = np.pi/2 - atan2f(np.abs(mat[x,y]), np.abs(mat[x,y+1]))
                    phi = angle_diff(mat[x,y+1], mat[x,y], offset=np.pi)
                    U2block = U2MZI
                mat = mat @ U2block(dim, y, y+1, phi, theta).conj().T
                thetas[y, cnt_fore[y]] = theta
                phis[y, cnt_fore[y]] = phi
                cnt_fore[y] += 1
            else:
                x = dim - 1 - p + q
                y = q
                if block == 'bs':
                    theta = atan2f(np.abs(mat[x,y]), np.abs(mat[x-1,y]))
                    phi = angle_diff(mat[x-1,y], mat[x,y], offset=np.pi/2)
                    U2block = U2BS
                elif block == 'mzi':
                    theta = np.pi/2 - atan2f(np.abs(mat[x,y]), np.abs(mat[x-1,y]))
                    phi = angle_diff(mat[x-1,y], mat[x,y], offset=0)
                    U2block = U2MZI
                mat = U2block(dim, x-1, x, phi, theta) @ mat
                thetas[x-1, cnt_back[x-1]] = theta
                phis[x-1, cnt_back[x-1]] = phi
                cnt_back[x-1] -= 1
    for p in range(dim-2, -1, -1):
        for q in range(p, -1, -1):
            if p % 2 == 0:
                continue
            x = dim - 1 - p + q
            y = q
            cnt_back[x-1] += 1
            theta = thetas[x-1, cnt_back[x-1]]
            phi = phis[x-1, cnt_back[x-1]]
            eta1 = mat[x-1, x-1]
            eta2 = mat[x, x]
            if block == 'bs':
                phi_new = angle_diff(eta2, -eta1, offset=0)
                mat[x-1, x-1] = eta1 * np.exp(-1j * (phi+phi_new))
            elif block == 'mzi':
                phi_new = angle_diff(eta2, eta1, offset=0)
                mat[x-1, x-1] = -eta1 * np.exp(-1j * (phi+phi_new))
                mat[x, x] = -eta2
            phis[x-1, cnt_back[x-1]] = phi_new
    for i in range(dim):
        alphas[i] = np.angle(mat[i, i])
    return phis, thetas, alphas


def reconstruct_clements(phis, thetas, alphas, block='bs', Lp_dB=0, Lc_dB=0):
    assert len(phis.squeeze().shape) == 2
    assert len(thetas.squeeze().shape) == 2
    assert len(alphas.squeeze().shape) == 1
    assert phis.squeeze().shape == thetas.squeeze().shape
    assert isinstance(block, str) and block.strip().lower() in ['bs', 'mzi']
    
    if block == 'bs':
        U2block = U2BS
    elif block == 'mzi':
        U2block = U2MZI
    
    row, col = thetas.squeeze().shape
    dim = row + 1
    num = int(dim * (dim - 1) / 2) 
    assert alphas.squeeze().shape[0] == dim
    
    Lp = 10 ** (Lp_dB / 10)
    Lc = 10 ** (Lc_dB / 10)
    
    sft = np.diag(np.exp(1j * alphas))
    mat = np.eye(dim)
    for p in range(col):
        for q in range(0, row, 2):
            mat = U2block(dim, q, q+1, phis[q,p], thetas[q,p], Lp=Lp, Lc=Lc) @ mat
        if p >= col - 1 and dim % 2 == 1:
            continue
        for q in range(1, row, 2):
            mat = U2block(dim, q, q+1, phis[q,p], thetas[q,p], Lp=Lp, Lc=Lc) @ mat
    mat = sft @ mat
    return mat

### Unit Test

In [None]:
# Parameter To Test
BOUND = 100
MAT_ROW = 40
MAT_COL = 50
print('==== Unit Test ====')

# Singular Value Decomposition
mat = np.random.randint(-np.abs(BOUND), np.abs(BOUND), (MAT_ROW, MAT_COL))
[u, s, v] = np.linalg.svd(mat, full_matrices=True)

# Recovery from SVD
print(f'Recovery from SVD:\t {np.allclose(mat, u[:, :MAT_COL] @ np.diag(s) @ v[:MAT_ROW, :])}')

# Reck BS
[p, t, a] = decompose_reck(u, block='bs')
reck_test = reconstruct_reck(p, t, a, block='bs')
print(f'Reck[BS] Test:\t\t {np.allclose(reck_test, u)}')

# Reck MZI
[p, t, a] = decompose_reck(u, block='mzi')
reck_test = reconstruct_reck(p, t, a, block='mzi')
print(f'Reck[MZI] Test:\t\t {np.allclose(reck_test, u)}')

# Clements BS
[p, t, a] = decompose_clements(u, block='bs')
clements_test = reconstruct_clements(p, t, a, block='bs')
print(f'Clements[BS] Test:\t {np.allclose(clements_test, u)}')

# Clements MZI
[p, t, a] = decompose_clements(u, block='mzi')
clements_test = reconstruct_clements(p, t, a, block='mzi')
print(f'Clements[MZI] Test:\t {np.allclose(clements_test, u)}')

# Photonic Neural Networks

## Lossy and Crosstalk-aware Unitary Blocks

### Symbolic-enabled Unitary Blocks in consideration of Loss and Crosstalk

In [None]:
def UB2BS(dim, m, n, phi, theta, E_in, P_x, use_sym=False, Lp=1, Lc=1, K1=0, K2=0):
    assert m < n < dim
    
    if use_sym:
        E_signal = sp.eye(2)
        E_signal[0, 0] = sp.sqrt(Lp) * sp.exp(sp.I * phi) * sp.cos(theta)
        E_signal[0, 1] = sp.sqrt(Lc) * sp.I * sp.sin(theta)
        E_signal[1, 0] = sp.sqrt(Lc) * sp.I * sp.exp(sp.I * phi) * sp.sin(theta)
        E_signal[1, 1] = sp.sqrt(Lp) * sp.cos(theta)
        E_port = E_signal @ E_in.extract([m, n], [0])
        E_out = E_in.copy()
        E_out[m] = E_port[0]
        E_out[n] = E_port[1]
        
        P_crosstalk = sp.eye(2)
        P_crosstalk[0, 0] = K1 * (sp.sin(theta) ** 2)
        P_crosstalk[0, 1] = K2 * (sp.cos(theta) ** 2)
        P_crosstalk[1, 0] = K2 * (sp.cos(theta) ** 2)
        P_crosstalk[1, 1] = K1 * (sp.sin(theta) ** 2)
        P_port = P_crosstalk @ (sp.Abs(E_in.extract([m, n], [0])).applyfunc(lambda x: x ** 2) + P_x.extract([m, n], [0]))
        P_port += (sp.Abs(E_signal).applyfunc(lambda x: x ** 2)) @ P_x.extract([m, n], [0])
        P_out = P_x.copy()
        P_out[m] = P_port[0]
        P_out[n] = P_port[1]
        
    else:
        assert 0 <= Lp <= 1 and 0 <= Lc <= 1
        assert 0 <= K1 <= 1 and 0 <= K2 <= 1
        assert np.ndim(E_in.squeeze()) == 1 and np.ndim(P_x.squeeze()) == 1
        assert E_in.squeeze().shape[0] == dim and P_x.squeeze().shape[0] == dim
        
        E_signal = np.eye(2, dtype=np.complex128)
        E_signal[0, 0] = np.sqrt(Lp) * np.exp(1j * phi) * np.cos(theta)
        E_signal[0, 1] = np.sqrt(Lc) * 1j * np.sin(theta)
        E_signal[1, 0] = np.sqrt(Lc) * 1j * np.exp(1j * phi) * np.sin(theta)
        E_signal[1, 1] = np.sqrt(Lp) * np.cos(theta)
        E_port = E_signal @ E_in.squeeze()[np.ix_([m, n])]
        E_out = E_in.squeeze().copy()
        E_out[m] = E_port[0]
        E_out[n] = E_port[1]
        E_out = E_out.reshape(E_in.shape)
        
        P_crosstalk = np.eye(2)
        P_crosstalk[0, 0] = K1 * (np.sin(theta) ** 2)
        P_crosstalk[0, 1] = K2 * (np.cos(theta) ** 2)
        P_crosstalk[1, 0] = K2 * (np.cos(theta) ** 2)
        P_crosstalk[1, 1] = K1 * (np.sin(theta) ** 2)
        P_port = P_crosstalk @ (np.square(np.abs(E_in.squeeze()[np.ix_([m,n])])) + P_x.squeeze()[np.ix_([m,n])])
        P_out = P_x.squeeze().copy()
        P_out[m] = P_port[0]
        P_out[n] = P_port[1]
        P_out = P_out.reshape(P_x.shape)
        
    return E_out, P_out


def UB2MZI(dim, m, n, phi, theta, E_in, P_x, use_sym=False, Lp=1, Lc=1, K1=0, K2=0):
    assert m < n < dim

    if use_sym:
        E_signal = sp.eye(2)
        E_signal[0, 0] = sp.sqrt(Lp) * sp.I * sp.exp(sp.I * phi) * sp.sin(theta)
        E_signal[0, 1] = sp.sqrt(Lc) * sp.I * sp.cos(theta)
        E_signal[1, 0] = sp.sqrt(Lc) * sp.I * sp.exp(sp.I * phi) * sp.cos(theta)
        E_signal[1, 1] = -sp.sqrt(Lp) * sp.I * sp.sin(theta)
        E_port = E_signal @ E_in.extract([m, n], [0])
        E_out = E_in.copy()
        E_out[m] = E_port[0]
        E_out[n] = E_port[1]
        
        P_crosstalk = sp.eye(2)
        P_crosstalk[0, 0] = K1 * (sp.cos(theta) ** 2)
        P_crosstalk[0, 1] = K2 * (sp.sin(theta) ** 2)
        P_crosstalk[1, 0] = K2 * (sp.sin(theta) ** 2)
        P_crosstalk[1, 1] = K1 * (sp.cos(theta) ** 2)
        P_port = P_crosstalk @ (sp.Abs(E_in.extract([m, n], [0])).applyfunc(lambda x: x ** 2) + P_x.extract([m, n], [0]))
        P_port += (sp.Abs(E_signal).applyfunc(lambda x: x ** 2)) @ P_x.extract([m, n], [0])
        P_out = P_x.copy()
        P_out[m] = P_port[0]
        P_out[n] = P_port[1]
        
    else:
        assert 0 <= Lp <= 1 and 0 <= Lc <= 1
        assert 0 <= K1 <= 1 and 0 <= K2 <= 1
        assert np.ndim(E_in.squeeze()) == 1 and np.ndim(P_x.squeeze()) == 1
        assert E_in.squeeze().shape[0] == dim and P_x.squeeze().shape[0] == dim
        
        E_signal = np.eye(2, dtype=np.complex128)
        E_signal[0, 0] = np.sqrt(Lp) * 1j *np.exp(1j * phi) * np.sin(theta)
        E_signal[0, 1] = np.sqrt(Lc) * 1j * np.cos(theta)
        E_signal[1, 0] = np.sqrt(Lc) * 1j * np.exp(1j * phi) * np.cos(theta)
        E_signal[1, 1] = -np.sqrt(Lp) * 1j * np.sin(theta)
        E_port = E_signal @ E_in.squeeze()[np.ix_([m, n])]
        E_out = E_in.squeeze().copy()
        E_out[m] = E_port[0]
        E_out[n] = E_port[1]
        E_out = E_out.reshape(E_in.shape)
        
        P_crosstalk = np.eye(2)
        P_crosstalk[0, 0] = K1 * (np.cos(theta) ** 2)
        P_crosstalk[0, 1] = K2 * (np.sin(theta) ** 2)
        P_crosstalk[1, 0] = K2 * (np.sin(theta) ** 2)
        P_crosstalk[1, 1] = K1 * (np.cos(theta) ** 2)
        P_port = P_crosstalk @ (np.square(np.abs(E_in.squeeze()[np.ix_([m,n])])) + P_x.squeeze()[np.ix_([m,n])])
        P_port += np.square(np.abs(E_signal)) @ P_x.squeeze()[np.ix_([m,n])]
        P_out = P_x.squeeze().copy()
        P_out[m] = P_port[0]
        P_out[n] = P_port[1]
        P_out = P_out.reshape(P_x.shape)
        
    return E_out, P_out

### Override Reconstruction Functions

In [None]:
def reconstruct_reck_pnn(phis, thetas, alphas, E_in, P_x, block='bs', Lp_dB=0, Lc_dB=0, K1_dB=-10000, K2_dB=-10000):
    assert len(phis.squeeze().shape) == 1
    assert len(thetas.squeeze().shape) == 1
    assert len(alphas.squeeze().shape) == 1
    assert phis.squeeze().shape[0] == thetas.squeeze().shape[0]
    assert isinstance(block, str) and block.strip().lower() in ['bs', 'mzi']
    if block == 'bs':
        U2block = UB2BS
    elif block == 'mzi':
        U2block = UB2MZI
    
    num = thetas.squeeze().shape[0]
    dim = int((1 + np.sqrt(1 + 8 * num))/ 2)
    assert alphas.squeeze().shape[0] == dim
    
    Lp = 10 ** (Lp_dB / 10)
    Lc = 10 ** (Lc_dB / 10)
    K1 = 10 ** (K1_dB / 10)
    K2 = 10 ** (K2_dB / 10)
    
    E_signal = E_in.copy()
    P_crosstalk = P_x.copy()
    
    sft = np.diag(np.exp(1j * alphas))
    index = 0
    for p in range(dim-1, 0, -1):
        for q in range(p-1, -1, -1):
            E_signal, P_crosstalk = U2block(dim, q, p, phis[index], thetas[index], E_signal, P_crosstalk, Lp=Lp, Lc=Lc, K1=K1, K2=K2)
            index += 1
            
    E_signal = sft @ E_signal
    return E_signal, P_crosstalk


def reconstruct_clements_pnn(phis, thetas, alphas, E_in, P_x, block='bs', Lp_dB=0, Lc_dB=0, K1_dB=-10000, K2_dB=-10000):
    assert len(phis.squeeze().shape) == 2
    assert len(thetas.squeeze().shape) == 2
    assert len(alphas.squeeze().shape) == 1
    assert phis.squeeze().shape == thetas.squeeze().shape
    assert isinstance(block, str) and block.strip().lower() in ['bs', 'mzi']
    
    if block == 'bs':
        U2block = UB2BS
    elif block == 'mzi':
        U2block = UB2MZI
    
    row, col = thetas.squeeze().shape
    dim = row + 1
    num = int(dim * (dim - 1) / 2) 
    assert alphas.squeeze().shape[0] == dim
    
    Lp = 10 ** (Lp_dB / 10)
    Lc = 10 ** (Lc_dB / 10)
    K1 = 10 ** (K1_dB / 10)
    K2 = 10 ** (K2_dB / 10)
    
    E_signal = E_in.copy()
    P_crosstalk = P_x.copy()
    
    sft = np.diag(np.exp(1j * alphas))
    for p in range(col):
        for q in range(0, row, 2):
            E_signal, P_crosstalk = U2block(dim, q, q+1, phis[q,p], thetas[q,p], E_signal, P_crosstalk, Lp=Lp, Lc=Lc, K1=K1, K2=K2)
        if p >= col - 1 and dim % 2 == 1:
            continue
        for q in range(1, row, 2):
            E_signal, P_crosstalk = U2block(dim, q, q+1, phis[q,p], thetas[q,p], E_signal, P_crosstalk, Lp=Lp, Lc=Lc, K1=K1, K2=K2)
    E_signal = sft @ E_signal
    return E_signal, P_crosstalk

### Unit Test for Lossless Unitary Blocks

In [None]:
# Obtain Weight Parameters
pnn_mat = model.classifier[0].weight.detach().numpy().T
[u, s, v] = np.linalg.svd(pnn_mat, full_matrices=True)
print('==== Unit Test ====')

# Ideal Lossless Unitary Blocks (Reck-BS)
[p, t, a] = decompose_reck(u, block='bs')
[E_signal, P_crosstalk] = reconstruct_reck_pnn(p, t, a, 
                     np.ones((120, 1), dtype=np.complex128), 
                     np.zeros((120, 1), dtype=np.complex128),
                     block='bs')
print(f'Reck[BS] Test:\t\t{np.allclose(u @ np.ones((120, 1), dtype=np.complex128), E_signal)}')

# Ideal Lossless Unitary Blocks (Reck-MZI)
[p, t, a] = decompose_reck(u, block='mzi')
[E_signal, P_crosstalk] = reconstruct_reck_pnn(p, t, a, 
                     np.ones((120, 1), dtype=np.complex128), 
                     np.zeros((120, 1), dtype=np.complex128),
                     block='mzi')
print(f'Reck[MZI] Test:\t\t{np.allclose(u @ np.ones((120, 1), dtype=np.complex128), E_signal)}')

# Ideal Lossless Unitary Blocks (Clements-BS)
[p, t, a] = decompose_clements(u, block='bs')
[E_signal, P_crosstalk] = reconstruct_clements_pnn(p, t, a, 
                     np.ones((120, 1), dtype=np.complex128), 
                     np.zeros((120, 1), dtype=np.complex128),
                     block='bs')
print(f'Clements[BS] Test:\t{np.allclose(u @ np.ones((120, 1), dtype=np.complex128), E_signal)}')

# Ideal Lossless Unitary Blocks (Clements-MZI)
[p, t, a] = decompose_clements(u, block='mzi')
[E_signal, P_crosstalk] = reconstruct_clements_pnn(p, t, a, 
                     np.ones((120, 1), dtype=np.complex128), 
                     np.zeros((120, 1), dtype=np.complex128),
                     block='mzi')
print(f'Clements[MZI] Test:\t{np.allclose(u @ np.ones((120, 1), dtype=np.complex128), E_signal)}')

### Unit Test for Symbolic Derivation

In [None]:
from IPython.display import display

t = sp.Symbol('theta')
p = sp.Symbol('phi')

k = sp.Symbol('K')
Lc = sp.Symbol('L_c')
Lp = sp.Symbol('L_p')

E1 = sp.Symbol('E_1')
E2 = sp.Symbol('E_2')
E3 = sp.Symbol('E_3')

Px1 = sp.Symbol('P_x1')
Px2 = sp.Symbol('P_x2')
Px3 = sp.Symbol('P_x3')

Ei = sp.Matrix([E1, E2, E3])
Pi = sp.Matrix([Px1, Px2, Px3])

[sym_signal, sym_crosstalk] = UB2MZI(3, 0, 1, p, t, Ei, Pi, use_sym=True, Lp=Lp, Lc=Lc, K1=k, K2=k)

print('\nLoss = ')
display(sym_signal)

print('\nCrosstalk = ')
display(sym_crosstalk)

## Theoretical Analysis

### Worst-case Mode-wise SNR

In [None]:
# System Configurations and Parameters
NUM_MODE = 101
Lp_dB = -0.01
Lc_dB = -0.1
ER_dB = [20, 25, 30, 35, 40]

# Buffers to store the data for drawings
signals = np.zeros((NUM_MODE, 1))
noises = np.zeros((NUM_MODE, 1))

# Setup for Matplotlib
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.weight"] = "normal"
plt.rcParams["font.size"] = 12
plt.figure(dpi=600)
x = range(3, NUM_MODE)

# Sweep Parameters
for er in ER_dB:
    K_dB = -er
    for i in range(3, NUM_MODE):
        mat = np.fliplr(np.eye(i))
        [u, s, v] = np.linalg.svd(mat, full_matrices=True)
        [p, t, a] = decompose_clements(u, block='mzi')
        [E_signal, P_crosstalk] = reconstruct_clements_pnn(p, t, a, 
                         np.ones((i, 1), dtype=np.complex128), 
                         np.zeros((i, 1), dtype=np.complex128),
                         block='mzi', Lp_dB=Lp_dB, Lc_dB=Lc_dB, K1_dB=K_dB, K2_dB=K_dB)
        signals[i] = np.abs(np.sum(np.square(E_signal).squeeze()))
        noises[i] = np.abs(np.sum(P_crosstalk.squeeze()))
        
    y1 = signals[3:NUM_MODE]
    y2 = noises[3:NUM_MODE]
    snr = y1 / y2
    snr_dB = 10 * np.log10(snr)
    plt.plot(x, snr_dB, '^', markersize=2)
    
# Annotate Setup
baseline = np.ones_like(x) * 10
plt.plot(x, baseline, '--')
    
plt.xlabel('Number of modes', fontsize=14)
plt.ylabel('Worst-case Mode-wise SNR (dB)', fontsize=14)

notes = [f"K = -{e}dB" for e in ER_dB]
notes.append('Min-Required')
plt.legend(notes, ncol=2, markerscale=2.0)
plt.show()

### Comparison with Fidelity

In [None]:
# System Configurations and Parameters
NUM_MODE = 101
Lp_dB = -0.01
Lc_dB = -0.1
COUNT = 10 #Setting 100 will Need 1.89 Hours to run

# Buffers to store the data for drawings
fidelity = np.zeros((NUM_MODE, 1))

# Setup for Matplotlib
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.weight"] = "normal"
plt.rcParams["font.size"] = 12
plt.figure()
x = range(3, NUM_MODE)

# Sweep Parameters [Reck]
for i in range(3, NUM_MODE):
    statistic = 0.0
    for _ in range(COUNT):
        mat = np.random.randn(i, i)
        [u, s, v] = np.linalg.svd(mat, full_matrices=True)
        [p, t, a] = decompose_reck(u, block='mzi')
        Ue = reconstruct_reck(p, t, a, block='mzi', Lp_dB=Lp_dB, Lc_dB=Lc_dB)
        statistic += np.abs(np.trace(u.conj().T @ Ue) / np.sqrt(i * np.trace(Ue.conj().T @ Ue))) ** 2
    fidelity[i] = statistic / COUNT

# Annotate Setup
y = fidelity[3:NUM_MODE]
plt.plot(x, y, 's', markersize=2)

# Sweep Parameters [Clements]
for i in range(3, NUM_MODE):
    statistic = 0.0
    for _ in range(COUNT):
        mat = np.random.randn(i, i)
        [u, s, v] = np.linalg.svd(mat, full_matrices=True)
        [p, t, a] = decompose_clements(u, block='mzi')
        Ue = reconstruct_clements(p, t, a, block='mzi', Lp_dB=Lp_dB, Lc_dB=Lc_dB)
        statistic += np.abs(np.trace(u.conj().T @ Ue) / np.sqrt(i * np.trace(Ue.conj().T @ Ue))) ** 2
    fidelity[i] = statistic / COUNT

# Annotate Setup
y = fidelity[3:NUM_MODE]
plt.plot(x, y, '^', markersize=2)

plt.xlabel('Number of modes', fontsize=14)
plt.ylabel('Fidelity', fontsize=14)
plt.legend(['Reck', 'Clements'])

plt.show()

## Experimental Verification

### Photo-Detectors & Attenuators & Nonlinear-Activations

In [None]:
def photo_detector(E_signal, P_crosstalk, 
                   dynamic_ranges=[-16, 16], precision_bitdepth=16):
    
    nums = 2**precision_bitdepth
    bins = np.linspace(*dynamic_ranges, num=nums)
    amps = E_signal + np.sqrt(P_crosstalk) * np.exp(1j * np.random.random() * 2 * np.pi)
    inds = np.digitize(np.real(amps), bins)
    inds[inds>=nums] = nums - 1
    return bins[inds]
    

def diagonal_attenuator(E_signal, diagonal_factor, num_output=None):
    if E_signal.size >= diagonal_factor.size:
        vec = E_signal.squeeze()[:diagonal_factor.size] * diagonal_factor
    else:
        vec = E_signal.squeeze() * diagonal_factor[:E_signal.size]
        
    vec = vec.astype(np.complex128)
    
    if num_output is not None:
        if vec.size >= num_output:
            return vec[:num_output][:, np.newaxis]
        else:
            padding = np.zeros(num_output, dtype=np.complex128)
            padding[:vec.size] = vec
            return padding[:, np.newaxis]
    else:
        return vec[:, np.newaxis]


def nonlinear_tanh(x):
    return np.tanh(x).astype(np.complex128)


def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))

### Weight Extraction from LeNet-5 

In [None]:
pnn_mat1 = model.classifier[0].weight.detach().numpy().T
[u1, s1, v1] = np.linalg.svd(pnn_mat1, full_matrices=True)
[p1u, t1u, a1u] = decompose_clements(u1.conj().T, block='mzi')
[p1v, t1v, a1v] = decompose_clements(v1.conj().T, block='mzi')

pnn_mat2 = model.classifier[2].weight.detach().numpy().T
[u2, s2, v2] = np.linalg.svd(pnn_mat2, full_matrices=True)
[p2u, t2u, a2u] = decompose_clements(u2.conj().T, block='mzi')
[p2v, t2v, a2v] = decompose_clements(v2.conj().T, block='mzi')

### One Iteration of Inference

In [None]:
# Parameters
Lp_dB = -0.01
Lc_dB = -0.1
K1_dB = -25
K2_dB = -25

X_samples = torch.ones(1,1,32,32)
Z_features = model.feature_extractor(X_samples).detach().numpy()
pnn_inputs = Z_features.squeeze()[:,np.newaxis].astype(np.complex128)

# Mat 1 - U
[E_sig, P_xtalk] = reconstruct_clements_pnn(p1u, t1u, a1u, 
                         pnn_inputs,
                         np.zeros_like(pnn_inputs, dtype=np.complex128),
                         block='mzi',
                         Lp_dB=Lp_dB, Lc_dB=Lc_dB, K1_dB=K1_dB, K2_dB=K2_dB)
E_sig = photo_detector(E_sig, P_xtalk)

# Mat 1 - S
E_sig = diagonal_attenuator(E_sig, s1, num_output=v1.shape[0])

# Mat 1 - V
[E_sig, P_xtalk] = reconstruct_clements_pnn(p1v, t1v, a1v, 
                         E_sig,
                         np.zeros_like(E_sig, dtype=np.complex128),
                         block='mzi',
                         Lp_dB=Lp_dB, Lc_dB=Lc_dB, K1_dB=K1_dB, K2_dB=K2_dB)
E_sig = photo_detector(E_sig, P_xtalk)

# Activation - Tanh
E_sig = nonlinear_tanh(E_sig)

# Mat 2 - U
[E_sig, P_xtalk] = reconstruct_clements_pnn(p2u, t2u, a2u, 
                         E_sig,
                         np.zeros_like(E_sig, dtype=np.complex128),
                         block='mzi',
                         Lp_dB=Lp_dB, Lc_dB=Lc_dB, K1_dB=K1_dB, K2_dB=K2_dB)
E_sig = photo_detector(E_sig, P_xtalk)

# Mat 2 - S
E_sig = diagonal_attenuator(E_sig, s2, num_output=v2.shape[0])

# Mat 2 - V
[E_sig, P_xtalk] = reconstruct_clements_pnn(p2v, t2v, a2v, 
                         E_sig,
                         np.zeros_like(E_sig, dtype=np.complex128),
                         block='mzi',
                         Lp_dB=Lp_dB, Lc_dB=Lc_dB, K1_dB=K1_dB, K2_dB=K2_dB)
E_sig = photo_detector(E_sig, P_xtalk)

# Softmax (prob: probability, pred: prediction)
prob = softmax(E_sig.astype(np.double))
pred = prob.argmax()

### Realization for PNN Class

In [None]:
class PNN:
    def __init__(self, matrix, Lp_dB=0, Lc_dB=0, K1_dB=-10000, K2_dB=-10000):
        [self.u, self.s, self.v] = np.linalg.svd(matrix, full_matrices=True)
        [self.p_u, self.t_u, self.a_u] = decompose_clements(self.u.conj().T, block='mzi')
        [self.p_v, self.t_v, self.a_v] = decompose_clements(self.v.conj().T, block='mzi')
        self.Lp_dB, self.Lc_dB = Lp_dB, Lc_dB
        self.K1_dB, self.K2_dB = K1_dB, K2_dB
        
    def __call__(self, inputs):
        pnn_inputs = inputs.squeeze()[:,np.newaxis].astype(np.complex128)
        [E_sig, P_xtalk] = reconstruct_clements_pnn(self.p_u, self.t_u, self.a_u, 
                                 pnn_inputs,
                                 np.zeros_like(pnn_inputs, dtype=np.complex128),
                                 block='mzi',
                                 Lp_dB=self.Lp_dB, Lc_dB=self.Lc_dB, K1_dB=self.K1_dB, K2_dB=self.K2_dB)
        E_sig = photo_detector(E_sig, P_xtalk)
        
        E_sig = diagonal_attenuator(E_sig, self.s, num_output=self.v.shape[0])
        P_xtalk = diagonal_attenuator(P_xtalk, self.s, num_output=self.v.shape[0])
        
        [E_sig, P_xtalk] = reconstruct_clements_pnn(self.p_v, self.t_v, self.a_v,
                                 E_sig,
                                 P_xtalk,
                                 block='mzi',
                                 Lp_dB=self.Lp_dB, Lc_dB=self.Lc_dB, K1_dB=self.K1_dB, K2_dB=self.K2_dB)
        E_sig = photo_detector(E_sig, P_xtalk)
        return E_sig

### Unit Test for PNN class

In [None]:
MAT_ROW = 20
MAT_COL = 20
REL_TOL = 1e-4
ABS_TOL = 1e-3

mat_test = np.random.rand(MAT_ROW, MAT_COL)
inputs = np.random.rand(MAT_ROW, 1)
pnn_test = PNN(mat_test)

print(f'PNN Test:\t{np.allclose(pnn_test(inputs), mat_test.T @ inputs, rtol=REL_TOL, atol=ABS_TOL)}\n')
print('[INFO] if OUTPUT is False, it might be caused by the dynamic_ranges of photo_detector.')
print(f'\n* PNN Output = \n{pnn_test(inputs)}\n\n* Reference Output = \n{mat_test.T @ inputs}\n')

### Batch Evaluation

In [None]:
def pnn_evaluate(data_loader, model, device, pnn1, pnn2):
    
    model.eval()
    cnt = 0
    num_total = 0
    num_right = 0
    
    for X, Y_true in data_loader:
        cnt += 1
        X = X.to(device)
        Y_true = Y_true.to(device)
        
        Z_features = model.feature_extractor(X).detach().numpy()
        _, Y_prob = model(X)
        Y_pred = Y_prob.argmax(axis=1).detach().numpy()
        
        for i in range(len(Z_features)):
            pnn_inputs = Z_features[i].squeeze()[:,np.newaxis]
            Z = pnn1(pnn_inputs)
            Z = nonlinear_tanh(Z)
            Z = pnn2(Z)
            prob = softmax(Z)
            pred = prob.argmax()
            
            # print(f'Y_true = {Y_true[i]}, Y_pred = {Y_pred[i]}, PNN_pred = {pred}')
            
            if Y_true[i] == pred:
                num_right += 1
            num_total += 1
        
        print(f'[INFO] Progress: {cnt} / {len(data_loader)}; Accumulate: Correct: {num_right}, Total: {num_total}')
        
    return num_right, num_total

In [None]:
pnn1 = PNN(pnn_mat1)
pnn2 = PNN(pnn_mat2)

num_true, num_all = pnn_evaluate(valid_loader, model, DEVICE, pnn1, pnn2)
print('[SUMMARY] Test Accuracy:\t {} / {} = {:.3f}'.format(num_true, num_all, num_true/num_all*100))

### Sweep Parameters to Investigate the Impacts of Crosstalk on PNN Inference

In [None]:
# Coefficient Setup
crosstalk_coefficients = [-40, -35, -30, -25, -20]
Lp_dB = -0.01
Lc_dB = [-0.1, -0.2, -0.3]

test_accuracies = np.zeros((len(Lc_dB), len(crosstalk_coefficients)), dtype=np.double)

# Sweep Parameters (Loss-Only)
pnn1 = PNN(pnn_mat1, Lp_dB=Lp_dB, Lc_dB=Lc_dB[0])
pnn2 = PNN(pnn_mat2, Lp_dB=Lp_dB, Lc_dB=Lc_dB[0])
num_true, num_all = pnn_evaluate(valid_loader, model, DEVICE, pnn1, pnn2)
acc_lossonly = num_true / num_all * 100.0

# Sweep Parameters (Crosstalk-aware)
for i, loss_coef in enumerate(Lc_dB):
    for j, xtalk_coef in enumerate(crosstalk_coefficients):
        pnn1 = PNN(pnn_mat1, Lp_dB=Lp_dB, Lc_dB=loss_coef, K1_dB=xtalk_coef, K2_dB=xtalk_coef)
        pnn2 = PNN(pnn_mat2, Lp_dB=Lp_dB, Lc_dB=loss_coef, K1_dB=xtalk_coef, K2_dB=xtalk_coef)
        num_true, num_all = pnn_evaluate(valid_loader, model, DEVICE, pnn1, pnn2)
        test_accuracies[i, j] = num_true / num_all * 100.0
    
# Drawing Variable Setup
x = crosstalk_coefficients
y = test_accuracies
y_lossonly = np.ones(len(crosstalk_coefficients)) * acc_lossonly

In [None]:
# Annotate Setup
plt.figure(figsize=(6.4, 3))
plt.plot(x, y_lossonly, '-p', color='cornflowerblue')
plt.plot(x, y[0], '-d', color='indigo')
plt.plot(x, y[1], '-d', color='darkorchid')
plt.plot(x, y[2], '-d', color='plum')
plt.xticks(x)
plt.xlabel('Crosstalk Coefficient (dB)')
plt.ylabel('Test Accuracy (%)')
plt.legend(['Loss Only','Lc = -0.1dB', 'Lc = -0.2dB', 'Lc = -0.3dB'])
plt.show()

# Full DNN-Architecture Evaluation

In [None]:
class PNN:
    def __init__(self, matrix, Lp_dB=0, Lc_dB=0, K1_dB=-10000, K2_dB=-10000):
        [self.u, self.s, self.v] = np.linalg.svd(matrix, full_matrices=True)
        [self.p_u, self.t_u, self.a_u] = decompose_clements(self.u.conj().T, block='mzi')
        [self.p_v, self.t_v, self.a_v] = decompose_clements(self.v.conj().T, block='mzi')
        self.Lp_dB, self.Lc_dB = Lp_dB, Lc_dB
        self.K1_dB, self.K2_dB = K1_dB, K2_dB
        
    def __call__(self, inputs, return_crosstalk=False):
        pnn_inputs = inputs.squeeze()[:,np.newaxis].astype(np.complex128)
        [E_sig, P_xtalk] = reconstruct_clements_pnn(self.p_u, self.t_u, self.a_u, 
                                 pnn_inputs,
                                 np.zeros_like(pnn_inputs, dtype=np.complex128),
                                 block='mzi',
                                 Lp_dB=self.Lp_dB, Lc_dB=self.Lc_dB, K1_dB=self.K1_dB, K2_dB=self.K2_dB)
        if not return_crosstalk:
            E_sig = photo_detector(E_sig, P_xtalk)
        
        E_sig = diagonal_attenuator(E_sig, self.s, num_output=self.v.shape[0])
        P_xtalk = diagonal_attenuator(P_xtalk, self.s, num_output=self.v.shape[0])
        
        [E_sig, P_xtalk] = reconstruct_clements_pnn(self.p_v, self.t_v, self.a_v,
                                 E_sig,
                                 P_xtalk,
                                 block='mzi',
                                 Lp_dB=self.Lp_dB, Lc_dB=self.Lc_dB, K1_dB=self.K1_dB, K2_dB=self.K2_dB)
        if not return_crosstalk:
            E_sig = photo_detector(E_sig, P_xtalk)
        
        if return_crosstalk:
            return E_sig, P_xtalk
        else:
            return E_sig

## Measurement of Post-Rectification Factor

In [None]:
Lp_dB = -0.05
Lc_dB = -0.1
K1_dB = -30
K2_dB = -30

pnn_test = PNN(np.eye(120), Lp_dB=Lp_dB, Lc_dB=Lc_dB, K1_dB=K1_dB, K2_dB=K2_dB)
[E_sig, P_xtalk] = pnn_test(pnn_inputs, return_crosstalk=True)

avg_intensity = np.mean(np.abs(E_sig) ** 2 + P_xtalk)
eta_factor = 1 / avg_intensity

print(avg_intensity.real, np.mean(np.abs(E_sig) ** 2).real, np.mean(P_xtalk).real)
print(f'Post Rectification Factor: eta = {np.real(eta_factor)} = {10 ** (np.real(eta_factor)/10)} dB')