# Model training

In progress...

In [7]:
from config import *
import torch
import torch.nn as nn
import torch.optim as optim
from OFDM_SDR_Functions_torch import *
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as tFunc # usually F, but that is reserved for other use

Create the dataset (defined in config file), and load data from disk.

In [8]:
dataset = CustomDataset()
dataset = torch.load('data/ofdm_dataset.pth')

Create a torch model for the receiver. The structure follows the DeepRX (https://arxiv.org/abs/2005.01494) structure, but is simplified and lighter, probably resulting in worse performance on higher modulation orders.

NOTE THAT THIS IS NOT YET FUNCTIONAL.

In [9]:
class ResidualBlock(nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()

        # Layer normalization is done over the last three dimensions: time, frequency, conv 'channels'
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=(128, 14, 72))
        self.conv_1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=(1, 1))
        #torch.nn.init.xavier_uniform(self.conv_1.weight)

        # Layer normalization is done over the last three dimensions: time, frequency, conv 'channels'
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=(128, 14, 72))
        self.conv_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, inputs):
        z = self.layer_norm_1(inputs)
        z = tFunc.relu(z)
        z = self.conv_1(z)
        z = self.layer_norm_2(z)
        z = tFunc.relu(z)
        z = self.conv_2(z)
        
        # Skip connection
        z = z + inputs

        return z

In [10]:
###################### Simple DeepRX-type Neural Network Receiver, PyTorch model #############################

class RXModel(nn.Module):

    def __init__(self, num_bits_per_symbol):
        super(RXModel, self).__init__()

        # Input convolution
        self.input_conv = nn.Conv2d(in_channels=4, out_channels=128, kernel_size=(3, 3), padding=(1, 1), bias=False)

        # Residual blocks
        self.res_block_1 = ResidualBlock()
        self.res_block_2 = ResidualBlock()
        self.res_block_3 = ResidualBlock()
        self.res_block_4 = ResidualBlock()

        # Output conv
        self.output_conv = nn.Conv2d(in_channels=128, out_channels=6, kernel_size=(3, 3), padding=(1, 1), bias=False)

    def forward(self, inputs):
        y, pilots = inputs
   
        # Stack the tensors along a new dimension (axis 0)
        z = torch.stack([y.real, y.imag, pilots.real, pilots.imag], dim=0)
        z = z.permute(1, 0, 2, 3)
        # Input conv
        print(1, z.shape)
        z = self.input_conv(z)
        print(2, z.shape)
        # Residual blocks
        z = self.res_block_1(z)
        print(3, z.shape)
        z = self.res_block_2(z)
        print(4, z.shape)
        z = self.res_block_3(z)
        print(5, z.shape)
        z = self.res_block_4(z)
        print(6, z.shape)
        # Output conv
        z = self.output_conv(z)
        print(7, z.shape)
        # Reshape the input to fit what the resource grid demapper is expected
        z = z.permute(0,2, 3, 1)
        print(8, z.shape)
        return z
    

model = RXModel(6)

# Print the model architecture
print(model)






RXModel(
  (input_conv): Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (res_block_1): ResidualBlock(
    (layer_norm_1): LayerNorm((128, 14, 72), eps=1e-05, elementwise_affine=True)
    (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (layer_norm_2): LayerNorm((128, 14, 72), eps=1e-05, elementwise_affine=True)
    (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (res_block_2): ResidualBlock(
    (layer_norm_1): LayerNorm((128, 14, 72), eps=1e-05, elementwise_affine=True)
    (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (layer_norm_2): LayerNorm((128, 14, 72), eps=1e-05, elementwise_affine=True)
    (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (res_block_3): ResidualBlock(
    (layer_norm_1): LayerNorm((128, 14, 72), eps=1e-05, elementwise_affine=True)
    (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1

Load the torch dataloader, set the optimizer and loss functions.

In [11]:
data_loader = DataLoader(dataset, batch_size=5, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()

In [12]:
#make a training loop
num_epochs = 100

# Training loop
for epoch in range(num_epochs):
    total_loss = 0.0

    for pdsch_iq, pilot_iq, labels in data_loader:

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model((pdsch_iq, pilot_iq))
        
        # Compute the loss
        print(outputs, labels)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()

        # Update the weights
        optimizer.step()

        # Accumulate the total loss
        total_loss += loss.item()

    # Print average loss for the epoch
    average_loss = total_loss / len(data_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss:.4f}")

# Save the trained model if needed
torch.save(model.state_dict(), 'data/rx_model.pth')


1 torch.Size([5, 4, 14, 72])
2 torch.Size([5, 128, 14, 72])
3 torch.Size([5, 128, 14, 72])
4 torch.Size([5, 128, 14, 72])
5 torch.Size([5, 128, 14, 72])
6 torch.Size([5, 128, 14, 72])
7 torch.Size([5, 6, 14, 72])
8 torch.Size([5, 14, 72, 6])


RuntimeError: all elements of input should be between 0 and 1