diff --git a/generic/png.c b/generic/png.c index 56132366..75782bb 100755 --- a/generic/png.c +++ b/generic/png.c @@ -198,11 +198,17 @@ static int libpng_(Main_load)(lua_State *L) return 2; } + static int libpng_(Main_save)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 2, torch_Tensor); const char *file_name = luaL_checkstring(L, 1); - + const int save_to_file = luaL_checkint(L, 3); + + struct libpng_inmem_write_struct _inmem; + + THByteTensor* tensor_dest = NULL; + int width=0, height=0; png_byte color_type = 0; png_byte bit_depth = 8; @@ -211,11 +217,17 @@ static int libpng_(Main_save)(lua_State *L) png_infop info_ptr; png_bytep * row_pointers; libpng_errmsg errmsg; + FILE *fp=NULL; /* get dims and contiguous tensor */ THTensor *tensorc = THTensor_(newContiguous)(tensor); real *tensor_data = THTensor_(data)(tensorc); long depth=0; + + if (save_to_file == 0) { + tensor_dest = luaT_checkudata(L, 4, "torch.ByteTensor"); + } + if (tensorc->nDimension == 3) { depth = tensorc->size[0]; height = tensorc->size[1]; @@ -234,27 +246,34 @@ static int libpng_(Main_save)(lua_State *L) else if (depth == 3) color_type = PNG_COLOR_TYPE_RGB; else if (depth == 1) color_type = PNG_COLOR_TYPE_GRAY; - /* create file */ - FILE *fp = fopen(file_name, "wb"); - if (!fp) - luaL_error(L, "[write_png_file] File %s could not be opened for writing", file_name); - /* initialize stuff */ png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL); if (!png_ptr) luaL_error(L, "[write_png_file] png_create_write_struct failed"); - + png_set_error_fn(png_ptr, &errmsg, libpng_error_fn, NULL); info_ptr = png_create_info_struct(png_ptr); if (!info_ptr) luaL_error(L, "[write_png_file] png_create_info_struct failed"); - - if (setjmp(png_jmpbuf(png_ptr))) - luaL_error(L, "[write_png_file] Error during init_io: %s", errmsg.str); - - png_init_io(png_ptr, fp); + + + /* create file */ + if(save_to_file) + { + fp = fopen(file_name, "wb"); + if (!fp) + luaL_error(L, "[write_png_file] File %s could not be opened for writing", file_name); + + if (setjmp(png_jmpbuf(png_ptr))) + luaL_error(L, "[write_png_file] Error during init_io: %s", errmsg.str); + png_init_io(png_ptr, fp); + } else { + _inmem.inmem=NULL; + _inmem.inmem_size=0; + png_set_write_fn(png_ptr, &_inmem, libpng_userWriteData, NULL); + } /* write header */ if (setjmp(png_jmpbuf(png_ptr))) @@ -304,8 +323,16 @@ static int libpng_(Main_save)(lua_State *L) free(row_pointers); /* cleanup */ - fclose(fp); + if(fp) fclose(fp); THTensor_(free)(tensorc); + + if (save_to_file == 0) { + + THByteTensor_resize1d(tensor_dest, _inmem.inmem_size); /* will fail if it's not a Byte Tensor */ + unsigned char* tensor_dest_data = THByteTensor_data(tensor_dest); + memcpy(tensor_dest_data, _inmem.inmem, _inmem.inmem_size); + free(_inmem.inmem); + } return 0; } diff --git a/init.lua b/init.lua index 24748cf..d1c5f2b 100644 --- a/init.lua +++ b/init.lua @@ -173,7 +173,8 @@ local function savePNG(filename, tensor) dok.error('libpng package not found, please install libpng','image.savePNG') end tensor = clampImage(tensor) - tensor.libpng.save(filename, tensor) + local save_to_file = 1 + tensor.libpng.save(filename, tensor, save_to_file) end rawset(image, 'savePNG', savePNG) @@ -203,6 +204,20 @@ function image.getPNGsize(filename) return torch.Tensor().libpng.size(filename) end +local function compressPNG(tensor) + if not xlua.require 'libpng' then + dok.error('libpng package not found, please install libpng', + 'image.compressPNG') + end + tensor = clampImage(tensor) + local b = torch.ByteTensor() + local save_to_file = 0 + tensor.libpng.save("", tensor, save_to_file, b) + return b +end +rawset(image, 'compressPNG', compressPNG) + + local function processJPG(img, depth, tensortype) local MAXVAL = 255 if tensortype ~= 'byte' then diff --git a/png.c b/png.c index e7a9d83..1e3f864 100644 --- a/png.c +++ b/png.c @@ -38,6 +38,25 @@ libpng_userReadData(png_structp pngPtrSrc, png_bytep dest, png_size_t length) src->offset += length; } + +struct libpng_inmem_write_struct +{ + unsigned char *inmem; /* destination memory (if saving to memory) */ + unsigned long inmem_size; /* destination memory size (bytes) */ +}; + +/* + * Call back for writing png data to memory + */ +static void libpng_userWriteData(png_structp png_ptr, png_bytep data, png_size_t length) { + struct libpng_inmem_write_struct *p = (struct libpng_inmem_write_struct*)png_get_io_ptr(png_ptr); + p->inmem=realloc(p->inmem,p->inmem_size+length); + memmove(p->inmem+p->inmem_size,data,length); + p->inmem_size+=length; +} + + + /* * Error message wrapper (single member struct to preserve `str` size info) */ diff --git a/test/test.lua b/test/test.lua index 0a25cd9..839900a 100644 --- a/test/test.lua +++ b/test/test.lua @@ -410,8 +410,21 @@ function test.CompressAndDecompress() img_compressed = image.compressJPG(img, quality) local size_25 = img_compressed:size(1) tester:assertlt(size_25, size_100, 'compressJPG quality setting error! ') + end +function test.CompressAndDecompressPNG() + local img = image.lena() + + local img_compressed_png = image.compressPNG(img) + local size_png = img_compressed_png:size(1) + local img_decompressed_png = image.decompressPNG(img_compressed_png) + local err_png = img_decompressed_png - img + local mean_err_png = err_png:mean() + local std_err_png = err_png:std() + tester:assertlt(mean_err_png, precision_mean, 'compressPNG error is too high! ') + tester:assertlt(std_err_png, precision_std, 'compressPNG error is too high! ') +end ---------------------------------------------------------------------- -- Lab conversion test -- These tests break if someone removes lena from the repo