Skip to content

Commit

Permalink
Merge pull request #25 from sooftware/tpu
Browse files Browse the repository at this point in the history
Fix name gub
  • Loading branch information
sooftware committed May 8, 2021
2 parents 3799581 + f192f12 commit a6fe05c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
21 changes: 4 additions & 17 deletions bin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,22 @@
import os
import hydra
import pytorch_lightning as pl
import logging
from omegaconf import OmegaConf, DictConfig
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import DictConfig

from lightning_asr.data.librispeech.lit_data_module import LightningLibriSpeechDataModule
from lightning_asr.metric import WordErrorRate
from lightning_asr.model import LightningASRModel
from lightning_asr.utilities import check_environment
from lightning_asr.utilities import parse_configs


@hydra.main(config_path=os.path.join('..', "configs"), config_name="train")
def hydra_entry(configs: DictConfig) -> None:
pl.seed_everything(configs.seed)

logger = logging.getLogger(__name__)
num_devices = check_environment(configs.use_cuda, logger)
logger.info(OmegaConf.to_yaml(configs))

if configs.use_tensorboard:
logger = TensorBoardLogger("tensorboard", name="Lightning Automatic Speech Recognition")
else:
logger = True

if configs.use_cuda and configs.use_tpu:
raise ValueError("configs.use_cuda and configs.use_tpu both are True, Please choose between GPU and TPU.")
logger, num_devices = parse_configs(configs)

data_module = LightningLibriSpeechDataModule(configs)
vocab = data_module.prepare_data(configs.dataset_download, configs.vocab_size)
data_module.setup(vacab=vocab)
data_module.setup(vocab=vocab)

model = LightningASRModel(
configs=configs,
Expand Down
22 changes: 20 additions & 2 deletions lightning_asr/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
import torch
import platform
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.loggers import TensorBoardLogger


def check_environment(use_cuda: bool, logger) -> int:
def _check_environment(use_cuda: bool, logger) -> int:
"""
Check execution envirionment.
OS, Processor, CUDA version, Pytorch version, ... etc.
Expand All @@ -49,3 +51,19 @@ def check_environment(use_cuda: bool, logger) -> int:
logger.info(f"PyTorch version : {torch.__version__}")

return num_devices


def parse_configs(configs: DictConfig):
logger = logging.getLogger(__name__)
logger.info(OmegaConf.to_yaml(configs))
num_devices = _check_environment(configs.use_cuda, logger)

if configs.use_tensorboard:
logger = TensorBoardLogger("tensorboard", name="Lightning Automatic Speech Recognition")
else:
logger = True

if configs.use_cuda and configs.use_tpu:
raise ValueError("configs.use_cuda and configs.use_tpu both are True, Please choose between GPU and TPU.")

return logger, num_devices

0 comments on commit a6fe05c

Please sign in to comment.