From b41532608d0f1648285d7794eda4331b1cfb297f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 14 Nov 2024 11:04:47 +0100 Subject: [PATCH] fix: do not print perf stat when NaN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If the chat is exited or interrupted it will still print the stats with NaN values which is unnecessary. Signed-off-by: Sébastien Han --- torchchat/generate.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 4a67195fb..66f26ff9f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1189,12 +1189,27 @@ def callback(x, *, done_generating=False): f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}" ) - print( - f"\n Average tokens/sec (total): {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f} \ - \nAverage tokens/sec (first token): {torch.mean(torch.tensor(aggregate_metrics['first_token_per_sec'])).item():.2f} \ - \nAverage tokens/sec (next tokens): {torch.mean(torch.tensor(aggregate_metrics['next_tokens_per_sec'])).item():.2f} \n\ + avg_tokens_sec = torch.mean( + torch.tensor(aggregate_metrics["tokens_per_sec"]) + ).item() + avg_first_token_sec = torch.mean( + torch.tensor(aggregate_metrics["first_token_per_sec"]) + ).item() + avg_next_tokens_sec = torch.mean( + torch.tensor(aggregate_metrics["next_tokens_per_sec"]) + ).item() + + if not ( + torch.isnan(torch.tensor(avg_tokens_sec)) + or torch.isnan(torch.tensor(avg_first_token_sec)) + or torch.isnan(torch.tensor(avg_next_tokens_sec)) + ): + print( + f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \ + \nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \ + \nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\ " - ) + ) if torch.cuda.is_available(): print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")