From 8d5d801adbe89c514125a3db314ddf4340fce4c5 Mon Sep 17 00:00:00 2001 From: Eric Lam Date: Sun, 29 Jun 2025 22:46:20 +0800 Subject: [PATCH] Add optional HuggingFace dataset support --- tfkit/utility/dataset.py | 43 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tfkit/utility/dataset.py b/tfkit/utility/dataset.py index 0384002..87fa1ca 100644 --- a/tfkit/utility/dataset.py +++ b/tfkit/utility/dataset.py @@ -7,6 +7,11 @@ from torch.utils import data from tqdm.contrib.concurrent import process_map +try: + from datasets import load_dataset +except Exception: # pragma: no cover - optional dependency + load_dataset = None + def get_dataset(file_path, task_class, tokenizer, parameter): panel = nlp2.Panel() @@ -18,6 +23,14 @@ def get_dataset(file_path, task_class, tokenizer, parameter): # panel.add_element(k=missarg, v=all_arg[missarg], msg=missarg, default=all_arg[missarg]) # filled_arg = panel.get_result_dict() # parameter.update(filled_arg) + if load_dataset is not None and not os.path.isfile(file_path): + try: + hf_ds = load_dataset(file_path, split=parameter.get('split', 'train')) + return HFDataset(hf_ds, tokenizer=tokenizer, + preprocessor=task_class.Preprocessor, + preprocessing_arg=parameter) + except Exception: + pass ds = TFKitDataset(fpath=file_path, tokenizer=tokenizer, preprocessor=task_class.Preprocessor, preprocessing_arg=parameter) @@ -76,3 +89,33 @@ def __getitem__(self, idx): {**{'task_dict': self.task_dict}, **{key: self.sample[key][idx] for key in self.sample.keys()}}, self.tokenizer, maxlen=self.preprocessor.parameters['maxlen']) + + +class HFDataset(data.Dataset): + """Dataset wrapper for the HuggingFace datasets library.""" + + def __init__(self, hf_dataset, tokenizer, preprocessor, preprocessing_arg=None): + preprocessing_arg = preprocessing_arg or {} + self.task_dict = {} + self.sample = defaultdict(list) + self.preprocessor = preprocessor(tokenizer, kwargs=preprocessing_arg) + self.tokenizer = tokenizer + + print("Start preprocessing with HuggingFace dataset...") + length = 0 + for raw_item in hf_dataset: + for items in self.preprocessor.preprocess(raw_item): + length += 1 + for k, v in items.items(): + self.sample[k].append(v) + self.length = length + self.task = self.task_dict + + def __len__(self): + return self.length + + def __getitem__(self, idx): + return self.preprocessor.postprocess( + {**{'task_dict': self.task_dict}, **{key: self.sample[key][idx] for key in self.sample.keys()}}, + self.tokenizer, + maxlen=self.preprocessor.parameters['maxlen'])