Skip to content

Commit

Permalink
adapt llama2 recipe to the latest modif
Browse files Browse the repository at this point in the history
  • Loading branch information
BenoitWang committed May 16, 2024
1 parent 2a79792 commit 1b6e295
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ llama2_model: !new:speechbrain.lobes.models.huggingface_transformers.llama2.LLAM
top_k: !ref <top_k>
top_p: !ref <top_p>
with_peft: True
use_4bit: True

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
Expand Down
13 changes: 10 additions & 3 deletions recipes/MultiWOZ/response_generation/llama2/train_with_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def compute_forward(self, batch, stage):
padding_mask = ~self.hparams.padding_mask(
input_ids, pad_idx=tokenizer.pad_token_id
)
outputs = self.modules.llama2_model(input_ids, padding_mask).logits

outputs = self.modules.llama2_model(
input_ids=input_ids,
attention_mask=padding_mask,
).logits

return outputs

Expand All @@ -57,14 +61,17 @@ def compute_objectives(self, predictions, batch, stage):
prompt_bos, pad_idx=tokenizer.pad_token_id
)
hyps = self.modules.llama2_model.generate(
prompt_bos.detach(), padding_mask.detach()
input_ids=prompt_bos.detach(),
attention_mask=padding_mask.detach(),
)
elif stage == sb.Stage.TEST:
padding_mask = ~self.hparams.padding_mask(
prompt_bos, pad_idx=tokenizer.pad_token_id
)
hyps = self.modules.llama2_model.generate(
prompt_bos.detach(), padding_mask.detach(), "beam"
input_ids=prompt_bos.detach(),
attention_mask=padding_mask.detach(),
decoder_type="beam",
)

if stage != sb.Stage.TRAIN:
Expand Down

0 comments on commit 1b6e295

Please sign in to comment.