In [4]:
import os
from scipy.io import loadmat
import numpy as np

In [1]:
from my_radar_dataset import MyRadarDataset

# Create a folder for dummy .mat files
test_folder = 'mat_files'

# Initialize MyRadarDataset
radar_dataset = MyRadarDataset()

# Load the dataset
print("\nLoading dataset...")
try:
    radar_dataset.load(folder=test_folder)
    print(f"Total samples loaded: {radar_dataset.num_samples}")
except Exception as e:
    print(f"Error loading data: {e}")


Loading dataset...
Total samples loaded: 7552


In [5]:

# Get a sample
print("\nRetrieving a sample...")
try:
    sample_index = 5
    x, y, file_name = radar_dataset.get(sample_index)
    assert x.shape == (1, 8), f"Incorrect shape for x: {x.shape}"
    assert y.shape == (1,), f"Incorrect shape for y: {y.shape}"
    print(f"Sample {sample_index} retrieved successfully.")
    print(f"  x: {x}")
    print(f"  y: {y}")
    print(f"  file_name: {file_name}")
except (ValueError, IndexError) as e:
    print(f"Error retrieving sample {sample_index}: {e}")

# Test bad sample index
print("\nTesting bad sample indices...")
try:
    bad_sample_index = -1
    radar_dataset.get(bad_sample_index)
except Exception as e:
    print(f"Correctly handled bad sample index {bad_sample_index}: {e}")

try:
    bad_sample_index = radar_dataset.num_samples + 1
    radar_dataset.get(bad_sample_index)
except Exception as e:
    print(f"Correctly handled bad sample index {bad_sample_index}: {e}")



Retrieving a sample...
Sample 5 retrieved successfully.
  x: [[-1.59432210e+01  1.35796134e+02  1.26786904e+02  1.53996342e+02
   1.30250409e+01  5.76189520e-04 -1.00000000e+00  1.48063430e+01]]
  y: [4]
  file_name: file1.mat

Testing bad sample indices...
Correctly handled bad sample index -1: Sample index out of range
Correctly handled bad sample index 7553: Sample index out of range


In [6]:

# Split the dataset
print("\nSplitting the dataset...")
try:
    split_index = np.arange(radar_dataset.num_samples // 2)
    dataset1, dataset2 = radar_dataset.split(split_index)
    print(f"Dataset 1 samples: {dataset1.num_samples}")
    print(f"Dataset 2 samples: {dataset2.num_samples}")
    assert dataset1.data.shape == (dataset1.num_samples, 1, 8), "Incorrect shape for dataset1 data"
    assert dataset1.labels.shape == (dataset1.num_samples, 1), "Incorrect shape for dataset1 labels"
    assert dataset2.data.shape == (dataset2.num_samples, 1, 8), "Incorrect shape for dataset2 data"
    assert dataset2.labels.shape == (dataset2.num_samples, 1), "Incorrect shape for dataset2 labels"
except Exception as e:
    print(f"Error splitting dataset: {e}")

# Test bad split indexes
print("\nTesting bad split indexes...")
try:
    bad_split_index = "not a list"
    radar_dataset.split(bad_split_index)
except Exception as e:
    print(f"Correctly handled bad split indexes: {e}")



Splitting the dataset...
Dataset 1 samples: 3776
Dataset 2 samples: 3776

Testing bad split indexes...
Correctly handled bad split indexes: slice_indexes must be a list, tuple, or numpy array


In [7]:

# Random split the dataset
print("\nRandomly splitting the dataset...")
try:
    dataset3, dataset4 = radar_dataset.random_split() # default is 0.7
    print(f"Random split - Dataset 3 samples: {dataset3.num_samples}")
    print(f"Random split - Dataset 4 samples: {dataset4.num_samples}")
    assert dataset3.data.shape == (dataset3.num_samples, 1, 8), "Incorrect shape for dataset3 data"
    assert dataset3.labels.shape == (dataset3.num_samples, 1), "Incorrect shape for dataset3 labels"
    assert dataset4.data.shape == (dataset4.num_samples, 1, 8), "Incorrect shape for dataset4 data"
    assert dataset4.labels.shape == (dataset4.num_samples, 1), "Incorrect shape for dataset4 labels"
except Exception as e:
    print(f"Error in random split: {e}")

# Test bad ratio for random split
print("\nTesting bad ratios for random split...")
try:
    bad_ratio = -0.5
    radar_dataset.random_split(ratio=bad_ratio)
except Exception as e:
    print(f"Correctly handled bad ratio {bad_ratio}: {e}")

try:
    bad_ratio = 1.5
    radar_dataset.random_split(ratio=bad_ratio)
except Exception as e:
    print(f"Correctly handled bad ratio {bad_ratio}: {e}")



Randomly splitting the dataset...
Random split - Dataset 3 samples: 5286
Random split - Dataset 4 samples: 2266

Testing bad ratios for random split...
Correctly handled bad ratio -0.5: ratio must be between 0 and 1
Correctly handled bad ratio 1.5: ratio must be between 0 and 1


In [8]:

# Generate batches
print("\nGenerating batches...")
try:
    batch_size = 3
    generator = radar_dataset.get_generator(batch_size=batch_size, shuffle=True)
    print(f"Generating batches of size {batch_size}:")
    for i, (x_batch, y_batch) in enumerate(generator):
        assert x_batch.shape[1] == 8, f"Incorrect shape for x_batch: {x_batch.shape}"
        assert y_batch.shape[0] == x_batch.shape[0], f"y_batch and x_batch size mismatch: {y_batch.shape}, {x_batch.shape}"
        print(f"  Batch {i} retrieved successfully.")
        print(f"    x_batch: {x_batch}")
        print(f"    y_batch: {y_batch}")
        if i >= 1:  # Limit the number of batches printed
            break
except Exception as e:
    print(f"Error generating batches: {e}")


Generating batches...
Generating batches of size 3:
  Batch 0 retrieved successfully.
    x_batch: [[ 1.05588594e+02 -5.46621973e+01  3.01084878e+02  1.13513804e+02
   1.26109534e+01  1.70257965e+02  5.44571020e-02  1.30124504e+01]
 [-1.08912179e+01 -3.30121989e+01  3.29953221e+01  4.58332195e+01
   2.10018154e-01  1.89529508e-02  3.21247667e-01  2.96171631e+00]
 [-7.61048851e+01  8.97550396e+01  5.06968142e+02  2.17921185e+02
  -2.13948448e+01  2.52211833e+00  4.96793874e-02  2.15597775e+01]]
    y_batch: [2 4 2]
  Batch 1 retrieved successfully.
    x_batch: [[ 1.07410382e+01  5.94914292e+00  2.21692734e+01  2.17670524e+01
  -3.13257716e-01  1.16170084e+00  5.73766343e-02  3.42043933e+00]
 [-2.80090953e+01  1.29036246e+01  2.59569837e+01  4.53956008e+01
  -6.16810017e-01  3.40265930e+02  1.69767424e-01  4.78017440e+00]
 [-6.51163659e+01  1.24496005e+02  3.22226049e+01  2.19605398e+01
   1.54709907e+01  1.17663750e-02  1.50067255e-01  1.90719144e+01]]
    y_batch: [2 4 4]
