Permalink
455 lines (424 sloc) 15.7 KB
local File = torch.getmetatable('torch.File')
function File:writeBool(value)
if value then
self:writeInt(1)
else
self:writeInt(0)
end
end
function File:readBool()
return (self:readInt() == 1)
end
local TYPE_NIL = 0
local TYPE_NUMBER = 1
local TYPE_STRING = 2
local TYPE_TABLE = 3
local TYPE_TORCH = 4
local TYPE_BOOLEAN = 5
local TYPE_FUNCTION = 6
local TYPE_RECUR_FUNCTION = 8
local LEGACY_TYPE_RECUR_FUNCTION = 7
-- Lua 5.2 compatibility
local loadstring = loadstring or load
function File:isWritableObject(object)
local typename = type(object)
local typeidx
if type(object) ~= 'boolean' and not object then
typeidx = TYPE_NIL
elseif torch.typename(object) and torch.factory(torch.typename(object)) then
typeidx = TYPE_TORCH
elseif typename == 'table' then
typeidx = TYPE_TABLE
elseif typename == 'number' then
typeidx = TYPE_NUMBER
elseif typename == 'string' then
typeidx = TYPE_STRING
elseif typename == 'boolean' then
typeidx = TYPE_BOOLEAN
elseif typename == 'function' and pcall(string.dump, object) then
typeidx = TYPE_RECUR_FUNCTION
end
return typeidx
end
function File:referenced(ref)
-- we use an environment to keep a record of written objects
if not torch.getenv(self).writeObjects then
torch.setenv(self, {
writeObjects={}, writeObjectsRef={},
readObjects={},
objectNameStack={},
upvalueRefToId={}, upvalueIdToClosure={},
})
end
local env = torch.getenv(self)
env.force = not ref
torch.setenv(self,env)
return self
end
function File:isReferenced()
-- if no environment, then no forcing setup yet
if not torch.getenv(self).writeObjects then
return true
end
local env = torch.getenv(self)
return not env.force
end
local function getmetamethod(obj, name)
local func
local status
-- check getmetatable(obj).__name or
-- check getmetatable(obj).name
status, func = pcall(
function()
-- note that sometimes the metatable is hidden
-- we get it for sure through the torch type system
local mt = torch.getmetatable(torch.typename(obj))
if mt then
return mt['__' .. name] or mt[name]
end
end
)
if status and type(func) == 'function' then
return func
end
end
local UPVALUES_TOKEN = {} -- unique object
local function formatStack(objectNameStack)
-- Format object name stack skipping UPVALUES_TOKEN and upvalue index
local parts = {}
for i, v in ipairs(objectNameStack) do
if v ~= UPVALUES_TOKEN and objectNameStack[i-1] ~= UPVALUES_TOKEN then
table.insert(parts, v)
end
end
return table.concat(parts, '.')
end
function File:writeObject(object, debugname, hook)
-- define a default hook function if not provided
hook = hook or function(object) return object end
-- we use an environment to keep a record of written objects
if not torch.getenv(self).writeObjects then
torch.setenv(self, {
writeObjects={}, writeObjectsRef={},
readObjects={},
objectNameStack={},
upvalueRefToId={}, upvalueIdToClosure={},
})
end
-- That guy is used for references' book-keeping
local sobject = object
-- That guy is the object that is actually persisted
-- hook(object) can be used to modify the object before writing it to the file.
-- Useful for serializing objects under a config
-- that we want to deserialize safely under another config.
-- (e.g. Cuda to Float tensors, cudnn to nn, ...)
object = hook(object)
local force = torch.getenv(self).force
-- if nil object, only write the type and return
if type(object) ~= 'boolean' and not object then
self:writeInt(TYPE_NIL)
return
end
local objectNameStack = torch.getenv(self).objectNameStack
table.insert(objectNameStack, debugname or '<?>')
-- check the type we are dealing with
local typeidx = self:isWritableObject(object)
if not typeidx then
error(string.format('Unwritable object <%s> at %s', type(object), formatStack(objectNameStack)))
end
self:writeInt(typeidx)
if typeidx == TYPE_NUMBER then
self:writeDouble(object)
elseif typeidx == TYPE_BOOLEAN then
self:writeBool(object)
elseif typeidx == TYPE_STRING then
local stringStorage = torch.CharStorage():string(object)
self:writeInt(#stringStorage)
self:writeChar(stringStorage)
elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE or typeidx == TYPE_RECUR_FUNCTION then
-- check it exists already (we look at the pointer!)
local objects = torch.getenv(self).writeObjects
local objectsRef = torch.getenv(self).writeObjectsRef
local index = objects[torch.pointer(sobject)]
if index and (not force) then
-- if already exists, write only its index
self:writeInt(index)
else
-- else write the object itself
index = objects.nWriteObject or 0
index = index + 1
if not force then
objects[torch.pointer(sobject)] = index
objectsRef[object] = index -- we make sure the object is not going to disappear
end
self:writeInt(index)
objects.nWriteObject = index
if typeidx == TYPE_RECUR_FUNCTION then
local upvalueRefToId = torch.getenv(self).upvalueRefToId
-- Unique ID for each ref since lightuserdata are not serializable
local nextId = 1
for _ in pairs(upvalueRefToId) do nextId=nextId+1 end
local upvalues = {}
local counter = 0
while true do
counter = counter + 1
local name,value = debug.getupvalue(object, counter)
if not name then break end
if name == '_ENV' then value = nil end
local id=nil
-- debug.upvalueid exists only for lua>=5.2 and luajit
if debug.upvalueid then
local upvalueRef = debug.upvalueid(object, counter)
if not upvalueRefToId[upvalueRef] then
upvalueRefToId[upvalueRef] = nextId
nextId = nextId + 1
end
id = upvalueRefToId[upvalueRef]
end
table.insert(upvalues, {name=name, id=id, value=value})
end
local dumped = string.dump(object)
local stringStorage = torch.CharStorage():string(dumped)
self:writeInt(#stringStorage)
self:writeChar(stringStorage)
self:writeObject(upvalues, UPVALUES_TOKEN, hook)
elseif typeidx == TYPE_TORCH then
local version = torch.CharStorage():string('V ' .. torch.version(object))
local className = torch.CharStorage():string(torch.typename(object))
self:writeInt(#version)
self:writeChar(version)
self:writeInt(#className)
self:writeChar(className)
local write = getmetamethod(object, 'write')
if write then
write(object, self)
elseif type(object) == 'table' then
local var = {}
for k,v in pairs(object) do
if self:isWritableObject(v) then
var[k] = v
else
print(string.format('$ Warning: cannot write object field <%s> of <%s> %s', k, torch.typename(object), formatStack(objectNameStack)))
end
end
self:writeObject(var, torch.typename(object), hook)
else
error(string.format('<%s> is a non-serializable Torch object %s', torch.typename(object), formatStack(objectNameStack)))
end
else -- it is a table
local size = 0; for k,v in pairs(object) do size = size + 1 end
self:writeInt(size)
for k,v in pairs(object) do
self:writeObject(k, nil, hook)
local name = (type(k) == 'string' or type(k) == 'number') and tostring(k) or nil
-- special case name for upvalues
if objectNameStack[#objectNameStack-1] == UPVALUES_TOKEN and
name == 'value' and type(object.name) == 'string' then
name = object.name
end
self:writeObject(v, name, hook)
end
end
end
else
error('Unwritable object')
end
table.remove(objectNameStack)
end
function File:readObject()
-- we use an environment to keep a record of read objects
if not torch.getenv(self).writeObjects then
torch.setenv(self, {
writeObjects={}, writeObjectsRef={},
readObjects={},
objectNameStack={},
upvalueRefToId={}, upvalueIdToClosure={},
})
end
local force = torch.getenv(self).force
-- read the typeidx
local typeidx = self:readInt()
-- is it nil?
if typeidx == TYPE_NIL then
return nil
end
if typeidx == TYPE_NUMBER then
return self:readDouble()
elseif typeidx == TYPE_BOOLEAN then
return self:readBool()
elseif typeidx == TYPE_STRING then
local size = self:readInt()
return self:readChar(size):string()
elseif typeidx == TYPE_FUNCTION then
local size = self:readInt()
local dumped = self:readChar(size):string()
local func, err = loadstring(dumped)
if not func then
io.stderr:write(string.format('Warning: Failed to load function from bytecode: %s', err))
end
local upvalues = self:readObject()
for index,upvalue in ipairs(upvalues) do
debug.setupvalue(func, index, upvalue)
end
return func
elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then
-- read the index
local index = self:readInt()
-- check it is loaded already
local objects = torch.getenv(self).readObjects
if objects[index] and not force then
return objects[index]
end
-- otherwise read it
if typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then
local size = self:readInt()
local dumped = self:readChar(size):string()
local func, err = loadstring(dumped)
if not func then
io.stderr:write(string.format('Warning: Failed to load function from bytecode: %s', err))
end
if not force then
objects[index] = func
end
local upvalueIdToClosure = torch.getenv(self).upvalueIdToClosure
local upvalues = self:readObject()
for index,upvalue in ipairs(upvalues) do
if typeidx == LEGACY_TYPE_RECUR_FUNCTION then
debug.setupvalue(func, index, upvalue)
elseif upvalue.name == '_ENV' then
debug.setupvalue(func, index, _ENV)
else
debug.setupvalue(func, index, upvalue.value)
-- debug.upvaluejoin exists only for lua>=5.2 and luajit
if debug.upvaluejoin and upvalue.id then
if upvalueIdToClosure[upvalue.id] then
-- This upvalue is linked to another one
local otherClosure = upvalueIdToClosure[upvalue.id]
debug.upvaluejoin(func, index, otherClosure.func, otherClosure.index)
else
-- Save this closure for next time
upvalueIdToClosure[upvalue.id] = {
func = func,
index = index,
}
end
end
end
end
return func
elseif typeidx == TYPE_TORCH then
local version, className, versionNumber
version = self:readChar(self:readInt()):string()
versionNumber = tonumber(string.match(version, '^V (.*)$'))
if not versionNumber then
className = version
versionNumber = 0 -- file created before existence of versioning system
else
className = self:readChar(self:readInt()):string()
end
if not torch.factory(className) then
error(string.format('unknown Torch class <%s>', tostring(className)))
end
local object = torch.factory(className)(self)
if not force then
objects[index] = object
end
local read = getmetamethod(object, 'read')
if read then
read(object, self, versionNumber)
elseif type(object) == 'table' then
local var = self:readObject()
for k,v in pairs(var) do
object[k] = v
end
else
error(string.format('Cannot load object class <%s>', tostring(className)))
end
return object
else -- it is a table
local size = self:readInt()
local object = {}
if not force then
objects[index] = object
end
for i = 1,size do
local k = self:readObject()
local v = self:readObject()
object[k] = v
end
return object
end
else
error('unknown object')
end
end
-- simple helpers to save/load arbitrary objects/tables
function torch.save(filename, object, mode, referenced)
assert(mode == nil or mode == 'binary' or mode == 'ascii', '"binary" or "ascii" (or nil) expected for mode')
assert(referenced == nil or referenced == true or referenced == false, 'true or false (or nil) expected for referenced')
mode = mode or 'binary'
referenced = referenced == nil and true or referenced
local file = torch.DiskFile(filename, 'w')
file[mode](file)
file:referenced(referenced)
file:writeObject(object)
file:close()
end
function torch.load(filename, mode, referenced)
assert(mode == 'binary' or mode == 'b32' or mode == 'b64' or
mode == nil or mode == 'ascii',
'"binary", "b32", "b64" or "ascii" (or nil) expected for mode')
assert(referenced == nil or referenced == true or referenced == false,
'true or false (or nil) expected for referenced')
local longSize
if mode == 'b32' or mode == 'b64' then
longSize = tonumber(mode:match('%d+')) / 8
mode = 'binary'
end
mode = mode or 'binary'
referenced = referenced == nil and true or referenced
local file = torch.DiskFile(filename, 'r')
file[mode](file)
file:referenced(referenced)
if longSize then file:longSize(longSize) end
local object = file:readObject()
file:close()
return object
end
-- simple helpers to serialize/deserialize arbitrary objects/tables
function torch.serialize(object, mode)
local storage = torch.serializeToStorage(object, mode)
return storage:string()
end
-- Serialize to a CharStorage, not a lua string. This avoids
function torch.serializeToStorage(object, mode)
mode = mode or 'binary'
local f = torch.MemoryFile()
f = f[mode](f)
f:writeObject(object)
local storage = f:storage()
-- the storage includes an extra NULL character: get rid of it
storage:resize(storage:size()-1)
f:close()
return storage
end
function torch.deserializeFromStorage(storage, mode)
mode = mode or 'binary'
local tx = torch.CharTensor(storage)
local xp = torch.CharStorage(tx:size(1)+1)
local txp = torch.CharTensor(xp)
txp:narrow(1,1,tx:size(1)):copy(tx)
txp[tx:size(1)+1] = 0
local f = torch.MemoryFile(xp)
f = f[mode](f)
local object = f:readObject()
f:close()
return object
end
function torch.deserialize(str, mode)
local storage = torch.CharStorage():string(str)
return torch.deserializeFromStorage(storage, mode)
end
-- public API (saveobj/loadobj are safe for global import)
torch.saveobj = torch.save
torch.loadobj = torch.load