Skip to content

Commit

Permalink
Return Uint8List runPix2Pix functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jfoutts committed Apr 10, 2019
1 parent 02924ad commit 848d618
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 72 deletions.
62 changes: 30 additions & 32 deletions android/src/main/java/sq/flutter/tflite/TflitePlugin.java
Expand Up @@ -553,8 +553,9 @@ void detectObjectOnFrame(HashMap args, Result result) throws IOException {
}

private class RunPix2PixOnImage extends TfliteTask {
String path;
String path, outputType;
float IMAGE_MEAN, IMAGE_STD;
long startTime;
ByteBuffer input, output;

RunPix2PixOnImage(HashMap args, Result result) throws IOException {
Expand All @@ -565,7 +566,8 @@ private class RunPix2PixOnImage extends TfliteTask {
double std = (double)(args.get("imageStd"));
IMAGE_STD = (float)std;

long startTime = SystemClock.uptimeMillis();
outputType = args.get("outputType").toString();
startTime = SystemClock.uptimeMillis();
input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD);
output = ByteBuffer.allocateDirect(input.limit());
output.order(ByteOrder.nativeOrder());
Expand All @@ -576,34 +578,29 @@ private class RunPix2PixOnImage extends TfliteTask {
protected void runTflite() { tfLite.run(input, output); }

protected void onRunTfliteDone() {
Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime));
if (output.position() != input.limit()) { result.error("Mismatching input/output position", null, null); return; }

output.flip();
Bitmap bitmapRaw = feedOutput(output, IMAGE_MEAN, IMAGE_STD);
String fileExt = path.substring(path.lastIndexOf('.')+1);
String outputFilename = path.substring(0, path.lastIndexOf('.')) + "_pix2pix." + fileExt;
try (FileOutputStream out = new FileOutputStream(outputFilename, false)) {
bitmapRaw.compress(Bitmap.CompressFormat.PNG, 100, out);
} catch (IOException e) {
e.printStackTrace();
outputFilename = "";
}

final ArrayList<Map<String, Object>> ret = new ArrayList<>();
Map<String, Object> res = new HashMap<>();
res.put("filename", outputFilename);
ret.add(res);
result.success(ret);
if (outputType.equals("png")) {
result.success(compressPNG(bitmapRaw));
} else {
result.success(bitmapRaw);
}
}
}

private class RunPix2PixOnBinary extends TfliteTask {
long startTime;
String outputType;
ByteBuffer input, output;

RunPix2PixOnBinary(HashMap args, Result result) throws IOException {
super(args, result);
byte[] binary = (byte[])args.get("binary");
outputType = args.get("outputType").toString();
startTime = SystemClock.uptimeMillis();
input = ByteBuffer.wrap(binary);
output = ByteBuffer.allocateDirect(input.limit());
Expand All @@ -619,16 +616,14 @@ protected void onRunTfliteDone() {
Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime));
if (output.position() != input.limit()) { result.error("Mismatching input/output position", null, null); return; }

final ArrayList<Map<String, Object>> ret = new ArrayList<>();
Map<String, Object> res = new HashMap<>();
res.put("binary", output.array());
ret.add(res);
result.success(ret);
output.flip();
result.success(output.array());
}
}

private class RunPix2PixOnFrame extends TfliteTask {
long startTime;
String outputType;
float IMAGE_MEAN, IMAGE_STD;
ByteBuffer input, output;

Expand All @@ -643,6 +638,7 @@ private class RunPix2PixOnFrame extends TfliteTask {
int imageWidth = (int)(args.get("imageWidth"));
int rotation = (int)(args.get("rotation"));

outputType = args.get("outputType").toString();
startTime = SystemClock.uptimeMillis();
input = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation);
output = ByteBuffer.allocateDirect(input.limit());
Expand All @@ -658,11 +654,14 @@ protected void onRunTfliteDone() {
Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime));
if (output.position() != input.limit()) { result.error("Mismatching input/output position", null, null); return; }

final ArrayList<Map<String, Object>> ret = new ArrayList<>();
Map<String, Object> res = new HashMap<>();
res.put("binary", output.array());
ret.add(res);
result.success(ret);
output.flip();
Bitmap bitmapRaw = feedOutput(output, IMAGE_MEAN, IMAGE_STD);

if (outputType.equals("png")) {
result.success(compressPNG(bitmapRaw));
} else {
result.success(bitmapRaw);
}
}
}

Expand Down Expand Up @@ -897,8 +896,8 @@ private class RunSegmentationOnImage extends TfliteTask {
protected void onRunTfliteDone() {
Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime));

if (input.limit() == 0) result.error("Unexpected input position, bad file?", null, null);
if (output.position() != output.limit()) result.error("Unexpected output position", null, null);
if (input.limit() == 0) { result.error("Unexpected input position, bad file?", null, null); return; }
if (output.position() != output.limit()) { result.error("Unexpected output position", null, null); return; }
output.flip();

result.success(fetchArgmax(output, labelColors, outputType));
Expand Down Expand Up @@ -929,8 +928,8 @@ private class RunSegmentationOnBinary extends TfliteTask {
protected void onRunTfliteDone() {
Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime));

if (input.limit() == 0) result.error("Unexpected input position, bad file?", null, null);
if (output.position() != output.limit()) result.error("Unexpected output position", null, null);
if (input.limit() == 0) { result.error("Unexpected input position, bad file?", null, null); return; }
if (output.position() != output.limit()) { result.error("Unexpected output position", null, null); return; }
output.flip();

result.success(fetchArgmax(output, labelColors, outputType));
Expand Down Expand Up @@ -968,8 +967,8 @@ private class RunSegmentationOnFrame extends TfliteTask {
protected void onRunTfliteDone() {
Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime));

if (input.limit() == 0) result.error("Unexpected input position, bad file?", null, null);
if (output.position() != output.limit()) result.error("Unexpected output position", null, null);
if (input.limit() == 0) { result.error("Unexpected input position, bad file?", null, null); return; }
if (output.position() != output.limit()) { result.error("Unexpected output position", null, null); return; }
output.flip();

result.success(fetchArgmax(output, labelColors, outputType));
Expand Down Expand Up @@ -1029,7 +1028,6 @@ byte[] fetchArgmax(ByteBuffer output, List<Long> labelColors, String outputType)
setPixel(outputBytes, i * outputWidth + j, labelColor);
}
}
results.add(ret);
}
}
if (outputType.equals("png")) {
Expand Down
37 changes: 13 additions & 24 deletions ios/Classes/TflitePlugin.mm
Expand Up @@ -709,7 +709,7 @@ void runPix2PixOnImage(NSDictionary* args, FlutterResult result) {
const NSString* image_path = args[@"path"];
const float input_mean = [args[@"imageMean"] floatValue];
const float input_std = [args[@"imageStd"] floatValue];

const NSString* outputType = args[@"outputType"];
NSMutableArray* empty = [@[] mutableCopy];

if (!interpreter || interpreter_busy) {
Expand All @@ -731,17 +731,11 @@ void runPix2PixOnImage(NSDictionary* args, FlutterResult result) {
if (output == NULL)
return result(empty);

NSString *ext = image_path.pathExtension, *out_path = image_path.stringByDeletingPathExtension;
out_path = [NSString stringWithFormat:@"%@_pix2pix.%@", out_path, ext];
if (SaveImageToFile(output, [out_path UTF8String], width, height, 1)) {
NSMutableArray* results = [NSMutableArray array];
NSMutableDictionary* res = [NSMutableDictionary dictionary];
[res setObject:out_path forKey:@"filename"];
[results addObject:res];
return result(results);
if ([outputType isEqual: @"png"]) {
return result(CompressImage(output, width, height, 1));
} else {
return result(output);
}

return result(empty);
});
}

Expand All @@ -768,12 +762,7 @@ void runPix2PixOnBinary(NSDictionary* args, FlutterResult result) {
if (output == NULL)
return result(empty);

FlutterStandardTypedData* ret = [FlutterStandardTypedData typedDataWithBytes: output];
NSMutableArray* results = [NSMutableArray array];
NSMutableDictionary* res = [NSMutableDictionary dictionary];
[res setObject:ret forKey:@"binary"];
[results addObject:res];
return result(results);
return result(output);
});
}

Expand All @@ -783,6 +772,7 @@ void runPix2PixOnFrame(NSDictionary* args, FlutterResult result) {
const int image_width = [args[@"imageWidth"] intValue];
const float input_mean = [args[@"imageMean"] floatValue];
const float input_std = [args[@"imageStd"] floatValue];
const NSString* outputType = args[@"outputType"];
NSMutableArray* empty = [@[] mutableCopy];

if (!interpreter || interpreter_busy) {
Expand All @@ -801,16 +791,15 @@ void runPix2PixOnFrame(NSDictionary* args, FlutterResult result) {
}

int width = 0, height = 0;
NSMutableData* output = feedOutputTensor(0, 0, 1, false, &width, &height);
NSMutableData* output = feedOutputTensor(image_channels, input_mean, input_std, true, &width, &height);
if (output == NULL)
return result(empty);

FlutterStandardTypedData* ret = [FlutterStandardTypedData typedDataWithBytes: output];
NSMutableArray* results = [NSMutableArray array];
NSMutableDictionary* res = [NSMutableDictionary dictionary];
[res setObject:ret forKey:@"binary"];
[results addObject:res];
return result(results);
if ([outputType isEqual: @"png"]) {
return result(CompressImage(output, width, height, 1));
} else {
return result(output);
}
});
}

Expand Down
3 changes: 1 addition & 2 deletions ios/Classes/ios_image_load.h
Expand Up @@ -5,8 +5,7 @@ std::vector<uint8_t> LoadImageFromFile(const char* file_name,
int* out_height,
int* out_channels);

BOOL SaveImageToFile(NSMutableData*,
const char* file_name,
NSData *CompressImage(NSMutableData*,
int width,
int height,
int bytesPerPixel);
Expand Down
16 changes: 5 additions & 11 deletions ios/Classes/ios_image_load.mm
Expand Up @@ -72,28 +72,22 @@
return result;
}

BOOL SaveImageToFile(NSMutableData *image, const char* file_name, int width, int height, int bytesPerPixel) {
NSData *CompressImage(NSMutableData *image, int width, int height, int bytesPerPixel) {
const int channels = 4;
CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB();
CGContextRef context = CGBitmapContextCreate([image mutableBytes], width, height,
bytesPerPixel*8, width*channels*bytesPerPixel, color_space,
kCGImageAlphaPremultipliedLast | (bytesPerPixel == 4 ? kCGBitmapFloatComponents : kCGBitmapByteOrder32Big));
CGColorSpaceRelease(color_space);
if (context == nil) return NO;
if (context == nil) return nil;

CGImageRef imgRef = CGBitmapContextCreateImage(context);
CGContextRelease(context);
if (imgRef == nil) return NO;
if (imgRef == nil) return nil;

UIImage* img = [UIImage imageWithCGImage:imgRef];
CGImageRelease(imgRef);
if (img == nil) return NO;
if (img == nil) return nil;

NSData *data = UIImagePNGRepresentation(img);
if (data == nil) return NO;

FILE* file_handle = fopen(file_name, "wb");
BOOL ret = data.length == fwrite([data bytes], 1, data.length, file_handle);
fclose(file_handle);
return ret;
return UIImagePNGRepresentation(img);
}
12 changes: 9 additions & 3 deletions lib/tflite.dart
Expand Up @@ -189,10 +189,11 @@ class Tflite {
return await _channel.invokeMethod('close');
}

static Future<List> runPix2PixOnImage(
static Future<Uint8List> runPix2PixOnImage(
{@required String path,
double imageMean = 0,
double imageStd = 255.0,
String outputType = "png",
bool asynch = true}) async {
return await _channel.invokeMethod(
'runPix2PixOnImage',
Expand All @@ -201,29 +202,33 @@ class Tflite {
"imageMean": imageMean,
"imageStd": imageStd,
"asynch": asynch,
"outputType": outputType,
},
);
}

static Future<List> runPix2PixOnBinary(
static Future<Uint8List> runPix2PixOnBinary(
{@required Uint8List binary,
String outputType = "png",
bool asynch = true}) async {
return await _channel.invokeMethod(
'runPix2PixOnBinary',
{
"binary": binary,
"asynch": asynch,
"outputType": outputType,
},
);
}

static Future<List> runPix2PixOnFrame({
static Future<Uint8List> runPix2PixOnFrame({
@required List<Uint8List> bytesList,
int imageHeight = 1280,
int imageWidth = 720,
double imageMean = 0,
double imageStd = 255.0,
int rotation: 90, // Android only
String outputType = "png",
bool asynch = true,
}) async {
return await _channel.invokeMethod(
Expand All @@ -236,6 +241,7 @@ class Tflite {
"imageStd": imageStd,
"rotation": rotation,
"asynch": asynch,
"outputType": outputType,
},
);
}
Expand Down

0 comments on commit 848d618

Please sign in to comment.