In [None]:
SLICED_Y_LENGTH = 16
BATCH_SIZE =  512

# in teh feature extractor path "f" : design param
# Our experiments have shown that even a
# small number of features, e.g., F = 4, significantly improves
# the performance.
N_FEATURES_EXTRACTED = 8

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class InternalSlicer(nn.Module):
    def __init__(self, l1, l2, complex_length):
        super(InternalSlicer, self).__init__()

        mid = complex_length // 2
        self.start = mid - l1
        self.end = mid + l2 + 1

    def forward(self, sliced_y):
        return self.C2R(self.R2C(sliced_y)[:, :, self.start:self.end])

    def R2C(self, a):
        aa = a.view(BATCH_SIZE, -1, 2).to(torch.float32)
        aaa = torch.complex(aa[:, :, 0], aa[:, :, 1])
        return aaa

    def C2R(self, a):
        real, imag = torch.unsqueeze(a.real, 2), torch.unsqueeze(a.imag, 2)
        R = torch.cat((real, imag), 2)
        return R.view(BATCH_SIZE, -1)


def phase_multiply(internally_sliced_y, estimated_phase):
    internally_sliced_y_complex = internally_sliced_y.view(BATCH_SIZE, -1, 2).to(torch.float32)
    estimated_phase_complex = estimated_phase.view(BATCH_SIZE, -1, 2).to(torch.float32)
    phase_corrected_complex = estimated_phase_complex * internally_sliced_y_complex
    phase_corrected = phase_corrected_complex.view(BATCH_SIZE, -1)
    return phase_corrected


In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.cf1 = nn.Linear(SLICED_Y_LENGTH, 256)
        self.cf2 = nn.Linear(256, N_FEATURES_EXTRACTED)

    def forward(self, sliced_y):
        sliced_y = F.relu(self.cf1(sliced_y))
        sliced_y = self.cf2(sliced_y)
        return sliced_y


class PhaseEstimator(nn.Module):
    def __init__(self):
        super(PhaseEstimator, self).__init__()

        self.cf1 = nn.Linear(SLICED_Y_LENGTH, 256)
        self.cf2 = nn.Linear(256, 2)

    def forward(self, sliced_y):
        sliced_y = F.relu(self.cf1(sliced_y))
        sliced_y = self.cf2(sliced_y)
        return sliced_y


class Rx_Decoder(nn.Module):
    def __init__(self):
        super(Rx_Decoder, self).__init__()

        self.cf1 = nn.Linear(256, 256)
        self.cf2 = nn.Linear(256, 256)
        self.cf3 = nn.Linear(256, 16)

    def forward(self, concat):
        concat = F.relu(self.cf1(concat))
        concat = F.relu(self.cf2(concat))
        concat = self.cf3(concat)
        return concat


class SequenceDecoder(nn.Module):
    def __init__(self, take_prev_phase_state=False):
        super(SequenceDecoder, self).__init__()

        self.take_prev_phase_state = take_prev_phase_state

        self.feature_extractor = FeatureExtractor()
        self.phase_estimator = PhaseEstimator()
        self.internal_slicer = InternalSlicer(l1=3, l2=3, complex_length=SLICED_Y_LENGTH // 2)
        self.rx_decoder = Rx_Decoder()

    def forward(self, sliced_y, prev_phase_state=None):
        if self.take_prev_phase_state:
            assert prev_phase_state is not None, "RNN need the previous phase state as an input"

        extracted_features = self.feature_extractor(sliced_y)
        estimated_phase = self.phase_estimator(sliced_y)
        internally_sliced_y = self.internal_slicer(sliced_y)

        phase_corrected_ = phase_multiply(internally_sliced_y, estimated_phase)

        concat = torch.cat((extracted_features, phase_corrected_, prev_phase_state), dim=1)

        st_hat = self.rx_decoder(concat)

        return st_hat


In [None]:
import torch.optim as optim
import numpy as np

# Convert data to torch tensors
X_tensor = torch.tensor(X.numpy(), dtype=torch.float32)
Y_tensor = torch.tensor(Y.numpy(), dtype=torch.int64)

# Initialize the model
mySD = SequenceDecoder()

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mySD.parameters(), lr=1e-2)

# Training loop
epochs = 2
batch_size = BATCH_SIZE

for epoch in range(epochs):
    running_loss = 0.0
    num_batches = len(X) // batch_size

    # Shuffle the data
    permutation = torch.randperm(len(X_tensor))
    X_tensor_shuffled = X_tensor[permutation]
    Y_tensor_shuffled = Y_tensor[permutation]

    for i in range(0, len(X_tensor), batch_size):
        # Get the inputs and labels
        inputs = X_tensor_shuffled[i:i+batch_size]
        labels = Y_tensor_shuffled[i:i+batch_size].squeeze()

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = mySD(inputs)

        # Calculate loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/num_batches:.4f}")


In [None]:
# generate fake data
m = 512* 2** 2
X = tf.random.normal(shape=(m,SLICED_Y_LENGTH),
                     mean=0,
                     stddev=1)

Y = tf.random.uniform(shape=(m,1),
                      minval=0,
                      maxval=16,
                      dtype=tf.int32)
# Y = keras.utils.to_categorical(Y,16)
