Skip to content

Commit

Permalink
Added code to compress tensor into PNG in-memory similar to compressJ…
Browse files Browse the repository at this point in the history
…PG, also added a test (#211)
  • Loading branch information
vfonov authored and soumith committed Mar 23, 2017
1 parent fc214c0 commit 705393f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 14 deletions.
53 changes: 40 additions & 13 deletions generic/png.c
Expand Up @@ -198,11 +198,17 @@ static int libpng_(Main_load)(lua_State *L)
return 2;
}


static int libpng_(Main_save)(lua_State *L)
{
THTensor *tensor = luaT_checkudata(L, 2, torch_Tensor);
const char *file_name = luaL_checkstring(L, 1);

const int save_to_file = luaL_checkint(L, 3);

struct libpng_inmem_write_struct _inmem;

THByteTensor* tensor_dest = NULL;

int width=0, height=0;
png_byte color_type = 0;
png_byte bit_depth = 8;
Expand All @@ -211,11 +217,17 @@ static int libpng_(Main_save)(lua_State *L)
png_infop info_ptr;
png_bytep * row_pointers;
libpng_errmsg errmsg;
FILE *fp=NULL;

/* get dims and contiguous tensor */
THTensor *tensorc = THTensor_(newContiguous)(tensor);
real *tensor_data = THTensor_(data)(tensorc);
long depth=0;

if (save_to_file == 0) {
tensor_dest = luaT_checkudata(L, 4, "torch.ByteTensor");
}

if (tensorc->nDimension == 3) {
depth = tensorc->size[0];
height = tensorc->size[1];
Expand All @@ -234,27 +246,34 @@ static int libpng_(Main_save)(lua_State *L)
else if (depth == 3) color_type = PNG_COLOR_TYPE_RGB;
else if (depth == 1) color_type = PNG_COLOR_TYPE_GRAY;

/* create file */
FILE *fp = fopen(file_name, "wb");
if (!fp)
luaL_error(L, "[write_png_file] File %s could not be opened for writing", file_name);

/* initialize stuff */
png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);

if (!png_ptr)
luaL_error(L, "[write_png_file] png_create_write_struct failed");

png_set_error_fn(png_ptr, &errmsg, libpng_error_fn, NULL);

info_ptr = png_create_info_struct(png_ptr);
if (!info_ptr)
luaL_error(L, "[write_png_file] png_create_info_struct failed");

if (setjmp(png_jmpbuf(png_ptr)))
luaL_error(L, "[write_png_file] Error during init_io: %s", errmsg.str);

png_init_io(png_ptr, fp);


/* create file */
if(save_to_file)
{
fp = fopen(file_name, "wb");
if (!fp)
luaL_error(L, "[write_png_file] File %s could not be opened for writing", file_name);

if (setjmp(png_jmpbuf(png_ptr)))
luaL_error(L, "[write_png_file] Error during init_io: %s", errmsg.str);
png_init_io(png_ptr, fp);
} else {
_inmem.inmem=NULL;
_inmem.inmem_size=0;
png_set_write_fn(png_ptr, &_inmem, libpng_userWriteData, NULL);
}

/* write header */
if (setjmp(png_jmpbuf(png_ptr)))
Expand Down Expand Up @@ -304,8 +323,16 @@ static int libpng_(Main_save)(lua_State *L)
free(row_pointers);

/* cleanup */
fclose(fp);
if(fp) fclose(fp);
THTensor_(free)(tensorc);

if (save_to_file == 0) {

THByteTensor_resize1d(tensor_dest, _inmem.inmem_size); /* will fail if it's not a Byte Tensor */
unsigned char* tensor_dest_data = THByteTensor_data(tensor_dest);
memcpy(tensor_dest_data, _inmem.inmem, _inmem.inmem_size);
free(_inmem.inmem);
}
return 0;
}

Expand Down
17 changes: 16 additions & 1 deletion init.lua
Expand Up @@ -173,7 +173,8 @@ local function savePNG(filename, tensor)
dok.error('libpng package not found, please install libpng','image.savePNG')
end
tensor = clampImage(tensor)
tensor.libpng.save(filename, tensor)
local save_to_file = 1
tensor.libpng.save(filename, tensor, save_to_file)
end
rawset(image, 'savePNG', savePNG)

Expand Down Expand Up @@ -203,6 +204,20 @@ function image.getPNGsize(filename)
return torch.Tensor().libpng.size(filename)
end

local function compressPNG(tensor)
if not xlua.require 'libpng' then
dok.error('libpng package not found, please install libpng',
'image.compressPNG')
end
tensor = clampImage(tensor)
local b = torch.ByteTensor()
local save_to_file = 0
tensor.libpng.save("", tensor, save_to_file, b)
return b
end
rawset(image, 'compressPNG', compressPNG)


local function processJPG(img, depth, tensortype)
local MAXVAL = 255
if tensortype ~= 'byte' then
Expand Down
19 changes: 19 additions & 0 deletions png.c
Expand Up @@ -38,6 +38,25 @@ libpng_userReadData(png_structp pngPtrSrc, png_bytep dest, png_size_t length)
src->offset += length;
}


struct libpng_inmem_write_struct
{
unsigned char *inmem; /* destination memory (if saving to memory) */
unsigned long inmem_size; /* destination memory size (bytes) */
};

/*
* Call back for writing png data to memory
*/
static void libpng_userWriteData(png_structp png_ptr, png_bytep data, png_size_t length) {
struct libpng_inmem_write_struct *p = (struct libpng_inmem_write_struct*)png_get_io_ptr(png_ptr);
p->inmem=realloc(p->inmem,p->inmem_size+length);
memmove(p->inmem+p->inmem_size,data,length);
p->inmem_size+=length;
}



/*
* Error message wrapper (single member struct to preserve `str` size info)
*/
Expand Down
13 changes: 13 additions & 0 deletions test/test.lua
Expand Up @@ -410,8 +410,21 @@ function test.CompressAndDecompress()
img_compressed = image.compressJPG(img, quality)
local size_25 = img_compressed:size(1)
tester:assertlt(size_25, size_100, 'compressJPG quality setting error! ')

end

function test.CompressAndDecompressPNG()
local img = image.lena()

local img_compressed_png = image.compressPNG(img)
local size_png = img_compressed_png:size(1)
local img_decompressed_png = image.decompressPNG(img_compressed_png)
local err_png = img_decompressed_png - img
local mean_err_png = err_png:mean()
local std_err_png = err_png:std()
tester:assertlt(mean_err_png, precision_mean, 'compressPNG error is too high! ')
tester:assertlt(std_err_png, precision_std, 'compressPNG error is too high! ')
end
----------------------------------------------------------------------
-- Lab conversion test
-- These tests break if someone removes lena from the repo
Expand Down

0 comments on commit 705393f

Please sign in to comment.