# Fault Segmentation Prediction Notebook

# General Explanation

![UNet++ Architecture](image/unet_plus_plus_archi.png)

*The above figure is from the [UNet++ paper](https://arxiv.org/pdf/1807.10165).*

I used a **2D UNet++** model from the `segmentation_models_pytorch` library, with an **EfficientNet-B8** encoder pretrained on ImageNet.

This approach yielded strong performance thanks to the following key factors:

### 2.5D Architecture
I predict the current frame using multiple frames at once: the 3 previous frames, the current frame, and the 3 following frames (7 frames total).  
This strategy enhances spatial consistency and overall robustness of the model.

### Multi-axis Training (x and y)
For each fold, I trained the same model on two different axes:
- The **x** axis, with frame shape (300, 1259)  
- The **y** axis, also with frame shape (300, 1259)

Predictions are then made along both axes, which enriches the model’s ability to detect faults.

### Data Augmentation
I applied transformations such as a 180° rotation and horizontal/vertical flips to help the model generalize better.

### Model Ensembling
I chose to split the dataset into 6 folds: 325 volumes for training and 65 for validation in each.  
Due to time constraints, my final solution uses only the models from the first 3 folds, each trained for a maximum of 12 epochs.

This combination of techniques — using **2D UNet++**, **2.5D** prediction, **multi-axis training**, **data augmentation**, and **model ensembling** — significantly contributed to achieving strong fault segmentation results.

# How to use this notebook ?

This notebook demonstrates how to generate fault predictions for 3D volumes using one or multiple model checkpoints. 
We assume each checkpoint is specialized in predicting along one axis (or a combined "xy" axis which effectively means x and y). 

Below is an overview of the steps:

0. **Convert 3D Volumes to 2D Slices**
1. **Configuration**: Define the parameters for:
   - Model checkpoints and corresponding axes.
   - Input dataset path and optional volume filtering.
   - Batch size, number of workers, etc.
   - Output settings (where to save the final submission, probability volumes, thresholded volumes).
   - Performance optimizations such as model compilation or forcing CPU usage.
2. **Dataset Loading**: Load the dataset index (e.g., a parquet file) containing volume information.
3. **Volume Filtering**: Optionally filter the dataset to process only specific volumes.
4. **Prediction**: For each checkpoint and axis specification, generate the prediction volumes.
5. **Ensembling**: Combine/average all prediction volumes.
6. **Thresholding**: Apply a probability threshold to the ensembled predictions to generate a final binary mask.
7. **Confidence-Based Zeroing**: We may choose to zero out an entire volume if its mean confidence 
   (across the predicted mask) is below a specified threshold, because we assume the volume might be 
   completely empty (no faults to predict).
8. **Submission**: Create the final `.npz` file (and optionally save intermediate probability and thresholded volumes).

Let's begin!

## 0. Convert 3D Volumes to 2D Slices
Before diving into the prediction pipeline, we need to preprocess the 3D volumes by converting them into 2D slices. This step is crucial as our models are designed to work with 2D data.

Function Overview
The write_2d_slices function handles the conversion of 3D volumes into 2D slices. It takes the following parameters:

root_dir: The directory containing your 3D volume data.
output_dir: The directory where the 2D slices will be saved. This will later be used as the root_dir in the prediction configuration.
mode: (Optional) If your root_dir includes multiple data parts (e.g., both training and testing data), you can specify the mode as "test" to convert only the test volumes.
Usage
First, ensure that the write_2d_slices function is correctly imported from your src/write_2d_slices.py module.

In [None]:
# Define the directories
ROOT_PATH_TO_ALL_TEST_PARTS = "/data/datasets/dark-size-test"  # Directory with 3D volumes
PATH_TO_2D_SLICES_OUTPUT = "/data/datasets/darkside-test-data-2d"   # Directory to save 2D slices

In [None]:
# Import the write_2d_slices function
from src.write_2d_slices import write_2d_slices

# (Optional) Specify the mode if your root_dir contains both train and test data parts
CONVERSION_MODE = "test"  # Use "test" to convert only test volumes

# Convert the 3D volumes to 2D slices
write_2d_slices(
    root_dir=ROOT_PATH_TO_ALL_TEST_PARTS,
    output_dir=PATH_TO_2D_SLICES_OUTPUT,
    axes=["x", "y"],              # Specify the axes along which to slice
    num_workers=25,          # Adjust based on your system's capabilities
    mode=CONVERSION_MODE,     # Optional: specify the mode if needed
)

print("3D volumes have been successfully converted to 2D slices.")


## 1. Imports

We'll start by importing the necessary libraries. 
The `predict_single_checkpoint`, `ensemble_volumes_and_save`, and `build_final_volumes_dir_name` functions 
are assumed to be defined in the `src.predict_utils` module.

In [None]:
import os
from pathlib import Path
from typing import List

import pandas as pd
import torch

# Assuming you have these in a local python file at src/predict_utils.py
from src.predict_utils import (
    build_final_volumes_dir_name,
    predict_single_checkpoint,
    ensemble_volumes_and_save,
)

print("Imports done.")


## 2. Configuration

In the cell below, we define a configuration dictionary `args` that simulates
what would typically come from command-line arguments. 
The important keys in `args` are:

- **checkpoints**: List of paths to your `.ckpt` model checkpoints.
- **axes**: List of axes for each checkpoint (`'x'`, `'y'`, or `'xy'`). 
  - For example, if `axes[i] == 'xy'`, that means that particular checkpoint should be run on the `x` axis and the `y` axis.
  - The length of `checkpoints` must match the length of `axes`.
- **root_dir**: The directory where your 2D slices (and `dataset.parquet`) are stored.
- **vol_filter**: (Optional) If not `None`, a list of volume IDs (`sample_id`) to process.
- **batch_size**: Batch size for inference.
- **num_workers**: Number of workers for the DataLoader.
- **save_threshold**: Probability threshold to apply to the averaged predictions.
- **min_mean_conf**: Minimum mean confidence required to keep the prediction (if below, the volume is zeroed).
- **submission_path**: Path where the final `.npz` submission file will be saved.
- **save_probas**: If `True`, save raw probability volumes in `predictions_probas/{model_name}/{axis}`.
- **save_final**: If `True`, save final thresholded volumes in a combined directory.
- **compile**: If `True`, compile the model for optimized performance (PyTorch 2.x feature).
- **cpu**: If `True`, run inference on CPU (otherwise use CUDA if available).
- **dtype**: The data type for inference (`float16`, `float32`, or `bf16`).
- **force_prediction**: If `True`, re-predict even if volumes are already available.

Feel free to modify the values below to match your environment.


In [None]:
# Configuration dictionary
args = {
    "checkpoints": [
        # Example: "path/to/checkpoint_model_a.ckpt",
        #          "path/to/checkpoint_model_b.ckpt"
        "checkpoints/checkpoints_unetpp_timm-efficientnet-b8_nchans7_val_axisxy_20250102_152812/fold0-best-model-epoch=09-val_dice_3d=0.8959.ckpt",
        "checkpoints/checkpoints_unetpp_timm-efficientnet-b8_nchans7_val_axisxy_20250102_152812/fold1-best-model-epoch=10-val_dice_3d=0.8988.ckpt",
        "checkpoints/checkpoints_unetpp_timm-efficientnet-b8_nchans7_val_axisxy_20250102_152812/fold2-best-model-epoch=10-val_dice_3d=0.8870.ckpt",
    ],
    "axes": [
        # Must be 'x', 'y', or 'xy' and match the number of checkpoints.
        # If a checkpoint is for 2 axes (say "xy"), that single checkpoint
        # will be used to predict on axis 'x' and axis 'y'.
        "xy",
        "xy",
        "xy",
    ],
    "root_dir": PATH_TO_2D_SLICES_OUTPUT,  # Contains dataset.parquet and 2D slices
    "vol_filter": None,               # List of sample_ids to filter, or None for no filter
    "batch_size": 8,
    "num_workers": 8,
    "save_threshold": 0.5,
    "min_mean_conf": 0.1,  # If not None, volume is zeroed if mean confidence < this
    "submission_path": "submission.npz",
    "save_probas": True,   # always set it to True
    "save_final": False,    # Whether to save thresholded volumes (just before encode them in create_submission)
    "compile": False,
    "cpu": False,
    "dtype": "bf16",
    "force_prediction": False, # Whether to re compute model predictions even if they already exists (with previous call with save_proba == True)
}

print("Configuration set.")


## 3. Setting Up the Device

We'll determine whether to use the CPU or GPU (CUDA) for predictions.

In [None]:
device = 'cpu' if args["cpu"] else ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 4. Basic Checks and Axis Counting

We ensure that the number of checkpoints matches the number of axis specifications. 
We also count the total number of axes (so that if there's only one total axis, 
the `min_mean_conf` won't really have an effect).

In [None]:
if len(args["checkpoints"]) != len(args["axes"]):
    raise ValueError("Number of checkpoints must match number of axes specifications.")

# Count total axes across all specs
total_axes = 0
for axis_spec in args["axes"]:
    if axis_spec.lower() == 'xy':
        total_axes += 2
    else:
        total_axes += 1

# If there's only one model/axis combination but min_mean_conf is set, warn the user
if len(args["checkpoints"]) == 1 and total_axes == 1 and args["min_mean_conf"] is not None:
    print(
        f"Warning: You set a min_mean_conf ({args['min_mean_conf']}) but there's only one model and one axis. "
        "Confidence-based filtering will be ignored."
    )


## 5. Loading the Dataset

We assume there's a `dataset.parquet` file in `args["root_dir"]` which contains 
information about each sample (volume). We'll load it into a Pandas DataFrame.

In [None]:
df_path = os.path.join(args["root_dir"], 'dataset.parquet')
full_df = pd.read_parquet(df_path)
print("Dataset index loaded.")
print(f"Total samples in dataset: {len(full_df)}")


## 6. Optional Volume Filtering (Only for fast debug, do not use these option in production)

If `args["vol_filter"]` is provided, we'll keep only the specified volume IDs. 
Otherwise, we'll process all samples.

In [None]:
if args["vol_filter"]:
    print(f"Applying volume filter with {len(args['vol_filter'])} sample_id(s).")
    initial_count = len(full_df)
    filtered_df = full_df[full_df['sample_id'].isin(args["vol_filter"])].copy()
    final_count = len(filtered_df)
    missing_samples = set(args["vol_filter"]) - set(filtered_df['sample_id'].unique())

    print(f"Number of samples after filtering: {final_count} (filtered out {initial_count - final_count} samples).")
    if missing_samples:
        print(
            "Warning: The following sample_id(s) were not found in the dataset and will be ignored: "
            f"{', '.join(missing_samples)}"
        )
    full_df = filtered_df
else:
    print("No volume filter applied. Processing all samples.")

print(f"Total samples to process: {len(full_df)}")

## 7. Generating Predictions for Each Checkpoint & Axis

For each entry in `args["checkpoints"]` and its corresponding `args["axes"]`, we call 
the `predict_single_checkpoint()` function, which handles slicing through the DataFrame, 
loading the model, and generating predictions.

- If `axes[i] == 'xy'`, we will run predictions on both `x` and `y` axes using the same checkpoint.
- We also pass:
  - `batch_size`, `num_workers`
  - `save_probas`: whether to save probability volumes
  - `force_prediction`: whether to overwrite existing predictions
  - `root_dir`: path to slices
  - `dtype`, `compile_model`, `cpu`, `device`: for inference configuration.

In [None]:
for ckpt_path_str, axis_spec in zip(args["checkpoints"], args["axes"]):
    checkpoint_path = Path(ckpt_path_str)
    
    # For 'xy', we do x, then y
    if axis_spec.lower() == 'xy':
        axis_list = ['x', 'y']
    else:
        axis_list = [axis_spec.lower()]
    
    for ax in axis_list:
        print(f"Predicting for Checkpoint: {checkpoint_path.name}, Axis: {ax}")
        predict_single_checkpoint(
            checkpoint_path=checkpoint_path,
            axis=ax,
            full_df=full_df,
            batch_size=args["batch_size"],
            num_workers=args["num_workers"],
            save_probas=args["save_probas"],
            force_prediction=args["force_prediction"],
            root_dir=args["root_dir"],
            dtype=args["dtype"],
            compile_model=args["compile"],
            cpu=args["cpu"],
            device=device
        )

print("All checkpoints processed.")

## 8. Gathering Predictions and Ensembling

Now that we've generated predictions for each checkpoint and axis, we collect 
the corresponding directories in order to perform an ensembling step 
(e.g., averaging predictions).

We'll look for sub-folders of the form: 
``predictions_probas/{model_name}/{axis}``.

Afterwards, we can apply a final threshold (and optionally `min_mean_conf`) 
to the ensembled probability volumes. 


In [None]:
# Collect the prediction directories for ensembling
prediction_dirs: List[Path] = []
for ckpt_path_str, axis_spec in zip(args["checkpoints"], args["axes"]):
    checkpoint_path = Path(ckpt_path_str)
    model_name = checkpoint_path.stem

    # If 'xy', we expect subaxes x and y
    if axis_spec.lower() == 'xy':
        sub_axes = ['x', 'y']
    else:
        sub_axes = [axis_spec.lower()]

    for sub_ax in sub_axes:
        pred_dir = checkpoint_path.parent / 'predictions_probas' / model_name / sub_ax
        print(f"Looking for predictions in {pred_dir}")
        if pred_dir.exists():
            prediction_dirs.append(pred_dir)

print(f"Found {len(prediction_dirs)} relevant prediction directories for ensembling.")

## 9. Building Final Volumes and Creating Submission

If `args["save_final"]` is `True`, we'll build a special folder name that indicates 
which checkpoints/axes were used, then we'll call `ensemble_volumes_and_save()` 
to finalize the outputs and create the submission.

In [None]:
final_volumes_dir = None
if args["save_final"]:
    final_dir_name = build_final_volumes_dir_name(args["checkpoints"], args["axes"])
    final_volumes_dir = Path(final_dir_name)
    print(f"Final volumes directory will be '{final_dir_name}'")

ensemble_volumes_and_save(
    all_predictions=prediction_dirs,
    dataset_index=full_df,
    output_path=Path(args["submission_path"]),
    save_threshold=args["save_threshold"],
    device=device,
    min_mean_conf=args["min_mean_conf"],
    save_final_volumes=args["save_final"],
    final_volumes_dir=final_volumes_dir
)

print("Ensembling and submission creation completed successfully.")
print(f"Submission file: {args['submission_path']}")
