Skip to content

Commit

Permalink
uupdate
Browse files Browse the repository at this point in the history
  • Loading branch information
qywu committed Jan 24, 2022
1 parent 80f031e commit adb7486
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 132 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ regex
overrides
omegaconf>=2.1
pyarrow
colorlog
colorlog==5.0.1
hydra-core
128 changes: 0 additions & 128 deletions tests/test-flyconfig/Untitled.ipynb

This file was deleted.

1 change: 1 addition & 0 deletions torchfly/training/callbacks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def record_validation_metrics(self, trainer: Trainer):
@handle_event(Events.TEST_BEGIN)
def info_test_begin(self, trainer: Trainer):
logger.info(f"Test starts! ")
os.makedirs("evaluation", exist_ok=True)
self.eval_start_time = time.time()

@handle_event(Events.TEST_END)
Expand Down
5 changes: 5 additions & 0 deletions torchfly/utilities/save_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from omegaconf import OmegaConf

def save_config(config, filename):
with open(filename, "w") as f:
OmegaConf.save(config, f)
2 changes: 1 addition & 1 deletion tutorials/MNIST/config/training/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ training:
evaluation:
save_top_k_models: 3
total_num:
epochs: 20
epochs: 5
update_steps: -1 # disabled when total_num.epochs < 0
14 changes: 12 additions & 2 deletions tutorials/MNIST/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
from torchfly.utils import distributed
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from model import CNNFlyModel

Expand All @@ -23,6 +25,7 @@ def train_loader_fn(self):
}

with distributed.mutex() as rank:

dataset = datasets.MNIST(os.path.join(self.config.task.datadir, 'MNIST'),
train=True,
download=True,
Expand Down Expand Up @@ -57,6 +60,8 @@ def main():
torch.distributed.init_process_group(backend='nccl', init_method='env://')

config = FlyConfig.load("config/config.yaml")
set_random_seed(config.training.random_seed)

data_helper = DataLoaderHelper(config)
train_dataloader = data_helper.train_loader_fn()
valid_dataloader = data_helper.valid_loader_fn()
Expand All @@ -65,9 +70,14 @@ def main():

trainer = Trainer(config.training, model)

with FlyLogger(config.flylogger, overwrite=True) as flylogger:
model.configure_metrics()

with FlyLogger(config.flylogger) as flylogger:
with open("config.yaml", "w") as f:
OmegaConf.save(config, f)
trainer.train(train_dataloader, valid_dataloader)



if __name__ == "__main__":
main()

0 comments on commit adb7486

Please sign in to comment.