In [11]:
import pickle 
with open('data/t0.pkl', 'rb') as file:
    a = pickle.load(file)

with open('data/train_data_afterCDCDSP.pkl', 'rb') as file:
    a = pickle.load(file)

In [52]:
import torch, torch.nn as nn
from src.TorchDSP.layers import ComplexConv1d
from src.TorchDSP.core import TorchSignal, TorchTime
from typing import Union

B = 5
L = 1000
M = 1


class pbcBlock(nn.Module):

    def __init__(self, Nmodes=1, xpm_size=101, fwm_heads=16):
        super().__init__()
        self.Nmodes = Nmodes
        self.xpm_size = xpm_size
        self.fwm_size = xpm_size
        self.fwm_heads = fwm_heads
        self.overlaps = self.xpm_size
        self.xpm_conv = nn.Conv1d(self.Nmodes, self.Nmodes, self.xpm_size)      # real convolution
        self.fwm_conv_m = ComplexConv1d(1, self.fwm_heads, self.fwm_size)   # complex convolution
        self.fwm_conv_n = ComplexConv1d(1, self.fwm_heads, self.fwm_size)   # complex convolution
        self.fwm_conv_k = ComplexConv1d(1, self.fwm_heads, self.fwm_size)   # complex convolution
        
    def forward(self, signal: TorchSignal, task_info: Union[torch.Tensor,None] = None) -> TorchSignal:
        P = torch.tensor(1) if task_info == None else 10**(task_info[:,0]/10)/signal.val.shape[-1]   # [batch] or ()
        P = P.to(signal.val.device)
        x = signal.val  # x [B, M, L]
        phi = self.xpm_conv(torch.abs(x)**2)      # [B, M, L - xpm_size + 1]
        x_ = x.view(-1, x.shape[-1]).unsqueeze(1) # [B*M, 1, L]
        Am = self.fwm_conv_m(x_).view(x.shape[0], x.shape[1], self.fwm_heads, -1)       # [B, M, heads, L - fwm_size + 1]
        An = self.fwm_conv_n(x_).view(x.shape[0], x.shape[1], self.fwm_heads, -1)       # [B, M, heads, L - fwm_size + 1]
        Ak = self.fwm_conv_k(x_).view(x.shape[0], x.shape[1], self.fwm_heads, -1)       # [B, M, heads, L - fwm_size + 1]
        S = torch.sum(Am*Ak.conj(), dim=1)                                              # [B, heads, L - fwm_size + 1]
        E = x[:,:, self.xpm_size//2:-(self.xpm_size//2)]*torch.exp(1j*phi) + torch.sum(An*S.unsqueeze(1), dim=2)  # [B, M, L - xpm_size + 1]
        return  TorchSignal(val=E, t=TorchTime(signal.t.start + (self.xpm_size//2), signal.t.stop - (self.xpm_size//2), signal.t.sps))


x = torch.rand(B, M, L) + 1j*torch.rand(B, M, L)
E = TorchSignal(val=x)

pbc = pbcBlock(Nmodes=M)
pbc(E)

TorchSignal(val: tensor with torch.Size([5, 1, 900]), cpu, t:TorchTime(start=50, stop=-50, sps=2))

In [53]:
x.shape

torch.Size([5, 1, 1000])

In [55]:
x.transpose(1,2).shape

torch.Size([5, 1000, 1])