In [None]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib
#matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import config
from Transmitter import Tx
from Receiver import  Rx
from Fiber import Fiber,Amplifier
from train_model import train, test_model, MSE, Acc

# Initializing the system
power = 60
k = int((config.Nch - 1)/2)
tx = Tx()
rx = Rx()
tx.power = torch.ones(config.Nch)*power
rx.power = torch.ones(config.Nch)*power
fiber = Fiber(tx.lam_set,length=1e5,alphaB=0.2,n2=2.7e-20,disp=17,dz=100,Nch=config.Nch,generate_noise=True)

# load model
show_path = 'ckpt-W120-D3/'   
#show_path = 'ckpt-P50-W60-D2/' 
fig_path =  '/Users/xinyu/Desktop/科研/WDM_Code/report/img/expriment/'
comp = {}
# 现在一共有三个模型
train_model_names = ['normal', 'plus','shared']   # 4 models
model_names = ['disp', 'single'] + train_model_names            # 6 models
all_names = ['no comp','full'] + model_names                    # 8 models
comp['disp'] = Fiber(tx.lam_set[k:k+1], length=1e5, alphaB=-0.2, n2=0, disp=-17, dz=5e3, Nch=1)
comp['single'] = Fiber(tx.lam_set[k:k+1], length=1e5, alphaB=-0.2, n2=-2.7e-20, disp=-17, dz=1e3, Nch=1)
comp['full'] = Fiber(tx.lam_set, length=1e5, alphaB=-0.2, n2=-2.7e-20, disp=-17, dz=1e3, Nch=config.Nch)
for name in train_model_names:
    comp[name] = torch.load(show_path + name + '_best.pt',map_location=torch.device('cpu'))['model']

########################### Fiber #############################
# fiber channel
x, symbol_stream, bit_stream = tx.wdm_signal_sample()
y = fiber(x)
z = {}

# compensation
z['full'] = comp['full'](y)
for name in model_names:
    z[name] = comp[name](y[k:k+1])

# receiver side
I = {}
I['no comp'] = rx.filter(y[k], Nch=k)
I['full'] = rx.filter(z['full'][k],Nch=k)
for name in model_names:
    I[name] = rx.filter(z[name][0], Nch=k)

# show constellation

plt.figure(figsize=(16,8))

for i,name in enumerate(all_names):
    plt.subplot(241+i)
    rx.show_symbol(I[name], symbol_stream[k],size=7)
    plt.title(name)
plt.savefig(fig_path + f'W120-D3-P{power}'+'star.png')

In [None]:
# 计算BER
acc = {}
for name in model_names:
    acc[name] = test_model(fiber, comp[name], tx,rx,N=10,power=52)

for key in acc.keys():
    print('%10s  &   %g \\\\' % (key, acc[key]))



In [None]:
# plot loss curve
loss = {}

#for name in train_model_names:
for name in train_model_names:
    loss[name] = torch.load(show_path + name + '_losspath.pt')['train loss']
    plt.plot(loss[name],label=name)

plt.legend(loc='best')
plt.title('loss curve')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.savefig(fig_path + f'W120-D3-P{power}'+'loss.png')

In [None]:
## 待办事项
'''
**1. 可视化参数，看到底是那部分起作用
2. 加宽，加深 网络，或者改变网络架构 （Transformer）
3. 数据增广， 看泛化
4. 学习率策略: 0.001

width: 20,60,100
depth: 2,3,4
power: [50,50], [50,60]
'''