In [3]:
import h5py
import torch, numpy as np, matplotlib.pyplot as plt
from src.TorchSimulation.receiver import BER
from src.TorchDSP.loss import Qsq
from src.TorchSimulation.utils import show_symb
from src.TorchDSP.dataloader import MyDataset
from torch.utils.data import DataLoader
from src.JaxSimulation.dsp import BPS, bps, ddpll, cpr, mimoaf, MetaMIMO
import src.JaxSimulation.adaptive_filter as af, jax
from src.JaxSimulation.core import MySignal, SigTime
from src.JaxSimulation.MetaOptimizer import *

import os 
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]='false'

def get_grp(f, Nch, Rs, Pch, Nsymb, NF, SF, L=2000, tag=',method=frequency cut'):
    for key in f.keys():
        if f[key].attrs['Nch'] == Nch and f[key].attrs['Rs(GHz)'] == Rs and f[key].attrs['Pch(dBm)'] == Pch and f[key]['SymbTx'].shape[1] == Nsymb and f[key].attrs['NF(dB)'] == NF and f[key].attrs['freqspace(Hz)']/1e9 / f[key].attrs['Rs(GHz)'] == SF and f[key].attrs['distance(km)'] == L:
            return f[key][f'Rx(sps=2,chid=0{tag})']
        
def get_signal(f, Nch, Rs, Pch, Nsymb, NF, SF, L=2000):
    for key in f.keys():
        if f[key].attrs['Nch'] == Nch and f[key].attrs['Rs(GHz)'] == Rs and f[key].attrs['Pch(dBm)'] == Pch and f[key]['SymbTx'].shape[1] == Nsymb and f[key].attrs['NF(dB)'] == NF and f[key].attrs['freqspace(Hz)']/1e9 / f[key].attrs['Rs(GHz)'] == SF and f[key].attrs['distance(km)'] == L:
            return f[key]
        

def Q_path(Rx, Tx, Ntest=10000, stride=10000):
    Q = []
    for t in  np.arange(0, Rx.shape[-2] - Ntest, stride):
        Q.append(np.mean(BER(torch.tensor(Rx[t:t+Ntest]), torch.tensor(Tx[t:t+Ntest]))['Qsq']))
    return Q


train_data = MyDataset('dataset_A800/train.h5', Nch=[21], Rs=[40], Pch=[0, 1, 2], Nmodes=2,
                        window_size=400, strides=400-15, Nwindow=200, truncate=0,
                        Tx_window=True, pre_transform='Rx_DBP16')
train_loader = DataLoader(train_data, batch_size=20, shuffle=True)


test_data = MyDataset('dataset_A800/test.h5', Nch=[21], Rs=[40], Pch=[0], Nmodes=2,
                        window_size=200000, strides=1, Nwindow=1, truncate=0,
                        Tx_window=True, pre_transform='Rx_DBP16')
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)

for Rx, Tx, info in test_loader:
    print(Rx.shape, Tx.shape, info.shape)
    break

const = np.unique(Tx)

signal = MySignal(val=Rx[0].numpy(), t=SigTime(0,0,2), Fs=0)
truth = MySignal(val=Tx[0].numpy(), t=SigTime(0,0,1), Fs=0)

torch.Size([1, 400000, 2]) torch.Size([1, 200000, 2]) torch.Size([1, 4])


In [4]:
model = MetaMIMO(taps=32, train=True, MetaOpt=MetaGRUOpt(hidden_dim=2))
z, params = model.init_with_output(jax.random.PRNGKey(0), signal, truth, True)

from src.JaxSimulation.dsp import  construct_update
update_step = construct_update(model, optax.adam(1e-4), device='cpu', loss_type='MSE')
opt_state = optax.adam(1e-3).init(params['params'])
state_init = params['state']
state = params['state']
param = params['params']

In [5]:
Ls = []

for epoch in range(10):
    N = len(train_loader)
    ls = []
    for i,(Rx, Tx, info) in enumerate(train_loader):
        sig_input = MySignal(val=Rx[0].numpy(), t=SigTime(0,0,2), Fs=0)
        sig_output = MySignal(val=Tx[0].numpy(), t=SigTime(0,0,1), Fs=0)
        param, state, opt_state,l = update_step(param, state, opt_state, sig_input, sig_output)
        ls.append(l)
        print(f'Batch {i}/{N} loss:',l)
    Ls = Ls + ls
    print(f'epoch {epoch} train loss: {np.mean(ls)}')

Batch 0/10 loss: 0.0279416
Batch 1/10 loss: 0.0353071
Batch 2/10 loss: 0.029496353
Batch 3/10 loss: 0.029809441
Batch 4/10 loss: 0.03249835
Batch 5/10 loss: 0.03148366
Batch 6/10 loss: 0.032804873
Batch 7/10 loss: 0.03562747
Batch 8/10 loss: 0.035662964
Batch 9/10 loss: 0.029397689
epoch 0 train loss: 0.032002951949834824
Batch 0/10 loss: 0.033270903
Batch 1/10 loss: 0.03084565
Batch 2/10 loss: 0.03142063
Batch 3/10 loss: 0.030986998
Batch 4/10 loss: 0.028030064
Batch 5/10 loss: 0.029269632
Batch 6/10 loss: 0.034228045
Batch 7/10 loss: 0.030301597
Batch 8/10 loss: 0.029152734
Batch 9/10 loss: 0.032346282
epoch 1 train loss: 0.03098525106906891
Batch 0/10 loss: 0.030316519
Batch 1/10 loss: 0.033585325
Batch 2/10 loss: 0.029688066
Batch 3/10 loss: 0.029683867
Batch 4/10 loss: 0.035278816
Batch 5/10 loss: 0.04561062
Batch 6/10 loss: 0.031227829
Batch 7/10 loss: 0.028927732
Batch 8/10 loss: 0.039922707
Batch 9/10 loss: 0.029442677
epoch 2 train loss: 0.033368416130542755
Batch 0/10 loss: 0

In [6]:
model_test = MetaMIMO(taps=32, train=False, MetaOpt=MetaGRUOpt())

from functools import partial
@partial(jax.jit, backend='cpu')
def apply_model(var, signal, truth):
    return model_test.apply(var, signal, truth, True, mutable='state')

In [7]:
z, _ = apply_model({'params':param, 'state':state_init}, signal, truth)

ScopeParamShapeError: Initializer expected to generate shape (2, 2) but got shape (2, 16) instead for parameter "kernel" in "/scan(MetaOpt)/linear_in/layers_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)