In [1]:
from tqdm.notebook import tqdm
from time import sleep

for i in tqdm(range(20)):
        sleep(.1)

  0%|          | 0/20 [00:00<?, ?it/s]

In [1]:
## simple dsp
import matplotlib.pyplot as plt
import numpy as np
from commpy.modulation import QAMModem
import jax
import jax.numpy as jnp
from jax import device_put, device_get
import jax.random as random

import pickle
from collections import namedtuple
Input = namedtuple('DataInput', ['y', 'x', 'w0', 'a'])
with open('sml_data/dataset','rb') as file:
    b = pickle.load(file)
data_sml = Input(*b)

In [None]:
## Transmitter
import matplotlib.pyplot as plt
import numpy as np
from commpy.modulation import QAMModem
import jax
import jax.numpy as jnp
from jax import device_put, device_get
import jax.random as random

from commplax import plot as cplt
from commplax.module import core
from optical_flax.dsp import firFilter, edc, fourthPowerFOE, dbp, cpr2, downsampling, simple_cpr, test_result
from optical_flax.models import manakov_ssf, cssfm
from optical_flax.tx import simpleWDMTx
from optical_flax.core import parameters

from scipy import signal
import scipy.constants as const
param = parameters()
param.M   = 16           # modulation formate
param.Rs  = 36e9         # symbol rate [baud]
param.SpS = 16            # samples/symb
param.Nbits = 400000     # number of bits
param.pulse_type = 'rrc'   # formato de pulso
param.Ntaps = 4096       # número de coeficientes do filtro RRC
param.alphaRRC = 0.1    # rolloff do filtro RRC
param.Pch_dBm = 0        # potência média por canal WDM [dBm]
param.Nch     = 7       # número de canais WDM
param.Fc      = 299792458/1550E-9 # frequência central do espectro WDM
param.freqSpac = 50e9    # espaçamento em frequência da grade de canais WDM
param.Nmodes = 2         # número de modos de polarização
param.mod = QAMModem(m=param.M)  # modulation
param.equation = 'NLSE'

# load data
sigWDM_Tx, symbTx_, param = simpleWDMTx(param)
print(f'signal shape: {sigWDM_Tx.shape}, symb shape: {symbTx_.shape}')

import pickle
path = 'sml_data/Tx'
with open(path,'wb') as file:
    pickle.dump((sigWDM_Tx, symbTx_, param), file)
print('data is saved!')

In [None]:
## optical fiber
from tqdm.notebook import tqdm
linearChannel = True
paramCh = parameters()
paramCh.Ltotal = 1125   # km
paramCh.Lspan  = 75     # km
paramCh.alpha = 0.2    # dB/km
paramCh.D = 16.5       # ps/nm/km
paramCh.Fc = 299792458/1550E-9 # Hz
paramCh.hz =  15      # km
paramCh.gamma = 1.3174420805376552    # 1/(W.km)
paramCh.amp = 'edfa'
if linearChannel:
    paramCh.hz = paramCh.Lspan  # km
    paramCh.gamma = 0   # 1/(W.km)
Fs = param.Rs*param.SpS  # sample rates
# sigWDM_, paramCh = ssfm(sigWDM_Tx, Fs, paramCh) 
sigWDM, paramCh = manakov_ssf(sigWDM_Tx, Fs, paramCh) 

# plot psd of tx and rx
plt.figure(figsize=(8,3))
plt.xlim(paramCh.Fc-Fs/2,paramCh.Fc+Fs/2)
plt.psd(sigWDM_Tx[:,0], Fs=param.SpS*param.Rs, Fc=paramCh.Fc, NFFT = 4*1024, sides='twosided', label = 'WDM spectrum - Tx')
plt.psd(sigWDM[:,0], Fs=Fs, Fc=paramCh.Fc, NFFT = 4*1024, sides='twosided', label = 'WDM spectrum - Rx')
plt.legend(loc='lower left')
plt.title('optical WDM spectrum')


In [None]:
### Receiver
from optical_flax.rx import simpleRx, sml_dataset
np.random.seed(123)
paramRx = parameters()
paramRx.chid = int(param.Nch / 2)
paramRx.sps = 2
paramRx.FO = 64e6 * 1          # frequency offset
paramRx.lw = 100e3 * 1      # linewidth
paramRx.Rs = param.Rs

paramRx.tx_sps = param.SpS
paramRx.pulse = param.pulse
paramRx.freq = param.freqGrid[paramRx.chid]
paramRx.Ta = 1/(param.SpS*param.Rs)

sigRx, paramRx = simpleRx(sigWDM, paramRx)
data_sml = sml_dataset(sigRx, symbTx_, param, paramCh, paramRx)
# ## data saving and loading
# import pickle
# with open('sml.npy','wb') as file:
#     b = pickle.dump((sigRx3, symbTx, 2 * np.pi * paramLo.FO / param.Rs, a), file)
# with open('sml.npy','rb') as file:
#     b = pickle.load(file)
# data_train_sml = Input(*b)


# CDC 实现对比

In [None]:
## CDC 实现对比
from optical_flax.layers import fdbp
from commplax.module import core
from optical_flax.initializers import fdbp_init
d_init, n_init = fdbp_init(data_sml.a, xi=0.0, steps=1, domain='time')
cdc = fdbp(steps = 1, dtaps=1001, ntaps=1, d_init=d_init, n_init=n_init)
key = random.PRNGKey(0)
cdc_param = cdc.init(key, core.Signal(data_sml.y))

y = cdc.apply(cdc_param, core.Signal(data_sml.y))
y1,H = edc(data_sml.y, paramCh.Ltotal, paramCh.D, param.Fc - paramRx.freq, param.Rs * paramRx.sps)
from optical_flax.utils import MSE
print(MSE(y.val,y1[y.t.start:y.t.stop])/MSE(y.val,0))