In [11]:
#!pip install transformers==4.41.1
#!pip install datasets==2.19.1

In [28]:
#from transformers.trainer_pt_utils import IterableDatasetShard
from datasets import Dataset, DatasetDict,IterableDatasetDict
from datasets import load_dataset
import pyarrow.parquet as pq
from abc import *
from typing import Union, Dict, Set, List, Callable, Tuple, Any
from torch.utils.data import IterableDataset, Dataset, DataLoader
import psutil
import os

In [21]:
class AbstractPreprocessor(metaclass=ABCMeta):
    def __init__(self, args, **kwargs):
        self.args = args
    

    @classmethod
    def clean_cache_dir(cls, cache_dir):
        shutil.rmtree(cache_dir)

    @classmethod
    def load(cls, file_path:Union[str, List], split:str=None, stream:bool=True, keep_in_memory:bool=True, is_cache:bool=True, cache_dir:str='./.cache') -> IterableDataset:
        
        _, before_mem_usage_gb, before_mem_avail_gb = cls.get_ram_usage_percent()

        if isinstance(file_path, List):
            dataset = load_dataset("parquet", 
                                   data_files=file_path, 
                                   split=split, 
                                   keep_in_memory=keep_in_memory, 
                                   streaming=stream,
                                   cache_dir=cache_dir)

        elif isinstance(file_path, str):
            if os.path.isfile(file_path):
                dataset = load_dataset("parquet", 
                                       data_files=file_path, 
                                       split=split, 
                                       keep_in_memory=keep_in_memory, 
                                       streaming=stream,
                                       cache_dir=cache_dir)
            
            elif os.path.isdir(file_path):
                dataset = load_dataset(file_path, 
                                       streaming=stream,
                                       cache_dir=cache_dir)
            else:
                raise FileNotFoundError(f'{file_path} should be dir_path | file_path')
        else:
            raise TypeError(f'file_path should be in type str | List')
        
        if split is None:
            dataset = dataset['train']
        
        mem_usage_percent, mem_usage_gb, mem_avail_gb = cls.get_ram_usage_percent()

        ## 추후 Logger로 바꾸기 ## 
        print(f'mem usage percent: {mem_usage_percent}%')
        print(f'before_mem_avail_gb: {before_mem_avail_gb} gb')
        print(f'mem usage gb: {mem_usage_gb} gb')
        print(f'mem avail gb: {mem_avail_gb} gb')

        print(f'only file memory usage: {mem_usage_gb - before_mem_usage_gb}' )


        if not is_cache:
            cls.clean_cache()

        return dataset
            
    @classmethod
    def clean_cache(cls, dataset:Union[IterableDataset, Dataset]):
        dataset.cleanup_cache_files()


    @staticmethod
    def get_ram_usage_percent():
        """Returns the current system-wide RAM usage as a percentage."""

        mem = psutil.virtual_memory()

        return mem.percent, mem.used / (1024 ** 3), mem.available / (1024 ** 3)        
    
    @staticmethod
    def select_columns(dataset: Union[IterableDataset, Dataset, IterableDatasetDict], column_names:list=[]):
        if isinstance(dataset, IterableDataset) or isinstance(dataset, IterableDatasetDict):
            if isinstance(column_names, str):
                 column_names = column_names.split(',')
            elif isinstance(column_names, list):
                pass
            
            else:
                raise TypeError('column_names should be list or string type')
            
            dataset = dataset.select_columns(column_names)
                
        elif isinstance(dataset, Dataset):
            dataset = dataset['train'].select_columns([column_names])
            
        else:
            raise TypeError(f'dataset should be IterableDataset | Dataset')

        return dataset
    
    
    @abstractmethod
    def preprocess(self, item):
        f"""code for apply to map function"""

In [22]:
class StreamPreprocessor(AbstractPreprocessor):
    def __init__(self, args, **kwargs):        
        super().__init__(args)

    @classmethod
    def load(cls, file_path:Union[str, List], split:str=None, keep_in_memory:bool=True, is_cache:bool=True) -> IterableDataset:        
        stream = True
        dataset = super(StreamPreprocessor, cls).load(
                file_path=file_path, 
                split=split, 
                stream=stream, 
                keep_in_memory=keep_in_memory,
                is_cache=is_cache
        )
        return dataset
    
    
    @classmethod
    def shuffle(cls, dataset:IterableDataset, seed:int=777, buffer_size:int=1000) -> IterableDataset:
        return dataset.shuffle(seed=seed, buffer_size=buffer_size)

    @abstractmethod
    def preprocess(self, item):
        """ Implement preprocessing logic"""
    
    @classmethod
    def apply_maps(cls, dataset:Dataset, functions_list: List[Tuple[Callable[..., Any], bool]]) -> Dataset:
        """ instance method for apply list of functions"""
        for func, with_indices in functions_list:
            dataset = cls.apply_map(dataset=dataset, func=func, with_indices=with_indices)
        
        return dataset
    
    @classmethod
    def apply_map(cls, dataset: Dataset, func:Callable, with_indices: bool = True) -> Dataset:
        """ instance method for apply only one function"""
        dataset = dataset.map(func, with_indices=with_indices)
        return dataset

In [31]:
#embedding_root_path = "/data/temp/one_model_2024-05-11/one_model_v3_result_adot_20240511_0_emb.parquet.gzip"

In [32]:
local_path = "/home/x1112436/onemodel/data/opensearch/indexing_data"

In [34]:
dataset = StreamPreprocessor.load(file_path=local_path)

mem usage percent: 13.0%
before_mem_avail_gb: 438.0115394592285 gb
mem usage gb: 59.92805099487305 gb
mem avail gb: 438.00575256347656 gb
only file memory usage: 0.006267547607421875


In [51]:
def preprocess_input(item, index_name:str):
    user_vector =  [float(x) for x in item['user_vector']]
    svc_mgmt_num = str(item.get("svc_mgmt_num", "temp"))  
    luna_id = item.get("luna_id", "temp")
    if luna_id is None:
        is_adot = False
    else:
        is_adot= True
                               
    mno_profile = item.get("mno_profile", "")
    adot_profile = item.get("adot_profile", "")
    behavior_profiles = item.get("behavior_profiles", "")
                               
    doc = {
        "_index": index_name,
        "_id": svc_mgmt_num,
        "svc_mgmt_num": svc_mgmt_num,
        "luna_id": item.get("luna_id", "temp"),
        "user_embedding":user_vector,
        "mno_profile": mno_profile,
        "adot_profile": adot_profile,
        "behavior_profile": behavior_profiles,
        "is_adot": is_adot
    }
    return doc

In [52]:
dataset = dataset.map(preprocess_input)

In [54]:
next(iter(dataset))    
    

TypeError: preprocess_input() missing 1 required positional argument: 'index_name'