Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
qywu committed Feb 9, 2022
1 parent adb7486 commit 776270c
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 151 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/qywu/TorchFly",
packages=find_packages(),
packages=['torchfly'],
install_requires=required_packages,
classifiers=[
'Programming Language :: Python :: 3',
Expand Down
2 changes: 1 addition & 1 deletion torchfly/flyconfig/config/training/default.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file contains the default configuration for TrainerLoop
# This file contains the default configuration for Trainer
#
training:
random_seed: 123
Expand Down
2 changes: 1 addition & 1 deletion torchfly/flylogger/train_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def setup_timer(self, trainer: Trainer):
@handle_event(Events.TRAIN_BEGIN, priority=199)
def setup_tensorboard(self, trainer: Trainer):
# Setup tensorboard
log_dir = os.path.join(os.getcwd(), f"{trainer.name}_tensorboard")
log_dir = os.path.join(os.getcwd(), f"{trainer.trainer_name}_{trainer.stage_name}/tensorboard")
os.makedirs(log_dir, exist_ok=True)
self.tensorboard = SummaryWriter(log_dir=log_dir, purge_step=trainer.global_step_count)
trainer.tensorboard = self.tensorboard
Expand Down
36 changes: 22 additions & 14 deletions torchfly/training/callbacks/checkpoint/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,37 @@ def __init__(self, config: DictConfig) -> None:
# Check distributed
if get_rank() != 0:
raise NotImplementedError("Checkpoint callback can only be called for rank 0!")
self.checkpointer = None
self.storage_dir = None

@handle_event(Events.TRAIN_BEGIN, priority=199)
def setup_checkpointer(self, trainer: Trainer):
# Initialize Checkpointer
self.storage_dir = os.path.join(f"{trainer.trainer_name}_{trainer.stage_name}", trainer.config.checkpointing.directory)
self.checkpointer = Checkpointer(
sync_every_save=True,
async_save=self.config.checkpointing.async_save,
num_checkpoints_to_keep=self.config.checkpointing.num_checkpoints_to_keep,
keep_checkpoint_every_num_seconds=(self.config.checkpointing.keep_checkpoint_every_num_seconds),
storage_dir=self.config.checkpointing.directory)

self.last_save_time = time.time()
async_save=trainer.config.checkpointing.async_save,
num_checkpoints_to_keep=trainer.config.checkpointing.num_checkpoints_to_keep,
keep_checkpoint_every_num_seconds=(trainer.config.checkpointing.keep_checkpoint_every_num_seconds),
storage_dir=self.storage_dir)

@handle_event(Events.INITIALIZE, priority=199)
def setup_checkpointer(self, trainer: Trainer):
# Checkpoint in epochs or steps
if self.config.checkpointing.steps_interval < 0 and self.config.checkpointing.seconds_interval < 0:
if trainer.config.checkpointing.steps_interval < 0 and trainer.config.checkpointing.seconds_interval < 0:
self.checkpoint_in_epoch = True
else:
self.checkpoint_in_epoch = False

# Checkpoint in seconds or steps
if self.config.checkpointing.steps_interval > 0 and self.config.checkpointing.seconds_interval > 0:
if trainer.config.checkpointing.steps_interval > 0 and trainer.config.checkpointing.seconds_interval > 0:
raise ValueError(
"Either `checkpointing.steps_interval` or `checkpointing.seconds_interval` can be set greater than 0!")
elif self.config.checkpointing.steps_interval < 0 and self.config.checkpointing.seconds_interval > 0:
elif trainer.config.checkpointing.steps_interval < 0 and trainer.config.checkpointing.seconds_interval > 0:
self.checkpoint_in_seconds = True
elif self.config.checkpointing.steps_interval > 0 and self.config.checkpointing.seconds_interval < 0:
elif trainer.config.checkpointing.steps_interval > 0 and trainer.config.checkpointing.seconds_interval < 0:
self.checkpoint_in_seconds = False
else:
self.checkpoint_in_seconds = False
self.last_save_time = time.time()

@handle_event(Events.BATCH_END)
def save_checkpoint(self, trainer: Trainer):
Expand All @@ -65,11 +67,11 @@ def save_checkpoint(self, trainer: Trainer):
if self.checkpoint_in_seconds:
current_time = time.time()
# the elapsed time is longer than the seconds
if (current_time - self.last_save_time) > self.config.checkpointing.seconds_interval:
if (current_time - self.last_save_time) > trainer.config.checkpointing.seconds_interval:
self._save_trainer_state(trainer)
self.last_save_time = current_time
else:
if (trainer.global_step_count + 1) % self.config.checkpointing.steps_interval == 0:
if (trainer.global_step_count + 1) % trainer.config.checkpointing.steps_interval == 0:
self._save_trainer_state(trainer)

@handle_event(Events.EPOCH_END)
Expand All @@ -78,6 +80,12 @@ def save_checkpoint_epoch(self, trainer: Trainer):
if self.checkpoint_in_epoch:
self._save_trainer_state(trainer)

@handle_event(Events.TRAIN_END)
def save_checkpoint_stage(self, trainer: Trainer):
trainer_state_dict = trainer.get_trainer_state()
self.checkpointer.save_checkpoint("stage_end", trainer.get_model_state(),
trainer_state_dict)

def _save_trainer_state(self, trainer: Trainer):

trainer_state_dict = trainer.get_trainer_state()
Expand Down
17 changes: 8 additions & 9 deletions torchfly/training/callbacks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_on_train_end(self, trainer):
# Start validation at the begining
if trainer.test_dataloader is not None and self.enabled:
try:
state_dict = torch.load(f"evaluation/{trainer.name}_model_weights/best.pth")
state_dict = torch.load(f"{trainer.trainer_name}_{trainer.stage_name}/evaluation/model_weights/best.pth")
trainer.model.load_state_dict(state_dict["model_weights"])
logger.info(
f"Loading the `best.pth` epoch {state_dict['epochs_trained']} step {state_dict['global_step_count']} with score {state_dict['score']}!"
Expand Down Expand Up @@ -107,8 +107,7 @@ def info_valid_begin(self, trainer: Trainer):
logger.info(f"Validation starts at epoch {trainer.epochs_trained + 1} steps {trainer.global_step_count}")
self.eval_start_time = time.time()
# save model here
os.makedirs("evaluation", exist_ok=True)
os.makedirs(f"evaluation/{trainer.name}_model_weights", exist_ok=True)
os.makedirs(f"{trainer.trainer_name}_{trainer.stage_name}/evaluation/model_weights", exist_ok=True)

@handle_event(Events.VALIDATE_END)
def record_validation_metrics(self, trainer: Trainer):
Expand Down Expand Up @@ -140,22 +139,22 @@ def record_validation_metrics(self, trainer: Trainer):
self.saved_top_k_models = sorted(self.saved_top_k_models, key=lambda item: item["score"])

if len(self.saved_top_k_models) > 0 and model_score > self.saved_top_k_models[-1]["score"]:
model_path = os.path.join(f"evaluation/{trainer.name}_model_weights", f"best.pth")
model_path = os.path.join(f"{trainer.trainer_name}_{trainer.stage_name}/evaluation/model_weights", f"best.pth")
torch.save(self.get_model_weights_stamp(trainer, model_score), model_path)
is_best = True

if len(self.saved_top_k_models) < self.config.save_top_k_models:
# save the model
model_path = "epoch_" + str(trainer.epochs_trained) + "_step_" + str(trainer.global_step_count) + ".pth"
model_path = os.path.join(f"evaluation/{trainer.name}_model_weights", model_path)
model_path = os.path.join(f"{trainer.trainer_name}_{trainer.stage_name}/evaluation/model_weights", model_path)
torch.save(self.get_model_weights_stamp(trainer, model_score), model_path)
self.saved_top_k_models.append({"path": model_path, "score": model_score})
else:
model_path = self.saved_top_k_models.pop(0)["path"]
os.remove(model_path)

with open("evaluation/results.txt", "a") as f:
f.write(f"{trainer.name} validation @epoch {trainer.epochs_trained} @step {trainer.global_step_count} | ")
with open(f"{trainer.trainer_name}_{trainer.stage_name}/evaluation/results.txt", "a") as f:
f.write(f"validation @epoch {trainer.epochs_trained} @step {trainer.global_step_count} | ")
f.write(json.dumps(metrics_dict))
f.write(" | ")
if is_best:
Expand Down Expand Up @@ -190,8 +189,8 @@ def record_test_metrics(self, trainer: Trainer):
metrics_dict[metric_name] = value
logger.info(log_string)

with open("evaluation/results.txt", "a") as f:
f.write(f"{trainer.name} test: ")
with open(f"{trainer.trainer_name}_{trainer.stage_name}/evaluation/results.txt", "a") as f:
f.write(f"test: ")
f.write(json.dumps(metrics_dict))
f.write("\n")

Expand Down
78 changes: 51 additions & 27 deletions torchfly/training/flymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@
import torch
import torch.nn as nn
import numpy as np

# import apex
# from apex.parallel import DistributedDataParallel, Reducer
# from torch.nn.parallel import DistributedDataParallel

from torchfly.utilities import move_to_device
from torchfly.metrics import CategoricalAccuracy, Average, MovingAverage, Speed
from torchfly.training.schedulers import ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, \
WarmupLinearSchedule, WarmupCosineWithHardRestartsSchedule
from torchfly.training.schedulers import (
ConstantLRSchedule,
WarmupConstantSchedule,
WarmupCosineSchedule,
WarmupLinearSchedule,
WarmupCosineWithHardRestartsSchedule,
)

import logging

logger = logging.getLogger(__name__)


class FlyModel(nn.Module):

def __init__(self, config, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
Expand All @@ -36,69 +41,86 @@ def predict_step(self, batch_idx, dataloder_idx=0, *args, **kwargs):
def get_metrics(self, reset):
return {}

def get_optimizer_parameters(self):
def get_optimizer_parameters(self, config=None):
"""
This function is used to set parameters with different weight decays
"""
try:
weight_decay = self.config.training.optimization.weight_decay
weight_decay = config.optimization.weight_decay
except:
weight_decay = 0.01

# default groups
no_decay = ["bias", "Norm"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
"params": [
p
for n, p in self.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": weight_decay,
},
{
"params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0
"params": [
p
for n, p in self.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters

def configure_optimizers(self, total_num_update_steps) -> [List, List]:
optimizer_grouped_parameters = self.get_optimizer_parameters()
lr = self.config.training.optimization.learning_rate
optimizer_name = self.config.training.optimization.optimizer_name
max_gradient_norm = self.config.training.optimization.max_gradient_norm
betas = self.config.training.optimization.betas if self.config.training.optimization.get("betas") else (0.9,
0.999)
def configure_optimizers(self, config, total_num_update_steps) -> [List, List]:
optimizer_grouped_parameters = self.get_optimizer_parameters(config)
lr = config.optimization.learning_rate
optimizer_name = config.optimization.optimizer_name
max_gradient_norm = config.optimization.max_gradient_norm
betas = (
config.optimization.betas
if config.optimization.get("betas")
else (0.9, 0.999)
)

if optimizer_name == "AdamW":
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, betas=betas)
optimizer = torch.optim.AdamW(
optimizer_grouped_parameters, lr=lr, betas=betas
)
elif optimizer_name == "Adafactor":
raise NotImplementedError
elif optimizer_name == "Adadelta":
optimizer = torch.optim.Adadelta(optimizer_grouped_parameters, lr=lr)
else:
raise NotImplementedError(
f"{optimizer_name} is not implemented! Override FlyModel's configure optimizer to continue!")
f"{optimizer_name} is not implemented! Override FlyModel's configure optimizer to continue!"
)

scheduler_name = self.config.training.scheduler.scheduler_name
warmup_steps = self.config.training.scheduler.get("warmup_steps", None)
warmup_cycle = self.config.training.scheduler.get("warmup_cosine_cycle", None)
scheduler_name = config.scheduler.scheduler_name
warmup_steps = config.scheduler.get("warmup_steps", None)
warmup_cycle = config.scheduler.get("warmup_cosine_cycle", None)

if scheduler_name == "Constant":
scheduler = ConstantLRSchedule(optimizer)
elif scheduler_name == "WarmupConstant":
scheduler = WarmupConstantSchedule(optimizer, warmup_steps)
elif scheduler_name == "WarmupLinear":
scheduler = WarmupLinearSchedule(optimizer, warmup_steps, total_num_update_steps)
scheduler = WarmupLinearSchedule(
optimizer, warmup_steps, total_num_update_steps
)
elif scheduler_name == "WarmupCosine":
if warmup_cycle is None:
warmup_cycle = 0.5
scheduler = WarmupCosineSchedule(optimizer, warmup_steps, total_num_update_steps, cycles=warmup_cycle)
scheduler = WarmupCosineSchedule(
optimizer, warmup_steps, total_num_update_steps, cycles=warmup_cycle
)
elif scheduler_name == "WarmupCosineWithHardRestartsSchedule":
if warmup_cycle is None:
warmup_cycle = 0.5

scheduler = WarmupCosineWithHardRestartsSchedule(optimizer,
warmup_steps,
total_num_update_steps,
cycles=warmup_cycle)
scheduler = WarmupCosineWithHardRestartsSchedule(
optimizer, warmup_steps, total_num_update_steps, cycles=warmup_cycle
)
else:
logger.error("Write your own version of `configure_scheduler`!")
raise NotImplementedError
Expand Down Expand Up @@ -141,7 +163,9 @@ def test_loop(self, dataloader):
self.test_step(batch, batch_idx)

def get_last_lr(self):
raise NotImplementedError("Please hook this function to the `scheduler.get_last_lr`!")
raise NotImplementedError(
"Please hook this function to the `scheduler.get_last_lr`!"
)

def get_training_metrics(self) -> Dict[str, str]:
raise NotImplementedError
Expand Down

0 comments on commit 776270c

Please sign in to comment.