In [1]:
# # Example 06 - Modulation Classifier
# This notebook walks through a simple example of how to use the clean Sig53 dataset, load a pre-trained supported model, and evaluate the trained network's performance. Note that the experiment and the results herein are not to be interpreted with any significant value but rather serve simply as a practical example of how the `torchsig` dataset and tools can be used and integrated within a typical [PyTorch](https://pytorch.org/) and/or [PyTorch Lightning](https://www.pytorchlightning.ai/) workflow.

# ----
# ### Import Libraries
# First, import all the necessary public libraries as well as a few classes from the `torchsig` toolkit. An additional import from the `cm_plotter.py` helper script is also done here to retrieve a function to streamline plotting of confusion matrices.

from torchsig.transforms.target_transforms import DescToClassIndex
from torchsig.models.iq_models.efficientnet.efficientnet import efficientnet_b4
from torchsig.transforms.transforms import (
    RandomPhaseShift,
    Normalize,
    ComplexTo2D,
    Compose,
)
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer
from sklearn.metrics import classification_report
from torchsig.utils.cm_plotter import plot_confusion_matrix
from torchsig.datasets.modulations import ModulationsDataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchsig.datasets import conf
from torch import optim
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np
import torch
import os

from torch.utils.data import DataLoader
from utils import *
import pandas as pd
from load_datasets import load_sig 
from torchsig.utils.dataset import SignalFileDataset
from torchsig.datasets.modulations import ModulationsDataset
import torchsig.transforms as ST
from torchvision import transforms
import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from torchsummary import summary
from scipy import interpolate
from scipy import signal as sp


EPOCHS = 4

# ----
# ### Instantiate Modulation Dataset
classes = ["4ask","8pam","16psk","32qam_cross","2fsk","ofdm-256"]
num_classes = len(classes)
training_samples_per_class = 4000
valid_samples_per_class = 1000
test_samples_per_class = 1000
num_workers=15
torch.set_default_dtype(torch.float32)
data_transform = ST.Compose([
    ST.Normalize(norm=np.inf),
    ST.ComplexTo2D(),
#    transforms.Lambda(lambda x: x.double()),
])

pl.seed_everything(1234567891)
    
ds_train = ModulationsDataset(
    classes=classes,
    use_class_idx=True,
    level=0,
    num_iq_samples=4096,
    num_samples=int(num_classes*training_samples_per_class),
    include_snr=False,
    transform = data_transform
)
ds_val = ModulationsDataset(
    classes=classes,
    use_class_idx=True,
    level=0,
    num_iq_samples=4096,
    num_samples=int(num_classes*valid_samples_per_class),
    include_snr=False,
    transform = data_transform
)

ds_test = ModulationsDataset(
    classes=classes,
    use_class_idx=True,
    level=0,
    num_iq_samples=4096,
    num_samples=int(num_classes*test_samples_per_class),
    include_snr=False,
    transform = data_transform
)

train_dataloader = DataLoader(
    dataset=ds_train,
    batch_size=16,
    #num_workers=num_workers,
    shuffle=True,
    drop_last=True,
)
val_dataloader = DataLoader(
    dataset=ds_val,
    batch_size=16,
    #num_workers=num_workers,
    shuffle=True,
    drop_last=True,
)
test_dataloader = DataLoader(
    dataset=ds_test,
    batch_size=64,
    #num_workers=num_workers,
    shuffle=False,
    drop_last=True,
)
torch.set_default_dtype(torch.float64)
model_save_path=os.path.join("tb_logs", f"EfficientNet_Classes6_e{EPOCHS}.pt")
   


ModuleNotFoundError: No module named 'numpy'

In [None]:
#pretrained = False if not os.path.exists("tb_logs/efficientnet_b4.pt") else True

model = efficientnet_b4(
    pretrained=False
)
model.float().to('cuda')

model

In [None]:
from torchsummary import summary


summary(model,(2,4096),16)