diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 9c06b5e34f3..70e7b91c3af 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -5619,7 +5619,7 @@ def test_t5(self): "python", f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py", "--dataset", - self.sentence_dataset, + self.qa_dataset, "--artifact", self.artifact_dir, "--build_folder", @@ -6486,6 +6486,11 @@ def setup_environment(): help="Location for imagenet dataset", type=str, ) + parser.add_argument( + "--qa_dataset", + help="Location for QA dataset", + type=str, + ) parser.add_argument( "--sentence_dataset", help="Location for sentence dataset", @@ -6549,6 +6554,7 @@ def setup_environment(): TestQNN.executorch_root = args.executorch_root TestQNN.artifact_dir = args.artifact_dir TestQNN.image_dataset = args.image_dataset + TestQNN.qa_dataset = args.qa_dataset TestQNN.sentence_dataset = args.sentence_dataset TestQNN.pretrained_weight = args.pretrained_weight TestQNN.model_name = args.model_name diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 94ca38ff091..17d847a5507 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -637,7 +637,7 @@ def __len__(self): # prepare input data inputs, targets = [], [] data_loader = get_data_loader() - for _, data in enumerate(data_loader): + for data in data_loader: if len(inputs) >= data_size: break input_ids = data[0] @@ -729,9 +729,9 @@ def __getitem__(self, idx): dataset, batch_size=1, shuffle=shuffle, collate_fn=collator ) - inputs, targets, input_list = [], [], "" + inputs, targets = [], [] data_loader = get_data_loader(max_hidden_seq_length) - for idx, batch in enumerate(data_loader): + for batch in data_loader: if len(inputs) >= data_size: break input_ids = batch["input_ids"] @@ -750,9 +750,8 @@ def __getitem__(self, idx): ) ) targets.append(labels) - input_list += f"input_{idx}_0.raw input_{idx}_1.raw input_{idx}_2.raw\n" - return inputs, targets, input_list + return inputs, targets def setup_common_args_and_variables():