In [1]:
import torch
from src.TorchDSP.dataloader import  get_signals, opticDataset
from src.TorchDSP.nneq import eqBiLSTM, eqMLP, eqCNNBiLSTM, eqID
path = 'data/lab/test_2_0_4_1.pkl'

In [13]:

Memory = 101
trainset = opticDataset(Nch=7, Rs=36, M=Memory, path=path, idx=(100000, 1100000), power_fix=False)
testset = opticDataset(Nch=7, Rs=36, M=Memory, path=path, idx=(1100000, 1200000), power_fix=False)

# construct dataloader
from torch.utils.data import DataLoader
train_loader = DataLoader(trainset, batch_size=5000, shuffle=True)
test_loader = DataLoader(testset, batch_size=5000, shuffle=False)
# for y,x,z in train_loader:
#     print(y.shape, x.shape, z.shape)
#     break

In [10]:
net = eqMLP(M=Memory)
# net = eqCNNBiLSTM(M=201, channels=10, kernel_size=11, hidden_size=20, res_net=True)
print('Number of net parameters: ', sum(p.numel() for p in net.parameters() if p.requires_grad))

Number of net parameters:  160920


In [11]:
import torch, numpy as np
import scipy.constants as const, scipy.special as special
from src.TorchSimulation.receiver import  BER

def MSE(predict, truth):
    return torch.mean(torch.abs(predict - truth)**2)  # predict, truth: [B, Nmodes]

def SNR(predict, truth):
    return 10 * torch.log10(torch.mean(torch.abs(truth)**2) / torch.mean(torch.abs(predict - truth)**2))

def Qsq(ber):
    return 20 * np.log10(np.sqrt(2) * np.maximum(special.erfcinv(2 * ber), 0.))

def test_model(dataloader, model, loss_fn, device):
    model = model.to(device)
    model.eval()
    mse = 0 
    ber = 0
    power = 0

    N = len(dataloader)
    with torch.no_grad():
        for x, y, z in dataloader:
            x, y, z = x.to(device), y.to(device), z.to(device)
            y_pred = model(x)
            mse += loss_fn(y_pred, y).item()
            power += MSE(0, y).item()
            ber += BER(y, y_pred)['BER']
            
    return {'loss_fn': mse/N, 'BER': np.mean(ber/N), 'SNR': 10 * np.log10(power / mse), 'Qsq': Qsq(np.mean(ber/N)), 'BER_XY':ber/N, 'Qsq_XY': Qsq(ber/N)}

In [14]:
test_model(test_loader, eqID(M=Memory), loss_fn=MSE, device='cuda:0')

{'loss_fn': 0.03156240377575159,
 'BER': 0.005119719,
 'SNR': 15.002040793148376,
 'Qsq': 8.190671790013035,
 'BER_XY': array([0.00365347, 0.00658597], dtype=float32),
 'Qsq_XY': array([8.570866 , 7.8858333], dtype=float32)}

In [15]:
import torch 



def train_model(net, train_loader, test_loader, optimizer, device, epochs, device):
    Memory = 101
    trainset = opticDataset(Nch=7, Rs=36, M=Memory, path=path, idx=(100000, 1100000), power_fix=False)
    testset = opticDataset(Nch=7, Rs=36, M=Memory, path=path, idx=(1100000, 1200000), power_fix=False)

    # construct dataloader
    from torch.utils.data import DataLoader
    train_loader = DataLoader(trainset, batch_size=5000, shuffle=True)
    test_loader = DataLoader(testset, batch_size=5000, shuffle=False)

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = 'cuda:0'
    net.to(device)


    for epoch in range(epochs):
        net.train()
        train_loss = 0
        N = len(train_loader)
        for x,y,z in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            predict = net(x)
            loss = MSE(predict, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        print('Epoch: %d, Loss: %.5f' % (epoch, train_loss/N))
        if epoch % 1 == 0:
            metric = test_model(test_loader, net, loss_fn=MSE, device=device)
            print(metric)

Epoch: 0, Loss: 0.03149
{'loss_fn': 0.030497416295111178, 'BER': 0.005237934, 'SNR': 15.15111102754222, 'Qsq': 8.16382496005648, 'BER_XY': array([0.00379668, 0.00667918], dtype=float32), 'Qsq_XY': array([8.529039 , 7.8682384], dtype=float32)}
Epoch: 1, Loss: 0.03021
{'loss_fn': 0.030299969669431447, 'BER': 0.005131224, 'SNR': 15.17931956977259, 'Qsq': 8.18803854048072, 'BER_XY': array([0.00361122, 0.00665122], dtype=float32), 'Qsq_XY': array([8.583445, 7.873497], dtype=float32)}
Epoch: 2, Loss: 0.03010
{'loss_fn': 0.030168524663895368, 'BER': 0.0049586985, 'SNR': 15.198200784237644, 'Qsq': 8.228003519192175, 'BER_XY': array([0.00353127, 0.00638612], dtype=float32), 'Qsq_XY': array([8.607567, 7.924187], dtype=float32)}
Epoch: 3, Loss: 0.02992
{'loss_fn': 0.0300045314244926, 'BER': 0.00492625, 'SNR': 15.221873018924384, 'Qsq': 8.235637268516172, 'BER_XY': array([0.00343117, 0.00642133], dtype=float32), 'Qsq_XY': array([8.638371, 7.917367], dtype=float32)}
Epoch: 4, Loss: 0.02965
{'loss_f

KeyboardInterrupt: 

In [15]:
device = 'cpu'
for x,y,z in train_loader:
    x, y = x.to(device), y.to(device)