# Storing Embeddings & Metadata in Redis

Learn Redis data structures and API best practices for storing Embeddings & Metadata in Redis:

1. **Serial** - Basic operations
2. **Pipeline** - Batched operations  

Three storage approaches: **String** (serialized) | **Hash** (flattened) | **JSON** (nested)

In [None]:
import redis
import numpy as np
import json
import pickle
import time

# Connect to Redis
r = redis.Redis(host='localhost', port=6379, decode_responses=False)
print(f"Connected to Redis: {r.ping()}")

Connected to Redis: True


In [2]:
# Sample data generator
def generate_user_data(user_id):
    return {
        'user_id': user_id,
        'embedding': np.random.rand(128).astype(np.float32),
        'category': ['tech', 'sports', 'music', 'food'][user_id % 4],
        'score': round(0.5 + 0.5 * np.random.random(), 3),
        'active': user_id % 3 == 0
    }

# Generate test data
sample_data = [generate_user_data(i) for i in range(1000)]
print(f"Generated {len(sample_data)} user records")
print(f"Sample: {dict((k, v) for k, v in sample_data[0].items() if k != 'embedding')}")

Generated 1000 user records
Sample: {'user_id': 0, 'category': 'tech', 'score': 0.971, 'active': True}


# 1. Serial Operations (Basic Redis API)

In [3]:
def serial_write_string(data_list, limit=100):
    """STRING: Individual SET operations"""
    start = time.time()
    string_data = [pickle.dumps(item) for item in data_list[:limit]]
    for i, item in enumerate(string_data):
        key = f"user:{data_list[i]['user_id']}:str"
        r.set(key, item)
    return time.time() - start

def serial_write_hash(data_list, limit=100):
    """HASH: Individual HSET operations"""
    start = time.time()
    for item in data_list[:limit]:
        key = f"user:{item['user_id']}:hash"
        # Flatten all fields
        hash_data = {
            'user_id': str(item['user_id']),
            'embedding': item['embedding'].tobytes(),
            'category': item['category'],
            'score': str(item['score']),
            'active': str(item['active'])
        }
        r.hset(key, mapping=hash_data)
    return time.time() - start

def serial_write_json(data_list, limit=100):
    """JSON: Individual JSON.SET operations"""
    start = time.time()
    for item in data_list[:limit]:
        key = f"user:{item['user_id']}:json"
        json_data = {
            **item,
            'embedding': item['embedding'].tolist()
        }
        r.json().set(key, "$", json_data)
    return time.time() - start

# Test serial writes
r.flushdb()
str_time = serial_write_string(sample_data)
hash_time = serial_write_hash(sample_data)
json_time = serial_write_json(sample_data)

print(f"Serial Write (100 records):")
print(f"  String: {str_time:.3f}s")
print(f"  Hash:   {hash_time:.3f}s")
print(f"  JSON:   {json_time:.3f}s")

Serial Write (100 records):
  String: 0.038s
  Hash:   0.029s
  JSON:   0.036s


In [4]:
def serial_read_string(user_ids):
    """STRING: Individual GET operations"""
    start = time.time()
    results = []
    for user_id in user_ids:
        key = f"user:{user_id}:str"
        data = r.get(key)
        if data:
            results.append(pickle.loads(data))
    return time.time() - start, len(results)

def serial_read_hash(user_ids):
    """HASH: Individual HGETALL operations"""
    start = time.time()
    results = []
    for user_id in user_ids:
        key = f"user:{user_id}:hash"
        hash_data = r.hgetall(key)
        if hash_data:
            # Reconstruct object
            item = {
                'user_id': int(hash_data[b'user_id'].decode()),
                'embedding': np.frombuffer(hash_data[b'embedding']),
                'category': hash_data[b'category'].decode(),
                'score': float(hash_data[b'score'].decode()),
                'active': hash_data[b'active'].decode() == 'True'
            }
            results.append(item)
    return time.time() - start, len(results)

def serial_read_json(user_ids):
    """JSON: Individual JSON.GET operations"""
    start = time.time()
    results = []
    for user_id in user_ids:
        key = f"user:{user_id}:json"
        data = r.json().get(key, "$")
        if data:
            results.append(data)

    return time.time() - start, len(results)

# Test serial reads
test_ids = list(range(50))
str_time, str_count = serial_read_string(test_ids)
hash_time, hash_count = serial_read_hash(test_ids)
json_time, json_count = serial_read_json(test_ids)

print(f"\nSerial Read (50 records):")
print(f"  String: {str_time:.3f}s ({str_count} found)")
print(f"  Hash:   {hash_time:.3f}s ({hash_count} found)")
print(f"  JSON:   {json_time:.3f}s ({json_count} found)")


Serial Read (50 records):
  String: 0.015s (50 found)
  Hash:   0.016s (50 found)
  JSON:   0.017s (50 found)


String timing overhead mostly due to serialization. This can be avoided by modeling data separately (storing under different keys). But, for this use case HASH or JSON are often best if you have more than one entry for a record.

# 2. Pipeline Operations (Batched Redis API)

In [5]:
PIPE_BATCH_SIZE=250

def pipeline_write_string(data_list, limit=500):
    """STRING: Pipelined SET operations"""
    string_data = [pickle.dumps(item) for item in data_list[:limit]]
    start = time.time()
    with r.pipeline(transaction=False) as pipe:
        for i, item in enumerate(string_data):
            key = f"user:{data_list[i]['user_id']}:str_pipe"
            pipe.set(key, item)
            # Write the batch
            if i % PIPE_BATCH_SIZE == 0:
                pipe.execute()
        pipe.execute()
    return time.time() - start

def pipeline_write_hash(data_list, limit=500):
    """HASH: Pipelined HSET operations"""
    start = time.time()
    with r.pipeline(transaction=False) as pipe:
        for i, item in enumerate(data_list[:limit]):
            key = f"user:{item['user_id']}:hash_pipe"
            hash_data = {
                'user_id': str(item['user_id']),
                'embedding': item['embedding'].tobytes(),
                'category': item['category'],
                'score': str(item['score']),
                'active': str(item['active'])
            }
            pipe.hset(key, mapping=hash_data)
            # Write the batch
            if i % PIPE_BATCH_SIZE == 0:
                pipe.execute()
        pipe.execute()
    return time.time() - start

def pipeline_write_json(data_list, limit=500):
    """JSON: Pipelined JSON.SET operations"""
    start = time.time()
    with r.pipeline(transaction=False) as pipe:
        for i, item in enumerate(data_list[:limit]):
            key = f"user:{item['user_id']}:json_pipe"
            json_data = {
                **item,
                'embedding': item['embedding'].tolist()
            }
            pipe.json().set(key, "$", json_data)
            # Write the batch
            if i % PIPE_BATCH_SIZE == 0:
                pipe.execute()
        pipe.execute()
    return time.time() - start

# Test pipeline writes
r.flushall()
str_time = pipeline_write_string(sample_data)
hash_time = pipeline_write_hash(sample_data)
json_time = pipeline_write_json(sample_data)

print(f"Pipeline Write (500 records):")
print(f"  String: {str_time:.3f}s")
print(f"  Hash:   {hash_time:.3f}s")
print(f"  JSON:   {json_time:.3f}s")

Pipeline Write (500 records):
  String: 0.005s
  Hash:   0.008s
  JSON:   0.033s


In [6]:
def pipeline_read_string(user_ids):
    """STRING: Pipelined GET operations"""
    start = time.time()
    results = []
    with r.pipeline(transaction=False) as pipe:
        batch_results = []
        for i, user_id in enumerate(user_ids):
            pipe.get(f"user:{user_id}:str_pipe")
            if i % PIPE_BATCH_SIZE == 0:
                raw_results = pipe.execute()
                batch_results.extend(raw_results)
        # Execute any remaining commands
        raw_results = pipe.execute()
        batch_results.extend(raw_results)
        results = [pickle.loads(data) for data in batch_results if data]
    return time.time() - start, len(results)

def pipeline_read_hash(user_ids):
    """HASH: Pipelined HGETALL operations"""
    start = time.time()
    results = []
    with r.pipeline(transaction=False) as pipe:
        batch_results = []
        for i, user_id in enumerate(user_ids):
            pipe.hgetall(f"user:{user_id}:hash_pipe")
            if i % PIPE_BATCH_SIZE == 0:
                raw_results = pipe.execute()
                batch_results.extend(raw_results)
        # Execute any remaining commands
        raw_results = pipe.execute()
        batch_results.extend(raw_results)
        for hash_data in batch_results:
            if hash_data:
                item = {
                    'user_id': int(hash_data[b'user_id'].decode()),
                    'embedding': np.frombuffer(hash_data[b'embedding']),
                    'category': hash_data[b'category'].decode(),
                    'score': float(hash_data[b'score'].decode()),
                    'active': hash_data[b'active'].decode() == 'True'
                }
                results.append(item)
    return time.time() - start, len(results)

def pipeline_read_json(user_ids):
    """JSON: Pipelined JSON.GET operations"""
    start = time.time()
    results = []
    with r.pipeline(transaction=False) as pipe:
        batch_results = []
        for i, user_id in enumerate(user_ids):
            pipe.json().get(f"user:{user_id}:json_pipe", "$")
            if i % PIPE_BATCH_SIZE == 0:
                raw_results = pipe.execute()
                batch_results.extend(raw_results)
        # Execute any remaining commands
        raw_results = pipe.execute()
        batch_results.extend(raw_results)
        for data in batch_results:
            if data:
                results.append(data)
    return time.time() - start, len(results)

# Test pipeline reads
test_ids = list(range(200))
str_time, str_count = pipeline_read_string(test_ids)
hash_time, hash_count = pipeline_read_hash(test_ids)
json_time, json_count = pipeline_read_json(test_ids)

print(f"\nPipeline Read (200 records):")
print(f"  String: {str_time:.3f}s ({str_count} found)")
print(f"  Hash:   {hash_time:.3f}s ({hash_count} found)")
print(f"  JSON:   {json_time:.3f}s ({json_count} found)")


Pipeline Read (200 records):
  String: 0.003s (200 found)
  Hash:   0.003s (200 found)
  JSON:   0.009s (200 found)


# Summary

## API Progression:
1. **Serial**: Serial `SET/GET` - Simple but slow
2. **Pipeline**: Batched operations - ~5-10x faster  

## Data Structure Trade-offs:
- **String**: Fastest, opaque blob (serialization overhead for complex data)
- **Hash**: Structured/flattened, fast and efficient
- **JSON**: Flexible, nested structure

## Redis Best Practices:
- Use **pipelines** for batch operations
- Choose data structure based on access patterns

In [7]:
# Cleanup
r.flushdb()
print("🧹 Cleaned up all test data")

🧹 Cleaned up all test data
