Skip to content

Commit

Permalink
Support for passing device parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ashok-arjun committed Apr 5, 2024
1 parent 7454088 commit f90bf56
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lag_llama/gluon/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
track_loss_per_series: bool = False,
ckpt_path: Optional[str] = None,
nonnegative_pred_samples: bool = False,
device: torch.device = "cuda"
) -> None:
default_trainer_kwargs = {"max_epochs": 100}
if trainer_kwargs is not None:
Expand Down Expand Up @@ -225,6 +226,7 @@ def __init__(

self.use_cosine_annealing_lr = use_cosine_annealing_lr
self.cosine_annealing_lr_args = cosine_annealing_lr_args
self.device = device

@classmethod
def derive_auto_fields(cls, train_iter):
Expand Down Expand Up @@ -284,6 +286,7 @@ def create_lightning_module(self, use_kv_cache: bool = False) -> pl.LightningMod
if self.ckpt_path is not None:
return LagLlamaLightningModule.load_from_checkpoint(
checkpoint_path=self.ckpt_path,
map_location=self.device,
strict=False,
loss=self.loss,
lr=self.lr,
Expand Down

0 comments on commit f90bf56

Please sign in to comment.