Below is a python class I developed to assist with managing machine learning features locally on my PC.

The class is a key value store with a little extra magic.

It will automatically use disk resources when the memory limit has been hit.

The features will be stored in a sqllite backed data file on each file system resource that is added during configuration.

In [None]:
# 3 Example use cases

    # ###############################################################################################
    # simple usage - no persistance here
    # ###############################################################################################
    with TierKV(name='storename') as store1:
        store1.set('key1', data1)
        store1.set('key2', data2)
        x = store1.get('key1')
        y = store1.get('key2')
        z = store1.get('key3') # this will return None
        l = store1.get(['key1','key2','key3'])
        print(x,y,z)
        print('l is', l)

    # ###############################################################################################
    # simple usage - no persistance here
    # ###############################################################################################
    with TierKV(name='sample_store') as store2:
        store2.set(['BLUE','1'], bytearray('BLUE 1', 'utf-8'))
        store2.set(['BLUE','2'], bytearray('BLUE 2', 'utf-8'))
        store2.set(['BLUE','3'], bytearray('BLUE 3', 'utf-8'))
        store2.set(['BLUE','4'], bytearray('BLUE 4', 'utf-8'))
        store2.set(['BLUE','5'], bytearray('BLUE 5', 'utf-8'))
        store2.set(['BLUE','6'], bytearray('BLUE 6', 'utf-8'))
        store2.set(['BLUE','7'], bytearray('BLUE 7', 'utf-8'))
        store2.set(['RED','1'], bytearray('RED 1', 'utf-8'))
        
        # get all data tagged with BLUE
        blues = store2.get('BLUE')
        print('blues',blues)

        # get 3 random data tagged with BLUE
        blues_3 = store2.sample('BLUE',30000) # this will return all
        blues_3 = store2.sample('BLUE',3)
        print('blues 3', blues_3)

        # get all data tagged with RED
        reds = store2.get('RED')
        print('reds',reds)
        
    # ###############################################################################################
    # more complex use case - with persistance via filesystem
    # configure the tiers of storage resources to utilise
    # all limits are 'soft' limits to target, no guarantees.
    # ###############################################################################################
    store3 = TierKV(name='mystore', resources=[
        # Tier 0 - use memory for speed, and limit to 90% available memory
        ResourceMemory(limit = 0.90, compression=False),   
        # Tier 1 - C: is a fast M2 drive, so use this next, enable compression on this resource and limit to 90% disk
        ResourceFilesystem(path = 'c:/store/', limit = 0.90, compression=True), 
        # Tier 2 - D: is a slower SSD, so use this last, enable compression and and limit to 90% disk
        ResourceFilesystem(path = 'd:/store/', limit = 0.90, compression=True)
        ])

    with store3:
        store3.set(['BLUE','1'], bytearray('BLUE 1', 'utf-8'))
        store3.set(['BLUE','2'], bytearray('BLUE 2', 'utf-8'))
        store3.set(['BLUE','3'], bytearray('BLUE 3', 'utf-8'))
        store3.set(['BLUE','4'], bytearray('BLUE 4', 'utf-8'))
        store3.set(['BLUE','5'], bytearray('BLUE 5', 'utf-8'))
        store3.set(['BLUE','6'], bytearray('BLUE 6', 'utf-8'))
        store3.set(['BLUE','7'], bytearray('BLUE 7', 'utf-8'))
        store3.set(['RED','1'], bytearray('RED 1', 'utf-8'))
        
        # get all data tagged with BLUE
        blues = store3.get('BLUE')
        print('blues',blues)

        # get 3 random data tagged with BLUE
        blues_3 = store3.sample('BLUE',30000) # this will return all
        blues_3 = store3.sample('BLUE',3)
        print('blues 3', blues_3)

        # get all data tagged with RED
        reds = store3.get('RED')
        print('reds',reds)      

In [None]:
import hashlib
import os
import blosc
import psutil
import threading
import sqlite3
import time
import logging
import random
import pickle


class Resource():

    def __init__(self, limit = 0.9, compression = False):
        self.limit = limit
        self.compression = compression  
        
    def on_exit(self):
        pass
    
    def on_enter(self):
        pass

    def set_name(self, name):
        self.name = name

    def get_limit(self):
        return self.limit


class ResourceMemory(Resource):

    def __init__(self, limit = 0.9, compression = False):
        super().__init__(limit, compression)
        
    def initialise(self):
        self.store = {}
        self.update_percent_used();

    def update_percent_used(self):
        self._percent_used = psutil.virtual_memory().percent/100.0

    def percent_used(self):
        return self._percent_used

    def keys(self):
        result = list(self.store.keys())
        return result

    def has_key(self, key):
        return key in self.store

    def delete_key(self, key):
        self.store.pop(key, None)

    # ###########################################################################
    # PUBLIC API
    # ###########################################################################
    def set(self, key, value, bypass=False):
        if self.compression and not bypass:
            value = blosc.compress(value, typesize=4, cname='lz4')
        self.store[key] = value

    def get(self, key, bypass=False):
        result = None
        try:
            if self.compression and not bypass:
                result = blosc.decompress(self.store[key])
            else:
                result = self.store[key]
        except:
            pass
        return result


class ResourceFilesystem(Resource):

    def __init__(self, path, limit = 0.90, compression = False):
        super().__init__(limit, compression)
        self.path = path

    def initialise(self):
        directory = os.path.join(self.path, self.name)
        if not os.path.exists(directory):
            os.makedirs(directory)
        filename = os.path.join(self.path, self.name, 'TierKV.db')
        first_time = not os.path.exists(filename)
        self.sqliteConnection = sqlite3.connect(filename, check_same_thread = False)
        if first_time:
            self.cursor = self.sqliteConnection.cursor()
            self.cursor.execute('CREATE TABLE IF NOT EXISTS data (k TEXT PRIMARY KEY, v BLOB);')
            # self.cursor.execute('PRAGMA auto_vacuum = INCREMENTAL;')
            self.cursor.execute('PRAGMA auto_vacuum = NONE;')
            self.cursor.execute('VACUUM;')
            self.cursor.close()
        self.update_percent_used();

    def on_enter(self):
        self.cursor = self.sqliteConnection.cursor()

    def on_exit(self):
        self.sqliteConnection.commit()
        self.cursor.close()

    def update_percent_used(self):
        self._percent_used = psutil.disk_usage(self.path).percent/100.0

    def percent_used(self):
        return self._percent_used

    def keys(self):
        self.cursor.execute('SELECT k FROM data')
        result = []
        for row in self.cursor.fetchall():
            result.append(row[0])
        return result

    def has_key(self, key):
        self.cursor.execute('SELECT k FROM data WHERE k=?', (key,))
        row = self.cursor.fetchone()
        result = row is not None
        return result

    def delete_key(self, key):
        # if self.has_key(key) is not None:
        self.cursor.execute('DELETE FROM data WHERE k=?', (key,))

    # ###########################################################################
    # PUBLIC API
    # ###########################################################################
    def set(self, key, value, bypass=False):
        if self.compression and not bypass:
            value = blosc.compress(value, typesize=4, cname='lz4')
        self.cursor.execute('INSERT INTO data (k,v) VALUES (?,?) ON CONFLICT(k) DO UPDATE SET v=?;', (key,value,value))

    def get(self, key, bypass=False):
        self.cursor.execute('SELECT v FROM data WHERE k=?', (key,))
        row = self.cursor.fetchone()
        result = None
        try:
            if row is not None:
                if self.compression and not bypass:
                    result = blosc.decompress(row[0])
                else:
                    result = row[0]
        except:
            pass
        return result


class TierKV():

    def __init__(self, name = 'default', resources=[ResourceMemory(limit = 0.90)], cache=False):
        self.tag_keys_map =  {} # tag_keys_map  tag  --> set(keys)
        self.key_tier_map =  {} # key_tier_map  key  --> tier
        self.tier_keys_map = {} # tier_keys_map tier --> set(keys)

        self.cache = cache
        self.name = name
        self.resources = resources        
        self.lock = threading.RLock()

        self.lock.acquire()
        for resource in self.resources:
            resource.set_name(self.name)
        for tier, resource in enumerate(self.resources):
            resource.initialise()
            self.tier_keys_map[tier] = set()
        self.lock.release()

    def __enter__(self):
        for resource in self.resources:
            resource.on_enter()
        self.load()
        self.thread_run = True
        self.thread = threading.Thread(target=self.__worker_thread)
        self.thread.start()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.thread_run = False
        self.thread.join()
        self.save()

        # do any final commit on resource
        for resource in self.resources:
            resource.on_exit()

    def __worker_thread(self):
        counter=0
        while(self.thread_run):
            time.sleep(0)
            counter = counter+1
            if (counter % 10000) == 0:
                self.lock.acquire()
                for resource in self.resources:
                    resource.update_percent_used()
                self.__rebalance_tiers()
                self.lock.release()

    def __find_free_tier_above(self, tier):
        self.lock.acquire()
        for t in range(tier-1, -1, -1):
            if self.resources[t].percent_used() < self.resources[t].get_limit():
                self.lock.release()
                return t
        self.lock.release()
        return None

    def __find_free_tier_below(self, tier):
        self.lock.acquire()
        for t in range(tier+1, len(self.resources)):
            if self.resources[t].percent_used() < self.resources[t].get_limit():
                self.lock.release()
                return t
        self.lock.release()
        return None

    def __key(self, tags):
        return hashlib.blake2s(bytes(''.join(tags), encoding='utf8'), digest_size=20).digest()

    def __move_up(self, key, tier, data):
        self.lock.acquire()
        above = self.__find_free_tier_above(tier)
        if above is not None:
            bypass=False
            if self.resources[tier].compression == self.resources[above].compression:
                bypass=True
            r_src = self.resources[tier]
            r_dst = self.resources[above]
            r_dst.set(key, data, bypass)
            # TODO: I wonder if this is ok?  r_src.delete_key(key)
            self.key_tier_map[key] = above
            self.tier_keys_map[tier].remove(key)
            self.tier_keys_map[above].add(key)
        self.lock.release()

    def __move_down(self, key, tier):
        self.lock.acquire()    
        below = self.__find_free_tier_below(tier)
        if below is not None:
            bypass=False
            if self.resources[tier].compression == self.resources[below].compression:
                bypass=True
            r_src = self.resources[tier]
            r_dst = self.resources[below]
            r_dst.set(key, r_src.get(key, bypass), bypass)
            r_src.delete_key(key)
            self.key_tier_map[key] = below
            self.tier_keys_map[tier].remove(key)
            self.tier_keys_map[below].add(key)
        self.lock.release()

    def __rebalance_tiers(self):
        self.lock.acquire()
        if len(self.resources)>1:
            # only tier 0
            tier = 0
            # slowly migrate 1 record at a time to free up 5% head room
            if len(self.tier_keys_map[tier]) > 1:
                if (self.resources[tier].percent_used() > (self.resources[tier].get_limit()-0.05)) and (self.resources[tier+1].percent_used() < self.resources[tier+1].get_limit()):
                    # evict a random key
                    for k in random.sample(self.tier_keys_map[tier], 1):  
                        self.__move_down(k, tier)
        self.lock.release()

    def set(self, tags, data):

        # ensure we have a list
        if not isinstance(tags, list):
            tags = [tags]

        # calculate the key
        key = self.__key(tags)

        self.lock.acquire()

        # first remove from existing tier
        if key in self.key_tier_map:
            tier = self.key_tier_map[key]
            self.resources[tier].delete_key(key)
            self.tier_keys_map[tier].remove(key)

        # add into first tier that has space left
        for tier, resource in enumerate(self.resources):
            if resource.percent_used() < resource.get_limit():
                resource.set(key, data)
                self.key_tier_map[key] = tier
                self.tier_keys_map[tier].add(key)

                for tag in tags:
                    if not tag in self.tag_keys_map:
                        self.tag_keys_map[tag] = set()
                    self.tag_keys_map[tag].add(key)

                self.lock.release()
                return key

        # only get here if we are out of resource space
        self.lock.release()
        raise Exception('out of resources')

    def sample(self, tags, k):
        
        # ensure we have a list
        if not isinstance(tags, list):
            tags = [tags]

        self.lock.acquire()
        result = []
        seen = set()
        for tag in tags:
            if tag in self.tag_keys_map:
                k =  min(k, len(self.tag_keys_map[tag]))
                for key in random.sample(self.tag_keys_map[tag], k):
                    if key not in seen:
                        tier = self.key_tier_map[key]
                        data = self.resources[tier].get(key)
                        if data is not None:
                            result.append(data)        
                            seen.add(key)
                            if self.cache and (tier>0):
                                self.__move_up(key, tier, data)
        self.lock.release()

        return result

    def get(self, tags):
        
        # ensure we have a list
        if not isinstance(tags, list):
            tags = [tags]

        self.lock.acquire()
        result = []
        seen = set()
        for tag in tags:
            if tag in self.tag_keys_map:
                for key in self.tag_keys_map[tag]:
                    if key not in seen:
                        tier = self.key_tier_map[key]
                        data = self.resources[tier].get(key)
                        if data is not None:
                            result.append(data)        
                            seen.add(key)
                            if self.cache and (tier>0):
                                self.__move_up(key, tier, data)
        self.lock.release()

        return result

    def save(self):
        self.lock.acquire()
        if len(self.resources)>1:

            # move everything down from tier 0 into a persistant tier
            tier = 0
            for key in self.resources[tier].keys():
                self.__move_down(key, tier)

            # save all the dictionaries too
            self.resources[1].set('tag_keys_map.pkl', pickle.dumps(self.tag_keys_map))
            self.resources[1].set('key_tier_map.pkl', pickle.dumps(self.key_tier_map))
            self.resources[1].set('tier_keys_map.pkl', pickle.dumps(self.tier_keys_map))

        self.lock.release()

    def load(self):
        self.lock.acquire()
        if len(self.resources)>1:
            if self.resources[1].get('tag_keys_map.pkl') is not None:
                self.tag_keys_map = pickle.loads(self.resources[1].get('tag_keys_map.pkl'))
                self.key_tier_map = pickle.loads(self.resources[1].get('key_tier_map.pkl'))
                self.tier_keys_map = pickle.loads(self.resources[1].get('tier_keys_map.pkl'))
        self.lock.release()

    def has_tags(self, tags):
        self.lock.acquire()
        result = False
        for tag in tags:
            if tag in self.tag_keys_map.keys():
                result = True
                break
        self.lock.release()
        return result

    def keys(self):
        return list(self.key_tier_map.keys())

    def get_by_key(self, key):      
        self.lock.acquire()
        result = []
        tier = self.key_tier_map[key]
        data = self.resources[tier].get(key)
        if data is not None:
            result.append(data)        
            if self.cache and (tier>0):
                self.__move_up(key, tier, data)
        self.lock.release()
        return result