Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint averaging #668

Merged
merged 3 commits into from
Mar 8, 2018
Merged

Checkpoint averaging #668

merged 3 commits into from
Mar 8, 2018

Conversation

jlibovicky
Copy link
Contributor

This a script for checkpoint averaging adapted from Tensor2Tensor.

The question for discussion is: should it remain like this, or should it get only one argument - the experiment directory - and generate the averaged checkpoint there?

@jlibovicky jlibovicky self-assigned this Mar 5, 2018
@varisd
Copy link
Member

varisd commented Mar 5, 2018

tl;dr: I am for the output_dir/output_path option (with default as a subdir of the original experiment dir).

IMO, it depends on what files are being created by the Saver.save method. I think that aside from the variables.data.* files, a new checkpoint file (containing the default model name) is also created which might create confusion, if the averaged model is saved in the same directory as the original models (or at least its the behavior in t2t, which I do not prefer).

Copy link
Member

@varisd varisd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jen drobnosti

reader = tf.contrib.framework.load_checkpoint(checkpoint)
for name in var_values:
tensor = reader.get_tensor(name)
var_dtypes[name] = tensor.dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevim, jestli tady nepridat kontrolu na kompatibilni dtypes (je ale mozne, ze to stejne spadne, kdyz nebudou kompatibilni).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jak to přesně myslíš? Když v tom checkpoint filu nebudou komaptiilní, tak to spadne.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jestli to tak fakt je, tak v pohode.

placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
global_step = tf.Variable(
0, name="global_step", trainable=False, dtype=tf.int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nebylo by lepsi pouzit tf.train.get_or_create_global_step() jak doporucil @cifkao u #634? (nejsem si ale jisty, jestli se to v tomto pripade hodi)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevím, přijde mi to bezpečnější takhle. Tady po tom global stepu stejně nechci, aby se choval jako global step (tj. inkrementoval s každým zavoláním optimizeru), tak je lepší, když bude takhle explitictní proměnná.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A je potřeba ho tam vůbec mít? (Za předpokladu že se v opici bude volat tf.train.get_or_create_global_step(), tak se nemusí načítat, ne?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Když ho tam nebudeme mít, tak ho musíme explicitně dát pryč ze seznamu proměnných, které se loadují, jinak to spadne. (A saver a loader zvlášť máme zatím je v branchi gaussian estimator.)

# Build a model only with variables, set them to the average values.
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

klidne bych to foukl na jeden radek (zahodil init_op)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

varisd
varisd previously approved these changes Mar 7, 2018
@@ -0,0 +1,78 @@
#!/usr/bin/env python3
"""Script to average values of variables in a list of checkpoint files.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/Script to average values of/Compute the average of each/
s/variables/variable/

#!/usr/bin/env python3
"""Script to average values of variables in a list of checkpoint files.

Given a list of model checkpoints, it generates a new checkpoint wiht
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/ht/th/

"""Script to average values of variables in a list of checkpoint files.

Given a list of model checkpoints, it generates a new checkpoint wiht
parameters which are an arithmetic average of them.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/an/the/

Given a list of model checkpoints, it generates a new checkpoint wiht
parameters which are an arithmetic average of them.

This is based on a script from Tensor2Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/This is b/B/

", ".join(non_existing_chckpoints)))

# Read variables from all checkpoints and average them.
log("Get list of variables:")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/Get/Getting/;s/:/./

var_list = tf.contrib.framework.list_variables(args.checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if not name.startswith("global_step"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tady bych udělal globální proměnnou IGNORED_PATTERNS a matchoval to regexama

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Přijde mi to zbytečný, jestli ti to udělá radost.

tensor = reader.get_tensor(name)
var_dtypes[name] = tensor.dtype
var_values[name] += tensor
log("Read from checkpoint {}".format(checkpoint))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tohle si měl říct než jsi to dělal, ze kterýho checkpointu čteš. Pak bych jen lognul, že hotovo, nebo bych nic nelogoval.

placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
global_step = tf.Variable(
0, name="global_step", trainable=False, dtype=tf.int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A je potřeba ho tam vůbec mít? (Za předpokladu že se v opici bude volat tf.train.get_or_create_global_step(), tak se nemusí načítat, ne?)

assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
global_step = tf.Variable(
0, name="global_step", trainable=False, dtype=tf.int64)
saver = tf.train.Saver(tf.all_variables())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ten parametr je povinnej?

Copy link
Contributor Author

@jlibovicky jlibovicky Mar 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevim. Bez paramteru to ukládá "all saveable object", to nevím jesti je totéž co tf.all_variables.

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for p, assign_op, (name, value) in zip(placeholders, assign_ops,
six.iteritems(var_values)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proč six.iteritems když tu může bejt var_values.items()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(plus pak zahodit import six)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protože to tak mají v t2t

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a kdyby z t2t skočili z okna, tak bys taky skočil? :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

voni uz skakali z vokna?! O.o

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jo!

@jindrahelcl jindrahelcl merged commit dfe8e09 into master Mar 8, 2018
@jindrahelcl jindrahelcl deleted the ckpt_avg branch March 8, 2018 23:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants