Skip to content

Commit

Permalink
Pull request #1: ready to push to public
Browse files Browse the repository at this point in the history
Merge in CSID/glaucus from feature/ieee to main

Squashed commit of the following:

commit 6bd21a4a162ffa0a51c062d2c2588bf6801b4357
Author: Kyle A Logue <kyle.a.logue@aero.org>
Date:   Wed Mar 1 14:40:49 2023 -0800

    better tests; copyright header; cleanup

commit cfa4ef362c74463ed519d977f63d7a1564e578c0
Author: Kyle A Logue <kal29868@dhcp-10-3-58-89.aero.org>
Date:   Wed Mar 1 11:03:57 2023 -0800

    fix repr

commit 1b052991d89953a65fae2e2275dbba092b56d669
Author: Kyle A Logue <kal29868@dhcp-10-3-58-89.aero.org>
Date:   Wed Mar 1 09:12:09 2023 -0800

    NL support and lighting upgrade

commit f72848db3f0f2e392c31aece3b68bffa18d26cff
Author: Kyle A Logue <kyle.a.logue@aero.org>
Date:   Tue Feb 28 16:11:38 2023 -0800

    add support for NL
  • Loading branch information
Kyle A Logue committed Mar 2, 2023
1 parent 2fa0a9e commit acd2de1
Show file tree
Hide file tree
Showing 13 changed files with 284 additions and 98 deletions.
30 changes: 18 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

# Glaucus

The Aerospace Corporation is proud to present our complex-valued encoder, decoder, and loss for RF DSP in PyTorch.
The Aerospace Corporation is proud to present our complex-valued encoder,
decoder, and a new loss function for RF DSP in PyTorch.

## Video (click to play)

[<img src="https://i.vimeocdn.com/video/1583946742-851ad3621192f133ca667bc87f4050276e450fcc721f117bbcd93b67cb0535f8-d_1000">](https://vimeo.com/787670661/ce13da4cd9)

## Using

Expand All @@ -18,7 +23,7 @@ The Aerospace Corporation is proud to present our complex-valued encoder, decode
* `coverage run -a --source=glaucus -m pytest --doctest-modules; coverage html`
* `pytest .`

### Use our pre-trained model
### Use pre-trained model with SigMF data

Load quantized model and return compressed signal vector & reconstruction.
Our weights were trained & evaluated on a corpus of 200GB of RF waveforms with
Expand All @@ -30,18 +35,17 @@ import sigmf
from glaucus import GlaucusAE

# create model
model = GlaucusAE(bottleneck_quantize=True)
model = GlaucusAE(bottleneck_quantize=True, data_format='nl')
model = torch.quantization.prepare(model)
# get weights for quantized model
state_dict = torch.hub.load_state_dict_from_url('https://pending-torch-hub-submission/ae-quantized.pth')
model.load_state_dict(state_dict)
# prepare for prediction
model.eval()
torch.quantization.convert(model), inplace=True)
# get samples into NCL tensor
# get samples into NL tensor
x_sigmf = sigmf.sigmffile.fromfile('example.sigmf')
x_np = x_sigmf.read_samples()
x_tensor = torch.view_as_real(torch.from_numpy(x_np)).swapaxes(-1, -2).unsqueeze(0)
x_tensor = torch.from_numpy(x_sigmf.read_samples())
# create prediction & quint8 signal vector
y_tensor, y_encoded = model(x_samples)
# get signal vector as uint8
Expand All @@ -55,7 +59,7 @@ import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from glaucus import GlaucusAE
model = GlaucusAE() # or FullyConnectedAE
model = GlaucusAE()
loader = DataModule() # Not provided
early_stopping_callback = EarlyStopping(monitor='val_loss', mode='min', patience=patience)
checkpoint_callback = ModelCheckpoint(monitor='val_loss', filename='glaucus-{epoch:03d}-{val_loss:05f}')
Expand All @@ -72,14 +76,10 @@ This code is documented by the two following IEEE publications.

### Glaucus: A Complex-Valued Radio Signal Autoencoder

[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5806615.svg)](https://doi.org/10.5281/zenodo.5806615)

A complex-valued autoencoder neural network capable of compressing \& denoising radio frequency (RF) signals with arbitrary model scaling is proposed. Complex-valued time samples received with various impairments are decoded into an embedding vector, then encoded back into complex-valued time samples. The embedding and the related latent space allow search, comparison, and clustering of signals. Traditional signal processing tasks like specific emitter identification, geolocation, or ambiguity estimation can utilize multiple compressed embeddings simultaneously. This paper demonstrates an autoencoder implementation capable of 64x compression hardened against RF channel impairments. The autoencoder allows separate or compound scaling of network depth, width, and resolution to target both embedded and data center deployment with differing resources. The common building block is inspired by the Fused Inverted Residual Block (Fused-MBConv), popularized by EfficientNetV2 \& MobileNetV3, with kernel sizes more appropriate for time-series signal processing
A complex-valued autoencoder neural network capable of compressing & denoising radio frequency (RF) signals with arbitrary model scaling is proposed. Complex-valued time samples received with various impairments are decoded into an embedding vector, then encoded back into complex-valued time samples. The embedding and the related latent space allow search, comparison, and clustering of signals. Traditional signal processing tasks like specific emitter identification, geolocation, or ambiguity estimation can utilize multiple compressed embeddings simultaneously. This paper demonstrates an autoencoder implementation capable of 64x compression hardened against RF channel impairments. The autoencoder allows separate or compound scaling of network depth, width, and resolution to target both embedded and data center deployment with differing resources. The common building block is inspired by the Fused Inverted Residual Block (Fused-MBConv), popularized by EfficientNetV2 \& MobileNetV3, with kernel sizes more appropriate for time-series signal processing

### Complex-Valued Radio Signal Loss for Neural Networks

[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5806615.svg)](https://doi.org/10.5281/zenodo.5806615)

A new optimized loss for training complex-valued neural networks that require reconstruction of radio signals is proposed. Given a complex-valued time series this method incorporates loss from spectrograms with multiple aspect ratios, cross-correlation loss, and loss from amplitude envelopes in the time \& frequency domains. When training a neural network an optimizer will observe batch loss and backpropagate this value through the network to determine how to update the model parameters. The proposed loss is robust to typical radio impairments and co-channel interference that would explode a naive mean-square-error approach. This robust loss enables higher quality steps along the loss surface which enables training of models specifically designed for impaired radio input. Loss vs channel impairment is shown in comparison to mean-squared error for an ensemble of common channel effects.

## Contributing
Expand Down Expand Up @@ -107,3 +107,9 @@ alternative license. An alternative license can allow you to create proprietary
applications around Aerospace products without being required to meet the
obligations of the GPL. To inquire about an alternative license, please get in
touch with us at [oss@aero.org](mailto:oss@aero.org).

## To-Do

* insert DOI links once papers are assigned DOI like [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5806615.svg)](https://doi.org/10.5281/zenodo.5806615)
* update this readme with published model weight path
* upload training notebook
6 changes: 5 additions & 1 deletion glaucus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
__version__ = '1.0.0'
# Copyright 2023 The Aerospace Corporation
# This file is a part of Glaucus
# SPDX-License-Identifier: LGPL-3.0-or-later

__version__ = '1.1.0'

from .rfloss import *
from .layers import *
Expand Down
38 changes: 31 additions & 7 deletions glaucus/autoencoders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright 2023 The Aerospace Corporation
# This file is a part of Glaucus
# SPDX-License-Identifier: LGPL-3.0-or-later

import logging

import torch
import pytorch_lightning as pl
import lightning as pl
from madgrad import MADGRAD

from .gblocks import GlaucusNet, ENCODER_BLOCKS, DECODER_BLOCKS
Expand All @@ -11,12 +15,13 @@

log = logging.getLogger(__name__)


class GlaucusAE(pl.LightningModule):
'''RF Autoencoder constructed with network of GBlocks.'''
def __init__(self, encoder_blocks=ENCODER_BLOCKS, decoder_blocks=DECODER_BLOCKS,
domain:str='time', width_coef:float=1, depth_coef:float=1, spatial_size:int=4096,
bottleneck_in:int=512, bottleneck_latent:int=512, bottleneck_out:int=512,
bottleneck_steps:int=1, bottleneck_quantize:bool=False,
bottleneck_steps:int=1, bottleneck_quantize:bool=False, data_format:str='ncl',
drop_connect_rate:float=0.2, optimizer:str='madgrad', lr:float=1e-3,
) -> None:
'''
Expand Down Expand Up @@ -57,20 +62,25 @@ def __init__(self, encoder_blocks=ENCODER_BLOCKS, decoder_blocks=DECODER_BLOCKS,
Currently support either `madgrad` or `adam` optimizers.
lr : float, default 1e-3
Learning Rate. Experiments from Dec 2021 to Mar 2022 yielded good values in range (1e-3, 1e-2).
data_format : str, default 'ncl'
Network normally consumes and produces complex-valued data represented as real-valued (NCL)
but if data is complex-valued (NL) will add a transform layer during encode/decode.
'''
super().__init__()

self.save_hyperparameters()
self.lr = lr
assert domain in ['time', 'freq']
assert data_format in ['ncl', 'nl']
self.domain = domain
self.data_format = data_format

self._rms_norm = RMSNormalize(spatial_size=spatial_size)
self._noise_layer = GaussianNoise(spatial_size=spatial_size)
if self.domain == 'freq':
self._time2freq = TimeDomain2FreqDomain()
self._freq2time = FreqDomain2TimeDomain()
self.loss_function = RFLoss(spatial_size=spatial_size)
self.loss_function = RFLoss(spatial_size=spatial_size, data_format=data_format)
self.encoder = GlaucusNet(
encoder_blocks, mode='encoder', width_coef=width_coef, depth_coef=depth_coef, drop_connect_rate=drop_connect_rate)
self.fc_encoder = FullyConnected(
Expand All @@ -92,6 +102,8 @@ def forward(self, x):

def encode(self, x):
'''normalize, add noise if training, and reduce to latent domain'''
if self.data_format == 'nl':
x = torch.view_as_real(x).swapaxes(-1,-2)
x = self._rms_norm(x)
x, _ = self._noise_layer(x)
if self.domain == 'freq':
Expand All @@ -108,6 +120,8 @@ def decode(self, x_enc):
if self.domain == 'freq':
# convert back to time domain
x_hat = self._freq2time(x_hat)
if self.data_format == 'nl':
x_hat = torch.view_as_complex(x_hat.swapaxes(-1,-2).contiguous())
return x_hat

def step(self, batch, batch_idx):
Expand All @@ -118,15 +132,19 @@ def step(self, batch, batch_idx):

def training_step(self, batch, batch_idx):
loss, _ = self.step(batch, batch_idx)
# self.log_dict({f"val_{k}": v for k, v in metrics.items()})
self.log('train_loss', loss, on_step=True, on_epoch=False)
return loss

def validation_step(self, batch, batch_idx):
loss, metrics = self.step(batch, batch_idx)
self.log_dict({f"val_{k}": v for k, v in metrics.items()})
self.log_dict({f"val_{k}": v for k, v in metrics.items()}, sync_dist=True)
return loss

def test_step(self, batch, batch_idx):
loss, metrics = self.step(batch, batch_idx)
self.log_dict({f"test_{k}": v for k, v in metrics.items()}, sync_dist=True)
return metrics

def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.lr)
return optimizer
Expand All @@ -136,12 +154,14 @@ class FullyConnectedAE(pl.LightningModule):
'''RF Autoencoder constructed with fully connected layers.'''
def __init__(self,
spatial_size:int=4096, latent_dim:int=512, lr:float=1e-3, steps:int=3, bottleneck_quantize:bool=False,
domain:str='time', optimizer:str='madgrad') -> None:
domain:str='time', data_format:str='ncl', optimizer:str='madgrad') -> None:
super().__init__()
self.save_hyperparameters()
self.lr = lr
assert domain in ['time', 'freq']
assert data_format in ['ncl', 'nl']
self.domain = domain
self.data_format = data_format

self.latent_dim = latent_dim
self.io_dim = spatial_size * 2
Expand All @@ -152,7 +172,7 @@ def __init__(self,
if self.domain == 'freq':
self._time2freq = TimeDomain2FreqDomain()
self._freq2time = FreqDomain2TimeDomain()
self.loss_function = RFLoss(spatial_size=spatial_size)
self.loss_function = RFLoss(spatial_size=spatial_size, data_format=data_format)

optimizer_map = {'adam': torch.optim.Adam, 'madgrad': MADGRAD}
self.optimizer = optimizer_map[optimizer]
Expand All @@ -173,6 +193,8 @@ def forward(self, x):

def encode(self, x):
'''normalize, add noise if training, and reduce to latent domain'''
if self.data_format == 'nl':
x = torch.view_as_real(x).swapaxes(-1, -2)
x = self._rms_norm(x)
x, _ = self._noise_layer(x)
if self.domain == 'freq':
Expand All @@ -191,6 +213,8 @@ def decode(self, x_enc):
if self.domain == 'freq':
# convert back to time domain
x_hat = self._freq2time(x_hat)
if self.data_format == 'nl':
x_hat = torch.view_as_complex(x_hat.swapaxes(-1, -2).contiguous())
return x_hat

def step(self, batch, batch_idx):
Expand Down
16 changes: 11 additions & 5 deletions glaucus/fcblocks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright 2023 The Aerospace Corporation
# This file is a part of Glaucus
# SPDX-License-Identifier: LGPL-3.0-or-later

import logging
import numpy as np

import torch
import pytorch_lightning as pl
import lightning as pl

log = logging.getLogger(__name__)


class FullyConnected(pl.LightningModule):
'''Sequential Layer Generator for 1D Fully Connected'''
def __init__(self, size_in:int=512, size_out:int=128, steps:int=3,
Expand All @@ -28,12 +33,13 @@ def __init__(self, size_in:int=512, size_out:int=128, steps:int=3,
self.bn_eps = 1e-3 # better than torch default

# deal with optional quantization
if self.quantize_in or self.quantize_out:
# 'fbgemm' is for servers, 'qnnpack' is for mobile
qconfig = torch.quantization.get_default_qconfig('fbgemm')
if self.quantize_in:
self._dequant_in = torch.quantization.DeQuantStub()
# not sure why this has to be set externally, but it does for convert() to work correctly
self._dequant_in.qconfig = torch.quantization.get_default_qconfig()
self._dequant_in = torch.quantization.DeQuantStub(qconfig=qconfig)
if self.quantize_out:
self._quant_out = torch.quantization.QuantStub(qconfig=torch.quantization.get_default_qconfig())
self._quant_out = torch.quantization.QuantStub(qconfig=qconfig)
if use_dropout:
self._dropout = torch.nn.Dropout(0.2)

Expand Down
18 changes: 13 additions & 5 deletions glaucus/gblocks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright 2023 The Aerospace Corporation
# This file is a part of Glaucus
# SPDX-License-Identifier: LGPL-3.0-or-later

import logging
import math
from collections import namedtuple
import numpy as np

import torch
import pytorch_lightning as pl
import lightning as pl
from torch import nn

from .layers import DropConnect
Expand Down Expand Up @@ -93,6 +97,7 @@ def blockgen(
blocks += [block]
return blocks


# defaults
ENCODER_BLOCKS = blockgen(steps=6, spatial_in=4096, spatial_out=8, filters_in=2, filters_out=64, mode='encoder')
DECODER_BLOCKS = blockgen(steps=6, spatial_in=8, spatial_out=4096, filters_in=64, filters_out=2, mode='decoder')
Expand Down Expand Up @@ -166,8 +171,9 @@ class GBlock(pl.LightningModule):
*DANGER* even kernel sizes not quite supported; padding nightmares
'''
def __init__(self,
filters_in, filters_out, mode='encoder', stride=1, drop_connect_rate=0.2,
expand_ratio=4, squeeze_ratio=4, kernel_size=7):
filters_in:int, filters_out:int, mode:str='encoder',
stride:int=1, drop_connect_rate:float=0.2,
expand_ratio:int=4, squeeze_ratio:int=4, kernel_size:int=7):
super().__init__()
assert mode in ['encoder', 'decoder']
self.filters_in = filters_in
Expand Down Expand Up @@ -252,12 +258,13 @@ def forward(self, x):
x = self._bn0(x)
x = self._activ(x)
# Squeeze-Excitation Phase
if self.squeeze_ratio != 1:
if self.squeeze_ratio != 1:
x_squeezed = self._avgpool(x).squeeze()
x_squeezed = self._se_reduce(x_squeezed)
x_squeezed = self._activ(x_squeezed)
x_squeezed = self._se_expand(x_squeezed).unsqueeze(-1)
x *= torch.sigmoid(x_squeezed)
# doing this in-place causes backprop err
x = x * torch.sigmoid(x_squeezed)
# Pointwise Convolution Phase
x = self._conv_tail(x)
x = self._bn1(x)
Expand All @@ -267,6 +274,7 @@ def forward(self, x):
x += identity
return x


class GlaucusNet(nn.Module):
def __init__(self,
blocks=ENCODER_BLOCKS, mode:str='encoder',
Expand Down
13 changes: 9 additions & 4 deletions glaucus/layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
'''custom layers for pytorch'''
# Copyright 2023 The Aerospace Corporation
# This file is a part of Glaucus
# SPDX-License-Identifier: LGPL-3.0-or-later

import typing
import torch
import pytorch_lightning as pl
import lightning as pl


def wrap_ncl(func):
Expand All @@ -19,7 +23,7 @@ def wrap(x_ncl, *args, **kwargs):

class RMSNormalizeIQ(pl.LightningModule):
'''
When consuming RF, assure the waveform has uniform scale to regularize input to our architecture.
When consuming RF, ensure the waveform has uniform scale to regularize input to our architecture.
Expects Complex-Valued NL format (batchsize, spatial_size).
Best Method
Expand Down Expand Up @@ -110,6 +114,7 @@ def forward(self, x):
x = torch.view_as_real(torch.fft.ifft(torch.fft.fftshift(x, dim=-1))).swapaxes(-1, -2)
return x


class DropConnect(pl.LightningModule):
'''
drop connections between blocks (alternative to dropout)
Expand All @@ -133,7 +138,7 @@ class DropConnect(pl.LightningModule):
[1] http://yann.lecun.com/exdb/publis/pdf/wan-icml-13.pdf
[2] https://github.com/tensorflow/tpu/blob/cd433314cc6f38c10a23f1d607a35ba422c8f967/models/official/efficientnet/utils.py#L146
'''
def __init__(self, drop_connect_rate: float = 0.2):
def __init__(self, drop_connect_rate:float=0.2):
super().__init__()
assert 0 <= drop_connect_rate <= 1, 'drop_connect_rate must be in range of [0, 1]'
self.survival_rate = 1 - drop_connect_rate
Expand Down Expand Up @@ -166,7 +171,7 @@ class GaussianNoise(torch.nn.Module):
Input should be RMS normalized.
Returns RMS normalized output.
'''
def __init__(self, spatial_size: int = 4096, min_snr_db: float = -3, max_snr_db: float = 20):
def __init__(self, spatial_size:int=4096, min_snr_db:float=-3, max_snr_db: float = 20):
super().__init__()
self.min_snr_db = min_snr_db
self.max_snr_db = max_snr_db
Expand Down
Loading

0 comments on commit acd2de1

Please sign in to comment.