-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Unable to download some files from AliyunOSS
- Loading branch information
1 parent
271a65a
commit 9db9e0a
Showing
6 changed files
with
149 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,28 @@ | ||
import logging | ||
import pathlib | ||
import pytest | ||
from vectordb_bench.backend.data_source import AliyunOSSReader, AwsS3Reader | ||
from vectordb_bench.backend.dataset import Dataset, DatasetManager | ||
from vectordb_bench.backend.data_source import DatasetSource | ||
from vectordb_bench.backend.cases import type2case | ||
|
||
log = logging.getLogger(__name__) | ||
log = logging.getLogger("vectordb_bench") | ||
|
||
class TestReader: | ||
@pytest.mark.parametrize("size", [ | ||
100_000, | ||
1_000_000, | ||
10_000_000, | ||
@pytest.mark.parametrize("type_case", [ | ||
(k, v) for k, v in type2case.items() | ||
]) | ||
def test_cohere(self, size): | ||
cohere = Dataset.COHERE.manager(size) | ||
self.per_dataset_test(cohere) | ||
def test_type_cases(self, type_case): | ||
self.per_case_test(type_case) | ||
|
||
@pytest.mark.parametrize("size", [ | ||
100_000, | ||
1_000_000, | ||
]) | ||
def test_gist(self, size): | ||
gist = Dataset.GIST.manager(size) | ||
self.per_dataset_test(gist) | ||
|
||
@pytest.mark.parametrize("size", [ | ||
1_000_000, | ||
]) | ||
def test_glove(self, size): | ||
glove = Dataset.GLOVE.manager(size) | ||
self.per_dataset_test(glove) | ||
|
||
@pytest.mark.parametrize("size", [ | ||
500_000, | ||
5_000_000, | ||
# 50_000_000, | ||
]) | ||
def test_sift(self, size): | ||
sift = Dataset.SIFT.manager(size) | ||
self.per_dataset_test(sift) | ||
|
||
@pytest.mark.parametrize("size", [ | ||
50_000, | ||
500_000, | ||
5_000_000, | ||
]) | ||
def test_openai(self, size): | ||
openai = Dataset.OPENAI.manager(size) | ||
self.per_dataset_test(openai) | ||
|
||
|
||
def per_dataset_test(self, dataset: DatasetManager): | ||
s3_reader = AwsS3Reader() | ||
all_files = s3_reader.ls_all(dataset.data.dir_name) | ||
|
||
|
||
remote_f_names = [] | ||
for file in all_files: | ||
remote_f = pathlib.Path(file).name | ||
if dataset.data.use_shuffled and remote_f.startswith("train"): | ||
continue | ||
|
||
elif (not dataset.data.use_shuffled) and remote_f.startswith("shuffle"): | ||
continue | ||
|
||
remote_f_names.append(remote_f) | ||
|
||
def per_case_test(self, type_case): | ||
t, ca_cls = type_case | ||
ca = ca_cls() | ||
log.info(f"test case: {t.name}, {ca.name}") | ||
|
||
assert set(dataset.data.files) == set(remote_f_names) | ||
filters = ca.filter_rate | ||
ca.dataset.prepare(source=DatasetSource.AliyunOSS, check=False, filters=filters) | ||
ali_trains = ca.dataset.train_files | ||
|
||
aliyun_reader = AliyunOSSReader() | ||
for fname in dataset.data.files: | ||
p = pathlib.Path("benchmark", dataset.data.dir_name, fname) | ||
assert aliyun_reader.bucket.object_exists(p.as_posix()) | ||
ca.dataset.prepare(check=False, filters=filters) | ||
s3_trains = ca.dataset.train_files | ||
|
||
log.info(f"downloading to {dataset.data_dir}") | ||
aliyun_reader.read(dataset.data.dir_name.lower(), dataset.data.files, dataset.data_dir) | ||
assert ali_trains == s3_trains |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.