# Spectrogram segmentation with SAM 2

In the following, we will use SAM 2 (a videos and images segmentation model) to segment the spectorgram of a noisy siganl in order to denoise it. There are two options to use SAM 2 for that aim:
- Create a single spectorgram of the whole audio signal and segment it.
- Create multiple spectorgrams of the audio signal from non overlapping sections and segment them.
The former corresponds to stting _single_frame_ to True and the latter to False.

At the moment, the model requires prompting the spectrogram with positive and negative clicks. 

After performing the spectorgram masking, we use the well known diff-wave model to reconstruct the audio signal from the masked spectorgram.

<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## 1. Environment Set-up

If running locally using jupyter, first install `sam2` in your environment using the [installation instructions](https://github.com/facebookresearch/sam2#installation) in the repository.

If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'. Note that it's recommended to use **A100 or L4 GPUs when running in Colab** (T4 GPUs might also work, but could be slow and might run out of memory in some cases).

In [16]:
# Parameters setting:
single_frame = True
use_pre_loaded_clicks = True
sigma = 0.02
add_noise = True # When single_frame is True, the segmentation is done on the noisy audio
                 # When single_frame is False, the segmentation is done on the first clen frame and then diffused
                 # to the next noisy segmentations.
deafult_saving_dpi = 600
deafult_printing_dpi = 600
audio_path = r'audio_example.wav'  # Replace with your audio file path
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"




In [17]:
using_colab = False

if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam2.git'

    !mkdir -p videos
    !wget -P videos https://dl.fbaipublicfiles.com/segment_anything_2/assets/bedroom.zip
    !unzip -d videos videos/bedroom.zip

    !mkdir -p ../checkpoints/
    !wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

## Set-up

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import numpy as np
import shutil
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import torchaudio as T
import torchaudio.transforms as TT
from diffwave.inference import predict as diffwave_predict
from diffwave.params import params
import librosa.display
from sam2.build_sam import build_sam2_video_predictor
from pathlib import Path
from clicker import collect_clicks
import pickle
import utils

if sys.platform == "darwin": # This means if we are using macOS
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


# setting global parameters for printing and saving the images. If we want to change it we need also to change it at SAM2 so leave it as is for the moment
plt.rcParams['figure.figsize'] = [8, 8]
plt.rcParams['figure.dpi'] = deafult_printing_dpi

single_frame_str = "single_frame" if single_frame else "diffused"
noise_str = "with_noise" if add_noise else "only_clean"
output_dir = Path("results")/ audio_path.split('.')[0] / f"sigma_{sigma}" / single_frame_str / noise_str / "spectorgrams"
clean_output_dir = output_dir / "clean"
# create a directory with pathlib mkdir of f"output_audios_for{audio_path (without suffix)}"
output_audio_dir = output_dir.parent / "output_audios"

prompts_dir = Path("prompts") / audio_path.split('.')[0] / f"sigma_{sigma}" / single_frame_str / noise_str / "clicks.pkl"

#if output_dir or output_audio_dir do exist, delete them and create them again
if output_dir.exists():
    shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True, parents = True)

if output_audio_dir.exists():
    shutil.rmtree(output_audio_dir)
output_audio_dir.mkdir(exist_ok=True, parents = True)



noisy_audio_path = output_audio_dir / "noisy_audio.wav"
noisy_output_dir = Path(output_dir) / "noisy"


In [None]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
# elif torch.backends.mps.is_available():
#     device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )



# create a noisy version if needed and compute SNR
if add_noise is True:
    waveform, sample_rate = T.load(audio_path)
    signal_power = torch.mean(waveform ** 2)
    # Generate Gaussian noise
    noise = torch.randn_like(waveform) * sigma 
    # Add noise
    noisy_waveform = waveform + noise
    # Normalize to prevent clipping
    noisy_waveform = torch.clamp(noisy_waveform, -1.0, 1.0)
    # Save the noisy audio
    T.save(noisy_audio_path, noisy_waveform, sample_rate)
    # Compute SNR
    noise_power = torch.mean(noise ** 2)
    snr = 10 * torch.log10(signal_power / noise_power)
    print(f"SNR: {snr.item():.2f} dB")
    print(f"signal power: {signal_power.item():.4f}, noise power: {noise_power.item():.4f}")
    

## 2. Creating a spectrogram from an audio file
### There is no need to run this section if the files are already saved on the computer.
We will define a function that converts an audio file to spectrogram files and spectrogram images

Upload an audio file and create a spectrogram:

In [None]:
sr, overlap, video_dir = utils.input2mel(audio_path, clean_output_dir, single_frame)

print(video_dir)
# if add_noise is true, creates mixed dirs where the first frame is clean and the rest are noisy. then, assign video_dir to mixed_video_dir
if add_noise is True:
    sr, overlap, noisy_video_dir = utils.input2mel(noisy_audio_path, noisy_output_dir, single_frame)
    
    # Create Path objects for all directories
    mixed_images_dir = output_dir / "mixed" / "images"
    mixed_np_dir = output_dir / "mixed" / "np_arrays"
    clean_images_dir = clean_output_dir / "images"
    clean_np_dir = clean_output_dir / "np_arrays"
    noisy_images_dir = noisy_output_dir / "images"
    noisy_np_dir = noisy_output_dir / "np_arrays"

    
    # Create the directories if they don't exist
    mixed_images_dir.mkdir(exist_ok=True, parents = True)
    mixed_np_dir.mkdir(exist_ok=True, parents = True)
    
    # Copy files for images using PIL
    # Copy 0000.jpg from clean directory
    img = Image.open(clean_images_dir / "0000.jpg")
    img.save(mixed_images_dir / "0000.jpg", "JPEG", quality=100)
    
    # Copy all non-0000 files from noisy directory
    for image_file in noisy_images_dir.glob("*.jpg"):
        if image_file.name != "0000.jpg":
            img = Image.open(image_file)
            img.save(mixed_images_dir / image_file.name, "JPEG", quality=100)
    
    # Copy files for numpy arrays
    # Copy 0000.npy from clean directory
    shutil.copy2(
        clean_np_dir / "0000.npy",
        mixed_np_dir / "0000.npy"
    )
    
    # Copy all non-0000 files from noisy directory
    for np_file in noisy_np_dir.glob("*.npy"):
        if np_file.name != "0000.npy":
            shutil.copy2(
                np_file,
                mixed_np_dir / np_file.name
            )

    if not single_frame:
        video_dir = str(mixed_images_dir)
    else:
        video_dir = str(noisy_images_dir)



## 3. Segmentation using SAM2
### Loading the SAM 2 video predictor

In [21]:
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

#### Upload the spectrogram images

We will upload one picture to make sure everything is correct

In [None]:
# scan all the JPEG frame names in the frames directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

#### Initialize the inference state

SAM 2 requires stateful inference for interactive video segmentation, so we need to initialize an **inference state** on this video.

During initialization, it loads all the JPEG frames in `video_path` and stores their pixels in `inference_state` (as shown in the progress bar below).

In [None]:
inference_state = predictor.init_state(video_path=video_dir)

Note: if you have run any previous tracking using this `inference_state`, please reset it first via `reset_state`.

(The cell below is just for illustration; it's not needed to call `reset_state` here as this `inference_state` is just freshly initialized above.)

In [24]:
predictor.reset_state(inference_state)

### Segment multiple objects simultaneously

In [None]:
prompts = {}  # hold all the clicks we add for visualization
ann_frame_idx = 0  # the frame index we interact with\

if use_pre_loaded_clicks and prompts_dir.exists():
    # Load the pre-loaded clicks
    points_list, labels = utils.load_lists(prompts_dir)

else:
    img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
    points_list, labels = collect_clicks(img_path)
    prompts_dir.parent.mkdir(exist_ok=True, parents=True)
    utils.save_lists(points_list, labels, prompts_dir)

assert len(points_list) == len(labels), "Number of points and labels should be the same"

num_of_promt = len(points_list)

#### Step 1: Show the mask created on the first frame

In [None]:
# # Example of how the points and labels should be set up manually

# # sending all clicks (and their labels) to `add_new_points_or_box`
# #points_list = [np.array([[600, 170], [275, 800]], dtype=np.float32)] # Exempale
# points_list = [np.array([[800, 2100]], dtype=np.float32), np.array([[1500, 2120]], dtype=np.float32)] #TODO: add clicks
# num_of_promt = len(points_list)

# # for labels, `1` means positive click and `0` means negative click
# #labels = [np.array([1, 0], np.int32),np.array([1], np.int32)] # Exempale
# labels = [np.array([1], np.int32),np.array([1], np.int32)] #TODO: add labels

for i in range(num_of_promt):
    prompts[i] = points_list[i], labels[i]
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=0,
        obj_id=i,
        points=points_list[i],
        labels=labels[i],
)


img = Image.open(os.path.join(video_dir, frame_names[ann_frame_idx]))
w, h = img.size
aspect_ratio = w / h

# Create figure with the correct aspect ratio
fig = plt.figure()
ax = plt.Axes(fig, [0., 0., 1., 1.])
fig.add_axes(ax)

# Add title (optionally adjust its position if needed)
ax.set_title(f"frame {ann_frame_idx}", pad=20)

# Show image and rest of visualization
ax.imshow(img)
for i in range(num_of_promt):
    utils.show_points(points_list[i], labels[i], ax)
    for i, out_obj_id in enumerate(out_obj_ids):
        utils.show_points(*prompts[i], ax)
        utils.show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), ax, obj_id=out_obj_id)

ax.axis('off')




#### Step 2: Propagate the prompts to get masklets across the video

Now, we propagate the prompts for both objects to get their masklets throughout the video.

Note: when there are multiple objects, the `propagate_in_video` API will return a list of masks for each object.

In [None]:

# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 1 # Should be reduced if the number of frames is large

plt.ioff()  # Turn off interactive mode for printing in the for loop

for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    # Create a new figure for each frame
    fig = plt.figure(figsize=(8, 8))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        utils.show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
    
    plt.show()  # Display the current figure
    plt.close(fig)  # Close the figure to free memory

plt.ion()  # Turn interactive mode back on

## 4. Applying the masks to the original spectrograms

Now we will fit the mask that sam made to the original spectrogram.


### We will upload the original spectrograms and apply the mask on them:

In [None]:
### apply the masks to each frame


# Directory containing spectrogram numpy arrays
if single_frame and add_noise:
    output_dir_spec = noisy_output_dir / "np_arrays"
elif add_noise is True:
    output_dir_spec = Path(output_dir) / "mixed" / "np_arrays"
else:
    output_dir_spec = Path(output_dir) / "clean" / "np_arrays"    

# Get and sort spectrogram file names
spectrogram_files = [
    p for p in os.listdir(output_dir_spec) if p.endswith(".npy")
]
spectrogram_files.sort(key=lambda p: int(os.path.splitext(p)[0]))

# Initialize the full masked spectrogram
maskt_Sxx_full = []

for i, file_name in enumerate(spectrogram_files):
    file_path = os.path.join(output_dir_spec, file_name)

    # Load spectrogram data
    Sxx = np.load(file_path)
    freq_bins, time_bins = Sxx.shape

    # Matrix-wise processing for all masks
    inx_list = []
    for j in range(num_of_promt):
        mask = video_segments[i][j][0]
        mask_height, mask_width = mask.shape

        # Extract indices where the mask is True
        mask_inx = np.where(mask)  # Returns arrays of y and x indices
        mask_x, mask_y = mask_inx[1], mask_inx[0]  # Extract x and y coordinates

        # Use matrix-wise function to map pixel coordinates to spectrogram indices
        spec_time_bins, spec_freq_bins = utils.matrix_to_spectrogram(
            mask_x, mask_y, mask_height, mask_width, time_bins, freq_bins
        )

        # Add unique indices to the global list
        unique_indices = np.unique(np.stack((spec_time_bins, spec_freq_bins), axis=1), axis=0)
        inx_list.append(unique_indices)

    # Combine all unique indices for this spectrogram
    inx_array = np.vstack(inx_list)

    # Create the spectrogram mask
    alpha = 1/2
    mask_to_spec = alpha*np.ones(Sxx.shape)
    mask_to_spec[inx_array[:, 1], inx_array[:, 0]] = 1

    # Apply mask to the spectrogram
    maskt_Sxx = np.multiply(Sxx,mask_to_spec)

    # Append to full spectrogram list
    maskt_Sxx_full.append(maskt_Sxx)

    # Plot the masked spectrogram
    plt.figure(figsize=(4, 4), dpi=300)
    librosa.display.specshow(
        maskt_Sxx,
        sr=sr,
        hop_length=overlap,
        x_axis="time",
        y_axis="mel",
        cmap="magma",
    )
    plt.axis("off")
    plt.show()

# Concatenate all spectrograms along the time axis
maskt_Sxx_full = np.concatenate(maskt_Sxx_full, axis=1)
print("Done")


Let's see the final maskt spectrogram

In [None]:
# Plot the spectrogram using librosa
fig = plt.figure()  # 1024x1024 pixels at 300 DPI
ax = plt.Axes(fig, [0., 0., 1., 1.])  # Remove margins for full screen
fig.add_axes(ax)
ax.axis("off")

librosa.display.specshow(
    maskt_Sxx_full,
    sr=sr,
    hop_length=overlap,
    x_axis="time",
    y_axis="mel",
    cmap="magma",  # Can try also 'jet'
)

# 5. Reconstruction the audio from the maskt spectrograms using DiffWave

In [None]:
# Download DiffWave pre-trained model from - https://github.com/lmnt-com/diffwave/blob/master/README.md
model_dir = r'diffwave-ljspeech-22kHz-1000578.pt' # Change to model dir
print('Mel-Spectrogram to Audio')
mel_spectrogram = torch.tensor(maskt_Sxx_full, dtype=torch.float)
audio, sample_rate = diffwave_predict(mel_spectrogram, model_dir, device=device, fast_sampling=True) # device=torch.device('cuda') in case of using GPU

print('Save Reconstructed Audio')
reconstructed_path = output_audio_dir / "proposed_method_reconst.wav"
T.save(reconstructed_path, audio.cpu(), sample_rate=sample_rate)

if single_frame:
    # read from clean_output_dir / "np_arrays" the first frame spectrogram into a vraiable and convert it to tensor float
    clean_Sxx = np.load(clean_output_dir / "np_arrays" / "0000.npy")
    clean_Sxx = torch.tensor(clean_Sxx, dtype=torch.float)
    # reconstruct the audio from the clean spectrogram
    reconstr_clean_audio, _ = diffwave_predict(clean_Sxx, model_dir, device=device, fast_sampling=True)
    baseline_diff_wave_path = output_audio_dir / "baseline_diff_wave_reconst.wav"
    T.save(baseline_diff_wave_path, reconstr_clean_audio.cpu(), sample_rate=sample_rate)



# reconstruct with spectral gating and compute metrics
if add_noise and single_frame:
    spec_gate_recon_path = output_audio_dir / "spectral_gating_recon.wav"
    utils.reduce_noise(noisy_audio_path, spec_gate_recon_path, decrease_factor=0.9)
    utils.print_quality_metrics(audio_path, noisy_audio_path, reconstructed_path,
                                 spec_gate_recon_path, baseline_diff_wave_path)