Skip to content

Commit

Permalink
Merge pull request #153 from shibing624/dev-round
Browse files Browse the repository at this point in the history
Dev round
  • Loading branch information
shibing624 committed Aug 7, 2023
2 parents 85d322e + 4549007 commit e60fb5b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
2 changes: 0 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,6 @@ def main():
print(f"Output: {response}\n")
results.append({"Input": prompt, "Output": response})

dirname = os.path.dirname(args.predictions_file)
os.makedirs(dirname, exist_ok=True)
with open(args.predictions_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)

Expand Down
4 changes: 2 additions & 2 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,9 +661,9 @@ def get_dialog(examples):
source_ids = source_ids[:max_source_length]
if len(target_ids) > max_target_length - 1: # eos token
target_ids = target_ids[:max_target_length - 1]
if source_ids[0] == tokenizer.eos_token_id:
if len(source_ids) > 0 and source_ids[0] == tokenizer.eos_token_id:
source_ids = source_ids[1:]
if target_ids[-1] == tokenizer.eos_token_id:
if len(target_ids) > 0 and target_ids[-1] == tokenizer.eos_token_id:
target_ids = target_ids[:-1]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
break
Expand Down

0 comments on commit e60fb5b

Please sign in to comment.