-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpoint averaging #668
Checkpoint averaging #668
Conversation
tl;dr: I am for the output_dir/output_path option (with default as a subdir of the original experiment dir). IMO, it depends on what files are being created by the Saver.save method. I think that aside from the variables.data.* files, a new checkpoint file (containing the default model name) is also created which might create confusion, if the averaged model is saved in the same directory as the original models (or at least its the behavior in t2t, which I do not prefer). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jen drobnosti
scripts/avg_checkpoints.py
Outdated
reader = tf.contrib.framework.load_checkpoint(checkpoint) | ||
for name in var_values: | ||
tensor = reader.get_tensor(name) | ||
var_dtypes[name] = tensor.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevim, jestli tady nepridat kontrolu na kompatibilni dtypes (je ale mozne, ze to stejne spadne, kdyz nebudou kompatibilni).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jak to přesně myslíš? Když v tom checkpoint filu nebudou komaptiilní, tak to spadne.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jestli to tak fakt je, tak v pohode.
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] | ||
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] | ||
global_step = tf.Variable( | ||
0, name="global_step", trainable=False, dtype=tf.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevím, přijde mi to bezpečnější takhle. Tady po tom global stepu stejně nechci, aby se choval jako global step (tj. inkrementoval s každým zavoláním optimizeru), tak je lepší, když bude takhle explitictní proměnná.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A je potřeba ho tam vůbec mít? (Za předpokladu že se v opici bude volat tf.train.get_or_create_global_step()
, tak se nemusí načítat, ne?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Když ho tam nebudeme mít, tak ho musíme explicitně dát pryč ze seznamu proměnných, které se loadují, jinak to spadne. (A saver a loader zvlášť máme zatím je v branchi gaussian estimator.)
scripts/avg_checkpoints.py
Outdated
# Build a model only with variables, set them to the average values. | ||
with tf.Session() as sess: | ||
init_op = tf.global_variables_initializer() | ||
sess.run(init_op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
klidne bych to foukl na jeden radek (zahodil init_op)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
scripts/avg_checkpoints.py
Outdated
@@ -0,0 +1,78 @@ | |||
#!/usr/bin/env python3 | |||
"""Script to average values of variables in a list of checkpoint files. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/Script to average values of/Compute the average of each/
s/variables/variable/
scripts/avg_checkpoints.py
Outdated
#!/usr/bin/env python3 | ||
"""Script to average values of variables in a list of checkpoint files. | ||
|
||
Given a list of model checkpoints, it generates a new checkpoint wiht |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/ht/th/
scripts/avg_checkpoints.py
Outdated
"""Script to average values of variables in a list of checkpoint files. | ||
|
||
Given a list of model checkpoints, it generates a new checkpoint wiht | ||
parameters which are an arithmetic average of them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/an/the/
scripts/avg_checkpoints.py
Outdated
Given a list of model checkpoints, it generates a new checkpoint wiht | ||
parameters which are an arithmetic average of them. | ||
|
||
This is based on a script from Tensor2Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/This is b/B/
scripts/avg_checkpoints.py
Outdated
", ".join(non_existing_chckpoints))) | ||
|
||
# Read variables from all checkpoints and average them. | ||
log("Get list of variables:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/Get/Getting/;s/:/./
scripts/avg_checkpoints.py
Outdated
var_list = tf.contrib.framework.list_variables(args.checkpoints[0]) | ||
var_values, var_dtypes = {}, {} | ||
for (name, shape) in var_list: | ||
if not name.startswith("global_step"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tady bych udělal globální proměnnou IGNORED_PATTERNS a matchoval to regexama
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Přijde mi to zbytečný, jestli ti to udělá radost.
scripts/avg_checkpoints.py
Outdated
tensor = reader.get_tensor(name) | ||
var_dtypes[name] = tensor.dtype | ||
var_values[name] += tensor | ||
log("Read from checkpoint {}".format(checkpoint)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tohle si měl říct než jsi to dělal, ze kterýho checkpointu čteš. Pak bych jen lognul, že hotovo, nebo bych nic nelogoval.
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] | ||
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] | ||
global_step = tf.Variable( | ||
0, name="global_step", trainable=False, dtype=tf.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A je potřeba ho tam vůbec mít? (Za předpokladu že se v opici bude volat tf.train.get_or_create_global_step()
, tak se nemusí načítat, ne?)
scripts/avg_checkpoints.py
Outdated
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] | ||
global_step = tf.Variable( | ||
0, name="global_step", trainable=False, dtype=tf.int64) | ||
saver = tf.train.Saver(tf.all_variables()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ten parametr je povinnej?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevim. Bez paramteru to ukládá "all saveable object", to nevím jesti je totéž co tf.all_variables
.
scripts/avg_checkpoints.py
Outdated
with tf.Session() as sess: | ||
sess.run(tf.global_variables_initializer()) | ||
for p, assign_op, (name, value) in zip(placeholders, assign_ops, | ||
six.iteritems(var_values)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proč six.iteritems
když tu může bejt var_values.items()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(plus pak zahodit import six
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
protože to tak mají v t2t
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a kdyby z t2t skočili z okna, tak bys taky skočil? :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
voni uz skakali z vokna?! O.o
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jo!
This a script for checkpoint averaging adapted from Tensor2Tensor.
The question for discussion is: should it remain like this, or should it get only one argument - the experiment directory - and generate the averaged checkpoint there?