![Redis](https://redis.io/wp-content/uploads/2024/04/Logotype.svg?auto=webp&quality=85,75&width=120)

# ML Model Serialization in Redis

The `ModelStore` class below implements the following logic:
- Builds a model metadata index for model version management
- Handles model chunking, serialization, and deserialization to/from Redis using Pickle

Below we test with various Python ML-native data types and models.



In [1]:
%pip install scikit-learn torch tensorflow

In [2]:
import os
import redis

from model_store import ModelStore

# Replace values below with your own if using Redis Cloud instance
REDIS_HOST = os.getenv("REDIS_HOST", "localhost") # ex: "redis-18374.c253.us-central1-1.gce.cloud.redislabs.com"
REDIS_PORT = os.getenv("REDIS_PORT", "6379")      # ex: 18374
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "")  # ex: "1TNxTEdYRDgIDKM2gDfasupCADXXXX"

# If SSL is enabled on the endpoint, use rediss:// as the URL prefix
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}"

# Initialize Redis client
redis_client = redis.Redis.from_url(REDIS_URL)

In [3]:
# Initialize the ModelStore
model_store = ModelStore(redis_client, shard_size=1012*100) # ~100Kb sized keys in Redis

## Test with simple Scikit-Learn model

In [4]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load a simple dataset and train a RandomForest model
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)

# Train a RandomForestClassifier
model = RandomForestClassifier()
model.fit(X_train, y_train)

# Evaluate the model
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Model accuracy: {accuracy:.2f}")

Model accuracy: 1.00


In [5]:
# Save the trained model to Redis
model_name = "random_forest"
version = model_store.save_model(model, model_name)

2025-01-22 10:56:13.871 - model_store.store - INFO - Saving 'random_forest' model
2025-01-22 10:56:13.873 - model_store.store - INFO - Starting model serialization and storage
2025-01-22 10:56:13.880 - model_store.store - INFO - Stored model in 2 shards (0.0069s)
2025-01-22 10:56:13.883 - model_store.store - INFO - Save operation completed in 0.0121s


In [6]:
# Load the model from Redis
loaded_model = model_store.load_model(model_name)

# Verify that the loaded model works
y_pred_loaded = loaded_model.predict(X_test)
loaded_accuracy = accuracy_score(y_test, y_pred_loaded)
print(f"Loaded model accuracy: {loaded_accuracy:.2f}")

2025-01-22 10:56:13.887 - model_store.store - INFO - Loading 'random_forest' model
2025-01-22 10:56:13.889 - model_store.store - INFO - Starting model reconstruction from shards
2025-01-22 10:56:13.899 - model_store.store - INFO - Loaded model from 2 shards (0.0105s)
2025-01-22 10:56:13.899 - model_store.store - INFO - Load operation completed in 0.0130s


Loaded model accuracy: 1.00


## Test with 1Gb numpy array


In [7]:
import numpy as np

desired_size_bytes = 1024*1024*1024
num_elements = desired_size_bytes // 8
large_array = np.random.rand(num_elements).astype(np.float64)

model_name = "numpy_array"

In [8]:
version = model_store.save_model(large_array, model_name)

2025-01-22 10:56:14.435 - model_store.store - INFO - Saving 'numpy_array' model
2025-01-22 10:56:14.438 - model_store.store - INFO - Starting model serialization and storage
2025-01-22 10:56:21.996 - model_store.store - INFO - Stored model in 10611 shards (7.5575s)
2025-01-22 10:56:22.004 - model_store.store - INFO - Save operation completed in 7.5689s


In [9]:
# Load the model from Redis
loaded_model = model_store.load_model(model_name)

2025-01-22 10:56:22.008 - model_store.store - INFO - Loading 'numpy_array' model
2025-01-22 10:56:22.013 - model_store.store - INFO - Starting model reconstruction from shards
2025-01-22 10:56:25.661 - model_store.store - INFO - Loaded model from 10611 shards (3.6480s)
2025-01-22 10:56:25.702 - model_store.store - INFO - Load operation completed in 3.6941s


In [10]:
# check if all elements match in the array
sum(loaded_model == large_array) == num_elements

np.True_

## Test with pytorch model

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

# Create model, define loss and optimizer
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Dummy data
x = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
y = torch.tensor([[2.0], [4.0], [6.0]])

# Training loop
for epoch in range(100):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

In [12]:
# Make a simple prediction
x_test = torch.tensor([[4.0]])
prediction = model(x_test).item()
print(f"Prediction for input 4.0: {prediction}")

Prediction for input 4.0: 7.4401021003723145


In [13]:
# Save the trained model to Redis
model_name = "pytorch"
version = "1.0"
model_store.save_model(model, model_name, version=version)

2025-01-22 10:56:32.978 - model_store.store - INFO - Saving 'pytorch' model
2025-01-22 10:56:32.982 - model_store.store - INFO - Starting model serialization and storage
2025-01-22 10:56:32.983 - model_store.store - INFO - Stored model in 1 shards (0.0012s)
2025-01-22 10:56:32.984 - model_store.store - INFO - Save operation completed in 0.0056s


'1.0'

In [14]:
# Load the model from Redis
loaded_model = model_store.load_model(model_name)

prediction = loaded_model(x_test).item()
print(f"Prediction for input 4.0 with loaded model: {prediction}")

2025-01-22 10:56:32.987 - model_store.store - INFO - Loading 'pytorch' model
2025-01-22 10:56:32.988 - model_store.store - INFO - Starting model reconstruction from shards
2025-01-22 10:56:32.990 - model_store.store - INFO - Loaded model from 1 shards (0.0012s)
2025-01-22 10:56:32.990 - model_store.store - INFO - Load operation completed in 0.0027s


Prediction for input 4.0 with loaded model: 7.4401021003723145


## Test with tensorflow

In [15]:
import tensorflow as tf

# Define a simple model
inputs = tf.keras.Input(shape=(1,))
outputs = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# Compile the model
model.compile(optimizer='sgd', loss='mse')

# Dummy data
x = tf.constant([[1.0], [2.0], [3.0]])
y = tf.constant([[2.0], [4.0], [6.0]])

# Train the model
model.fit(x, y, epochs=100, verbose=0)

<keras.src.callbacks.history.History at 0x16bd3a5d0>

In [16]:
# Make a simple prediction
x_test = tf.constant([[4.0]])
prediction = model(x_test).numpy()[0, 0]
print(f"Prediction for input 4.0: {prediction}")

Prediction for input 4.0: 7.760094165802002


In [17]:
# Save the trained model to Redis
model_name = "tensorflow"
version = "1.0"
model_store.save_model(model, model_name, version=version)

2025-01-22 10:56:36.920 - model_store.store - INFO - Saving 'tensorflow' model
2025-01-22 10:56:36.922 - model_store.store - INFO - Starting model serialization and storage
2025-01-22 10:56:36.932 - model_store.store - INFO - Stored model in 1 shards (0.0103s)
2025-01-22 10:56:36.934 - model_store.store - INFO - Save operation completed in 0.0138s


'1.0'

In [18]:
# Load the model from Redis
loaded_model = model_store.load_model(model_name)

prediction = loaded_model(x_test).numpy()[0, 0]
print(f"Prediction for input 4.0 with loaded model: {prediction}")

2025-01-22 10:56:36.941 - model_store.store - INFO - Loading 'tensorflow' model
2025-01-22 10:56:36.943 - model_store.store - INFO - Starting model reconstruction from shards
2025-01-22 10:56:36.954 - model_store.store - INFO - Loaded model from 1 shards (0.0106s)
2025-01-22 10:56:36.954 - model_store.store - INFO - Load operation completed in 0.0131s


Prediction for input 4.0 with loaded model: 7.760094165802002


# Model versioning 

In [19]:
# List all available models in the store
models = model_store.list_models()
models

['numpy_array', 'pytorch', 'random_forest', 'tensorflow']

In [20]:
# List model versions for a model
versions = model_store.get_all_versions(models[1])
versions

[ModelVersion(name='pytorch', description='', version='1.0', created_at=1737561392.98, shard_keys=['shard:pytorch:1.0:0'])]

In [21]:
# Delete a model version
model_version = versions[0]
model_store.delete_version(model_version.name, model_version.version)

2

In [22]:
# Clear all versions for all models
model_store.clear()

10617