# Configurable NMformer Demo
This notebook demonstrates training with the NMformer.

In [None]:

import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tqdm import notebook

from dieselwolf.data.DigitalModulations import DigitalModulationDataset
from dieselwolf.data.TransformsRF import (
    Random_Fading,
    RandomCarrierFrequency,
    RandomAWGN,
    AWGN,
    Normalize_Amplitude_Range,
    Fix_Dtype,
)

from dieselwolf.models import ConfigurableNMformer

### Create dataset

In [None]:

batch_sz = 128
train_channel = torchvision.transforms.Compose([
    Random_Fading(0.1, 1.0),
    RandomCarrierFrequency(0.01),
    RandomAWGN(0, 30),
    Fix_Dtype(),
])
train_norm = torchvision.transforms.Compose([
    Normalize_Amplitude_Range(data_keys=["data", "data_Tx"]),
    Fix_Dtype(data_keys=["data", "data_Tx"]),
])

train_dataset = DigitalModulationDataset(
    2**12,
    num_samples=512,
    transform=train_channel,
    normalize_transform=train_norm,
    min_samp=8,
    max_samp=16,
    need_tx=True,
)
train_loader = DataLoader(
    train_dataset, batch_size=batch_sz, shuffle=True, num_workers=16, pin_memory=True, drop_last=False
)

val_dataset = DigitalModulationDataset(
    2**8,
    num_samples=512,
    transform=train_channel,
    normalize_transform=train_norm,
    min_samp=8,
    max_samp=16,
    need_tx=True,
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_sz, shuffle=False, num_workers=16, pin_memory=True, drop_last=False
)


### Train model

In [None]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = ConfigurableNMformer(seq_len=512, num_classes=len(train_dataset.classes), conv_channels=[32,64], kernel_sizes=[3,3], nhead=4, num_layers=2, num_noise_tokens=2)
model.to(device)
model = nn.DataParallel(model)

conf_matrix = np.zeros((len(train_dataset.classes), len(train_dataset.classes)))
conf_matrix_2 = np.zeros((len(val_dataset.classes), len(val_dataset.classes)))
epoch_nums = 16
best_accuracy = -1000

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in notebook.tqdm(range(epoch_nums), desc='Epoch'):
    model.train()
    for item in notebook.tqdm(train_loader, desc='Training', leave=False):
        inputs = item['data'].to(device)
        labels = item['label'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if epoch == epoch_nums - 1:
            for i in range(inputs.shape[0]):
                label = labels[i]
                pred = predicted[i]
                conf_matrix[label, pred] += 1

    correct = 0
    model.eval()
    for item in notebook.tqdm(val_loader, desc='Validation', leave=False):
        inputs = item['data'].to(device)
        labels = item['label'].to(device)
        with torch.no_grad():
            outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().cpu().data.numpy()
        if epoch == epoch_nums - 1:
            for i in range(inputs.shape[0]):
                label = labels[i]
                pred = predicted[i]
                conf_matrix_2[label, pred] += 1
    accuracy = correct / float(len(val_loader) * batch_sz)
    if accuracy > best_accuracy:
        best_accuracy = accuracy

for i in range(len(conf_matrix_2)):
    conf_matrix_2[i] = conf_matrix_2[i] / conf_matrix_2[i].sum()
for i in range(len(conf_matrix)):
    conf_matrix[i] = conf_matrix[i] / conf_matrix[i].sum()


### Confusion matrices

In [None]:

def conf_color(value):
    return 'white' if value > 0.7 else 'black'

fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(conf_matrix, cmap='Blues')
ax[1].imshow(conf_matrix_2, cmap='Blues')
classes = val_dataset.classes
for i in range(len(classes)):
    for j in range(len(classes)):
        num = conf_matrix[i, j] * 100
        ax[0].text(j - 0.43, i + 0.1, f'{num:2.1f}%', color=conf_color(conf_matrix[i, j]), fontsize=12)
        num = conf_matrix_2[i, j] * 100
        ax[1].text(j - 0.43, i + 0.1, f'{num:2.1f}%', color=conf_color(conf_matrix_2[i, j]), fontsize=12)
for aaxx in ax:
    aaxx.set_xticks(np.arange(0, len(classes)))
    aaxx.set_xticklabels(classes, fontsize=12)
    aaxx.set_yticks(np.arange(0, len(classes)))
    aaxx.set_yticklabels(classes, fontsize=12)
    aaxx.set_ylabel('Actual', fontsize=14)
    aaxx.set_xlabel('Predicted', fontsize=14)
ax[0].set_title('Training Confusion Matrix: Avg = ' + f'{100 * conf_matrix.diagonal().mean():2.2f}%', fontsize=18)
ax[1].set_title('Validation Confusion Matrix: Avg = ' + f'{100 * conf_matrix_2.diagonal().mean():2.2f}%', fontsize=18)
plt.tight_layout()
plt.show()


### Accuracy vs SNR

In [None]:

snrs = np.arange(-30, 31, 1)[::-1]
results = []
for sss in notebook.tqdm(snrs):
    channel = torchvision.transforms.Compose([
        Random_Fading(0.1, 1.0),
        RandomCarrierFrequency(0.01),
        AWGN(sss),
        Fix_Dtype(),
    ])
    norm = torchvision.transforms.Compose([
        Normalize_Amplitude_Range(data_keys=['data', 'data_Tx']),
        Fix_Dtype(data_keys=['data', 'data_Tx']),
    ])
    dataset = DigitalModulationDataset(
        2**12,
        num_samples=512,
        min_samp=8,
        max_samp=8,
        transform=channel,
        normalize_transform=norm,
        need_tx=True,
    )
    loader = DataLoader(dataset, batch_size=batch_sz, shuffle=True, num_workers=16, pin_memory=True, drop_last=False)
    conf_matrix = np.zeros((len(dataset.classes), len(dataset.classes)))
    correct = 0
    model.eval()
    for item in loader:
        RXinputs = item['data'].to(device)
        labels = item['label'].to(device)
        with torch.no_grad():
            outputs = model(RXinputs)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().cpu().data.numpy()
        for i in range(RXinputs.shape[0]):
            label = labels[i]
            pred = predicted[i]
            conf_matrix[label, pred] += 1
    for i in range(len(conf_matrix)):
        conf_matrix[i] = conf_matrix[i] / conf_matrix[i].sum()
    results.append(conf_matrix)
results = np.array(results)
acc_snr = np.array([np.diagonal(res).mean() for res in results])
fig, ax = plt.subplots(figsize=(10, 9))
ax.plot(snrs, acc_snr, linewidth=3)
ax.set_xlim(-34, 48)
ax.grid()
ax.set_title('Modulation Classification Accuracy', fontsize=20)
ax.set_ylabel('Accuracy', fontsize=18)
ax.set_xlabel('SNR (dB)', fontsize=18)
ax.tick_params(axis='both', which='major', labelsize=14)
plt.tight_layout()
plt.show()
