Skip to content

Commit

Permalink
Datamodule for SpaceNet1 (microsoft#965)
Browse files Browse the repository at this point in the history
* Add SpaceNet1 datamodule

* Running black and isort

* version added

* Fix docs

* SpaceNet1 tests

* Testing spacenet datamodule with trainers

* no loveda

* black

* doc fix

* Removing direct datamodule test

* Make sure percent normalization doesn't divide by zero

* Speed up preprocessing

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
calebrob6 and adamjstewart authored Dec 24, 2022
1 parent 36529a2 commit cde1187
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 7 deletions.
25 changes: 25 additions & 0 deletions conf/spacenet1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
program:
overwrite: False
seed: 0
trainer:
gpus: [3]
min_epochs: 50
max_epochs: 200
benchmark: True
experiment:
name: "spacenet-example"
task: "sen12ms"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 3
ignore_index: 0
datamodule:
root: "data/spacenet"
batch_size: 32
num_workers: 4
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ So2Sat

.. autoclass:: So2SatDataModule

SpaceNet
^^^^^^^^

.. autoclass:: SpaceNet1DataModule

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
20 changes: 20 additions & 0 deletions tests/conf/spacenet1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
experiment:
task: "spacenet1"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 3
num_filters: 1
ignore_index: null
datamodule:
root: "tests/data/spacenet"
batch_size: 1
num_workers: 0
val_split_pct: 0.33
test_split_pct: 0.33
5 changes: 2 additions & 3 deletions tests/data/spacenet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,6 @@ def create_test_label(
def main() -> None:
ROOT_DIR = os.path.dirname(os.path.realpath(__file__))

num_samples = 2

for dataset in datasets:

collections = list(dataset.collection_md5_dict.keys())
Expand All @@ -187,11 +185,12 @@ def main() -> None:
"sn5_AOI_8_Mumbai",
"sn7_test_source",
]:
num_samples = 2
num_samples = 3
elif collection == "sn5_AOI_8_Mumbai":
num_samples = 3
else:
num_samples = 1

for sample in range(num_samples):
out_dir = os.path.join(ROOT_DIR, collection)
if collection == "sn6_AOI_11_Rotterdam":
Expand Down
Binary file modified tests/data/spacenet/sn1_AOI_1_RIO.tar.gz
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/datasets/test_spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet1:
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection)
test_md5 = {"sn1_AOI_1_RIO": "246e27fcd7ae73496212a7f585a43dbb"}
test_md5 = {"sn1_AOI_1_RIO": "127a523561987110f008e8c9815ce807"}

# Refer https://github.com/python/mypy/issues/1032
monkeypatch.setattr(SpaceNet1, "collection_md5_dict", test_md5)
Expand All @@ -85,7 +85,7 @@ def test_getitem(self, dataset: SpaceNet1) -> None:
assert x["image"].shape[0] == 8

def test_len(self, dataset: SpaceNet1) -> None:
assert len(dataset) == 2
assert len(dataset) == 3

def test_already_downloaded(self, dataset: SpaceNet1) -> None:
SpaceNet1(root=dataset.root, download=True)
Expand Down
2 changes: 2 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NAIPChesapeakeDataModule,
OSCDDataModule,
SEN12MSDataModule,
SpaceNet1DataModule,
)
from torchgeo.datasets import LandCoverAI
from torchgeo.trainers import SemanticSegmentationTask
Expand Down Expand Up @@ -48,6 +49,7 @@ class TestSemanticSegmentationTask:
("sen12ms_s1", SEN12MSDataModule),
("sen12ms_s2_all", SEN12MSDataModule),
("sen12ms_s2_reduced", SEN12MSDataModule),
("spacenet1", SpaceNet1DataModule),
],
)
def test_trainer(
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .resisc45 import RESISC45DataModule
from .sen12ms import SEN12MSDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .vaihingen import Vaihingen2DDataModule
Expand All @@ -46,6 +47,7 @@
"RESISC45DataModule",
"SEN12MSDataModule",
"So2SatDataModule",
"SpaceNet1DataModule",
"TropicalCycloneDataModule",
"UCMercedDataModule",
"USAVarsDataModule",
Expand Down
182 changes: 182 additions & 0 deletions torchgeo/datamodules/spacenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""SpaceNet datamodules."""

from typing import Any, Dict, Optional

import kornia.augmentation as K
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader

from ..datasets import SpaceNet1
from .utils import dataset_split

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class SpaceNet1DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the SpaceNet1 dataset.
Randomly splits into train/val/test.
.. versionadded:: 0.4
"""

def __init__(
self,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.1,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for SpaceNet1.
Args:
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SpaceNet1`
"""
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
self.kwargs = kwargs

self.padto = K.PadTo((448, 448))

def on_after_batch_transfer(
self, batch: Dict[str, Any], batch_idx: int
) -> Dict[str, Any]:
"""Apply batch augmentations after batch is transferred to the device.
Args:
batch: mini-batch of data
batch_idx: batch index
Returns:
augmented mini-batch
"""
if (
hasattr(self, "trainer")
and self.trainer is not None
and hasattr(self.trainer, "training")
and self.trainer.training
):
# Kornia expects masks to be floats with a channel dimension
x = batch["image"]
y = batch["mask"].float().unsqueeze(1)

train_augmentations = K.AugmentationSequential(
K.RandomRotation(p=0.5, degrees=90),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
K.ColorJitter(
p=0.5,
brightness=0.1,
contrast=0.1,
saturation=0.1,
hue=0.1,
silence_instantiation_warning=True,
),
data_keys=["input", "mask"],
)
x, y = train_augmentations(x, y)

# torchmetrics expects masks to be longs without a channel dimension
batch["image"] = x
batch["mask"] = y.squeeze(1).long()

return batch

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float() / 255
sample["image"] = self.padto(sample["image"]).squeeze()

if "mask" in sample:
# We add 1 to the mask to map the current {background, building} labels to
# the values {1, 2}. This is necessary because we add 0 padding to the
# mask that we want to ignore in the loss function.
sample["mask"] = self.padto(sample["mask"].float() + 1).squeeze()
sample["mask"] = sample["mask"].long()
return sample

def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
SpaceNet1(**self.kwargs)

def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
self.dataset = SpaceNet1(transforms=self.preprocess, **self.kwargs)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)

def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.SpaceNet.plot`."""
return self.dataset.plot(*args, **kwargs)
2 changes: 1 addition & 1 deletion torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,6 @@ def percentile_normalization(
lower_percentile = np.percentile(img, lower, axis=axis)
upper_percentile = np.percentile(img, upper, axis=axis)
img_normalized: "np.typing.NDArray[np.int_]" = np.clip(
(img - lower_percentile) / (upper_percentile - lower_percentile), 0, 1
(img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1
)
return img_normalized
8 changes: 7 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
RESISC45DataModule,
SEN12MSDataModule,
So2SatDataModule,
SpaceNet1DataModule,
TropicalCycloneDataModule,
UCMercedDataModule,
)
Expand Down Expand Up @@ -57,6 +58,7 @@
"resisc45": (ClassificationTask, RESISC45DataModule),
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
"so2sat": (ClassificationTask, So2SatDataModule),
"spacenet1": (SemanticSegmentationTask, SpaceNet1DataModule),
"ucmerced": (ClassificationTask, UCMercedDataModule),
}

Expand Down Expand Up @@ -179,7 +181,11 @@ def main(conf: DictConfig) -> None:
mode = "min"

checkpoint_callback = ModelCheckpoint(
monitor=monitor_metric, dirpath=experiment_dir, save_top_k=1, save_last=True
monitor=monitor_metric,
filename="checkpoint-epoch{epoch:02d}-val_loss{val_loss:.2f}",
dirpath=experiment_dir,
save_top_k=1,
save_last=True,
)
early_stopping_callback = EarlyStopping(
monitor=monitor_metric, min_delta=0.00, patience=18, mode=mode
Expand Down

0 comments on commit cde1187

Please sign in to comment.