-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: add `test_all.py`, add `include_test` checkbox - add CI - add checks in generating zip tests ref #103
- Loading branch information
Showing
17 changed files
with
208 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ pytorch-ignite>=0.4.2 | |
pyyaml | ||
albumentations | ||
image_dataset_viz | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,7 @@ pyyaml | |
#:::= it.logger :::# | ||
|
||
#::: } :::# | ||
|
||
#::: if (it.include_tests) { :::# | ||
pytest | ||
#::: } :::# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,7 @@ pyyaml | |
#:::= it.logger :::# | ||
|
||
#::: } :::# | ||
|
||
#::: if (it.include_tests) { :::# | ||
pytest | ||
#::: } :::# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
from argparse import Namespace | ||
from typing import Iterable | ||
|
||
import ignite.distributed as idist | ||
import pytest | ||
import torch | ||
from data import setup_data | ||
from torch import Tensor, nn, optim | ||
from torch.utils.data.dataloader import DataLoader | ||
from trainers import setup_evaluator | ||
|
||
|
||
def set_up(): | ||
model = nn.Linear(1, 1) | ||
optimizer = optim.Adam(model.parameters()) | ||
device = idist.device() | ||
loss_fn = nn.MSELoss() | ||
batch = [torch.tensor([1.0]), torch.tensor([1.0])] | ||
|
||
return model, optimizer, device, loss_fn, batch | ||
|
||
|
||
@pytest.mark.skipif( | ||
os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests" | ||
) | ||
def test_setup_data(): | ||
config = Namespace( | ||
data_path="~/data", train_batch_size=1, eval_batch_size=1, num_workers=0 | ||
) | ||
dataloader_train, dataloader_eval = setup_data(config) | ||
|
||
assert isinstance(dataloader_train, DataLoader) | ||
assert isinstance(dataloader_eval, DataLoader) | ||
train_batch = next(iter(dataloader_train)) | ||
assert isinstance(train_batch, Iterable) | ||
assert isinstance(train_batch[0], Tensor) | ||
assert isinstance(train_batch[1], Tensor) | ||
assert train_batch[0].ndim == 4 | ||
assert train_batch[1].ndim == 1 | ||
eval_batch = next(iter(dataloader_eval)) | ||
assert isinstance(eval_batch, Iterable) | ||
assert isinstance(eval_batch[0], Tensor) | ||
assert isinstance(eval_batch[1], Tensor) | ||
assert eval_batch[0].ndim == 4 | ||
assert eval_batch[1].ndim == 1 | ||
|
||
|
||
def test_setup_evaluator(): | ||
model, _, device, _, batch = set_up() | ||
config = Namespace(use_amp=False) | ||
evaluator = setup_evaluator(config, model, device) | ||
evaluator.run([batch, batch]) | ||
assert isinstance(evaluator.state.output, tuple) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,7 @@ pyyaml | |
#:::= it.logger :::# | ||
|
||
#::: } :::# | ||
|
||
#::: if (it.include_tests) { :::# | ||
pytest | ||
#::: } :::# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
from argparse import Namespace | ||
from numbers import Number | ||
from typing import Iterable | ||
|
||
import ignite.distributed as idist | ||
import pytest | ||
import torch | ||
from data import setup_data | ||
from model import Discriminator, Generator | ||
from torch import Tensor, nn, optim | ||
from torch.utils.data.dataloader import DataLoader | ||
from trainers import setup_trainer | ||
|
||
|
||
def set_up(): | ||
model = nn.Linear(1, 1) | ||
optimizer = optim.Adam(model.parameters()) | ||
device = idist.device() | ||
loss_fn = nn.MSELoss() | ||
batch = [torch.tensor([1.0]), torch.tensor([1.0])] | ||
|
||
return model, optimizer, device, loss_fn, batch | ||
|
||
|
||
@pytest.mark.skipif( | ||
os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests" | ||
) | ||
def test_setup_data(): | ||
config = Namespace( | ||
data_path="~/data", train_batch_size=1, eval_batch_size=1, num_workers=0 | ||
) | ||
dataloader_train, dataloader_eval, _ = setup_data(config) | ||
|
||
assert isinstance(dataloader_train, DataLoader) | ||
assert isinstance(dataloader_eval, DataLoader) | ||
train_batch = next(iter(dataloader_train)) | ||
assert isinstance(train_batch, Iterable) | ||
assert isinstance(train_batch[0], Tensor) | ||
assert isinstance(train_batch[1], Tensor) | ||
assert train_batch[0].ndim == 4 | ||
assert train_batch[1].ndim == 1 | ||
eval_batch = next(iter(dataloader_eval)) | ||
assert isinstance(eval_batch, Iterable) | ||
assert isinstance(eval_batch[0], Tensor) | ||
assert isinstance(eval_batch[1], Tensor) | ||
assert eval_batch[0].ndim == 4 | ||
assert eval_batch[1].ndim == 1 | ||
|
||
|
||
def test_models(): | ||
model_G = Generator(100, 64, 3) | ||
model_D = Discriminator(3, 64) | ||
x = torch.rand([1, 100, 32, 32]) | ||
y = model_G(x) | ||
y.sum().backward() | ||
z = model_D(y) | ||
assert y.shape == torch.Size([1, 3, 560, 560]) | ||
assert z.shape == torch.Size([1024]) | ||
assert isinstance(model_D, nn.Module) | ||
assert isinstance(model_G, nn.Module) | ||
|
||
|
||
def test_setup_trainer(): | ||
model, optimizer, device, loss_fn, batch = set_up() | ||
config = Namespace(use_amp=False, train_batch_size=2, z_dim=100) | ||
trainer = setup_trainer( | ||
config, model, model, optimizer, optimizer, loss_fn, device | ||
) | ||
trainer.run([batch, batch]) | ||
assert isinstance(trainer.state.output, dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,7 @@ image_dataset_viz | |
#:::= it.logger :::# | ||
|
||
#::: } :::# | ||
|
||
#::: if (it.include_tests) { :::# | ||
pytest | ||
#::: } :::# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import os | ||
from argparse import Namespace | ||
|
||
import pytest | ||
from data import setup_data | ||
from torch import Tensor | ||
from torch.utils.data.dataloader import DataLoader | ||
|
||
|
||
@pytest.mark.skipif( | ||
os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests" | ||
) | ||
def test_setup_data(): | ||
config = Namespace( | ||
data_path="~/data", train_batch_size=1, eval_batch_size=1, num_workers=0 | ||
) | ||
dataloader_train, dataloader_eval = setup_data(config) | ||
|
||
assert isinstance(dataloader_train, DataLoader) | ||
assert isinstance(dataloader_eval, DataLoader) | ||
train_batch = next(iter(dataloader_train)) | ||
assert isinstance(train_batch, dict) | ||
assert isinstance(train_batch["image"], Tensor) | ||
assert isinstance(train_batch["mask"], Tensor) | ||
assert train_batch["image"].ndim == 4 | ||
assert train_batch["mask"].ndim == 3 | ||
eval_batch = next(iter(dataloader_eval)) | ||
assert isinstance(eval_batch, dict) | ||
assert isinstance(eval_batch["image"], Tensor) | ||
assert isinstance(eval_batch["mask"], Tensor) | ||
assert eval_batch["image"].ndim == 4 | ||
assert eval_batch["mask"].ndim == 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters