diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 0d1dd306091..ea4e6b37243 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -240,7 +240,7 @@ def __init__(self, params: ModelArgs): self.precompute_freqs_cis = partial( hf_precompute_freqs_cis, partial_rotary_factor=self.params.partial_rotary_factor, - device=self.params.device, + device=getattr(self.params, "device", "cpu"), ) self.apply_rotary_emb = hf_apply_rotary_emb else: @@ -249,7 +249,7 @@ def __init__(self, params: ModelArgs): use_scaled=self.params.use_scaled_rope, scale_factor=self.params.rope_scale_factor, high_freq_factor=self.params.high_freq_factor, - device=self.params.device, + device=getattr(self.params, "device", "cpu"), ) self.apply_rotary_emb = RotaryEmbedding()