Skip to content

Commit

Permalink
Standardizing Image Data implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
zsogitbe committed Apr 24, 2024
1 parent ccc49eb commit b2423fe
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 20 deletions.
31 changes: 17 additions & 14 deletions LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using System.Text.RegularExpressions;
using LLama.Batched;
using LLama.Common;
using Spectre.Console;
using LLama.Native;
using LLama.Abstractions;

namespace LLama.Examples.Examples
{
Expand All @@ -18,8 +19,12 @@ public static async Task Run()

var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";

var parameters = new ModelParams(modelPath);

var parameters = new ModelParams(modelPath)
{
ContextSize = 4096,
Seed = 1337,
GpuLayerCount = 10
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);

Expand All @@ -42,16 +47,16 @@ public static async Task Run()
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imageCount = imageMatches.Count();
var hasImages = imageCount > 0;
byte[][] imageBytes = null;

if (hasImages)
{
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);

List<byte[]> imageBytes;
try
{
imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
}
catch (IOException exception)
{
Expand All @@ -64,17 +69,15 @@ public static async Task Run()
break;
}

// Each prompt with images we clear cache
// When the prompt contains images we clear KV_CACHE to restart conversation
// See:
// https://github.com/ggerganov/llama.cpp/discussions/3620
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );

int index = 0;
foreach (var path in imagePathsWithCurlyBraces)
{
// First image replace to tag <image, the rest of the images delete the tag
prompt = prompt.Replace(path, index++ == 0 ? "<image>" : "");
if (index++ == 0)
prompt = prompt.Replace(path, "<image>");
else
prompt = prompt.Replace(path, "");
}


Expand All @@ -99,7 +102,7 @@ public static async Task Run()
//
foreach (var image in imagePaths)
{
ex.Images.Add(await File.ReadAllBytesAsync(image));
ex.Images.Add(new ImageData(ImageData.DataType.ImagePath, image));
}
}

Expand All @@ -115,7 +118,7 @@ await foreach (var text in ex.InferAsync(prompt, inferenceParams))

// let the user finish with exit
//
if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;

}
Expand Down
46 changes: 44 additions & 2 deletions LLama/Abstractions/ILLamaExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ public interface ILLamaExecutor
public LLavaWeights? ClipModel { get; }

/// <summary>
/// List of images: List of images in byte array format.
/// List of images: Image filen path, uri or image byte array. See ImageData.
/// </summary>
public List<byte[]> Images { get; }
public List<ImageData> Images { get; }

/// <summary>
/// Asynchronously infers a response from the model.
Expand All @@ -38,4 +38,46 @@ public interface ILLamaExecutor
/// <returns></returns>
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
}

/// <summary>
/// Holds image data
/// </summary>
public class ImageData
{
/// <summary>
/// constructor
/// </summary>
/// <param name="type"></param>
/// <param name="data"></param>
public ImageData(DataType type, object data) { Type = type; Data = data; }

/// <summary>
/// the possible types of image data
/// </summary>
public enum DataType
{
/// <summary>
/// file path
/// </summary>
ImagePath,
/// <summary>
/// byte array
/// </summary>
ImageBytes,
/// <summary>
/// uri
/// </summary>
ImageURL
}

/// <summary>
/// the type of this image data
/// </summary>
public DataType Type { get; set; }

/// <summary>
/// the image data (string, byte array or uri)
/// </summary>
public object? Data { get; set; }
}
}
8 changes: 4 additions & 4 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class StatelessExecutor
public LLavaWeights? ClipModel { get; }

/// <inheritdoc />
public List<byte[]> Images { get; set; }
public List<ImageData> Images { get; set; }

/// <summary>
/// The context used by the executor when running the inference.
Expand All @@ -49,7 +49,7 @@ public class StatelessExecutor
/// <param name="logger"></param>
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
Images = new List<byte[]>();
Images = new List<ImageData>();
_weights = weights;
_params = @params;
_logger = logger;
Expand Down Expand Up @@ -90,7 +90,7 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
lastTokens.Add(0);

// Tokenize the prompt
var tokens = Context.Tokenize(prompt, special: true).ToList();
var tokens = Context.Tokenize(prompt).ToList();
lastTokens.AddRange(tokens);

// Evaluate the prompt, in chunks smaller than the max batch size
Expand Down Expand Up @@ -124,7 +124,7 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
}

// Check if this is the EOS token
if (id == _weights.Tokens.EOS)
if (id == _weights.EndOfSentenceToken)
break;

// Decode this token into text
Expand Down

0 comments on commit b2423fe

Please sign in to comment.