Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5619,7 +5619,7 @@ def test_t5(self):
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py",
"--dataset",
self.sentence_dataset,
self.qa_dataset,
"--artifact",
self.artifact_dir,
"--build_folder",
Expand Down Expand Up @@ -6486,6 +6486,11 @@ def setup_environment():
help="Location for imagenet dataset",
type=str,
)
parser.add_argument(
"--qa_dataset",
help="Location for QA dataset",
type=str,
)
parser.add_argument(
"--sentence_dataset",
help="Location for sentence dataset",
Expand Down Expand Up @@ -6549,6 +6554,7 @@ def setup_environment():
TestQNN.executorch_root = args.executorch_root
TestQNN.artifact_dir = args.artifact_dir
TestQNN.image_dataset = args.image_dataset
TestQNN.qa_dataset = args.qa_dataset
TestQNN.sentence_dataset = args.sentence_dataset
TestQNN.pretrained_weight = args.pretrained_weight
TestQNN.model_name = args.model_name
Expand Down
9 changes: 4 additions & 5 deletions examples/qualcomm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def __len__(self):
# prepare input data
inputs, targets = [], []
data_loader = get_data_loader()
for _, data in enumerate(data_loader):
for data in data_loader:
if len(inputs) >= data_size:
break
input_ids = data[0]
Expand Down Expand Up @@ -729,9 +729,9 @@ def __getitem__(self, idx):
dataset, batch_size=1, shuffle=shuffle, collate_fn=collator
)

inputs, targets, input_list = [], [], ""
inputs, targets = [], []
data_loader = get_data_loader(max_hidden_seq_length)
for idx, batch in enumerate(data_loader):
for batch in data_loader:
if len(inputs) >= data_size:
break
input_ids = batch["input_ids"]
Expand All @@ -750,9 +750,8 @@ def __getitem__(self, idx):
)
)
targets.append(labels)
input_list += f"input_{idx}_0.raw input_{idx}_1.raw input_{idx}_2.raw\n"

return inputs, targets, input_list
return inputs, targets


def setup_common_args_and_variables():
Expand Down
Loading