33import tarfile
44import urllib .request
55from dataclasses import dataclass , field
6- from typing import Dict , Optional
7-
6+ from typing import Dict , List , Optional , Union
7+ import boto3
88from benchmark import DATASETS_DIR
99from dataset_reader .ann_compound_reader import AnnCompoundReader
1010from dataset_reader .ann_h5_reader import AnnH5Reader
11+ from dataset_reader .ann_h5_multi_reader import AnnH5MultiReader
1112from dataset_reader .base_reader import BaseReader
1213from dataset_reader .json_reader import JSONReader
14+ from tqdm import tqdm
15+ from pathlib import Path
1316
1417
1518@dataclass
@@ -18,67 +21,231 @@ class DatasetConfig:
1821 distance : str
1922 name : str
2023 type : str
21- path : str
22- link : Optional [str ] = None
24+ path : Dict [
25+ str , List [Dict [str , str ]]
26+ ] # Now path is expected to handle multi-file structure for h5-multi
27+ link : Optional [Dict [str , List [Dict [str , str ]]]] = None
2328 schema : Optional [Dict [str , str ]] = field (default_factory = dict )
2429
2530
26- READER_TYPE = {"h5" : AnnH5Reader , "jsonl" : JSONReader , "tar" : AnnCompoundReader }
31+ READER_TYPE = {
32+ "h5" : AnnH5Reader ,
33+ "h5-multi" : AnnH5MultiReader ,
34+ "jsonl" : JSONReader ,
35+ "tar" : AnnCompoundReader ,
36+ }
2737
2838
29- # prepare progressbar
39+ # Progress bar for urllib downloads
3040def show_progress (block_num , block_size , total_size ):
3141 percent = round (block_num * block_size / total_size * 100 , 2 )
3242 print (f"{ percent } %" , end = "\r " )
3343
3444
45+ # Progress handler for S3 downloads
46+ class S3Progress (tqdm ):
47+ def __init__ (self , total_size ):
48+ super ().__init__ (
49+ total = total_size , unit = "B" , unit_scale = True , desc = "Downloading from S3"
50+ )
51+
52+ def __call__ (self , bytes_amount ):
53+ self .update (bytes_amount )
54+
55+
3556class Dataset :
36- def __init__ (self , config : dict ):
57+ def __init__ (
58+ self ,
59+ config : dict ,
60+ skip_upload : bool ,
61+ skip_search : bool ,
62+ upload_start_idx : int ,
63+ upload_end_idx : int ,
64+ ):
3765 self .config = DatasetConfig (** config )
66+ self .skip_upload = skip_upload
67+ self .skip_search = skip_search
68+ self .upload_start_idx = upload_start_idx
69+ self .upload_end_idx = upload_end_idx
3870
3971 def download (self ):
40- target_path = DATASETS_DIR / self .config .path
41-
72+ if isinstance (self .config .path , dict ): # Handle multi-file datasets
73+ if self .skip_search is False :
74+ # Download query files
75+ for query in self .config .path .get ("queries" , []):
76+ self ._download_file (query ["path" ], query ["link" ])
77+ else :
78+ print (
79+ f"skipping to download query file given skip_search={ self .skip_search } "
80+ )
81+ if self .skip_upload is False :
82+ # Download data files
83+ for data in self .config .path .get ("data" , []):
84+ start_idx = data ["start_idx" ]
85+ end_idx = data ["end_idx" ]
86+ data_path = data ["path" ]
87+ data_link = data ["link" ]
88+ if self .upload_start_idx >= end_idx :
89+ print (
90+ f"skipping downloading { data_path } from { data_link } given { self .upload_start_idx } >{ end_idx } "
91+ )
92+ continue
93+ if self .upload_end_idx < start_idx :
94+ print (
95+ f"skipping downloading { data_path } from { data_link } given { self .upload_end_idx } <{ start_idx } "
96+ )
97+ continue
98+ self ._download_file (data ["path" ], data ["link" ])
99+ else :
100+ print (
101+ f"skipping to download data/upload files given skip_upload={ self .skip_upload } "
102+ )
103+
104+ else : # Handle single-file datasets
105+ target_path = DATASETS_DIR / self .config .path
106+
107+ if target_path .exists ():
108+ print (f"{ target_path } already exists" )
109+ return
110+
111+ if self .config .link :
112+ if is_s3_link (self .config .link ):
113+ print ("Use boto3 to download from S3. Faster!" )
114+ self ._download_from_s3 (self .config .link , target_path )
115+ else :
116+ print (f"Downloading from URL { self .config .link } ..." )
117+ tmp_path , _ = urllib .request .urlretrieve (
118+ self .config .link , None , show_progress
119+ )
120+ self ._extract_or_move_file (tmp_path , target_path )
121+
122+ def _download_file (self , relative_path : str , url : str ):
123+ target_path = DATASETS_DIR / relative_path
42124 if target_path .exists ():
43125 print (f"{ target_path } already exists" )
44126 return
45127
46- if self .config .link :
47- print (f"Downloading { self .config .link } ..." )
48- tmp_path , _ = urllib .request .urlretrieve (
49- self .config .link , None , show_progress
50- )
128+ print (f"Downloading from { url } to { target_path } " )
129+ tmp_path , _ = urllib .request .urlretrieve (url , None , show_progress )
130+ self ._extract_or_move_file (tmp_path , target_path )
51131
52- if self .config .link .endswith (".tgz" ) or self .config .link .endswith (
53- ".tar.gz"
54- ):
55- print (f"Extracting: { tmp_path } -> { target_path } " )
56- (DATASETS_DIR / self .config .path ).mkdir (exist_ok = True , parents = True )
57- file = tarfile .open (tmp_path )
132+ def _extract_or_move_file (self , tmp_path , target_path ):
133+ if tmp_path .endswith (".tgz" ) or tmp_path .endswith (".tar.gz" ):
134+ print (f"Extracting: { tmp_path } -> { target_path } " )
135+ (DATASETS_DIR / self .config .path ).mkdir (exist_ok = True , parents = True )
136+ with tarfile .open (tmp_path ) as file :
58137 file .extractall (target_path )
59- file .close ()
60- os .remove (tmp_path )
61- else :
62- print (f"Moving: { tmp_path } -> { target_path } " )
63- (DATASETS_DIR / self .config .path ).parent .mkdir (exist_ok = True )
64- shutil .copy2 (tmp_path , target_path )
65- os .remove (tmp_path )
138+ os .remove (tmp_path )
139+ else :
140+ print (f"Moving: { tmp_path } -> { target_path } " )
141+ Path (target_path ).parent .mkdir (exist_ok = True )
142+ shutil .copy2 (tmp_path , target_path )
143+ os .remove (tmp_path )
144+
145+ def _download_from_s3 (self , link , target_path ):
146+ s3 = boto3 .client ("s3" )
147+ bucket_name , s3_key = parse_s3_url (link )
148+ tmp_path = f"/tmp/{ os .path .basename (s3_key )} "
149+
150+ print (
151+ f"Downloading from S3: { link } ... bucket_name={ bucket_name } , s3_key={ s3_key } "
152+ )
153+ object_info = s3 .head_object (Bucket = bucket_name , Key = s3_key )
154+ total_size = object_info ["ContentLength" ]
155+
156+ with open (tmp_path , "wb" ) as f :
157+ progress = S3Progress (total_size )
158+ s3 .download_fileobj (bucket_name , s3_key , f , Callback = progress )
159+
160+ self ._extract_or_move_file (tmp_path , target_path )
66161
67162 def get_reader (self , normalize : bool ) -> BaseReader :
68163 reader_class = READER_TYPE [self .config .type ]
69- return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
164+
165+ if self .config .type == "h5-multi" :
166+ # For h5-multi, we need to pass both data files and query file
167+ data_files = self .config .path ["data" ]
168+ for data_file_dict in data_files :
169+ data_file_dict ["path" ] = DATASETS_DIR / data_file_dict ["path" ]
170+ query_file = DATASETS_DIR / self .config .path ["queries" ][0 ]["path" ]
171+ return reader_class (
172+ data_files = data_files ,
173+ query_file = query_file ,
174+ normalize = normalize ,
175+ skip_upload = self .skip_upload ,
176+ skip_search = self .skip_search ,
177+ )
178+ else :
179+ # For single-file datasets
180+ return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
181+
182+
183+ def is_s3_link (link ):
184+ return link .startswith ("s3://" ) or "s3.amazonaws.com" in link
185+
186+
187+ def parse_s3_url (s3_url ):
188+ if s3_url .startswith ("s3://" ):
189+ s3_parts = s3_url .replace ("s3://" , "" ).split ("/" , 1 )
190+ bucket_name = s3_parts [0 ]
191+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
192+ else :
193+ s3_parts = s3_url .replace ("http://" , "" ).replace ("https://" , "" ).split ("/" , 1 )
194+
195+ if ".s3.amazonaws.com" in s3_parts [0 ]:
196+ bucket_name = s3_parts [0 ].split (".s3.amazonaws.com" )[0 ]
197+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
198+ else :
199+ bucket_name = s3_parts [0 ]
200+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
201+
202+ return bucket_name , s3_key
70203
71204
72205if __name__ == "__main__" :
73- dataset = Dataset (
206+ dataset_s3_split = Dataset (
74207 {
75- "name" : "glove-25-angular" ,
76- "vector_size" : 25 ,
77- "distance" : "Cosine" ,
78- "type" : "h5" ,
79- "path" : "glove-25-angular/glove-25-angular.hdf5" ,
80- "link" : "http://ann-benchmarks.com/glove-25-angular.hdf5" ,
81- }
208+ "name" : "laion-img-emb-768d-1Billion-cosine" ,
209+ "vector_size" : 768 ,
210+ "distance" : "cosine" ,
211+ "type" : "h5-multi" ,
212+ "path" : {
213+ "data" : [
214+ {
215+ "file_number" : 1 ,
216+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5" ,
217+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5" ,
218+ "vector_range" : "0-10000000" ,
219+ "file_size" : "30.7 GB" ,
220+ },
221+ {
222+ "file_number" : 2 ,
223+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5" ,
224+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5" ,
225+ "vector_range" : "90000000-100000000" ,
226+ "file_size" : "30.7 GB" ,
227+ },
228+ {
229+ "file_number" : 3 ,
230+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5" ,
231+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5" ,
232+ "vector_range" : "990000000-1000000000" ,
233+ "file_size" : "30.7 GB" ,
234+ },
235+ ],
236+ "queries" : [
237+ {
238+ "path" : "laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5" ,
239+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5" ,
240+ "file_size" : "38.7 MB" ,
241+ },
242+ ],
243+ },
244+ },
245+ skip_upload = True ,
246+ skip_search = False ,
82247 )
83248
84- dataset .download ()
249+ dataset_s3_split .download ()
250+ reader = dataset_s3_split .get_reader (normalize = False )
251+ print (reader ) # Outputs the AnnH5MultiReader instance
0 commit comments