From 0d700deda58c3ff67191a418eae3fef1367bcbe2 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Sun, 22 Sep 2024 18:04:15 -0700 Subject: [PATCH] Fix support for ET repo generated pte by adding batch dim --- torchchat/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchchat/model.py b/torchchat/model.py index aaa72cb2a..79bd1f188 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -951,6 +951,11 @@ def forward(self, x, input_pos): # the first element to get the tensor assert len(logits) == 1 logits = logits[0] + + # Add a batch dimension, if it's missing (e.g. some pte's + # exported from the ExecuTorch repo) + if logits.dim() == 2: + logits = logits.unsqueeze(0) return logits def setup_caches(self, max_batch_size, max_seq_length):