Skip to content

[Android] Kotlin API improvements for vision LLM models #19820

@kirklandsign

Description

@kirklandsign

🚀 The feature, motivation and pitch

The current Android vision API (LlmModule.prefillImages) works at a low level but requires significant manual effort from developers — no Bitmap support, no preprocessing, no Image type, raw arrays with scattered dimension params. For vision LLMs (LLaVA, Gemma, Phi-3-Vision, etc.) to be accessible on Android, the API needs to be more ergonomic and idiomatic.

Current State

  • 4 prefillImages variants exist: int[], ByteBuffer (uint8), float[], ByteBuffer (float)
  • No Java/Kotlin Image wrapper type — dimensions passed as loose params
  • No android.graphics.Bitmap support
  • No preprocessing utilities (resize, normalize, crop, HWC↔CHW)
  • C++ layer assumes CHW format, but Android Bitmap is HWC (ARGB_8888) — undocumented
  • int[] used for uint8 pixel data (4x memory waste)
  • Even ByteBuffer paths copy data in JNI before use

Proposed Improvements

1. Add an Image wrapper type

Replace scattered raw arrays + dimension params with a proper type:

class LlmImage private constructor(
    val data: ByteBuffer,
    val width: Int,
    val height: Int,
    val channels: Int,
    val dtype: DType,  // UINT8 or FLOAT32
) {
    companion object {
        fun fromBitmap(bitmap: Bitmap): LlmImage  // handles ARGB→RGB + HWC→CHW
        fun fromRgb(data: ByteArray, width: Int, height: Int): LlmImage
        fun fromNormalized(data: FloatArray, width: Int, height: Int, channels: Int): LlmImage
        fun fromBuffer(buffer: ByteBuffer, width: Int, height: Int, channels: Int, dtype: DType): LlmImage
    }
}

2. Native Bitmap support

This is the #1 ergonomic gap. Android developers work with Bitmap — forcing manual pixel extraction, alpha stripping, and HWC→CHW conversion is error-prone:

// Today (painful):
val pixels = IntArray(bitmap.width * bitmap.height)
bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
val rgb = ByteArray(pixels.size * 3)
for (i in pixels.indices) {
    rgb[i * 3]     = ((pixels[i] shr 16) and 0xFF).toByte()  // R
    rgb[i * 3 + 1] = ((pixels[i] shr 8) and 0xFF).toByte()   // G
    rgb[i * 3 + 2] = (pixels[i] and 0xFF).toByte()            // B
}
// ... then CHW reorder, then normalize, then call prefillImages
// Goal:
llmModule.prefillImage(LlmImage.fromBitmap(bitmap))

3. Built-in image preprocessing utilities

Vision models have specific input requirements (resolution, normalization). Provide common transforms:

object ImagePreprocessor {
    fun resize(image: LlmImage, targetWidth: Int, targetHeight: Int): LlmImage
    fun centerCrop(image: LlmImage, cropSize: Int): LlmImage
    fun normalize(image: LlmImage, mean: FloatArray, std: FloatArray): LlmImage
    fun toChw(image: LlmImage): LlmImage  // HWC → CHW if needed
}

Or a pipeline builder:

val preprocessed = ImagePreprocessor.pipeline(bitmap)
    .resize(336, 336)
    .centerCrop(224)
    .normalize(mean = floatArrayOf(0.485f, 0.456f, 0.406f), std = floatArrayOf(0.229f, 0.224f, 0.225f))
    .build()

4. Fix int[]byte[] for uint8 pixel data

prefillImages(int[] image, ...) uses 32-bit ints for 8-bit pixel data. The JNI layer truncates each element. Replace with byte[] or ByteBuffer to avoid 4x memory overhead.

5. Multi-image support in a single call

Current API requires one prefillImages() call per image. For multi-image conversations (e.g., "compare these two photos"), add batch support:

fun prefillImages(images: List<LlmImage>)

This would map to a single JNI call and a single runner_->prefill() with multiple MultimodalInput entries, avoiding repeated JNI overhead.

6. Reduce JNI copies for ByteBuffer paths

Even direct ByteBuffer variants currently memcpy into a std::vector in JNI before constructing Image. For large images (1024x1024x3 = 3MB), this is wasteful. The C++ Image could hold a reference to the JNI buffer directly (with appropriate lifetime management).

7. Input validation at Java layer

Currently, passing wrong dimensions silently propagates to C++ where the error is opaque. Add early validation:

require(image.size == width * height * channels) {
    "Image data size (${image.size}) doesn't match dimensions ($width x $height x $channels = ${width * height * channels})"
}

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions