# Spectrogram Segmentation

In this example, we use PyTorch and PyTorch Lightning to train deep learning models to differentiate between 
5G NR and 4G LTE signals within wideband spectrograms.s).

## Outline

This example is divided into the following sections:

**[Background Information](#background):** Contextualize the problem and introduce the machine learning frameworks and tools used in this example.

**[Set-up](#set-up):** Install the necessary libraries.

**[Data Preprocessing](data-preprocessing):** Load and analyze the Spectrum Sensing 5G dataset.

**[Model Training](#model-training):** Train a machine learning model on the dataset.

**[Model Verification](#model-verification):** Asses model performance using a suite common machine learning metrics

**[Challenge Data](#challange-data):** Challange the model with combined frames with both LTE and NR signal.

**[Conclusions & Next Step](#conclusion-&-next-steps):** Interpret results, summarize learnings, and suggest further steps to extend this example.

## Background

5G NR (New Radio) and 4G LTE (Long-Term Evolution) are both cellular network technologies, but they represent 
different generations of mobile network standards. Being able to distinguish between the two, holds significant 
applications in [spectrum sensing](https://iopscience.iop.org/article/10.1088/1742-6596/2261/1/012016#:~:text=In%20cognitive%20radio%2C%20spectrum%20sensing,user%20can%20use%20the%20spectrum.) and serves as a foundational example showcasing the near-term feasibility of 
[intelligent radio](https://www.qoherent.ai/intelligentradio/) technology.

A spectrogram, which depicts the frequency spectrum of a signal over time, is just an image. Consequently, we 
apply state-of-the-art [semantic segmentation](https://www.ibm.com/topics/semantic-segmentation) techniques from 
the field of computer vision to the problem of spectrogram anlysis. Our task is simply to assign one of the 
following labels to each pixel: 'LTE', 'NR', 'Noise'. ('Noise' refers to the absence of signal, representing 
a vacant or empty spectrum, also known as whitespace.)

The machine learning models utilized in this example are DeepLabV3 models featuring ResNet-50 or MobileNet-V3 backbones. The DeepLabv3 framework was originally introduced by Chen _et al._ in their 2017 paper titled '[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)'. For an accessible introduction to the DeepLabV3 framework, please check out Isaac Berrios' article: [DeepLabv3: Biulding Clocks for Robust Segmentation Models](https://medium.com/@itberrios6/deeplabv3-c0c8c93d25a4).

The dataset used in this example is the Spectrum Sensing 5G dataset, provided by MathWorks. This dataset contains 900 LTE frames, 900 NR frames, and 900 combined frames with both LTE and NR signal. In this example, we train exclusively on the individual LTE and NR examples, excluding the combined frames.

To ensure comparability with results obtained using MathWorks' AI-based network, we use the the hyperparameter configuration from MathWorks' spectrum sensing example: [Deep Learning Toolbox](https://www.mathworks.com/products/deep-learning.html): [Spectrum Sensing with Deep Learning to Identify 5G and LTE Signals](https://www.mathworks.com/help/comm/ug/spectrum-sensing-with-deep-learning-to-identify-5g-and-lte-signals.html).

# Set-Up

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))  # Project root
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
%matplotlib inline

In [None]:
import os
import glob
import torch
import torchvision
import torchmetrics

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pytorch_lightning as pl

from PIL import Image
from osgeo import gdal

from torch import nn
from torch.utils.data import DataLoader
from torchvision.io import read_image
from torchvision.datasets import VisionDataset
from torchvision.transforms import v2
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.classification import MulticlassConfusionMatrix
from torchmetrics import JaccardIndex as jac_ind
from sklearn.metrics import ConfusionMatrixDisplay

# Data Preprocessing

For each frame in the dataset, there are three corresponding files:

`.png`: The spectrogram image.

`.hdf`: The target mask.

`.mat`: Frame metadata, details such as the signal sample rate and the number of DFT points used for spectrogram computation. None of this metadata is necessary for this example, so we can safely ignore these files.

Because our dataset is specifically tailored for computer vision tasks, let's extend the [VisionDataset](https://pytorch.org/vision/main/generated/torchvision.datasets.VisionDataset.html) class.

In [None]:
class SpectrumSensing(VisionDataset):
    """MathWorks' Spectrum Sensing 5G dataset."""

    # Mapping of pixel labels for different waveforms.
    pixel_labels = {"Noise": 0, "NR": 127, "LTE": 255}  # noise is red, green is NR (5G), blue is LTE (4G)

    def __init__(self, root, transforms=None, transform=None, target_transform=None):
        """Initialize the Spectrum Sensing 5G dataset with the root directory of the dataset and
        any necessary functions/transforms.
        """
        super().__init__(root, transforms, transform, target_transform)
        self.root = root

        # Parse the dataset to extract the file names of individual LTE and NR frames.
        files = glob.glob(os.path.join(root, "*.png"))
        self.frames = [os.path.basename(frame).split(".")[0] for frame in files]

    def __len__(self) -> int:
        return len(self.frames)

    def __getitem__(self, idx: int) -> tuple[Image, Image]:
        """Get the image-mask pair at idx."""
        basename = self.frames[idx]

        image_file = os.path.join(self.root, f"{basename}.png")
        target_file = os.path.join(self.root, f"{basename}.hdf")

        print("image file:\t", image_file)
        print("pixel label file:\t", target_file)

        image = Image.open(image_file)

        mask_data = gdal.Open(target_file)
        mask_band = mask_data.GetRasterBand(1)
        mask = (Image.fromarray(mask_band.ReadAsArray())).convert(mode="P", palette=0)

        if self.transforms is not None:
            image, mask = self.transforms(image, mask)

        return image, mask

In [None]:
project_root = os.getcwd()
path_to_training_data = os.path.join(project_root, "SpectrumSensingDataset", "TrainingData")

mean = [0.485, 0.456, 0.406]  # Required by our model
std = [0.229, 0.224, 0.225]  # Required by our model

# transforms:
# 1. Convert image to tensor, and rescale to [0.0, 1.0]
# 2. Change the dtype of the pixel numbers to float
# 3. Normalize using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]

transforms = v2.Compose([v2.PILToTensor(), v2.ToDtype(dtype=torch.float32), v2.Normalize(mean=mean, std=std)])

# Initialize the dataset
dataset = SpectrumSensing(root=path_to_training_data, transforms=transforms)
print(f"The full dataset has {len(dataset)} training examples")

In [None]:
# Inspect a random example
random_index = np.random.randint(len(dataset))
random_image, random_mask = dataset[random_index]
print(f"Loading image at index {random_index}:\n{dataset[random_index]}")

# TODO: Show the example

# Analyze Dataset Statistics

In [None]:
freq = {}
labels = [0, 1, 2]  # 0->Noise, 1->NR, 2->LTE
freq = {label: 0 for label in labels}  # Frequency of pixels corresponding to each label
for _, y in dataset:
    for label in labels:
        out = np.sum(y.numpy() == label)
        freq[label] += out

In [None]:
normalized_heights = np.array(list(freq.values())) / sum(list(freq.values()))

plt.bar(freq.keys(), normalized_heights, tick_label=["Noise", "NR", "LTE"], color=["red", "green", "orange"])
plt.xlabel("signal names", color="blue")
plt.ylabel("Normalized\nFrequency of Pixels", color="blue")
plt.title("Distribution of Pixels Across Signals", color="blue")

__Note__: Our dataset seems to be imbalanced. We will need to take this fact into our account when training our model.

# Prepare Training and Validation Set

The data is split into a training and validation set according to: 80% is training set and 20% is validation set.

In [None]:
full_size = len(dataset)
train_split = 0.80
train_size = round(len(dataset) * train_split)
val_size = len(dataset) - train_size

train_subset, val_subset = torch.utils.data.random_split(
    dataset, [train_size, val_size], generator=torch.Generator().manual_seed(0)
)

print("Full dataset has {} instances".format(full_size))
print("Training subset has {} instances".format(len(train_subset)))
print("Validation subset has {} instances".format(len(val_subset)))

# Train Deep Neural Network


 



Steps needed to train our model:

- Create loaders (training and validation)
- Download the model and use Pytorch-Lightning to train the model with  the specified hyperparameter values and a loss function (hyperparameter values are taken from the MATLAB project).

These steps are implemented in great detail below.

_Training and Validation Dataloaders_

In [None]:
BATCH = 4

train_loader = DataLoader(train_subset, batch_size=BATCH, shuffle=True, num_workers=5)
val_loader = DataLoader(val_subset, batch_size=BATCH, shuffle=False, num_workers=2)

img_batch, target_batch = next(iter(train_loader))

print("Shape of image batch tensor: ", img_batch.shape, img_batch.dtype)
print("Shape of mask batch tensor: ", target_batch.shape, target_batch.dtype)

_Download the Model_

In [None]:
# Download the model, either ResNet-50 or MobileNetV3
NUM_CLASSES = 3
# model = torchvision.models.segmentation.deeplabv3_mobilenet_v3_large(num_classes=NUM_CLASSES)
model = torchvision.models.segmentation.deeplabv3_resnet50(num_classes=NUM_CLASSES)

_Loss function and Hyperparameter values_

- To deal with imbalanced dataset, we create function `balance_weights` that outputs the class weights that will be passed to the loss function. Balancing weights is just one tehcnique we can use to address the imbalance in the training data.

- The Loss function is a crossentropy loss function.

- Optimizer: SGD with 
     - `momentum = 0.9`
     - `lr = 0.2`
     - `weight_decay = 1.0e-04` (l2-regularization)
- Learning Rate Scheduler:
  - `Stepwise` with `stepsize = 5` and `gamma = 0.1`

In [None]:
class_labels = {"Noise": 0, "NR": 1, "LTE": 2}


def balance_weights(dataset: VisionDataset):
    """
    This function outputs the weights for each class by computing
     the relative frequency for each class and then taking the inverse of
     these relative frequencies.
    """
    pixel_count = {k: (0, 0) for k in class_labels.keys()}

    # Count pixels of each class label
    for i in range(len(dataset)):
        _, mask = dataset[i]
        mask = mask.numpy()
        pixel_count_i = {k: (mask == v).sum() for k, v in class_labels.items()}

        assert (
            sum(pixel_count_i.values()) == mask.size
        ), "Sum of each pixel label count does not equal total pixels in mask"

        pixel_count = {
            k: (
                (pixel_count[k][0] + pixel_count_i[k], pixel_count[k][1] + mask.size)
                if (pixel_count_i[k] != 0)
                else (pixel_count[k][0] + pixel_count_i[k], pixel_count[k][1])
            )
            for k in pixel_count_i.keys()
        }
    # print("Pixel count of each label: \n{}".format(pixel_count))

    # make nx2 array of count values, col 0 for label pixel count, col 1 for image pixel count
    rows = len(pixel_count.values())
    cols = len(next(iter(pixel_count.values())))
    pixel_count_arr = np.zeros((rows, cols), dtype=int)
    for i, v in enumerate(pixel_count.values()):
        for j in range(len(v)):
            pixel_count_arr[i][j] = v[j]
    # print(pixel_count_arr)

    # calculate frequency of each label
    class_freq = pixel_count_arr[:, 0] / pixel_count_arr[:, 1]
    # calculate weight of each label
    class_weights = np.median(class_freq) / class_freq
    print(class_weights, class_weights.dtype)

    return torch.as_tensor(class_weights, dtype=torch.float)

In [None]:
# Compute the class weights
class_weights = balance_weights(train_subset)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = "cpu"

print(f"Model will train on: {device}")
class_weights = class_weights.to(device)
print(type(device))

_Introduce the loss function_

In [None]:
# Loss function
# Pass the class_weights to the loss function


def criterion(results, target, weight=class_weights):
    # The output of the model is a dictionary. So, results is a dictionary
    losses = {}
    loss_fn = nn.CrossEntropyLoss(weight=weight)
    for i, result in results.items():
        losses[i] = loss_fn(result, target)

    return losses["out"]

_Integrate the downloaded model with Pytorch-Lightning environment 
 including the desirable optimizer and learning rate scheduler and the initial learning rate_

In [None]:
# Initialize hyperparamters needed for SGD (stochastic gradient descent)
initial_lr = 0.02  # Represents the initial learning rate, which is the step size used to update the model's parameters
# during each iteration of SGD
momentum = 0.9  # Accelerates SGD in the relevant direction and dampens oscillations
weight_decay = 1.0e-04  # A regularization term added to the loss function to penalize large weights in the model

# Initialize hyperparameters needed for learning scheduler
step_size = 10  # Represents the number of training epochs after which the learning rate will be adjusted
gamma = 0.1  # The multiplicative factor by which the learning rate is reduced after every step_size epochs

In [None]:
class SegmentationModelSGD(pl.LightningModule):
    def __init__(self, model, num_classes, lr, momentum, weight_decay, step_size, gamma, optimizer_name="SGD"):
        super().__init__()
        self.model = model
        self.num_classes = num_classes
        self.train_acc = MulticlassAccuracy(num_classes=self.num_classes)
        self.val_acc = MulticlassAccuracy(num_classes=self.num_classes)
        self.lr = lr
        self.optimizer = getattr(torch.optim, optimizer_name)
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.step_size = step_size
        self.gamma = gamma

    def forward(self, x):
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        image, target = batch
        preds = self(image)  # Dictionary
        loss = criterion(preds, target)
        preds = preds["out"].argmax(dim=1)  # Our prediction
        self.train_acc.update(preds, target)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        self.log("train_accuracy", self.train_acc.compute(), prog_bar=True)

    def validation_step(self, batch, batch_idx):
        image, target = batch
        preds = self(image)
        loss = criterion(preds, target)
        preds = preds["out"].argmax(dim=1)
        self.val_acc.update(preds, target)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_accuracy", self.val_acc.compute(), on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = self.optimizer(
            self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay
        )
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

In [None]:
segm_sgd_model = SegmentationModelSGD(
    model,
    lr=initial_lr,
    num_classes=NUM_CLASSES,
    momentum=momentum,
    weight_decay=weight_decay,
    step_size=step_size,
    gamma=gamma,
)

In [None]:
NUM_EPOCHS = 20
# NUM_EPOCHS = 3

if torch.cuda.is_available():
    trainer = pl.Trainer(accelerator="gpu", max_epochs=NUM_EPOCHS, logger=True)
else:
    trainer = pl.Trainer(max_epochs=NUM_EPOCHS, logger=True)

In [None]:
trainer.fit(segm_sgd_model, train_loader, val_loader)

# Test with signals from the validation set
 
We employ the following metrics to evaluate the performance of our model:

 - Confusiion Matrix: Provides a comprehensive overview of the model's ability, highlighting areas of success and misclassification, with diagonal elements representing correct predictions and off-diagonal elements indicating errors.
 
 - Accuracy: The ratio of correctly predicted instances to the total number of instances.
 
 - Recall: The recall (Sensitivity) measures the ability of a model to identify all relevant instances.

 - Precision: The precision assesses the accuracy of positive predictions.

 - F1: The F1 Score combines both recall and precision, providing a measure of the harmonic mean of precision and recall.

 - Intersection over Union (IoU) and Histogram of IoU: The IoU quantifies the overlap between the predicted bounding box or segmented region and the ground truth bounding box or annotated region from a dataset. A higher IoU value indicates a better alignment between the predicted and actual regions, reflecting a more accurate model.

### _Confusion Matrix_

In [None]:
def confusion_matrix(
    model, val_loader: DataLoader, num_classes: int, device: str, normalize: str | None = "true"
) -> MulticlassConfusionMatrix:
    """
    Computes and displays the confusion matrix for a given PyTorch model using a validation DataLoader.

    :param model: The PyTorch model to evaluate.
    :param val_loader: The validation DataLoader.
    :param num_classes: The number of classes in the classification problem.

    :param device: The device on which to perform the evaluation (e.g., 'cuda' for GPU or 'cpu').

    :param normalize: Type of normalization applied to the confusion matrix. Options: {'none', 'true', 'pred', 'all'}.

    :return: The confusion matrix.
    """
    conf_matrix = (MulticlassConfusionMatrix(num_classes=num_classes, normalize=normalize)).to(device)
    model.to(device)
    model.eval()

    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            y = y.to(device)
            pred = (model(x)["out"]).argmax(dim=1)
            # conf_matrix.update(pred.to('cpu'),y.to('cpu'))
            conf_matrix.update(pred, y)

    # Round the entries in the confusiion matrix to two decimals
    conf_matrix_rounded = torch.round(conf_matrix.compute(), decimals=2)

    # ConfusionnMatrixDisplay expects for the input an array and not a tensor
    fig, ax = plt.subplots(1, figsize=(3, 3))
    displ = ConfusionMatrixDisplay(np.array(conf_matrix_rounded.to("cpu")))
    displ.plot(ax=ax, colorbar=False)
    ax.set_xlabel("True label", color="blue")
    ax.set_ylabel("Predicted label", color="blue")
    ax.set_title("Confusion Matrix", color="blue")

    return conf_matrix

In [None]:
confusion_matrix(segm_sgd_model, val_loader, NUM_CLASSES, device)

### _Validation Acccuracy_

In [None]:
segm_sgd_model.eval()
scores = trainer.validate(segm_sgd_model, val_loader)

### _Histogram of Intersection over Union (IoU) Scores per Image_

In [None]:
def plot_hist(model, loader, device):
    iou_scores = []
    model.to(device)
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            pred = (model(x)["out"]).argmax(dim=1)
            jaccard = jac_ind(task="multiclass", num_classes=3).to(device)
            score = jaccard(pred.to(device), y)
            iou_scores.append(score.item())

    plt.hist(iou_scores, color="green", histtype="bar")
    plt.xlabel("IoU", color="blue")
    plt.ylabel("Number of Masks", color="blue")
    plt.title("Mean IoU", color="blue")

In [None]:
new_val_loader = DataLoader(val_subset, batch_size=1, shuffle=False)  # Changed the batch_size
plot_hist(segm_sgd_model, new_val_loader, device)

### _Recall, Precision, F1 Score, IOU_

In [None]:
def compute_metric(model, loader, metric_name, num_classes, device):
    metric = getattr(torchmetrics.classification, metric_name)
    metric_per_class = metric(num_classes=num_classes, average="none").to(device)
    average_metric = metric(num_classes=num_classes, average="macro").to(device)
    weighted_metric = metric(num_classes=num_classes, average="weighted").to(device)
    model.to(device)
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            pred = (model(x)["out"]).argmax(dim=1)
            pred = pred.to(device)
            metric_per_class.update(pred, y)
            average_metric.update(pred, y)
            weighted_metric.update(pred, y)

    value_per_class = metric_per_class.compute()
    average_value = average_metric.compute()
    weighted_value = weighted_metric.compute()
    return value_per_class, average_value.unsqueeze(0), weighted_value.unsqueeze(0)

In [None]:
metric_names = ["MulticlassRecall", "MulticlassPrecision", "MulticlassF1Score", "MulticlassJaccardIndex"]

metric_results = {
    metric_name: torch.hstack(compute_metric(segm_sgd_model, val_loader, metric_name, NUM_CLASSES, device))
    for metric_name in metric_names
}

In [None]:
metric_results_cpu = {key: (value.to("cpu")).numpy() for key, value in metric_results.items()}

classification_report = pd.DataFrame(metric_results_cpu)

index = ["Noise", "NR", "LTE", "macro avg", "weighted avg"]
columns = ["recall", "precision", "f1-score", "IoU"]

classification_report.columns = columns
classification_report.index = index

print(classification_report)

# Signal Identification in Spectrograms

In [None]:
def signal_label(mask) -> str:
    """
    :param mask: The mask image containing signal labels.

    :return: The signal label based on unique labels in the mask.
    """
    labels = {0: "Noise", 1: "NR", 2: "LTE"}
    unique_labels_in_mask = torch.unique(mask)

    if len(unique_labels_in_mask) == 2:
        key = torch.unique(mask)[1].item()
        return labels[key]

    elif len(unique_labels_in_mask) == 3:
        key_1 = torch.unique(mask)[1].item()
        key_2 = torch.unique(mask)[2].item()
        return labels[key_1] + "_" + labels[key_2]

    else:
        key = torch.unique(mask)[0].item()

In [None]:
def plot_spectrogram_mask(image, mask: Image, target: Image = None):

    if target is not None:
        fig, ax = plt.subplots(3, 1, figsize=(4, 10))

        ax[0].imshow(torch.permute(image, (1, 2, 0)))
        ax[0].set_xlabel("Frequency", fontsize=12, color="blue")
        ax[0].set_ylabel("Time", fontsize=12, color="blue")
        ax[0].set_title(f"Received Spectrogram ({signal_label(target)})", color="blue")

        ax[1].set_ylabel("Time", fontsize=12, color="blue")
        ax[1].set_xlabel("Frequency", fontsize=12, color="blue")
        ax[1].imshow(target)
        ax[1].set_title(f"True Signal Label ({signal_label(target)})", color="blue")

        ax[2].imshow(mask)
        # plt.imshow(predicted_image.permute(1,2,0)[:,:,0])
        # or equivalnetly
        # plt.imshow(pred['out'][0][0].to('cpu').detach())
        ax[2].set_xlabel("Frequency", fontsize=12, color="blue")
        ax[2].set_ylabel("Time", fontsize=12, color="blue")
        ax[2].set_title(f"Prediction ({signal_label(mask)})", color="blue")
        plt.tight_layout()

    else:
        fig, ax = plt.subplots(2, 1, figsize=(4, 10))

        ax[0].imshow(torch.permute(image, (1, 2, 0)))
        ax[0].set_xlabel("Frequency", fontsize=12, color="blue")
        ax[0].set_ylabel("Time", fontsize=12, color="blue")
        ax[0].set_title("Received Spectrogram", color="blue")

        ax[1].imshow(mask)
        # plt.imshow(predicted_image.permute(1,2,0)[:,:,0])
        # or equivalnetly
        # plt.imshow(pred['out'][0][0].to('cpu').detach())
        ax[1].set_xlabel("Frequency", fontsize=12, color="blue")
        ax[1].set_ylabel("Time", fontsize=12, color="blue")
        ax[1].set_title(f"Prediction ({signal_label(mask)})", color="blue")
        plt.tight_layout()

In [None]:
image, target = next(iter(val_loader))  # First batch of spectrograms

In [None]:
image, target = image.to(device), target.to(device)
segm_sgd_model.eval()
with torch.no_grad():
    predicted_masks = segm_sgd_model(image)["out"]
    first_mask_in_batch = predicted_masks[0].argmax(dim=0)

first_image_in_batch = image[0]
first_target_in_batch = target[0]

In [None]:
plot_spectrogram_mask(first_image_in_batch.to("cpu"), first_mask_in_batch.to("cpu"), first_target_in_batch.to("cpu"))

# Testing Model with Combined (NR_LTE) Signals

_Note_ : Combined NR_LTE signals were excluded from the training dataset.

In [None]:
# Grab the first NR_LTE signal
spec_path = os.path.join(PATH_TO_MATLAB5G_TRAINING_DATA, "LTE_NR", "LTE_NR_frame_0.png")
spectrogram = read_image(spec_path)  # Image has both NR and LTE signal
spectrogram = spectrogram.to(device)

segm_sgd_model.eval()
with torch.no_grad():
    pred = segm_sgd_model((spectrogram.to(torch.float)).unsqueeze(0))["out"]
    mask = pred[0].argmax(dim=0)

In [None]:
plot_spectrogram_mask(spectrogram.to("cpu"), mask.to("cpu"))

# _Conclusion_ 

We checked the performance of our model using a spectrogram that has both LTE and NR
signals (and Noise). However, the prediction made by the model on the same spectrogram depicts only NR signal.