# Building a neural receiver, training and testing it 

### WORK IN PROGRESS

## Neural receiver

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):

    def __init__(self):
        super().__init__()

        # Layer normalization is done over the last three dimensions: time, frequency, conv 'channels'
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=[128, None, None])  # Replace None with specific time and frequency dimensions if known
        self.conv_1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding='same')
        
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=[128, None, None])  # Replace None with specific time and frequency dimensions if known
        self.conv_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding='same')

    def forward(self, inputs):
        z = self.layer_norm_1(inputs)
        z = F.relu(z)
        z = self.conv_1(z)
        z = self.layer_norm_2(z)
        z = F.relu(z)
        z = self.conv_2(z)
        # Skip connection
        z += inputs

        return z

class NeuralReceiver(nn.Module):

    def __init__(self, num_bits_per_symbol):
        super().__init__()

        self.input_conv = nn.Conv2d(in_channels=2, out_channels=128, kernel_size=(3, 3), padding='same')
        self.res_block_1 = ResidualBlock()
        self.res_block_2 = ResidualBlock()
        self.res_block_3 = ResidualBlock()
        self.res_block_4 = ResidualBlock()
        self.output_conv = nn.Conv2d(in_channels=128, out_channels=num_bits_per_symbol, kernel_size=(3, 3), padding='same')

    def forward(self, y, no):
        # Assuming a single receiver, remove the num_rx dimension
        y = y.squeeze(dim=1)

        # Feeding the noise power in log10 scale helps with the performance
        no = torch.log10(no)

        # Stacking the real and imaginary components of the different antennas along the 'channel' dimension
        y = y.permute(0, 2, 3, 1)  # Putting antenna dimension last
        no = no.unsqueeze(3).expand(-1, y.shape[1], y.shape[2], -1)
        z = torch.cat([torch.real(y), torch.imag(y), no], dim=-1)
        z = self.input_conv(z)
        z = self.res_block_1(z)
        z = self.res_block_2(z)
        z = self.res_block_3(z)
        z = self.res_block_4(z)
        z = self.output_conv(z)

        # Reshape the input to fit what the resource grid demapper is expected
        # Add a dimension at position 2, if needed
        # z = z.unsqueeze(2)

        return z