Skip to content

Commit 4d1b39c

Browse files
committed
Added h5-multi file reader
1 parent 4a566f1 commit 4d1b39c

File tree

10 files changed

+3856
-637
lines changed

10 files changed

+3856
-637
lines changed

benchmark/dataset.py

Lines changed: 204 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import tarfile
44
import urllib.request
55
from dataclasses import dataclass, field
6-
from typing import Dict, Optional
7-
6+
from typing import Dict, List, Optional, Union
7+
import boto3
88
from benchmark import DATASETS_DIR
99
from dataset_reader.ann_compound_reader import AnnCompoundReader
1010
from dataset_reader.ann_h5_reader import AnnH5Reader
11+
from dataset_reader.ann_h5_multi_reader import AnnH5MultiReader
1112
from dataset_reader.base_reader import BaseReader
1213
from dataset_reader.json_reader import JSONReader
14+
from tqdm import tqdm
15+
from pathlib import Path
1316

1417

1518
@dataclass
@@ -18,67 +21,231 @@ class DatasetConfig:
1821
distance: str
1922
name: str
2023
type: str
21-
path: str
22-
link: Optional[str] = None
24+
path: Dict[
25+
str, List[Dict[str, str]]
26+
] # Now path is expected to handle multi-file structure for h5-multi
27+
link: Optional[Dict[str, List[Dict[str, str]]]] = None
2328
schema: Optional[Dict[str, str]] = field(default_factory=dict)
2429

2530

26-
READER_TYPE = {"h5": AnnH5Reader, "jsonl": JSONReader, "tar": AnnCompoundReader}
31+
READER_TYPE = {
32+
"h5": AnnH5Reader,
33+
"h5-multi": AnnH5MultiReader,
34+
"jsonl": JSONReader,
35+
"tar": AnnCompoundReader,
36+
}
2737

2838

29-
# prepare progressbar
39+
# Progress bar for urllib downloads
3040
def show_progress(block_num, block_size, total_size):
3141
percent = round(block_num * block_size / total_size * 100, 2)
3242
print(f"{percent} %", end="\r")
3343

3444

45+
# Progress handler for S3 downloads
46+
class S3Progress(tqdm):
47+
def __init__(self, total_size):
48+
super().__init__(
49+
total=total_size, unit="B", unit_scale=True, desc="Downloading from S3"
50+
)
51+
52+
def __call__(self, bytes_amount):
53+
self.update(bytes_amount)
54+
55+
3556
class Dataset:
36-
def __init__(self, config: dict):
57+
def __init__(
58+
self,
59+
config: dict,
60+
skip_upload: bool,
61+
skip_search: bool,
62+
upload_start_idx: int,
63+
upload_end_idx: int,
64+
):
3765
self.config = DatasetConfig(**config)
66+
self.skip_upload = skip_upload
67+
self.skip_search = skip_search
68+
self.upload_start_idx = upload_start_idx
69+
self.upload_end_idx = upload_end_idx
3870

3971
def download(self):
40-
target_path = DATASETS_DIR / self.config.path
41-
72+
if isinstance(self.config.path, dict): # Handle multi-file datasets
73+
if self.skip_search is False:
74+
# Download query files
75+
for query in self.config.path.get("queries", []):
76+
self._download_file(query["path"], query["link"])
77+
else:
78+
print(
79+
f"skipping to download query file given skip_search={self.skip_search}"
80+
)
81+
if self.skip_upload is False:
82+
# Download data files
83+
for data in self.config.path.get("data", []):
84+
start_idx = data["start_idx"]
85+
end_idx = data["end_idx"]
86+
data_path = data["path"]
87+
data_link = data["link"]
88+
if self.upload_start_idx >= end_idx:
89+
print(
90+
f"skipping downloading {data_path} from {data_link} given {self.upload_start_idx}>{end_idx}"
91+
)
92+
continue
93+
if self.upload_end_idx < start_idx:
94+
print(
95+
f"skipping downloading {data_path} from {data_link} given {self.upload_end_idx}<{start_idx}"
96+
)
97+
continue
98+
self._download_file(data["path"], data["link"])
99+
else:
100+
print(
101+
f"skipping to download data/upload files given skip_upload={self.skip_upload}"
102+
)
103+
104+
else: # Handle single-file datasets
105+
target_path = DATASETS_DIR / self.config.path
106+
107+
if target_path.exists():
108+
print(f"{target_path} already exists")
109+
return
110+
111+
if self.config.link:
112+
if is_s3_link(self.config.link):
113+
print("Use boto3 to download from S3. Faster!")
114+
self._download_from_s3(self.config.link, target_path)
115+
else:
116+
print(f"Downloading from URL {self.config.link}...")
117+
tmp_path, _ = urllib.request.urlretrieve(
118+
self.config.link, None, show_progress
119+
)
120+
self._extract_or_move_file(tmp_path, target_path)
121+
122+
def _download_file(self, relative_path: str, url: str):
123+
target_path = DATASETS_DIR / relative_path
42124
if target_path.exists():
43125
print(f"{target_path} already exists")
44126
return
45127

46-
if self.config.link:
47-
print(f"Downloading {self.config.link}...")
48-
tmp_path, _ = urllib.request.urlretrieve(
49-
self.config.link, None, show_progress
50-
)
128+
print(f"Downloading from {url} to {target_path}")
129+
tmp_path, _ = urllib.request.urlretrieve(url, None, show_progress)
130+
self._extract_or_move_file(tmp_path, target_path)
51131

52-
if self.config.link.endswith(".tgz") or self.config.link.endswith(
53-
".tar.gz"
54-
):
55-
print(f"Extracting: {tmp_path} -> {target_path}")
56-
(DATASETS_DIR / self.config.path).mkdir(exist_ok=True, parents=True)
57-
file = tarfile.open(tmp_path)
132+
def _extract_or_move_file(self, tmp_path, target_path):
133+
if tmp_path.endswith(".tgz") or tmp_path.endswith(".tar.gz"):
134+
print(f"Extracting: {tmp_path} -> {target_path}")
135+
(DATASETS_DIR / self.config.path).mkdir(exist_ok=True, parents=True)
136+
with tarfile.open(tmp_path) as file:
58137
file.extractall(target_path)
59-
file.close()
60-
os.remove(tmp_path)
61-
else:
62-
print(f"Moving: {tmp_path} -> {target_path}")
63-
(DATASETS_DIR / self.config.path).parent.mkdir(exist_ok=True)
64-
shutil.copy2(tmp_path, target_path)
65-
os.remove(tmp_path)
138+
os.remove(tmp_path)
139+
else:
140+
print(f"Moving: {tmp_path} -> {target_path}")
141+
Path(target_path).parent.mkdir(exist_ok=True)
142+
shutil.copy2(tmp_path, target_path)
143+
os.remove(tmp_path)
144+
145+
def _download_from_s3(self, link, target_path):
146+
s3 = boto3.client("s3")
147+
bucket_name, s3_key = parse_s3_url(link)
148+
tmp_path = f"/tmp/{os.path.basename(s3_key)}"
149+
150+
print(
151+
f"Downloading from S3: {link}... bucket_name={bucket_name}, s3_key={s3_key}"
152+
)
153+
object_info = s3.head_object(Bucket=bucket_name, Key=s3_key)
154+
total_size = object_info["ContentLength"]
155+
156+
with open(tmp_path, "wb") as f:
157+
progress = S3Progress(total_size)
158+
s3.download_fileobj(bucket_name, s3_key, f, Callback=progress)
159+
160+
self._extract_or_move_file(tmp_path, target_path)
66161

67162
def get_reader(self, normalize: bool) -> BaseReader:
68163
reader_class = READER_TYPE[self.config.type]
69-
return reader_class(DATASETS_DIR / self.config.path, normalize=normalize)
164+
165+
if self.config.type == "h5-multi":
166+
# For h5-multi, we need to pass both data files and query file
167+
data_files = self.config.path["data"]
168+
for data_file_dict in data_files:
169+
data_file_dict["path"] = DATASETS_DIR / data_file_dict["path"]
170+
query_file = DATASETS_DIR / self.config.path["queries"][0]["path"]
171+
return reader_class(
172+
data_files=data_files,
173+
query_file=query_file,
174+
normalize=normalize,
175+
skip_upload=self.skip_upload,
176+
skip_search=self.skip_search,
177+
)
178+
else:
179+
# For single-file datasets
180+
return reader_class(DATASETS_DIR / self.config.path, normalize=normalize)
181+
182+
183+
def is_s3_link(link):
184+
return link.startswith("s3://") or "s3.amazonaws.com" in link
185+
186+
187+
def parse_s3_url(s3_url):
188+
if s3_url.startswith("s3://"):
189+
s3_parts = s3_url.replace("s3://", "").split("/", 1)
190+
bucket_name = s3_parts[0]
191+
s3_key = s3_parts[1] if len(s3_parts) > 1 else ""
192+
else:
193+
s3_parts = s3_url.replace("http://", "").replace("https://", "").split("/", 1)
194+
195+
if ".s3.amazonaws.com" in s3_parts[0]:
196+
bucket_name = s3_parts[0].split(".s3.amazonaws.com")[0]
197+
s3_key = s3_parts[1] if len(s3_parts) > 1 else ""
198+
else:
199+
bucket_name = s3_parts[0]
200+
s3_key = s3_parts[1] if len(s3_parts) > 1 else ""
201+
202+
return bucket_name, s3_key
70203

71204

72205
if __name__ == "__main__":
73-
dataset = Dataset(
206+
dataset_s3_split = Dataset(
74207
{
75-
"name": "glove-25-angular",
76-
"vector_size": 25,
77-
"distance": "Cosine",
78-
"type": "h5",
79-
"path": "glove-25-angular/glove-25-angular.hdf5",
80-
"link": "http://ann-benchmarks.com/glove-25-angular.hdf5",
81-
}
208+
"name": "laion-img-emb-768d-1Billion-cosine",
209+
"vector_size": 768,
210+
"distance": "cosine",
211+
"type": "h5-multi",
212+
"path": {
213+
"data": [
214+
{
215+
"file_number": 1,
216+
"path": "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5",
217+
"link": "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5",
218+
"vector_range": "0-10000000",
219+
"file_size": "30.7 GB",
220+
},
221+
{
222+
"file_number": 2,
223+
"path": "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5",
224+
"link": "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5",
225+
"vector_range": "90000000-100000000",
226+
"file_size": "30.7 GB",
227+
},
228+
{
229+
"file_number": 3,
230+
"path": "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5",
231+
"link": "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5",
232+
"vector_range": "990000000-1000000000",
233+
"file_size": "30.7 GB",
234+
},
235+
],
236+
"queries": [
237+
{
238+
"path": "laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5",
239+
"link": "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5",
240+
"file_size": "38.7 MB",
241+
},
242+
],
243+
},
244+
},
245+
skip_upload=True,
246+
skip_search=False,
82247
)
83248

84-
dataset.download()
249+
dataset_s3_split.download()
250+
reader = dataset_s3_split.get_reader(normalize=False)
251+
print(reader) # Outputs the AnnH5MultiReader instance

0 commit comments

Comments
 (0)