This repository contains data preprocessing scripts for handling scroll dataset images and segmentation masks. The pipeline includes dataset preparation, normalization, augmentation, and model training using PyTorch Lightning and TimeSformer for temporal segmentation.
├── data_preprocessing/
│ ├── dataset.py # Custom PyTorch Dataset class for scroll images
│ ├── main.py # Main script for training the segmentation model
│ ├── model.py # Implementation of TimeSformer segmentation model
│ ├── preprocessing.py # Prepares datasets from fragments and segments
│ ├── scheduler.py # Learning rate scheduler with warm-up
│ ├── utils.py # Utility functions for seed setting and directory creation
│ |── config/
│ ├── config.yaml # Configuration file for Hydra
│── setup.py # Installation script for pip
│── vesuvius.yaml # Conda environment file
│── README.md # Project documentation
To create a Conda environment with all dependencies:
conda env create -f vesuvius.yaml
conda activate vesuvius
Then run the following to use the setup.py file.
pip install -e .
Before running the preprocessing or training scripts, ensure your data is structured as follows:
<comp_dataset_path>/ # Root data directory (specified in config.yaml)
├── Fragments/
│ ├── Fragment1/
│ │ ├── cache/
│ │ │ ├── data.zarr/ # Preprocessed fragment data
│ ├── Fragment2/
│ │ ├── cache/
│ │ │ ├── data.zarr/
│ ├── ...
│
├── Scroll1/
│ ├── segments/
│ │ ├── 20230503225234/
│ │ │ ├── cache/
│ │ │ │ ├── data.zarr/ # Preprocessed segment data
│ │ ├── 20230504093154/
│ │ │ ├── cache/
│ │ │ │ ├── data.zarr/
│ │ ├── 20230504094316/
│ │ │ ├── cache/
│ │ │ │ ├── data.zarr/
│ │ ├── ...
│
├── Scroll2/
│ ├── segments/
│ │ ├── <segment_number>/
│ │ │ ├── cache/
│ │ │ │ ├── data.zarr/
│ │ ├── ...
Each data.zarr/
directory contains processed fragment or segment data. This is the dataset that will be available for download from our submission (more details on where/how to access this will be added later).
If your dataset is structured exactly as shown above, you do not need to preprocess the data manually. Instead, simply run the training script:
python data_preprocessing/main.py
To skip preprocessing and load cached .zarr/
files instead, ensure use_cache
is enabled in config.yaml:
comp_dataset_path: "/path/to/your/data" # Update this path to your dataset location
use_cache: True # Load preprocessed data instead of recomputing
This ensures that main.py will automatically read from the existing cache/data.zarr/
directories, speeding up the pipeline.
Before training, preprocess the dataset using:
python data_preprocessing/preprocessing.py
This script processes image fragments and segmentation masks, splitting them into train, validation, and test sets.
You can modify dataset parameters in config/config.yaml before running the script. Some key parameters you might want to tweak:
# Fragments to process
fragment_ids: [1, 2, 3] # List of fragment numbers to include
# Specify a single fragment for validation
validation_fragment_id: 1 # Set to "-1" to disable validation
# Segments to process
scrolls:
- scroll_id: 2
num_segments: -1 # "-1" means process all segments from this scroll
# - scroll_id: 3
# num_segments: 10 # Reads 10 random segments from this scroll
# Validation segment frequency
validation_segments_mod: 5 # Every 5th segment is used for validation (set "-1" to disable)
fragment_ids
: List of fragment numbers to process.validation_fragment_id
: Select a single fragment for validation (-1 disables validation).scrolls
: Specifies which scrolls to read and how many random segments to include.num_segments
:- -1 → Read all segments from this scroll
- n → Select n random segments from this scroll (Replace n with an actual value)
validation_segments_mod
: Controls how often a segment is assigned to validation.
🚨 Important: During data collection, no explicit validation data was specified so by default, everything is saved as training data. If needed, adjust validation_fragment_id
and validation_segments_mod
to designate validation samples.
To train the segmentation model, run:
python data_preprocessing/main.py
This will:
- Load the dataset using
ScrollDataset
- Train the
TimeSformerModel
with PyTorch Lightning - Log training metrics and model checkpoints using Weights & Biases (wandb)
The model is based on TimeSformer, a video transformer adapted for scroll image segmentation. The model structure is implemented in data_preprocessing/model.py
.
Defines ScrollDataset
, a custom PyTorch Dataset class that:
- Loads image and mask patches
- Applies temporal augmentations
- Returns training/validation data
Implements TimeSformerModel
, a spatiotemporal transformer that:
- Uses
TimeSformer
for feature extraction - Optimizes with Dice Loss + BCE Loss
- Implements gradient clipping for stability
Handles dataset preparation:
- Loads images/masks from fragments & segments
- Applies tiling and augmentation
- Saves preprocessed data as NumPy arrays
Main script that:
- Loads configuration via Hydra
- Initializes dataset and dataloaders
- Runs the training loop using PyTorch Lightning
- Logs to Weights & Biases (wandb)
Helper functions for:
- Setting random seed for reproducibility
- Creating directories dynamically
Implements a warm-up learning rate scheduler to stabilize training.
The pipeline is configured using Hydra (config/config.yaml
). Modify hyperparameters such as:
train_batch_size: 8
valid_batch_size: 4
epochs: 50
lr: 0.0001
tile_size: 256
stride: 128
Training logs are stored in Weights & Biases (wandb). Ensure wandb
is properly configured:
wandb login
To visualize results:
wandb online
This project is licensed under the MIT License.
- Aryan Gulati
- Aditya Kumar
- Jaiv Doshi
- Leslie Moreno