In [None]:
import pickle
import numpy as np 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def edic_features(image, window_size=5, transform=np.log):
    """
    Compute EDIC features F1 and F2 on a SAR image using a sliding window approach.

    For each n x n window G(i,j) (with stride 1) whose center is at (i,j):
      1. Compute the singular values {s1, s2, ..., s_n} via SVD.
      2. Feature F1 is defined as:
             F1(i,j) = s1 - s2
      3. Feature F2 is defined as:
             F2(i,j) = f( s1 / (sum(s2...s_n) + ε) )
         where f(·) is a transformation function (e.g. logarithm) and ε is a small constant
         to avoid division by zero.

    Parameters:
        image (np.ndarray): 2D SAR image.
        window_size (int): Size of the square sliding window (default: 5).
        transform (function): Transformation function applied to the ratio in F2 (default: np.log).

    Returns:
        F1 (np.ndarray): Feature image computed from the difference s1-s2.
        F2 (np.ndarray): Feature image computed from the transformed ratio.
                         Both output images have shape (rows - window_size + 1, cols - window_size + 1).
    """
    rows, cols = image.shape
    out_rows = rows - window_size + 1
    out_cols = cols - window_size + 1

    F1 = np.zeros((out_rows, out_cols))
    F2 = np.zeros((out_rows, out_cols))
    eps = 1e-8  # small constant to prevent division by zero

    # Slide over the image with a stride of 1 (valid windows only)
    for i in range(out_rows):
        for j in range(out_cols):
            # Extract the local n×n window
            window = image[i:i+window_size, j:j+window_size]
            # Compute the singular values; they are returned in descending order
            singular_values = np.linalg.svd(window, compute_uv=False)
            
            # F1: Difference between the first and second singular values
            F1[i, j] = singular_values[0] - singular_values[1]
            
            # F2: Apply the transformation function to the ratio s1/(sum(s2...s_n))
            ratio = singular_values[0] / (np.sum(singular_values[1:]) + eps)
            F2[i, j] = transform(ratio)
            
    return F1, F2

In [None]:
import torch
import torch.nn.functional as F

def edic_features_torch(image: torch.Tensor, window_size: int = 5, transform: callable = torch.log) -> tuple:
    """
    Compute EDIC features F1 and F2 on a SAR image using a sliding window approach with PyTorch.

    Parameters:
        image (torch.Tensor): 2D SAR image (tensor of shape [H, W]).
        window_size (int): Size of the square sliding window (default: 5).
        transform (callable): Transformation function applied to the ratio in F2 (default: torch.log).

    Returns:
        F1 (torch.Tensor): Feature image computed from the difference s1-s2.
        F2 (torch.Tensor): Feature image computed from the transformed ratio.
    """
    
    cuda_flag = torch.cuda.is_available()
    device = torch.device("cuda" if cuda_flag else "cpu")
    image = image.to(device)
    
    eps = 1e-8  # Small constant to prevent division by zero
    pad = window_size // 2
    H, W = image.shape

    # Unfold the image into sliding windows
    unfolded = F.unfold(image.unsqueeze(0).unsqueeze(0), kernel_size=window_size).squeeze(0)
    
    # Reshape to (num_windows, window_size, window_size)
    num_patches = unfolded.shape[1]
    unfolded = unfolded.T.view(num_patches, window_size, window_size)

    # Compute singular values using SVD
    U, S, V = torch.svd(unfolded.view(num_patches, window_size, window_size))

    # Compute F1 and F2
    F1 = S[:, 0] - S[:, 1]
    ratio = S[:, 0] / (S[:, 1:].sum(dim=1) + eps)
    F2 = transform(ratio)

    # Reshape to output spatial dimensions
    out_H, out_W = H - window_size + 1, W - window_size + 1
    F1 = F1.view(out_H, out_W)
    F2 = F2.view(out_H, out_W)


    F1 = F1.cpu()
    F2 = F2.cpu()
    # return as numpy arrays
    return F1.numpy(), F2.numpy()

In [None]:
def plot_complex_image_and_features(image, window_size=5, transform=np.log):
    """
    Plot the complex image, F1, and F2 in a 1x3 graph.

    Parameters:
        image (np.ndarray): 2D complex image.
        window_size (int): Size of the square sliding window (default: 5).
        transform (function): Transformation function applied to the ratio in F2 (default: np.log).
    """
    # Compute EDIC features
    F1, F2 = edic_features(image, window_size, transform)
    
    # Create a 1x3 subplot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot the complex image
    axes[0].imshow(np.abs(image), cmap='gray')
    axes[0].set_title('Complex Image')
    axes[0].axis('off')
    
    # Plot F1
    axes[1].imshow(F1, cmap='jet')
    axes[1].set_title('Feature F1')
    axes[1].axis('off')
    
    # Plot F2
    axes[2].imshow(F2, cmap='jet')
    axes[2].set_title('Feature F2')
    axes[2].axis('off')
    
    # Show the plots
    plt.show()

Load pickle file from path

In [None]:
# Load the NumPy array from a pickle file
file_path = '/Data_large/marine/PythonProjects/RFI_project/Data/Data_raw_decoded/S1B_IW_RAW__0SDV_20200116T054832_20200116T054904_019838_02582D_2BD3.pkl'
with open(file_path, 'rb') as file:
    array = pickle.load(file)


Unpacke the decoded IWs

In [None]:
from collections import defaultdict
import pandas as pd 
import numpy as np
from tqdm import tqdm

def iw_extraction(decoded_prod, verbose=False):
    """
    Extract the IWs from the echo data.

    Parameters:
        decoded_prod (dict): Dictionary containing the echo data.
        verbose (bool): Print the number of lines for each rg_len (default: False).

    Returns:
        IWs (dict): Dictionary containing the IWs grouped by rg_len.
    """
    # Extract the echo data
    # Ensure the input dictionary contains the required keys
    assert 'echo' in decoded_prod, "The input dictionary must contain the key 'echo'."
    assert 'metadata' in decoded_prod, "The input dictionary must contain the key 'metadata'."
    
    echo = decoded_prod["echo"]
    metadata = decoded_prod["metadata"]
    
    # Ensure echo and metadata have the same length
    assert len(echo) == len(metadata), "The length of 'echo' and 'metadata' must be the same."
    
    # Group the lines based on the rg_len
    IWs = defaultdict(list)
    SubComm = defaultdict(list)
    for line, metaline in tqdm(zip(echo, metadata.itertuples(index=False)), total=len(echo), desc="Extracting IWs"):
        rg_len = len(line)
        IWs[rg_len].append(line)
        SubComm[rg_len].append(metaline._asdict())

    if verbose:
        # Print the grouped lines
        for rg_len, lines in IWs.items():
            print(f"rg_len: {rg_len}, number of lines: {len(lines)}")
        # Print the keys of the metadata
        print(metadata.columns)

    # Collect into a unique dictionary
    Bursts_dict = {}
    for key in tqdm(IWs.keys(), desc="Converting to arrays:"):
        Bursts_dict[key] = dict(echo=np.array(IWs[key]), metadata=pd.DataFrame(SubComm[key]))

    return Bursts_dict


def plot_feature(feature, title):
    """
    Plot a feature matrix (F1 or F2).

    Parameters:
        feature (np.ndarray): 2D feature matrix to plot.
        title (str): Title of the plot.
    """
    plt.figure(figsize=(40, 40))
    plt.imshow(feature, cmap='jet')
    plt.title(title)
    plt.show()


In [None]:
bursts = iw_extraction(array, verbose=False)

In [None]:
bursts.keys()

In [None]:
img = bursts[19722]["echo"]

sel = 0
sub_img = img[500:1000, :10000]
print(sub_img.shape)


F1, F2 = edic_features_torch(torch.tensor(sub_img), window_size=5, transform=torch.log)

plot_feature(F1, 'Feature F1')

In [None]:
plot_feature(F2, 'Feature F2')

In [None]:
from scipy.signal import welch

# Assuming sqdT is equivalent to filtered_data and RadPar.fs is the sampling frequency
sampling_frequency = 300e6  # 300 MHz

filtered_data = sub_img


# Compute the power spectral density
frequencies, psd = welch(filtered_data.flatten(), nperseg=filtered_data.shape[0], noverlap=None, nfft=filtered_data.shape[1], fs=sampling_frequency, return_onesided=False)

# Shift the zero frequency component to the center
frequencies = np.fft.fftshift(frequencies)
psd = np.fft.fftshift(psd)

# Plot the power spectral density
plt.figure(figsize=(14, 14), dpi=120)
plt.plot(frequencies, 10 * np.log10(psd))
plt.title('Power Spectral Density of Filtered Data (dB)')
plt.xlabel('Frequency')
plt.ylabel('dB')
plt.gca()
plt.grid(True)
plt.show()

In [None]:
fft_filtered_data = np.fft.fft(filtered_data, axis=1)

plt.figure(figsize=(15, 5), dpi=120)
plt.imshow(np.abs(fft_filtered_data), aspect='auto', cmap='jet', extent=[frequencies[0], frequencies[-1], 0, fft_filtered_data.shape[0]])
plt.title('FFT of Filtered Data')
plt.xlabel('Frequency')
plt.ylabel('Magnitude')
plt.colorbar()
plt.show()


# NEW PATCH WITH RFI

In [None]:
img = bursts[19722]["echo"]

sub_img = img[10000:11000, :]


plot_feature(np.abs(sub_img), 'img')

F1, F2 = edic_features_torch(torch.tensor(sub_img), window_size=5, transform=torch.log)

# plot_feature(F1, 'Feature F1')

In [None]:
from scipy.signal import welch

# Assuming sqdT is equivalent to filtered_data and RadPar.fs is the sampling frequency
sampling_frequency = 300e6  # 300 MHz

filtered_data = sub_img

# Compute the power spectral density
frequencies, psd = welch(filtered_data.flatten(), nperseg=filtered_data.shape[0], noverlap=None, nfft=filtered_data.shape[1], fs=sampling_frequency, return_onesided=False)

# Shift the zero frequency component to the center
frequencies = np.fft.fftshift(frequencies)
psd = np.fft.fftshift(psd)

# Plot the power spectral density
plt.figure(figsize=(14, 4), dpi=120)
plt.plot(frequencies, 10 * np.log10(psd))
plt.title('Power Spectral Density of Filtered Data (dB)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('dB')
plt.xticks(rotation=45)  # Rotate x-axis labels for better readability
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0f}'))
plt.grid(True)
plt.show()

In [None]:
fft_filtered_data = np.fft.fft(filtered_data, axis=1)

plt.figure(figsize=(15, 5), dpi=120)
plt.imshow(np.abs(fft_filtered_data), aspect='auto', cmap='jet', extent=[frequencies[0], frequencies[-1], 0, fft_filtered_data.shape[0]])
plt.title('FFT of Filtered Data')
plt.xlabel('Frequency')
plt.ylabel('Magnitude')
plt.colorbar()
plt.show()


# Define the frequency range for the plot
# freq_range = (frequencies >= 10e6) & (frequencies <= 50e6)

# # Plot the FFT of filtered data within the specified frequency range
# plt.figure(figsize=(15, 5), dpi=120)
# plt.imshow(np.abs(fft_filtered_data[:, freq_range]), aspect='auto', cmap='jet', extent=[frequencies[freq_range][0], frequencies[freq_range][-1], 0, fft_filtered_data.shape[0]])
# plt.title('FFT of Filtered Data (10 MHz to 50 MHz)')
# plt.xlabel('Frequency (Hz)')
# plt.ylabel('Magnitude')
# plt.colorbar()
# plt.show()


In [None]:
# plot_complex_image_and_features(sub_img, window_size=5, transform=np.log)

In [None]:
# Define a threshold values
threshold_value1 = np.mean(F1) + 3 * np.std(F1)
threshold_value2 = np.mean(F2) + 3 * np.std(F2)

# Create a segmentation mask based on the threshold
segmentation_mask1 = F1 > threshold_value1
segmentation_mask2 = F2 > threshold_value2


# segmentation mask is considered wwhen both F1 and F2 are above the threshold
segmentation_mask = segmentation_mask1 & segmentation_mask2

# Plot the segmentation mask
plt.figure(figsize=(10, 8))
plt.imshow(segmentation_mask1, cmap='gray')
plt.title('Segmentation Mask1')
plt.axis('off')
plt.show()


# Plot the segmentation mask
plt.figure(figsize=(10, 8))
plt.imshow(segmentation_mask2, cmap='gray')
plt.title('Segmentation Mask2')
plt.axis('off')
plt.show()



# Plot the segmentation mask
plt.figure(figsize=(10, 8))
plt.imshow(segmentation_mask, cmap='gray')
plt.title('Segmentation Mask')
plt.axis('off')
plt.show()



In [None]:
from skimage.morphology import remove_small_objects

# Remove small objects from the segmentation mask
cleaned_segmentation_mask = remove_small_objects(segmentation_mask, min_size=75)

# Plot the cleaned segmentation mask
plt.figure(figsize=(10, 8))
plt.imshow(cleaned_segmentation_mask[:,:4000], cmap='gray')
plt.title('Cleaned Segmentation Mask')
plt.show()