Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Sep 4, 2018
1 parent 0cc5c1f commit b672e6e
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 36 deletions.
131 changes: 95 additions & 36 deletions lagom/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,29 @@

from collections import OrderedDict

import numpy as np
from operator import itemgetter # get list elements with arbitrary indices

import pickle


class Logger(logging.Logger):
"""Logging information during experiment.
r"""Log the information of the experiment.
It supports iterative logging and dumping. That is, when same key is logged more than once,
the values for this key will be appended successively. During dumping, the user can also
choose to dump either the entire list of logged values or the values with specific index.
.. note::
It uses pickle to serialize the data. Empirically, pickle is 2x faster than ``numpy.save``
and other alternatives like yaml is too slow and JSON does not support numpy array.
.. warning::
It is highly discouraged to use hierarchical logging, e.g. list of dict of list of ndarray.
This is because pickling such complex large data structure is extremely slow. It is recommended
to use dictionary at topmost level
Note that we do not support hierarchical logging, e.g. list of dict of list of dict of ndarray
this is because pickling is extremely slow for such a hierarhical data structure with mixture
of dict and ndarray. Thus, we keep dict always at the top, if hierarchical logging is really
Expand All @@ -22,12 +35,41 @@ class Logger(logging.Logger):
log policy loss, the hierarchical key can be combine into one string with ':' to separate each
level, for example we want to log the policy loss with goal number 34 and internal training iteration
20, the key can be 'goal_34:train:iter_20:policy_loss'.
Example::
>> logger = Logger(name='logger')
>> logger.log('iteration', 1)
>> logger.log('training_loss', 0.12)
>> logger.log('iteration', 2)
>> logger.log('training_loss', 0.11)
>> logger.log('iteration', 3)
>> logger.log('training_loss', 0.09)
>> logger.dump()
Iteration: [1, 2, 3]
Training Loss: [0.12, 0.11, 0.09]
>> logger.dump(keys=None, index=None, indent=1)
Iteration: [1, 2, 3]
Training Loss: [0.12, 0.11, 0.09]
>> logger.dump(keys=['iteration'], index=None, indent=0)
Iteration: [1, 2, 3]
>> logger.dump(keys=None, index=0, indent=0)
Iteration: 1
Training Loss: 0.12
>> logger.dump(keys=None, index=[0, 2], indent=0)
Iteration: [1, 3]
Training Loss: [0.12, 0.09]
"""
def __init__(self, name='logger'):
"""Initialize the Logger.
r"""Initialize the Logger.
Args:
name (str): name of the logger
name (str): name of the Logger
"""
super().__init__(name)

Expand All @@ -36,83 +78,100 @@ def __init__(self, name='logger'):
# Create logging dictionary, we use OrderedDict to keep insert ordering of the keys
self.logs = OrderedDict()

def log(self, key, val):
"""Log the information with given key and value.
def log(self, key, value):
r"""Log the information with given key and value.
Note that if key is already existed, the new value will be appended.
.. note::
A recommandation for the string style of the key, it should be named semantically
and each word separated by '_', because `dump()` will automatically replace all '_'
with a whitespace and make each word capitalized by `str.title()`.
By default, each key is associated with a list. The list is created when using the key for
the first time. All future loggings for this key will be appended to the list.
It is highly recommended to name the key string semantically and each word separated
by '-', then :meth:`dump` will automatically replace all '-' with a whitespace and capitalize
each word by ``str.title()``.
Args:
key (str): key of the information
val (object): value to be logged
value (object): value to be logged
"""
# Initialize the logging with a list
if key not in self.logs:
if key not in self.logs: # first time for this key, create a list
self.logs[key] = []

# Append the current value to be logged
self.logs[key].append(val)
# Append the value
self.logs[key].append(value)

def dump(self, keys=None, index=None, indent=0):
"""Dump the item to the screen.
r"""Dump the loggings to the screen.
Args:
keys (list, optional): List of keys to dump, if got None, then dump all logged information.
Default: None
index (int/list, optional): The index of logged information for dumping. It can be used with the
following types:
1. Scalar: applies to all given keys. It can also be -1, i.e. last element in the list.
2. List: applies to each key with given index.
3. None: dump everything for all given keys. It can also be list of None.
Default: None
indent (int, optional): the number of tab indentation before dumping the information.
Default: 0.
keys (list, optional): the list of selected keys to dump. If ``None``, then all keys will be used.
Default: ``None``
index (int/list, optional): the index in the list of each logged key to dump. If ``scalar``, then
dumps all keys with given index and it can also be -1 to indicate the last element in the list.
If ``list``, then dumps all keys with given indices. If ``None``, then dumps everything for all
given keys. Default: ``None``
indent (int, optional): the number of tab indentation before dumping the information. Default: 0
"""
# Make keys depends on the cases
if keys is None: # dump all keys
keys = list(self.logs.keys())
assert isinstance(keys, list), f'keys must be list type, got {type(keys)}'
assert isinstance(keys, list), f'expected list dtype, got {type(keys)}'

# Make all indicies consistent with keys
if index is None: # dump everything in given keys
index = ['all']*len(keys)
if isinstance(index, int): # apply to all given keys
elif isinstance(index, int): # single index in given keys
index = [index]*len(keys)
elif isinstance(index, list): # specific indices in given keys
index = [index]*len(keys)
elif isinstance(index, list): # specific index for each key
assert len(index) == len(keys), f'index length should be same as that of keys, got {len(index)}'
index = index

# Dump all logged information given the keys and index
for key, idx in zip(keys, index):
# Print given indentation
if indent > 0:
print('\t'*indent, end='') # do not create a new line

#print(keys)

# Get logged information based on index
if idx == 'all':
log_data = self.logs[key]
else:
elif isinstance(idx, int): # single index
log_data = self.logs[key][idx]
elif isinstance(idx, list): # specific indices
log_data = list(itemgetter(*idx)(self.logs[key]))

# Polish key string and make it visually beautiful
key = key.strip().replace('_', ' ').title()

# Print logged information
print(f'{key}: {log_data}')

def save(self, file=None):
"""Save loggings to a file
def save(self, file):
r"""Save loggings to a file using pickling.
Args:
file (str): path to save the logged information.
"""
np.save(file, self.logs)
with open(file, 'wb') as f:
pickle.dump(obj=self.logs, file=f, protocol=pickle.HIGHEST_PROTOCOL)

@staticmethod
def load(file):
r"""Load loggings from a file using pickling.
Returns
-------
logging : OrderedDict
Loaded logging dictionary
"""
with open(file, 'rb') as f:
logging = pickle.load(f)

return logging

def clear(self):
"""Remove all loggings"""
r"""Remove all loggings in the dictionary. """
self.logs.clear()

def __repr__(self):
Expand Down
54 changes: 54 additions & 0 deletions test/test_lagom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import numpy as np

import os

from lagom import Seeder
from lagom import Logger


class TestLagom(object):
Expand All @@ -20,3 +23,54 @@ def test_utils(self):
assert np.alltrue(np.array(seeds).shape == (1, 3))
seeds = seeder(size=[2, 3])
assert np.alltrue(np.array(seeds).shape == (2, 3))

def test_logger(self):
logger = Logger(name='logger')

logger.log('iteration', 1)
logger.log('learning_rate', 1e-3)
logger.log('training_loss', 0.12)
logger.log('evaluation_loss', 0.14)

logger.log('iteration', 2)
logger.log('learning_rate', 5e-4)
logger.log('training_loss', 0.11)
logger.log('evaluation_loss', 0.13)

logger.log('iteration', 3)
logger.log('learning_rate', 1e-4)
logger.log('training_loss', 0.09)
logger.log('evaluation_loss', 0.10)

# Test dump, because dump will call print, impossible to use assert
logger.dump()
logger.dump(keys=None, index=None, indent=1)
logger.dump(keys=None, index=None, indent=2)
logger.dump(keys=['iteration', 'evaluation_loss'], index=None, indent=0)
logger.dump(keys=None, index=0, indent=0)
logger.dump(keys=None, index=2, indent=0)
logger.dump(keys=None, index=[0, 2], indent=0)
logger.dump(keys=['iteration', 'training_loss'], index=[0, 2], indent=0)

# Test save function
file = './test_logger_file'
logger.save(file=file)

assert os.path.exists(file)

# Load file
logging = Logger.load(file)

assert len(logging) == 4
assert 'iteration' in logging
assert 'learning_rate' in logging
assert 'training_loss' in logging
assert 'evaluation_loss' in logging

assert np.allclose(logging['iteration'], [1, 2, 3])
assert np.allclose(logging['learning_rate'], [1e-3, 5e-4, 1e-4])
assert np.allclose(logging['training_loss'], [0.12, 0.11, 0.09])
assert np.allclose(logging['evaluation_loss'], [0.14, 0.13, 0.1])

# Delete the temp logger file
os.unlink(file)

0 comments on commit b672e6e

Please sign in to comment.