In [2]:
import matplotlib.pyplot as plt
plt.style.use('science')
import numpy as np

from commpy.modulation import QAMModem

from optic.dsp import pulseShape, firFilter, decimate, symbolSync, resample
from optic.models import phaseNoise, linFiberCh, KramersKronigRx, photodiode

from optic.tx import simpleWDMTx
from optic.core import parameters
from optic.equalization import edc, mimoAdaptEqualizer
from optic.carrierRecovery import cpr
from optic.metrics import fastBERcalc, monteCarloGMI, monteCarloMI, signal_power
from optic.plot import pconst

import scipy.constants as const
from tqdm.notebook import tqdm

from tensorflow.keras.layers import Dense, BatchNormalization, Conv1DTranspose, Conv1D, Flatten, Add
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow as tf
from tensorflow.keras import backend as K

In [6]:
figurePath = 'C:/Users/optic/Documents/PIVIC-PIBIC-Comunicacoes-Opticas/resultados/Figuras/'
path_data  = r'C:/Users/optic/Documents/PIVIC-PIBIC-Comunicacoes-Opticas/resultados/Data'
path_mlp   = 'C:/Users/optic/Documents/PIVIC-PIBIC-Comunicacoes-Opticas/setups/NN_models/' 
path_conv  = 'C:/Users/optic/Documents/PIVIC-PIBIC-Comunicacoes-Opticas/setups/Conv_models/'

N = 256 # number of input amplitude samples to the NN

## Simulation of a single polarization optical signal transmission

In [14]:
## Transmitter parameters:
paramTx = parameters()
paramTx.M = 16                 # order of the modulation format
paramTx.Rs = 32e9              # symbol rate [baud]
paramTx.SpS = 4                # samples per symbol
paramTx.Nbits = 400000         # total number of bits per polarization
paramTx.pulse = "rrc"          # pulse shaping filter
paramTx.Ntaps = 1024           # number of pulse shaping filter coefficients
#paramTx.alphaRRC = 0.01       # RRC rolloff
paramTx.Pch_dBm = 0            # power of the optical signal [dBm]
paramTx.Nch = 1                # number of WDM channels
paramTx.Fc = 193.1e12          # central frequency of the optical spectrum
paramTx.freqSpac = 37.5e9      # WDM grid spacing

## Optical channel parameters:
Ltotal = 100     # total link distance [km]
alpha = 0        # fiber loss parameter [dB/km]
D = 16           # fiber dispersion parameter [ps/nm/km]
Fc = paramTx.Fc  # central optical frequency of the WDM spectrum [Hz]

## Receiver parameters:

# local oscillator (LO)
FO = paramTx.Rs/2 + 1e9  # frequency offset
lw = 0*200e3       # linewidth
ϕ_lo = 0           # initial phase in rad
Plo_dBm = 12       # power in dBm

# ADC sampling rate
paramADC = parameters()
paramADC.Rs = paramTx.Rs
paramADC.SpS_in = paramTx.SpS
paramADC.SpS_out = 4

## General simulation parameters:
chIndex = 0  # index of the channel to be demodulated
plotPSD = True
Fs = paramTx.Rs * paramTx.SpS  # simulation sampling rate
Ts = 1 / Fs

# photodiode parameters
paramPD = parameters()
paramPD.B  = 1.1*paramTx.Rs
paramPD.Fs = Fs

## Run BER vs rolloff vs FO for all cases

In [None]:
Rolloff = np.arange(0.05, 0.95, 0.05)
FO_Values = np.arange(0e9, 6e9, 1e9)

BER = np.zeros((3, len(Rolloff), len(FO_Values)))
SER = np.zeros((3, len(Rolloff), len(FO_Values))) 
GMI = np.zeros((3, len(Rolloff), len(FO_Values)))
MI  = np.zeros((3, len(Rolloff), len(FO_Values)))
SNR = np.zeros((3, len(Rolloff), len(FO_Values)))
SIR = np.zeros((3, len(Rolloff), len(FO_Values)))

for indAlg, alg in enumerate(['KK', 'MLP', 'CONVNET']):
    for indFO, FOfreq in enumerate(tqdm(FO_Values)):
        for indRolloff, rollOff in enumerate(tqdm(Rolloff)):
            
            paramTx.alphaRRC = rollOff
            FO = paramTx.Rs/2 + FOfreq

             # generate optical signal signal
            sigTx, symbTx_, paramTx = simpleWDMTx(paramTx)

            # simulate linear signal propagation
            sigCh = linFiberCh(sigTx, Ltotal, alpha, D, Fc, Fs)

            symbTx = symbTx_[:, :, chIndex]
            Plo = 10 ** (Plo_dBm / 10) * 1e-3  # power in W

            # generate LO field
            π = np.pi
            t = np.arange(0, len(sigCh))*Ts
            ϕ_pn_lo = phaseNoise(lw, len(sigCh), Ts)

            sigLO = np.sqrt(Plo) * np.exp(-1j * (2 * π * FO * t + ϕ_lo + ϕ_pn_lo))

            # Add LO to the received signal
            sigRx = np.sqrt(Plo) + sigCh* np.exp(1j * (2 * π * FO * t + ϕ_lo + ϕ_pn_lo))
            sfm   = sigRx.copy()

            print('CSPR = %.2f dB'%(10*np.log10(signal_power(sigLO)/signal_power(sigCh))))

            # simulate ideal direct-detection optical receiver
            Ipd = photodiode(sigRx, paramPD)
            Amp = np.sqrt(Ipd.real)
            Amp = resample(Amp, paramADC).real

            # resampling to ADC sampling rate
            sigCh = resample(sigCh, paramADC)
            sfm = resample(sfm, paramADC)
            newFs = paramADC.SpS_out*paramTx.Rs

            sfm = sfm/np.sqrt(signal_power(sfm))

            if alg == 'KK':
                # Kramers-Kronig phase-retrieval
                phiTime = KramersKronigRx(Amp, newFs)
                # optical field reconstruction
                sigRx = Amp*np.exp(1j*phiTime)
            
            elif alg == 'MLP':
                # Mlp phase-retrieval
                model = tf.keras.models.load_model(path_mlp+'testModel_SpS_'+str(paramADC.SpS_out)+'_FO_'+str(FO/10e9)+'GHz_Rolloff_'+str(round(paramTx.alphaRRC, 2)))
                #sigPhase = np.angle(sfm) # get signal phase samples (labels) (L,)
                sigAmp  = np.pad(Amp, (int(N/2), int(N/2)), 'constant') # get signal amplitude samples (L,)
                # create set of input features
                X_input = np.zeros((len(sfm), N)) #(L,N)

                for indPhase in range(len(sfm)):
                    X_input[indPhase] = sigAmp[indPhase:N+indPhase]

                sigRx_NN = model.predict(X_input)

                # optical field reconstruction
                sigRx = sigRx_NN[:,0]+1j*sigRx_NN[:,1]

            else:
                # ConvNet phase-retrieval
                model = tf.keras.models.load_model(path_conv+'testModel_SpS_'+str(paramADC.SpS_out)+'_FO_'+str(FO/10e9)+'GHz_Rolloff_'+str(round(paramTx.alphaRRC, 2)))
                #sigPhase = np.angle(sfm) # get signal phase samples (labels) (L,)
                sigAmp  = np.pad(Amp, (int(N/2), int(N/2)), 'constant') # get signal amplitude samples (L,)

                # create set of input features
                X_input = np.zeros((len(sfm), N)) #(L,N)

                for indPhase in range(len(sfm)):
                    X_input[indPhase] = sigAmp[indPhase:N+indPhase]

                sigRx_NN = model.predict(X_input)

                # optical field reconstruction
                sigRx = sigRx_NN[:,0]+1j*sigRx_NN[:,1]

            # remove DC level
            sigRx -= np.mean(sigRx) # np.sqrt(Plo)
            
            # downshift to baseband
            t = np.arange(0, len(sigRx))*(1/newFs)
            sigRx *= np.exp(-1j * (2 * π * FO * t))

            # Matched filtering
            if paramTx.pulse == "nrz":
                pulse = pulseShape("nrz", paramADC.SpS_out)
            elif paramTx.pulse == "rrc":
                pulse = pulseShape(
                    "rrc", paramADC.SpS_out, N=paramTx.Ntaps, alpha=paramTx.alphaRRC, Ts=1 / paramTx.Rs
                )

            pulse = pulse / np.max(np.abs(pulse))
            sigRx = firFilter(pulse, sigRx)
            sigCh = firFilter(pulse, sigCh)

            # correct for (possible) phase ambiguity
            rot = np.mean(sigCh/sigRx)
            sigRx = rot * sigRx
            sigRx = sigRx / np.sqrt(signal_power(sigRx))

            intf = sigRx/np.sqrt(signal_power(sigRx))-sigCh/np.sqrt(signal_power(sigCh))

            SIR[indAlg, indRolloff, indFO] = 1/signal_power(intf)

            # resample to 2 samples/symbol:
            paramRes = parameters()
            paramRes.Rs = paramTx.Rs
            paramRes.SpS_in  = paramADC.SpS_out
            paramRes.SpS_out = 2

            sigRx = resample(sigRx, paramRes)

            # CD compensation
            sigRx = edc(sigRx, Ltotal, D, Fc, paramRes.SpS_out*paramTx.Rs)

            # Downsampling to 2 sps and re-synchronization with transmitted sequences
            sigRx = sigRx.reshape(-1, 1)

            symbRx = symbolSync(sigRx, symbTx, 2)

            # Power normalization
            x = sigRx
            d = symbRx

            x = x.reshape(len(x), 1) / np.sqrt(signal_power(x))
            d = d.reshape(len(d), 1) / np.sqrt(signal_power(d))

            # Adaptive equalization          
            paramEq = parameters()
            paramEq.nTaps = 15
            paramEq.SpS = 2
            paramEq.mu = [1e-3, 5e-4]
            paramEq.numIter = 5
            paramEq.storeCoeff = False
            paramEq.alg = ["da-rde", "rde"]
            paramEq.M = paramTx.M
            paramEq.L = [20000, 80000]
            paramEq.prgsBar = False

            y_EQ, H, errSq, Hiter = mimoAdaptEqualizer(x, dx=d, paramEq=paramEq)

            # Carrier phase recovery
            paramCPR = parameters()
            paramCPR.alg = "bps"
            paramCPR.M = paramTx.M
            paramCPR.N = 85
            paramCPR.B = 64
            paramCPR.pilotInd = np.arange(0, len(y_EQ), 20)

            y_CPR, θ = cpr(y_EQ, symbTx=d, paramCPR=paramCPR)

            y_CPR = y_CPR / np.sqrt(signal_power(y_CPR))

            # correct for (possible) phase ambiguity
            for k in range(y_CPR.shape[1]):
                rot = np.mean(d[:, k] / y_CPR[:, k])
                y_CPR[:, k] = rot * y_CPR[:, k]

            y_CPR = y_CPR / np.sqrt(signal_power(y_CPR))

            discard = int(paramEq.L[0]/2)

            ind = np.arange(discard, d.shape[0] - discard)
            BER[indAlg,indRolloff,indFO], SER[indAlg,indRolloff,indFO], SNR[indAlg,indRolloff,indFO] = fastBERcalc(y_CPR[ind, :], d[ind, :], paramTx.M, 'qam')
            GMI[indAlg,indRolloff,indFO], _ = monteCarloGMI(y_CPR[ind, :], d[ind, :],  paramTx.M, 'qam')
            MI[indAlg,indRolloff,indFO] = monteCarloMI(y_CPR[ind, :], d[ind, :],  paramTx.M, 'qam')

            print("Results:")
            print("BER: %.2e" %(BER[indAlg,indRolloff,indFO]))
            print("SNR: %.2f dB" %(SNR[indAlg,indRolloff,indFO]))
            print('SIR = ', round(10*np.log10(SIR[indAlg,indRolloff,indFO]), 2), ' dB')
            print("GMI: %.2f bits\n" %(GMI[indAlg,indRolloff,indFO]))
    
    # save the simulation data
    np.save(path_data+'BER_SpS_'+str(paramADC.SpS_out)+'_'+str(N)+'_Sample', BER)
    np.save(path_data+'SNR_SpS_'+str(paramADC.SpS_out)+'_'+str(N)+'_Sample', SNR)
    np.save(path_data+'SIR_SpS_'+str(paramADC.SpS_out)+'_'+str(N)+'_Sample', SIR)

## Simulation of a single polarization optical signal transmission

In [4]:
## Transmitter parameters:
paramTx = parameters()
paramTx.M = 16                 # order of the modulation format
paramTx.Rs = 32e9              # symbol rate [baud]
paramTx.SpS = 4                # samples per symbol
paramTx.Nbits = 400000         # total number of bits per polarization
paramTx.pulse = "rrc"          # pulse shaping filter
paramTx.Ntaps = 1024           # number of pulse shaping filter coefficients
paramTx.alphaRRC = 0.5       # RRC rolloff
paramTx.Pch_dBm = 0            # power of the optical signal [dBm]
paramTx.Nch = 1                # number of WDM channels
paramTx.Fc = 193.1e12          # central frequency of the optical spectrum
paramTx.freqSpac = 37.5e9      # WDM grid spacing

## Optical channel parameters:
Ltotal = 100     # total link distance [km]
alpha = 0        # fiber loss parameter [dB/km]
D = 16           # fiber dispersion parameter [ps/nm/km]
Fc = paramTx.Fc  # central optical frequency of the WDM spectrum [Hz]

## Receiver parameters:

# local oscillator (LO)
FO = paramTx.Rs/2  # frequency offset
lw = 0*200e3       # linewidth
ϕ_lo = 0           # initial phase in rad
#Plo_dBm = 12      # power in dBm

# ADC sampling rate
paramADC = parameters()
paramADC.Rs = paramTx.Rs
paramADC.SpS_in = paramTx.SpS
paramADC.SpS_out = 4

## General simulation parameters:
chIndex = 0  # index of the channel to be demodulated
plotPSD = True
Fs = paramTx.Rs * paramTx.SpS  # simulation sampling rate
Ts = 1 / Fs

# photodiode parameters
paramPD = parameters()
paramPD.B  = 1.1*paramTx.Rs
paramPD.Fs = Fs

## Run all CSPR variations

In [3]:
loPower = np.arange(6,16,1)
BER = np.zeros((3, len(loPower)))
SER = np.zeros((3, len(loPower)))
GMI = np.zeros((3, len(loPower)))
MI  = np.zeros((3, len(loPower)))
SNR = np.zeros((3, len(loPower)))
SIR = np.zeros((3, len(loPower)))

for indAlg, alg in enumerate(['KK', 'MLP', 'CONVNET']):
    for indPower, Plo_dBm in enumerate(tqdm(loPower)):
        
        # generate optical signal signal
        sigTx, symbTx_, paramTx = simpleWDMTx(paramTx)

        # simulate linear signal propagation
        sigCh = linFiberCh(sigTx, Ltotal, alpha, D, Fc, Fs)

        symbTx = symbTx_[:, :, chIndex]
        Plo = 10 ** (Plo_dBm / 10) * 1e-3  # power in W

        # generate LO field
        π = np.pi
        t = np.arange(0, len(sigCh))*Ts
        ϕ_pn_lo = phaseNoise(lw, len(sigCh), Ts)

        sigLO = np.sqrt(Plo) * np.exp(-1j * (2 * π * FO * t + ϕ_lo + ϕ_pn_lo))

        # Add LO to the received signal
        sigRx = np.sqrt(Plo) + sigCh* np.exp(1j * (2 * π * FO * t + ϕ_lo + ϕ_pn_lo))
        sfm   = sigRx.copy()

        print('CSPR = %.2f dB'%(10*np.log10(signal_power(sigLO)/signal_power(sigCh))))

        # simulate ideal direct-detection optical receiver
        Ipd = photodiode(sigRx, paramPD)
        Amp = np.sqrt(Ipd.real)
        Amp = resample(Amp, paramADC).real

        # resampling to ADC sampling rate
        sigCh = resample(sigCh, paramADC)
        sfm = resample(sfm, paramADC)
        newFs = paramADC.SpS_out*paramTx.Rs

        sfm = sfm/np.sqrt(signal_power(sfm))
        if alg == 'KK':
            # Kramers-Kronig phase-retrieval
            phiTime = KramersKronigRx(Amp, newFs)
            # optical field reconstruction
            sigRx = Amp*np.exp(1j*phiTime)

        elif alg == 'MLP':
            # Mlp phase-retrieval
            model = tf.keras.load_model(path_mlp+'testModel_SpS_'+str(paramADC.SpS_out)+'_FO_'+str(FO/10e9)+'GHz_Rolloff_'+str(round(paramTx.alphaRRC, 2))) 
            #sigPhase = np.angle(sfm) # get signal phase samples (labels) (L,)
            sigAmp  = np.pad(Amp, (int(N/2), int(N/2)), 'constant') # get signal amplitude samples (L,)
            # create set of input features
            X_input = np.zeros((len(sfm), N)) #(L,N)

            for indPhase in range(len(sfm)):
                X_input[indPhase] = sigAmp[indPhase:N+indPhase]

            sigRx_NN = model.predict(X_input)
            # optical field reconstruction
            sigRx = sigRx_NN[:,0]+1j*sigRx_NN[:,1]

        else:
            # ConvNet phase-retrieval
            model = tf.keras.load_model(path_conv+'testModel_SpS_'+str(paramADC.SpS_out)+'_FO_'+str(FO/10e9)+'GHz_Rolloff_'+str(round(paramTx.alphaRRC, 2)))
            #sigPhase = np.angle(sfm) # get signal phase samples (labels) (L,)
            sigAmp  = np.pad(Amp, (int(N/2), int(N/2)), 'constant') # get signal amplitude samples (L,)
            # create set of input features
            X_input = np.zeros((len(sfm), N)) #(L,N)

            for indPhase in range(len(sfm)):
                X_input[indPhase] = sigAmp[indPhase:N+indPhase]

            sigRx_NN = model.predict(X_input)
            # optical field reconstruction
            sigRx = sigRx_NN[:,0]+1j*sigRx_NN[:,1]    

        # remove DC level
        sigRx -= np.mean(sigRx) # np.sqrt(Plo)  

        # downshift to baseband
        t = np.arange(0, len(sigRx))*(1/newFs)
        sigRx *= np.exp(-1j * (2 * π * FO * t))
        
        # Matched filtering
        if paramTx.pulse == "nrz":
            pulse = pulseShape("nrz", paramADC.SpS_out)
        elif paramTx.pulse == "rrc":
            pulse = pulseShape(
                "rrc", paramADC.SpS_out, N=paramTx.Ntaps, alpha=paramTx.alphaRRC, Ts=1 / paramTx.Rs
            )

        pulse = pulse / np.max(np.abs(pulse))
        sigRx = firFilter(pulse, sigRx)
        sigCh = firFilter(pulse, sigCh)
        
        # correct for (possible) phase ambiguity
        rot = np.mean(sigCh/sigRx)
        sigRx = rot * sigRx
        sigRx = sigRx / np.sqrt(signal_power(sigRx))

        intf = sigRx/np.sqrt(signal_power(sigRx))-sigCh/np.sqrt(signal_power(sigCh))

        SIR[indAlg, indPower] = 1/signal_power(intf)


        # resample to 2 samples/symbol:
        paramRes = parameters()
        paramRes.Rs = paramTx.Rs
        paramRes.SpS_in  = paramADC.SpS_out
        paramRes.SpS_out = 2

        sigRx = resample(sigRx, paramRes)

        # CD compensation
        sigRx = edc(sigRx, Ltotal, D, Fc, paramRes.SpS_out*paramTx.Rs)

        # Downsampling to 2 sps and re-synchronization with transmitted sequences
        sigRx = sigRx.reshape(-1, 1)

        symbRx = symbolSync(sigRx, symbTx, 2)

        # Power normalization
        x = sigRx
        d = symbRx

        x = x.reshape(len(x), 1) / np.sqrt(signal_power(x))
        d = d.reshape(len(d), 1) / np.sqrt(signal_power(d))

        # Adaptive equalization
        mod = QAMModem(m=paramTx.M)

        paramEq = parameters()
        paramEq.nTaps = 15
        paramEq.SpS = 2
        paramEq.mu = [1e-3, 5e-4]
        paramEq.numIter = 5
        paramEq.storeCoeff = False
        paramEq.alg = ["da-rde", "rde"]
        paramEq.M = paramTx.M
        paramEq.L = [20000, 80000]
        paramEq.prgsBar = False

        y_EQ, H, errSq, Hiter = mimoAdaptEqualizer(x, dx=d, paramEq=paramEq)

        # Carrier phase recovery
        paramCPR = parameters()
        paramCPR.alg = "bps"
        paramCPR.M = paramTx.M
        paramCPR.N = 85
        paramCPR.B = 64
        paramCPR.pilotInd = np.arange(0, len(y_EQ), 20)

        y_CPR, θ = cpr(y_EQ, symbTx=d, paramCPR=paramCPR)

        y_CPR = y_CPR / np.sqrt(signal_power(y_CPR))

        # correct for (possible) phase ambiguity
        for k in range(y_CPR.shape[1]):
            rot = np.mean(d[:, k] / y_CPR[:, k])
            y_CPR[:, k] = rot * y_CPR[:, k]

        y_CPR = y_CPR / np.sqrt(signal_power(y_CPR))

        discard = int(paramEq.L[0]/2)

        ind = np.arange(discard, d.shape[0] - discard)
        BER[indAlg,indPower], SER[indAlg,indPower], SNR[indAlg,indPower] = fastBERcalc(y_CPR[ind, :], d[ind, :], paramTx.M, 'qam')
        GMI[indAlg,indPower], _ = monteCarloGMI(y_CPR[ind, :], d[ind, :],  paramTx.M, 'qam')
        MI[indAlg,indPower] = monteCarloMI(y_CPR[ind, :], d[ind, :],  paramTx.M, 'qam')

        print("Results:")
        print("BER: %.2e" %(BER[indAlg,indPower]))
        print("SNR: %.2f dB" %(SNR[indAlg,indPower]))
        print('SIR = ', round(10*np.log10(SIR[indAlg,indPower]), 2), ' dB')
        print("GMI: %.2f bits\n" %(GMI[indAlg,indPower]))

array([ 6,  7,  8,  9, 10, 11, 12, 13, 14, 15])