Permalink
Browse files

Make torch.Generator serializable

  • Loading branch information...
1 parent 1812606 commit b33715ccec92f1a9fb7318534a589e2ff20cb108 @kosklain kosklain committed Apr 19, 2016
Showing with 39 additions and 4 deletions.
  1. +25 −4 Generator.c
  2. +14 −0 test/test.lua
View
@@ -1,9 +1,5 @@
#include <general.h>
-static const struct luaL_Reg torch_Generator_table_ [] = {
- {NULL, NULL}
-};
-
int torch_Generator_new(lua_State *L)
{
THGenerator *gen = THGenerator_new();
@@ -18,6 +14,31 @@ int torch_Generator_free(lua_State *L)
return 0;
}
+static int torch_Generator_write(lua_State *L)
+{
+ THGenerator *gen = luaT_checkudata(L, 1, torch_Generator);
+ THFile *file = luaT_checkudata(L, 2, "torch.File");
+
+ THFile_writeByteRaw(file, (unsigned char *)gen, sizeof(THGenerator));
+ return 0;
+}
+
+static int torch_Generator_read(lua_State *L)
+{
+ THGenerator *gen = luaT_checkudata(L, 1, torch_Generator);
+ THFile *file = luaT_checkudata(L, 2, "torch.File");
+
+ THFile_readByteRaw(file, (unsigned char *)gen, sizeof(THGenerator));
+ return 0;
+}
+
+
+static const struct luaL_Reg torch_Generator_table_ [] = {
+ {"write", torch_Generator_write},
+ {"read", torch_Generator_read},
+ {NULL, NULL}
+};
+
#define torch_Generator_factory torch_Generator_new
void torch_Generator_init(lua_State *L)
View
@@ -2483,6 +2483,20 @@ function torchtest.RNGStateAliasing()
mytester:assertTensorEq(target_value, forked_value, 1e-16, "RNG has not forked correctly.")
end
+function torchtest.serializeGenerator()
+ local generator = torch.Generator()
+ torch.manualSeed(generator, 123)
+ local differentGenerator = torch.Generator()
+ torch.manualSeed(differentGenerator, 124)
+ local serializedGenerator = torch.serialize(generator)
+ local deserializedGenerator = torch.deserialize(serializedGenerator)
+ local generated = torch.random(generator)
+ local differentGenerated = torch.random(differentGenerator)
+ local deserializedGenerated = torch.random(deserializedGenerator)
+ mytester:asserteq(generated, deserializedGenerated, 'torch.Generator changed internal state after being serialized')
+ mytester:assertne(generated, differentGenerated, 'Generators with different random seed should not produce the same output')
+end
+
function torchtest.testBoxMullerState()
torch.manualSeed(123)
local odd_number = 101

0 comments on commit b33715c

Please sign in to comment.