In [None]:
#| default_exp libs.loader.data_loader

# Data Loader

## `SeismicDataset(Dataset)`

### **Description:**
- Represents a PyTorch Dataset for seismic data.

### **Methods:**

 #### `__init__(self, data_path, labels_path, orientation, compute_weights=False, faulty_slices_list=None)`
   - **Description:**
     - Initializes the SeismicDataset.
   - **Parameters:**
     - `data_path` (str): Path to the seismic data file.
     - `labels_path` (str): Path to the corresponding labels file.
     - `orientation` (str): Orientation of the seismic data ('in' for inlines, 'crossline' for crosslines).
     - `compute_weights` (bool): Whether to compute class weights based on frequency. Default is False.
     - `faulty_slices_list` (str): Path to a JSON file containing a list of faulty slices to remove from the data. Default is None.
   - **Returns:**
     - None
---

 #### `__getitem__(self, index)`
   - **Description:**
     - Gets an item (data sample and its label) from the dataset.
   - **Parameters:**
     - `index` (int): Index of the item to retrieve.
   - **Returns:**
     - `tuple`: A tuple containing the data sample and its label.

---

 #### `__len__(self)`
   - **Description:**
     - Gets the length of the dataset.
   - **Returns:**
     - `int`: Length of the dataset.
     
---

 #### `get_class_weights(self)`
   - **Description:**
     - Gets the computed class weights.
   - **Returns:**
     - `numpy.ndarray`: Computed class weights.

---

 #### `get_n_classes(self)`
   - **Description:**
     - Gets the number of unique classes in the dataset.
   - **Returns:**
     - `int`: Number of unique classes.
     
---

 #### `__load_data(self, data_path, labels_path)`
   - **Description:**
     - Loads seismic data and labels from files.
   - **Parameters:**
     - `data_path` (str): Path to the seismic data file.
     - `labels_path` (str): Path to the corresponding labels file.
   - **Returns:**
     - `tuple`: A tuple containing the loaded data and labels.

---

 #### `__compute_class_weights(self)`
   - **Description:**
     - Computes class weights based on frequency.
   - **Returns:**
     - `numpy.ndarray`: Computed class weights.

---

 #### `__remove_faulty_slices(self, faulty_slices_list)`
   - **Description:**
     - Removes faulty slices from the data.
   - **Parameters:**
     - `faulty_slices_list` (str): Path to a JSON file containing a list of faulty slices to remove.
   - **Returns:**
     - None
---

 #### `__process_class_labels(self)`
   - **Description:**
     - Processes class labels to ensure they are in the correct range.
   - **Returns:**
     - `int`: Number of unique classes.


In [None]:
#| export
import numpy as np
import segyio
import os
import json
from torch.utils.data import Dataset


class SeismicDataset(Dataset):

    def __init__(self, data_path, labels_path, orientation, compute_weights=False, faulty_slices_list=None):
        self.data, self.labels = self.__load_data(data_path, labels_path)
        self.orientation = orientation

        # Removing faulty slices from the data if specified
        if faulty_slices_list is not None:
            self.__remove_faulty_slices(faulty_slices_list)
        
        self.n_inlines, self.n_crosslines, self.n_time_slices = self.data.shape
        
        self.n_classes = self.__process_class_labels()
        self.weights = self.__compute_class_weights() if compute_weights else None


    def __getitem__(self, index):
        if self.orientation == 'in':
            image = self.data[index, :, :]
            label = self.labels[index, :, :]
        else:
            image = self.data[:, index, :]
            label = self.labels[:, index, :]
        
        # Reshaping to 3D image
        image = np.expand_dims(image, axis=0)
        label = np.expand_dims(label, axis=0)

        return image, label


    def __len__(self):
        return self.n_inlines if self.orientation == 'in' else self.n_crosslines
    
    
    def __load_data(self, data_path, labels_path):
        if not os.path.isfile(data_path):
            raise FileNotFoundError(f'File {data_path} does not exist.')
        
        if not os.path.isfile(labels_path):
            raise FileNotFoundError(f'File {labels_path} does not exist.')
        
        _, data_extension = os.path.splitext(data_path)
        
        # Loading data
        if data_extension in ['.segy', '.sgy']:
            inlines = []
        
            with segyio.open(data_path, 'r') as segyfile:
                segyfile.mmap()

                for inline in segyfile.ilines:
                    inlines.append(segyfile.iline[inline])

            data = np.array(inlines)
        else:
            data = np.load(data_path)

        # Loading labels
        labels = np.load(labels_path)
        
        return data, labels


    def __compute_class_weights(self):
        total_n_values = self.n_inlines * self.n_crosslines * self.n_time_slices
        # Weights are inversely proportional to the frequency of the classes in the training set
        _, counts = np.unique(self.labels, return_counts=True)
        
        return total_n_values / (counts*self.n_classes)
    

    def __remove_faulty_slices(self, faulty_slices_list):
        try:
            with open(faulty_slices_list, 'r') as json_buffer:
                # File containing the list of slices to delete
                faulty_slices = json.loads(json_buffer.read())

                self.data = np.delete(self.data, obj=faulty_slices['inlines'], axis=0)
                self.data = np.delete(self.data, obj=faulty_slices['crosslines'], axis=1)
                self.data = np.delete(self.data, obj=faulty_slices['time_slices'], axis=2)

                self.labels = np.delete(self.labels, obj=faulty_slices['inlines'], axis=0)
                self.labels = np.delete(self.labels, obj=faulty_slices['crosslines'], axis=1)
                self.labels = np.delete(self.labels, obj=faulty_slices['time_slices'], axis=2)

        except FileNotFoundError:
            print('Could not open the .json file containing the faulty slices.')
            print('Training with the whole volume instead.\n')

            pass
    

    def __process_class_labels(self):
        # Labels must be in the range [0, number_of_classes) for the loss function to work properly
        label_values = np.unique(self.labels)
        new_labels_dict = {label_values[i]: i for i in range(len(label_values))}

        for key, value in zip(new_labels_dict.keys(), new_labels_dict.values()):
            self.labels[self.labels == key] = value
        
        return len(label_values)
    

    def get_class_weights(self):
        return self.weights
    

    def get_n_classes(self):
        return self.n_classes
