forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Datamodule for SpaceNet1 (microsoft#965)
* Add SpaceNet1 datamodule * Running black and isort * version added * Fix docs * SpaceNet1 tests * Testing spacenet datamodule with trainers * no loveda * black * doc fix * Removing direct datamodule test * Make sure percent normalization doesn't divide by zero * Speed up preprocessing Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
- Loading branch information
1 parent
36529a2
commit cde1187
Showing
11 changed files
with
248 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
program: | ||
overwrite: False | ||
seed: 0 | ||
trainer: | ||
gpus: [3] | ||
min_epochs: 50 | ||
max_epochs: 200 | ||
benchmark: True | ||
experiment: | ||
name: "spacenet-example" | ||
task: "sen12ms" | ||
module: | ||
loss: "ce" | ||
model: "unet" | ||
backbone: "resnet18" | ||
weights: "imagenet" | ||
learning_rate: 1e-3 | ||
learning_rate_schedule_patience: 6 | ||
in_channels: 3 | ||
num_classes: 3 | ||
ignore_index: 0 | ||
datamodule: | ||
root: "data/spacenet" | ||
batch_size: 32 | ||
num_workers: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
experiment: | ||
task: "spacenet1" | ||
module: | ||
loss: "ce" | ||
model: "unet" | ||
backbone: "resnet18" | ||
weights: null | ||
learning_rate: 1e-3 | ||
learning_rate_schedule_patience: 6 | ||
verbose: false | ||
in_channels: 3 | ||
num_classes: 3 | ||
num_filters: 1 | ||
ignore_index: null | ||
datamodule: | ||
root: "tests/data/spacenet" | ||
batch_size: 1 | ||
num_workers: 0 | ||
val_split_pct: 0.33 | ||
test_split_pct: 0.33 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""SpaceNet datamodules.""" | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
import kornia.augmentation as K | ||
import matplotlib.pyplot as plt | ||
import pytorch_lightning as pl | ||
from torch.utils.data import DataLoader | ||
|
||
from ..datasets import SpaceNet1 | ||
from .utils import dataset_split | ||
|
||
# https://github.com/pytorch/pytorch/issues/60979 | ||
# https://github.com/pytorch/pytorch/pull/61045 | ||
DataLoader.__module__ = "torch.utils.data" | ||
|
||
|
||
class SpaceNet1DataModule(pl.LightningDataModule): | ||
"""LightningDataModule implementation for the SpaceNet1 dataset. | ||
Randomly splits into train/val/test. | ||
.. versionadded:: 0.4 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
batch_size: int = 64, | ||
num_workers: int = 0, | ||
val_split_pct: float = 0.1, | ||
test_split_pct: float = 0.2, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Initialize a LightningDataModule for SpaceNet1. | ||
Args: | ||
batch_size: The batch size to use in all created DataLoaders | ||
num_workers: The number of workers to use in all created DataLoaders | ||
val_split_pct: What percentage of the dataset to use as a validation set | ||
test_split_pct: What percentage of the dataset to use as a test set | ||
**kwargs: Additional keyword arguments passed to | ||
:class:`~torchgeo.datasets.SpaceNet1` | ||
""" | ||
super().__init__() | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.val_split_pct = val_split_pct | ||
self.test_split_pct = test_split_pct | ||
self.kwargs = kwargs | ||
|
||
self.padto = K.PadTo((448, 448)) | ||
|
||
def on_after_batch_transfer( | ||
self, batch: Dict[str, Any], batch_idx: int | ||
) -> Dict[str, Any]: | ||
"""Apply batch augmentations after batch is transferred to the device. | ||
Args: | ||
batch: mini-batch of data | ||
batch_idx: batch index | ||
Returns: | ||
augmented mini-batch | ||
""" | ||
if ( | ||
hasattr(self, "trainer") | ||
and self.trainer is not None | ||
and hasattr(self.trainer, "training") | ||
and self.trainer.training | ||
): | ||
# Kornia expects masks to be floats with a channel dimension | ||
x = batch["image"] | ||
y = batch["mask"].float().unsqueeze(1) | ||
|
||
train_augmentations = K.AugmentationSequential( | ||
K.RandomRotation(p=0.5, degrees=90), | ||
K.RandomHorizontalFlip(p=0.5), | ||
K.RandomVerticalFlip(p=0.5), | ||
K.RandomSharpness(p=0.5), | ||
K.ColorJitter( | ||
p=0.5, | ||
brightness=0.1, | ||
contrast=0.1, | ||
saturation=0.1, | ||
hue=0.1, | ||
silence_instantiation_warning=True, | ||
), | ||
data_keys=["input", "mask"], | ||
) | ||
x, y = train_augmentations(x, y) | ||
|
||
# torchmetrics expects masks to be longs without a channel dimension | ||
batch["image"] = x | ||
batch["mask"] = y.squeeze(1).long() | ||
|
||
return batch | ||
|
||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Transform a single sample from the Dataset. | ||
Args: | ||
sample: dictionary containing image and mask | ||
Returns: | ||
preprocessed sample | ||
""" | ||
sample["image"] = sample["image"].float() / 255 | ||
sample["image"] = self.padto(sample["image"]).squeeze() | ||
|
||
if "mask" in sample: | ||
# We add 1 to the mask to map the current {background, building} labels to | ||
# the values {1, 2}. This is necessary because we add 0 padding to the | ||
# mask that we want to ignore in the loss function. | ||
sample["mask"] = self.padto(sample["mask"].float() + 1).squeeze() | ||
sample["mask"] = sample["mask"].long() | ||
return sample | ||
|
||
def prepare_data(self) -> None: | ||
"""Make sure that the dataset is downloaded. | ||
This method is only called once per run. | ||
""" | ||
SpaceNet1(**self.kwargs) | ||
|
||
def setup(self, stage: Optional[str] = None) -> None: | ||
"""Initialize the main ``Dataset`` objects. | ||
This method is called once per GPU per run. | ||
Args: | ||
stage: stage to set up | ||
""" | ||
self.dataset = SpaceNet1(transforms=self.preprocess, **self.kwargs) | ||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( | ||
self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct | ||
) | ||
|
||
def train_dataloader(self) -> DataLoader[Any]: | ||
"""Return a DataLoader for training. | ||
Returns: | ||
training data loader | ||
""" | ||
return DataLoader( | ||
self.train_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=True, | ||
) | ||
|
||
def val_dataloader(self) -> DataLoader[Any]: | ||
"""Return a DataLoader for validation. | ||
Returns: | ||
validation data loader | ||
""" | ||
return DataLoader( | ||
self.val_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=False, | ||
) | ||
|
||
def test_dataloader(self) -> DataLoader[Any]: | ||
"""Return a DataLoader for testing. | ||
Returns: | ||
testing data loader | ||
""" | ||
return DataLoader( | ||
self.test_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=False, | ||
) | ||
|
||
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: | ||
"""Run :meth:`torchgeo.datasets.SpaceNet.plot`.""" | ||
return self.dataset.plot(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters