Skip to content

timholm/warp-cache

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

warp-cache

Motion-adaptive caching for DiT-based video model inference — reuse, warp, or recompute attention features per denoising step.

What it does

WarpCache sits between your application and any Diffusion Transformer (DiT) video model serving endpoint, intercepting attention feature computation and deciding — per layer, per step — whether to reuse a cached result, warp it with a motion vector, or recompute it from scratch. It implements the WorldCache insight (arXiv:2603.22286) that the tolerable drift ceiling scales inversely with motion magnitude, combined with saliency-weighted drift estimation from Cache-DiT (arXiv:2405.08748) and a three-phase denoising schedule multiplier. The result is 2–3× speedup with less than 1% quality degradation, with no per-model retraining required.

Install

Build from source:

git clone https://github.com/timholm/warp-cache.git
cd warp-cache
make build

Or import the library in your Go project:

go get github.com/timholm/warp-cache

Usage

Integrate the policy engine into a DiT inference loop:

import (
    "github.com/timholm/warp-cache/internal/cache"
    "github.com/timholm/warp-cache/internal/policy"
    "github.com/timholm/warp-cache/internal/tensor"
)

store := cache.NewStore(2048)
engine := policy.NewEngine(store)

totalSteps := 50
for step := 0; step < totalSteps; step++ {
    key := cache.LayerKey{RequestID: "req-1", LayerID: 4, Step: step}
    stepFraction := float64(step) / float64(totalSteps)
    motionMag := 0.3 // normalised optical flow magnitude in [0, 1]

    decision, cached := engine.Decide(key, currentFeature, motionMag, stepFraction)
    switch decision {
    case policy.UseCache:
        // reuse cached.Tensor directly
    case policy.WarpAndUse:
        // warp cached.Tensor with a motion field before reuse
    case policy.Recompute:
        // run the attention layer, then record result
        engine.Record(key, newFeature, motionMag)
    }
}

Warp cached features using a motion field:

import (
    "github.com/timholm/warp-cache/internal/warp"
)

// Estimate motion from adjacent feature maps
mf, err := warp.EstimateFromTensors(prev.Data, curr.Data, h, w, c)
if err != nil {
    // handle
}

blender := warp.NewBlender()
result, err := blender.Blend(cached.Tensor, current, mf, h, w, c)
// result.Tensor holds the motion-warped blend
// result.MeanAlpha reports the average blend weight applied

Load a named preset and validate config:

import "github.com/timholm/warp-cache/internal/config"

cfg, err := config.Preset("quality") // "fast" | "balanced" | "quality"
// or load from JSON:
cfg, err = config.Load("warpcache.json")

API

internal/cache

// Store — thread-safe LRU feature cache
func NewStore(capacity int) *Store
func (s *Store) Put(key LayerKey, t *tensor.Tensor, motionMag, drift float64)
func (s *Store) Get(key LayerKey) (*Entry, bool)
func (s *Store) GetPrev(key LayerKey) (*Entry, bool)   // step - 1 shortcut
func (s *Store) Invalidate(requestID string)
func (s *Store) Stats() Stats
func (s *Store) Len() int

// LRU — generic typed LRU
func NewLRU[K comparable, V any](capacity int, onEvict func(K, V)) *LRU[K, V]
func (c *LRU[K, V]) Set(key K, value V) bool
func (c *LRU[K, V]) Get(key K) (V, bool)
func (c *LRU[K, V]) Peek(key K) (V, bool)
func (c *LRU[K, V]) Delete(key K)
func (c *LRU[K, V]) Len() int
func (c *LRU[K, V]) Keys() []K
func (c *LRU[K, V]) Purge()

internal/policy

// Engine — integrates drift, motion, and phase scheduling
func NewEngine(store *cache.Store) *Engine
func (e *Engine) Decide(key cache.LayerKey, current *tensor.Tensor, motionMag, stepFraction float64) (Decision, *cache.Entry)
func (e *Engine) Record(key cache.LayerKey, t *tensor.Tensor, motionMag float64)

// Decision values: Recompute | UseCache | WarpAndUse

// DriftEstimator — saliency-weighted cosine drift
func NewDriftEstimator() *DriftEstimator
func (d *DriftEstimator) Estimate(prev, curr *tensor.Tensor) (float64, error)
func (d *DriftEstimator) EstimateL2(prev, curr *tensor.Tensor) (float64, error)

// MotionAdaptiveThreshold — threshold = base * exp(-scale * motionMag)
func DefaultMotionAdaptiveThreshold() *MotionAdaptiveThreshold
func (m *MotionAdaptiveThreshold) Threshold(motionMag float64) float64

// PhaseMultiplier — [0.5, 1.5] ramp across early/mid/late denoising phases
func PhaseMultiplier(stepFraction float64) float64

internal/warp

// MotionField — dense 2-D optical flow grid
func NewMotionField(w, h int) *MotionField
func EstimateFromTensors(prevData, currData []float32, h, w, c int) (*MotionField, error)
func (mf *MotionField) MeanMagnitude() float64
func (mf *MotionField) MaxMagnitude() float64
func (mf *MotionField) Interpolate(fx, fy float64) Vec2

// Blender — bilinear inverse-warp + motion-adaptive blend
func NewBlender() *Blender
func (b *Blender) Blend(cached, current *tensor.Tensor, mf *MotionField, h, w, c int) (*WarpedFeature, error)
func SimpleBlend(cached, current *tensor.Tensor, alpha float32) (*tensor.Tensor, error)

internal/tensor

func New(shape ...int) *Tensor
func From(data []float32, shape ...int) *Tensor
func CosineSimilarity(a, b *Tensor) (float64, error)
func L2Distance(a, b *Tensor) (float64, error)
func Add(a, b *Tensor) error
func Scale(t *Tensor, s float32)
func Lerp(a, b *Tensor, alpha float32) (*Tensor, error)
func Mean(t *Tensor) float64
func Abs(t *Tensor) *Tensor
func (t *Tensor) Norm2() float64
func (t *Tensor) Clone() *Tensor
func (t *Tensor) Size() int

internal/config

func Default() *Config
func Load(path string) (*Config, error)   // JSON; missing fields fall back to Default()
func Preset(name string) (*Config, error) // "fast" | "balanced" | "quality"
func (c *Config) Validate() error
func (c *Config) Save(path string) error

Config fields: upstream, listen_addr, cache.capacity, cache.ttl, policy.base_threshold (default 0.12), policy.warp_threshold (0.06), policy.min_threshold (0.02), policy.motion_scale (3.0), policy.saliency_power (1.0), policy.warm_up_steps (2), metrics.enabled, metrics.path, profiler.auto_profile, profiler.profile_requests, profiler.sensitivity_path.

internal/metrics

func NewRegistry() *Registry
func (r *Registry) NewCounter(name, help string) *Counter
func (r *Registry) NewGauge(name, help string) *Gauge
func (r *Registry) NewHistogram(name, help string, buckets []float64) *Histogram
func (r *Registry) WriteText(w io.Writer)  // Prometheus exposition format

func NewWarpCacheMetrics(reg *Registry) *WarpCacheMetrics
func (m *WarpCacheMetrics) UpdateHitRate()
// Exposes: warpcache_requests_total, warpcache_cache_hits_total,
//          warpcache_cache_misses_total, warpcache_warp_operations_total,
//          warpcache_recomputes_total, warpcache_cache_size,
//          warpcache_hit_rate, warpcache_request_latency_seconds,
//          warpcache_drift_observed

Architecture

internal/
  tensor/       ops.go          — float32 Tensor type; cosine similarity, L2, lerp, add, scale
  cache/        lru.go          — generic thread-safe LRU[K, V] with eviction callback
                store.go        — LayerKey-indexed feature Store wrapping LRU; hit/miss stats
  policy/       drift_estimator.go     — saliency-weighted cosine drift between successive steps
                motion_adaptive.go     — exponential threshold curve + phase multiplier
                engine.go              — Decide() / Record() integrating all policy components
  warp/         motion_vector.go       — MotionField, Vec2, block-matching EstimateFromTensors
                blend.go               — bilinear inverse-warp Blender; SimpleBlend fallback
  config/       config.go       — JSON config load/save, Validate(), named Presets
  metrics/      metrics.go      — zero-dependency Counter/Gauge/Histogram + Prometheus text output

Data flows from the calling inference loop through policy.Engine.Decide(), which consults the cache.Store for the previous step's entry, calls DriftEstimator to measure cosine drift, applies MotionAdaptiveThreshold scaled by PhaseMultiplier, and returns one of Recompute / UseCache / WarpAndUse. The caller optionally runs warp.Blender.Blend() on a WarpAndUse decision, then calls engine.Record() to store the fresh feature for the next step.

References

Research Papers

Related Projects

Market Analysis

The buyer is ML infrastructure teams at video AI companies (Runway, Pika, Luma, enterprise media companies) who are spending $50K–$500K/month on GPU inference for DiT-based video generation and want to cut costs without retraining models or degrading quality. Revenue model is open-core: the CLI profiler and basic LRU caching are free (driving adoption), while the production gRPC proxy with auto-tuning, motion-adaptive warping, Prometheus metrics, multi-GPU coordination, and SLA-grade quality guarantees require a commercial license at $2K–$10K/month per cluster. The moat is the auto-profiling engine that eliminates per-model tuning — competitors require manual threshold configuration per architecture, while WarpCache learns optimal caching policies automatically, making it the only zero-config solution that works across Cosmos, HunyuanVideo, Wan, and other DiT variants.

License

MIT

About

Video AI companies running diffusion transformer pipelines (world models, video generation) spend 60-80% of GPU compute on redundant attention calculations, but existing caching solutions require per-model tuning, cause visual artifacts, and don't integrate into production serving infrastructure.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors