Skip to content

Commit

Permalink
Added progress bar feature (#105)
Browse files Browse the repository at this point in the history
* added progress bar feature

* solved linting issues
  • Loading branch information
namanbiyani authored and avik-pal committed Jul 3, 2019
1 parent 65ec5d4 commit 8c2bc99
Show file tree
Hide file tree
Showing 22 changed files with 140 additions and 372 deletions.
2 changes: 1 addition & 1 deletion .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ install:
# Install torch with pip and other dependencies
- pip install "https://download.pytorch.org/whl/cpu/torch-1.1.0-%PY_TAG%-%PY_TAG%m-win_amd64.whl"
- pip install "https://download.pytorch.org/whl/cpu/torchvision-0.3.0-%PY_TAG%-%PY_TAG%m-win_amd64.whl"
- pip install scipy
- pip install scipy

# Now install spiceypy
- IF "%ARCH%"=="32" (call "C:\Program Files (x86)\Microsoft Visual Studio 12.0\VC\vcvarsall.bat" x86) ELSE (ECHO "probably a 64bit build")
Expand Down
2 changes: 1 addition & 1 deletion examples/CycleGAN Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
"version": "3.7.3"
}
},
"nbformat": 4,
Expand Down
418 changes: 85 additions & 333 deletions examples/Introduction To TorchGAN.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy
pillow==5.3.0
fastprogress
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def find_version(*file_paths):

VERSION = find_version("torchgan", "__init__.py")

requirements = ["numpy", "pillow==5.3.0"]
requirements = ["numpy", "pillow==5.3.0", "fastprogress"]

setup(
# Metadata
Expand Down
7 changes: 4 additions & 3 deletions tests/torchgan/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch

from torchgan.layers import *

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
Expand Down Expand Up @@ -104,7 +105,7 @@ def test_dense_block2d(self):
layer = DenseBlock2d(5, 3, 16, BottleneckBlock2d, 3, padding=1)

self.match_layer_outputs(layer, input, (16, 83, 10, 10))

def test_self_attention2d(self):
input = torch.rand(16, 88, 10, 10)

Expand All @@ -116,7 +117,7 @@ def test_spectral_norm2d(self):
input = torch.rand(16, 3, 10, 10)

layer = SpectralNorm2d(
torch.nn.Conv2d(3, 10, 3, padding=1),
power_iterations=10)
torch.nn.Conv2d(3, 10, 3, padding=1), power_iterations=10
)

self.match_layer_outputs(layer, input, (16, 10, 10, 10))
1 change: 1 addition & 0 deletions tests/torchgan/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributions as ds

from torchgan.losses import *

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
Expand Down
1 change: 1 addition & 0 deletions tests/torchgan/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch

from torchgan.metrics import *

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
Expand Down
1 change: 1 addition & 0 deletions tests/torchgan/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributions as distributions

from torchgan.models import *

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
Expand Down
3 changes: 2 additions & 1 deletion tests/torchgan/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.optim import Adam

from torchgan import *
from torchgan.losses import *
from torchgan.metrics import *
Expand All @@ -25,7 +26,7 @@ def mnist_dataloader():
[
transforms.Pad((2, 2)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, ), std=(0.5, )),
transforms.Normalize(mean=(0.5,), std=(0.5,)),
]
),
download=True,
Expand Down
7 changes: 1 addition & 6 deletions torchgan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from torchgan import losses
from torchgan import models
from torchgan import trainer
from torchgan import metrics
from torchgan import logging
from torchgan import layers
from torchgan import layers, logging, losses, metrics, models, trainer

__version__ = "v0.0.3-alpha"

Expand Down
6 changes: 3 additions & 3 deletions torchgan/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .residual import *
from .denseblock import *
from .spectralnorm import *
from .selfattention import *
from .minibatchdiscrimination import *
from .residual import *
from .selfattention import *
from .spectralnorm import *
from .virtualbatchnorm import *
2 changes: 1 addition & 1 deletion torchgan/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .visualize import *
from .logger import *
from .visualize import *
3 changes: 2 additions & 1 deletion torchgan/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,15 @@ def run_mid_epoch(self, trainer, *args):
else:
logger(*args, lock_console=True)

def run_end_epoch(self, trainer, epoch, *args):
def run_end_epoch(self, trainer, epoch, time_duration, *args):
r"""Runs the Visualizers at the end of one epoch.
Args:
trainer (torchgan.trainer.Trainer): The base trainer used for training.
epoch (int): The epoch number which was completed.
"""
print("Epoch {} Summary".format(epoch + 1))
print("Epoch time duration : {}".format(time_duration))
for logger in self.logger_mid_epoch:
if type(logger).__name__ == "LossVisualize":
logger(trainer)
Expand Down
4 changes: 3 additions & 1 deletion torchgan/logging/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ def __call__(self, trainer, **kwargs):
generator = getattr(trainer, model)
with torch.no_grad():
image = generator(*self.test_noise[pos])
image = torchvision.utils.make_grid(image, nrow=self.nrow, normalize=True, range=(-1, 1))
image = torchvision.utils.make_grid(
image, nrow=self.nrow, normalize=True, range=(-1, 1)
)
super(ImageVisualize, self).__call__(
trainer, image, model, **kwargs
)
Expand Down
16 changes: 8 additions & 8 deletions torchgan/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .loss import *
from .minimax import *
from .wasserstein import *
from .leastsquares import *
from .auxclassifier import *
from .boundaryequilibrium import *
from .mutualinfo import *
from .energybased import *
from .draganpenalty import *
from .auxclassifier import *
from .historical import *
from .energybased import *
from .featurematching import *
from .functional import *
from .historical import *
from .leastsquares import *
from .loss import *
from .minimax import *
from .mutualinfo import *
from .wasserstein import *
2 changes: 1 addition & 1 deletion torchgan/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .metric import *
from .classifierscore import *
from .metric import *
8 changes: 4 additions & 4 deletions torchgan/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .model import *
from .dcgan import *
from .conditional import *
from .acgan import *
from .autoencoding import *
from .conditional import *
from .dcgan import *
from .infogan import *
from .acgan import *
from .model import *
2 changes: 1 addition & 1 deletion torchgan/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base_trainer import *
from .trainer import *
from .parallel_trainer import *
from .trainer import *
20 changes: 16 additions & 4 deletions torchgan/trainer/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import time
from inspect import _empty, signature
from warnings import warn

import torch
import torchvision
from fastprogress import master_bar, progress_bar

from ..logging.logger import Logger
from ..losses.loss import DiscriminatorLoss, GeneratorLoss
Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(
log_dir=None,
test_noise=None,
nrow=8,
**kwargs
**kwargs,
):
self.device = device
self.losses = {}
Expand Down Expand Up @@ -391,12 +393,22 @@ def train(self, data_loader, **kwargs):
for name in self.optimizer_names:
getattr(self, name).zero_grad()

for epoch in range(self.start_epoch, self.epochs):
master_bar_iter = master_bar(range(self.start_epoch, self.epochs))
for epoch in master_bar_iter:

start_time = time.time()
master_bar_iter.first_bar.comment = f"Training Progress"

for model in self.model_names:
getattr(self, model).train()

for data in data_loader:
for progress_bar_iter, data in zip(
progress_bar(range(len(data_loader)), parent=master_bar_iter),
data_loader,
):

master_bar_iter.child.comment = f"Epoch {epoch+1} Progress"

if type(data) is tuple or type(data) is list:
self.real_inputs = data[0].to(self.device)
self.labels = data[1].to(self.device)
Expand All @@ -422,7 +434,7 @@ def train(self, data_loader, **kwargs):
getattr(self, model).eval()

self.eval_ops(**kwargs)
self.logger.run_end_epoch(self, epoch)
self.logger.run_end_epoch(self, epoch, time.time() - start_time)
self.optim_ops()

print("Training of the Model is Complete")
Expand Down
2 changes: 1 addition & 1 deletion torchgan/trainer/parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
metrics_list,
log_dir=log_dir,
nrow=nrow,
test_noise=test_noise
test_noise=test_noise,
)

self._store_loss_maps()
Expand Down
2 changes: 1 addition & 1 deletion torchgan/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
metrics_list,
log_dir=log_dir,
nrow=nrow,
test_noise=test_noise
test_noise=test_noise,
)

self._store_loss_maps()
Expand Down

0 comments on commit 8c2bc99

Please sign in to comment.