Skip to content

Commit

Permalink
minor QOL changes
Browse files Browse the repository at this point in the history
  • Loading branch information
robogast committed Nov 15, 2021
1 parent 93cf218 commit 448854b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 73 deletions.
1 change: 0 additions & 1 deletion conf/model/vq_ae.yaml
Expand Up @@ -3,7 +3,6 @@ defaults:
- loss_f@loss_f_conf: huber
- encoder@encoder_conf: default
- decoder@decoder_conf: default
- metrics: null

- optional optional_overrides/optim: ${model/optim@model.optim_conf}

Expand Down
96 changes: 44 additions & 52 deletions utils/conf_helpers.py
@@ -1,55 +1,47 @@
from typing import List, Union, Any
from collections.abc import Callable
from dataclasses import dataclass
from abc import ABC
from pathlib import Path
from functools import partial, reduce
from functools import reduce
from operator import add, mul, pow
from pathlib import Path
from typing import Any, List, Union

from hydra.utils import instantiate
from omegaconf import OmegaConf, DictConfig, ListConfig, MISSING


OmegaConf.register_new_resolver(
name="path.stem",
resolver= lambda path: Path(path).stem,
replace=True # need this for multirun
)
OmegaConf.register_new_resolver(
name="path.absolute",
resolver= lambda path: Path(path).absolute(),
replace=True # need this for multirun
)

OmegaConf.register_new_resolver(
name="len",
resolver=lambda iterable: len([elem for elem in iterable if not (isinstance(elem, str) and len(elem) > 0 and elem[0] == '_')]),
replace=True # need this for multirun
)

OmegaConf.register_new_resolver(
name="add",
resolver=lambda *x: reduce(add, x),
replace=True # need this for multirun
)
OmegaConf.register_new_resolver(
name="mul",
resolver=lambda *x: reduce(mul, x),
replace=True # need this for multirun
)
OmegaConf.register_new_resolver(
name="pow",
resolver=lambda x, y: pow(x, y),
replace=True # need this for multirun
)
from omegaconf import DictConfig, ListConfig, MISSING, OmegaConf


def add_resolvers() -> None:
"""Adds resolvers to the OmegaConf parsers"""

def add_resolver(name: str, resolver: Callable):
OmegaConf.register_new_resolver(
name=name,
resolver=resolver,
replace=True # need this for multirun
)

for name_, resolver_ in (
("path.stem", lambda path: Path(path).stem),
("path.absolute", lambda path: Path(path).absolute()),
# calculate length of any list or object,
# but skip list elements if they are a `str` which starts with '_'
("len", lambda iterable: len([
elem for elem in iterable
if not (isinstance(elem, str) and len(elem) > 0 and elem[0] == '_')
])),
("add", lambda *x: reduce(add, x)),
("mul", lambda *x: reduce(mul, x)),
("pow", lambda x, y: pow(x, y))
):
add_resolver(name_, resolver_)


@dataclass
class DatasetConf(ABC):
class DatasetConf(DictConfig):
_target_: str = 'torch.utils.data.Dataset'


@dataclass
class DataloaderConf:
class DataloaderConf(DictConfig):
_target_: str = 'torch.utils.data.DataLoader'

dataset: DatasetConf = MISSING
Expand All @@ -62,30 +54,32 @@ class DataloaderConf:
prefetch_factor: int = 2
persistent_workers: bool = True


@dataclass
class OptimizerConf(ABC):
class OptimizerConf(DictConfig):
_target_: str = 'torch.optim.Optimizer'


@dataclass
class ModuleConf(ABC):
class ModuleConf(DictConfig):
_target_: str = 'torch.nn.Module'


def instantiate_nested_dictconf(**nested_conf) -> Any:
def instantiate_nested_dictconf(**nested_conf: DictConfig) -> Any:
listified_obj = instantiate_dictified_listconf(**nested_conf)

assert len(listified_obj) == 1, f"more than one root object found in {listified_obj}"

return listified_obj[0]


def instantiate_dictified_listconf(**nested_conf) -> List:
'''
def instantiate_dictified_listconf(**nested_conf: DictConfig) -> List:
"""
Warning:
set `_recursive_: False`
at the same level as `_target_: utils.conf_helpers.instantiate_nested_conf`
inside your config!
'''
"""

de_nested = listify_nested_conf(nested_conf)

Expand All @@ -96,7 +90,7 @@ def instantiate_dictified_listconf(**nested_conf) -> List:


def listify_nested_conf(conf: Any) -> Union[DictConfig, ListConfig]:
'''
"""
Given the keys and values of a nested config,
removes keys and makes their corresponding value a ListConfig,
if the nest level doesn't contain the key `_target_`.
Expand All @@ -122,7 +116,7 @@ def listify_nested_conf(conf: Any) -> Union[DictConfig, ListConfig]:
- _target: albumentations.pytorch.transforms.ToTensorV2
_target_: albumentations.Compose
```
'''
"""

if isinstance(conf, (DictConfig, dict)):
return (
Expand All @@ -142,8 +136,6 @@ def listify_nested_conf(conf: Any) -> Union[DictConfig, ListConfig]:
return conf




if __name__ == '__main__':
conf = {
'compose': {
Expand All @@ -160,4 +152,4 @@ def listify_nested_conf(conf: Any) -> Union[DictConfig, ListConfig]:
'_target_': 'albumentations.Compose'
}
}
print(listify_nested_conf(conf))
print(listify_nested_conf(conf))
10 changes: 0 additions & 10 deletions vq_ae/__init__.py
Expand Up @@ -8,14 +8,4 @@

del Path

from pathlib import Path

__all__ = [
f.stem
for f in Path(__file__).parent.glob("*.py")
if "__" != f.stem[:2]
]

del Path

from . import *
13 changes: 7 additions & 6 deletions vq_ae/layers/conv_block.py
Expand Up @@ -141,7 +141,7 @@ def __init__(
bottleneck_divisor: float,
activation: ModuleConf,
conv_conf: ModuleConf,
n_layers: int,
n_layers: Optional[int] = None,
):
super().__init__()

Expand Down Expand Up @@ -190,10 +190,11 @@ def __init__(
else:
self.skip_conv = None

self.initialize_weights(n_layers)
if n_layers is not None:
self.initialize_weights(n_layers)

def forward(self, input: torch.Tensor):
out = input
def forward(self, inp: torch.Tensor):
out = inp

out = self.activation(out + self.bias1a)
out = self.branch_conv1(out + self.bias1b)
Expand All @@ -207,9 +208,9 @@ def forward(self, input: torch.Tensor):
out = out * self.scale + self.bias4

out = out + (
self.skip_conv(input + self.bias1c) + self.bias1d
self.skip_conv(inp + self.bias1c) + self.bias1d
if self.skip_conv is not None
else input
else inp
)

return out
Expand Down
2 changes: 0 additions & 2 deletions vq_ae/model.py
Expand Up @@ -20,7 +20,6 @@ def __init__(
loss_f_conf: ModuleConf,
encoder_conf: ModuleConf,
decoder_conf: ModuleConf,
metrics: Optional[Sequence[ModuleConf]],
**kwargs
):
super().__init__()
Expand All @@ -35,7 +34,6 @@ def __init__(
('loss_f', loss_f_conf),
('encoder', encoder_conf),
('decoder', decoder_conf),
('metrics', metrics)
):
setattr(self, attr_name, instantiate(attr_conf))

Expand Down
8 changes: 6 additions & 2 deletions vq_ae/train.py
Expand Up @@ -2,14 +2,15 @@
import pytorch_lightning as pl
import torch
from hydra.utils import call, instantiate

import utils.conf_helpers # import adds parsers to hydra parser
from omegaconf import OmegaConf


@hydra.main(config_path="../conf", config_name="camelyon16_config")
def main(experiment):
torch.cuda.empty_cache()

OmegaConf.save(experiment, 'experiment.yml')

if 'utils' in experiment:
call(experiment.utils)

Expand All @@ -26,4 +27,7 @@ def main(experiment):


if __name__ == '__main__':
from utils.conf_helpers import add_resolvers
add_resolvers()

main()

0 comments on commit 448854b

Please sign in to comment.