From 198380ef3dcb3281c18f6fd7267cfcf66e039f42 Mon Sep 17 00:00:00 2001 From: seba-1511 Date: Mon, 10 Jun 2019 17:42:58 -0700 Subject: [PATCH] Open-source release. --- .gitignore | 104 ++++++++++++ LICENSE | 201 ++++++++++++++++++++++ Makefile | 13 ++ README.md | 34 ++++ derivations/exp.svg | 9 + derivations/heavyball_igt.svg | 17 ++ derivations/igt.svg | 16 ++ derivations/nesterov_igt.svg | 16 ++ setup.py | 21 +++ tests/test_adam_igt.py | 87 ++++++++++ tests/test_exp_wrapper.py | 49 ++++++ tests/test_heavyball_igt.py | 95 +++++++++++ tests/test_igt.py | 82 +++++++++ tests/test_igt_wrapper.py | 64 +++++++ tests/test_ncigt.py | 113 +++++++++++++ tests/test_nesterov_igt.py | 96 +++++++++++ tests/test_wrapper_cnn.py | 87 ++++++++++ torch_igt/__init__.py | 7 + torch_igt/adam_igt.py | 115 +++++++++++++ torch_igt/igt.py | 213 ++++++++++++++++++++++++ torch_igt/non_convex_wrapper.py | 94 +++++++++++ torch_igt/transported_exp_ata.py | 101 ++++++++++++ torch_igt/wrapper.py | 275 +++++++++++++++++++++++++++++++ 23 files changed, 1909 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 derivations/exp.svg create mode 100644 derivations/heavyball_igt.svg create mode 100644 derivations/igt.svg create mode 100644 derivations/nesterov_igt.svg create mode 100644 setup.py create mode 100644 tests/test_adam_igt.py create mode 100644 tests/test_exp_wrapper.py create mode 100644 tests/test_heavyball_igt.py create mode 100644 tests/test_igt.py create mode 100644 tests/test_igt_wrapper.py create mode 100644 tests/test_ncigt.py create mode 100644 tests/test_nesterov_igt.py create mode 100644 tests/test_wrapper_cnn.py create mode 100644 torch_igt/__init__.py create mode 100644 torch_igt/adam_igt.py create mode 100644 torch_igt/igt.py create mode 100644 torch_igt/non_convex_wrapper.py create mode 100644 torch_igt/transported_exp_ata.py create mode 100644 torch_igt/wrapper.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..894a44c --- /dev/null +++ b/.gitignore @@ -0,0 +1,104 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b0a2e1a --- /dev/null +++ b/Makefile @@ -0,0 +1,13 @@ + +.PHONY: all + +all: + python tests/test_ncigt.py + +test: + python tests/test_igt.py + python tests/test_heavyball_igt.py + python tests/test_nesterov_igt.py + python tests/test_igt_wrapper.py + python tests/test_adam_igt.py + python tests/test_ncigt.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..8f71b08 --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +# Implicit Gradient Transport in PyTorch + +# Installation + +``` +pip install -e . +``` + +# Usage + +See `tests/` folder for more examples. + +```python +import torch.optim as optim +from torch_igt import IGTransporter + +opt = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) +opt = IGTransporter(model.parameters(), opt) + +# Compute a single optimization step +opt.train() # Ensures parameters are set to the transported ones +loss = L(model(X_train), y_train) +opt.zero_grad() +loss.backward() +opt.step() + +# Reverts parameters to the true ones +opt.eval() +loss = L(model(X_test), y_test) +``` + +# Note + +The ITA family of algorithms (such as `Heavyball-ITA` in the paper) are implemented as `torch_igt.ITA(params, opt)`. diff --git a/derivations/exp.svg b/derivations/exp.svg new file mode 100644 index 0000000..ca9289b --- /dev/null +++ b/derivations/exp.svg @@ -0,0 +1,9 @@ + Atif, x. IT, f-t; to.i +Un-ta, th-↳ gin +XnF Xn-d. Un m +reset Vn=o it Fm: n=¥gzi=Z't I +vey!,3#KY +xi-foo:[rat +vi. artist +* Edition't:* +. \ No newline at end of file diff --git a/derivations/heavyball_igt.svg b/derivations/heavyball_igt.svg new file mode 100644 index 0000000..3fbed3a --- /dev/null +++ b/derivations/heavyball_igt.svg @@ -0,0 +1,17 @@ + Orft:], a-0.1, µ--0.5 +Heagb±=GF +a#¥ 2,] +↳ =-o.IE?3=Eooi4­o.=I3+EoMEoo:p +Y¥at#ti:#a:÷tt⇒ +4=0.5%+0. 5g,-[bit] +Has.tk#owteoK=E?Id +⇐ A + affiliate, ftp.tm +* 2¥ otE¥trEhN=aEhD±Eld +voted#E¥iE÷d +went:*-oieaooftoeoatfiioiffoi:D +ostztwz-Eo¥ftf:?µY HIV +Y¥aEo*itsE¥HtE¥p¥:p +us#⇒ HE:#too:iD +wrote:#tank:#Eisley +← Emilio:#fiiiiid +Ayt \ No newline at end of file diff --git a/derivations/igt.svg b/derivations/igt.svg new file mode 100644 index 0000000..0730acd --- /dev/null +++ b/derivations/igt.svg @@ -0,0 +1,16 @@ + QED, A-Eo!], Va igtwkity +T¥. oh±omYr;:D:b. +Vo=QQ=# +Q=Q-au.-Httottfoofif +HE A +v. ⇒ v. + ltshgtatftgtoiad) +* all:¥e÷tEaHM +-all::H*t¥t +→ vi. as#+ AHHHH +ha-oivrfieaffoitiaffotebf +YEIE:#+ ¥st¥:D) +-oiled:D-8%3)=aEo:D +¥:* ÷¥h#Kaftans +at:*f8:atE¥D +attests.AE#t3f5kYsD +TEN fuxtptsrrhdj \ No newline at end of file diff --git a/derivations/nesterov_igt.svg b/derivations/nesterov_igt.svg new file mode 100644 index 0000000..550e41f --- /dev/null +++ b/derivations/nesterov_igt.svg @@ -0,0 +1,16 @@ + Nest.ua#:IIEtEoI:iQ=Egg, += * glotttlorot,) d=Q1 +WYETH.tt#Hs+f=o.5 +↳ =f¢=µ &"#tntttknlw +w.=-hoo-Ek +a#heat:p +Hate AIQHHHW.to#9f)f8.Ff­u.O.5LY+a5EFt=Ysst­w.=o.sEiztttKsst=Eo?Yf­Q=Q-nw.+H+hwi=Eo?ssfo.5fojpytt5EoFst­t::nsh e-Q(Q+2fnuo+H+Hw,)) +Atkin.tt#EoIttsEotMtaEsoEy +to:p +u⇐}kkt+5Esbnp=E:# +wrote:#t.ie?kaF=Eo?EzI­Q=A-mu+H+Hwz=EoHYtsta5Eo?sst+t5E?E;] +HEY'd i +test-Atfststuwrltfhwd) +oµ¥*H*¥.ms#Ei:nsDtaEia:fEiii59 +¥EK:#THE:D-Ee;D +wso'E0o?IFF.at#Ihrst=EoiYusaI­Q=O3tklHHw=EYFars3o.5EYFzsftt5E8YIfsasf­=&?hYoau5tI \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7df007f --- /dev/null +++ b/setup.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +from setuptools import ( + setup, + find_packages, + ) + +VERSION = '0.0.1' + +setup( + name='torch_igt', + packages=find_packages(), + version=VERSION, + description='PyTorch implementation of Accelerated Implicit Gradient Transport.', + author='Anonymous', + author_email='ano@nymous.com', + url = 'https://github.com/', + license='License :: OSI Approved :: Apache Software License', + classifiers=[], + scripts=[], +) diff --git a/tests/test_adam_igt.py b/tests/test_adam_igt.py new file mode 100644 index 0000000..b280359 --- /dev/null +++ b/tests/test_adam_igt.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import torch as th +import torch.nn as nn +from torch import Tensor as T +from torch.autograd import Variable as V +import torch.nn.functional as F + +import torch_igt + + +class Convnet(nn.Module): + def __init__(self): + super(Convnet, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def dist(x, y): + return (x - y).pow(2).sum() + + +def close(x, y): + return dist(x, y) < 1e-8 + + +if __name__ == '__main__': + th.manual_seed(1234) + model1 = Convnet().double() + model2 = Convnet().double() + for p1, p2 in zip(model1.parameters(), model2.parameters()): + p1.data.copy_(p2.data) + + ref = torch_igt.AdamIGT(model1.parameters(), lr=0.001) + opt = th.optim.Adam(model2.parameters(), lr=0.001) + igt = torch_igt.IGTransporter(model2.parameters(), opt) + + x = V(th.randn(3, 1, 28, 28).double(), requires_grad=False) + + for i in range(100): + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.data, p2.data) + + # Compute reference gradients + ref.train() + ref.zero_grad() + loss1 = model1.forward(x).pow(2).mean() + loss1.backward() + + # Compute wrapper gradients + igt.train() + igt.zero_grad() + loss2 = model2.forward(x).pow(2).mean() + loss2.backward() + + assert close(loss1.data, loss2.data) + + # Test identical gradients + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.grad.data, p2.grad.data) + + # Take on step + ref.step() + igt.step() + + # Test identical parameters (train) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.data, p2.data) + + # Test identical parameters (eval) + ref.eval() + igt.eval() + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.data, p2.data) + + ref.train() + igt.train() diff --git a/tests/test_exp_wrapper.py b/tests/test_exp_wrapper.py new file mode 100644 index 0000000..d2d7218 --- /dev/null +++ b/tests/test_exp_wrapper.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +import torch as th +import torch.nn as nn +from torch import Tensor as T +from torch.autograd import Variable as V + +import torch_igt + + +class Vector(nn.Module): + def __init__(self, init): + super(Vector, self).__init__() + self.vector = nn.Parameter(init) + + def forward(self): + return self.vector + + +def dist(x, y): + return (x - y).pow(2).sum() + + +def close(x, y): + return (x - y).pow(2).sum() < 1e-8 + + +if __name__ == '__main__': + H = V(T([[2, 0], [0, 1]])) + model = Vector(T([1, 1]).view(2, 1)) + opt = th.optim.SGD(model.parameters(), lr=0.1, + momentum=0.0, weight_decay=0.0) + opt = torch_igt.Exp(model.parameters(), opt=opt, delta=1) + + xs = [ + T([1, 1]).view(2, 1), + T([0.8, 0.9]).view(2, 1), + T([0.64, 0.81]).view(2, 1), + T([0.496, 0.7245]).view(2, 1), + T([0.3968, 0.65205]).view(2, 1), + T([0.30752, 0.5832225]).view(2, 1), + ] + for i in range(6): + # Compute one step on the reference + assert close(xs[i], model.vector.data) + opt.zero_grad() + loss = 0.5 * th.mm(model().t(), th.mm(H, model())) + loss.backward() + opt.step() diff --git a/tests/test_heavyball_igt.py b/tests/test_heavyball_igt.py new file mode 100644 index 0000000..e65b3c0 --- /dev/null +++ b/tests/test_heavyball_igt.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +import torch as th +import torch.nn as nn +from torch import Tensor as T +from torch.autograd import Variable as V + +import torch_igt + +""" +The reference values were hand-derived for this toy example, +following the algorithm from the paper. +(C.f. ./derivations/heavyball_igt.svg) +""" + + +class Vector(nn.Module): + def __init__(self, init): + super(Vector, self).__init__() + self.vector = nn.Parameter(init) + + def forward(self): + return self.vector + + +def close(x, y): + return (x - y).pow(2).sum() < 1e-8 + + +reference = [ + + { + 'transported_grad': T([2, 1]).view(2, 1), + 'true_param': T([0.8, 0.9]).view(2, 1), + 'transported_param': T([0.6, 0.8]).view(2, 1), + 'igt_velocity': T([2, 1]).view(2, 1), + 'momentum_velocity': T([-0.2, -0.1]).view(2, 1), + }, + + { + 'transported_grad': T([1.2, 0.8]).view(2, 1), + 'true_param': T([0.54, 0.76]).view(2, 1), + 'transported_param': T([0.02, 0.48]).view(2, 1), + 'igt_velocity': T([1.6, 0.9]).view(2, 1), + 'momentum_velocity': T([-0.26, -0.14]).view(2, 1), + }, + { + 'transported_grad': T([0.04, 0.48]).view(2, 1), + 'true_param': T([0.302, 0.614]).view(2, 1), + 'transported_param': T([-0.412, 0.176]).view(2, 1), + 'igt_velocity': T([1.08, 0.76]).view(2, 1), + 'momentum_velocity': T([-0.238, -0.146]).view(2, 1), + }, + { + 'transported_grad': T([-0.824, 0.176]).view(2, 1), + 'true_param': T([0.1226, 0.4796]).view(2, 1), + 'transported_param': T([-0.595, -0.058]).view(2, 1), + 'igt_velocity': T([0.604, 0.614]).view(2, 1), + 'momentum_velocity': T([-0.1794, -0.1344]).view(2, 1), + }, + +] + + +if __name__ == '__main__': + model = Vector(th.ones(2, 1)) + opt = torch_igt.MomentumIGT(model.parameters(), lr=0.1, momentum=0.5) + + H = V(T([[2, 0], [0, 1]])) + for i in range(4): + opt.train() + opt.zero_grad() + loss = 0.5 * th.mm(model().t(), + th.mm(H, model())) + loss.backward() + assert(close(model.vector.grad.data, + reference[i]['transported_grad'])) + + opt.step() + params = opt.state[model.vector] + assert(close(params['igt_velocity'], + reference[i]['igt_velocity'])) + # Adjust for Pytorch's (equivalent) way of computing momentum + assert(close(params['momentum_velocity'] * -0.1, + reference[i]['momentum_velocity'])) + assert(close(model().data, + reference[i]['transported_param'])) + + opt.eval() + assert(close(model().data, + reference[i]['true_param'])) + + opt.train() + assert(close(model().data, + reference[i]['transported_param'])) diff --git a/tests/test_igt.py b/tests/test_igt.py new file mode 100644 index 0000000..da3c820 --- /dev/null +++ b/tests/test_igt.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +import torch as th +import torch.nn as nn +from torch import Tensor as T +from torch.autograd import Variable as V + +import torch_igt + +""" +The reference values were hand-derived for this toy example, +following the algorithm from the paper. +(C.f. ./derivations/igt.svg) +""" + + +class Vector(nn.Module): + def __init__(self, init): + super(Vector, self).__init__() + self.vector = nn.Parameter(init) + + def forward(self): + return self.vector + + +def close(x, y): + return (x - y).pow(2).sum() < 1e-8 + + +reference = [ + + { + 'transported_grad': T([2, 1]).view(2, 1), + 'true_param': T([0.8, 0.9]).view(2, 1), + 'transported_param': T([0.6, 0.8]).view(2, 1), + 'igt_velocity': T([2, 1]).view(2, 1), + }, + + { + 'transported_grad': T([1.2, 0.8]).view(2, 1), + 'true_param': T([0.64, 0.81]).view(2, 1), + 'transported_param': T([0.32, 0.63]).view(2, 1), + 'igt_velocity': T([1.6, 0.9]).view(2, 1), + }, + { + 'transported_grad': T([0.64, 0.63]).view(2, 1), + 'true_param': T([0.512, 0.729]).view(2, 1), + 'transported_param': T([0.128, 0.486]).view(2, 1), + 'igt_velocity': T([1.28, 0.81]).view(2, 1), + }, + +] + + +if __name__ == '__main__': + model = Vector(th.ones(2, 1)) + opt = torch_igt.MomentumIGT(model.parameters(), lr=0.1, momentum=0.0) + + H = V(T([[2, 0], [0, 1]])) + for i in range(3): + opt.train() + opt.zero_grad() + loss = 0.5 * th.mm(model().t(), + th.mm(H, model())) + loss.backward() + assert(close(model.vector.grad.data, + reference[i]['transported_grad'])) + + opt.step() + params = opt.state[model.vector] + assert(close(params['igt_velocity'], + reference[i]['igt_velocity'])) + assert(close(model().data, + reference[i]['transported_param'])) + + opt.eval() + assert(close(model().data, + reference[i]['true_param'])) + + opt.train() + assert(close(model().data, + reference[i]['transported_param'])) diff --git a/tests/test_igt_wrapper.py b/tests/test_igt_wrapper.py new file mode 100644 index 0000000..2aeef87 --- /dev/null +++ b/tests/test_igt_wrapper.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 + +import torch as th +import torch.nn as nn +from torch import Tensor as T +from torch.autograd import Variable as V + +import torch_igt + + +class Vector(nn.Module): + def __init__(self, init): + super(Vector, self).__init__() + self.vector = nn.Parameter(init) + + def forward(self): + return self.vector + + +def dist(x, y): + return (x - y).pow(2).sum() + + +def close(x, y): + return (x - y).pow(2).sum() < 1e-8 + + +if __name__ == '__main__': + H = V(T([[2, 0], [0, 1]])) + model1 = Vector(th.ones(2, 1)) + model2 = Vector(th.ones(2, 1)) + + ref = torch_igt.MomentumIGT(model1.parameters(), lr=0.1, momentum=0.5) + opt = th.optim.SGD(model2.parameters(), lr=0.1, momentum=0.5) + igt = torch_igt.IGTransporter(model2.parameters(), opt) + + for i in range(100): + # Compute one step on the reference + ref.train() + ref.zero_grad() + loss1 = 0.5 * th.mm(model1().t(), + th.mm(H, model1())) + loss1.backward() + ref.step() + + # Compute 1 step on the wrapper + igt.train() + igt.zero_grad() + loss2 = 0.5 * th.mm(model2().t(), + th.mm(H, model2())) + loss2.backward() + igt.step() + + # Test identical parameters (train and eval) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.data, p2.data) + + ref.eval() + igt.eval() + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.data, p2.data) + + ref.train() + igt.train() diff --git a/tests/test_ncigt.py b/tests/test_ncigt.py new file mode 100644 index 0000000..83f7fce --- /dev/null +++ b/tests/test_ncigt.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +import torch as th +import torch.nn as nn +import torch.optim as optim +from torch import Tensor as T +from torch.autograd import Variable as V + +import torch_igt + +""" +The reference values were hand-derived for this toy example, +following the algorithm from the paper. +(C.f. ./derivations/ncigt.pdf) +""" + + +class Vector(nn.Module): + def __init__(self, init): + super(Vector, self).__init__() + self.vector = nn.Parameter(init) + + def forward(self): + return self.vector + + +def close(x, y): + return (x - y).pow(2).sum() < 1e-8 + + +reference = [ + + {}, # In NC-IGT, t is init at 1 + { + 'theta': T([1, 1]).view(2, 1), + 'theta_tilde': T([1, 1]).view(2, 1), + 'g_tilde': T([2, 1]).view(2, 1), + 'g_hat': T([2, 1]).view(2, 1), + }, + { + 'theta': T([0.8, 0.9]).view(2, 1), + 'theta_tilde': T([0.8, 0.9]).view(2, 1), + 'g_tilde': T([1.6, 0.9]).view(2, 1), + 'g_hat': T([1.6, 0.9]).view(2, 1), + }, + { + 'theta': T([0.64, 0.81]).view(2, 1), + 'theta_tilde': T([0.56, 0.765]).view(2, 1), + 'g_tilde': T([1.12, 0.765]).view(2, 1), + 'g_hat': T([1.28, 0.81]).view(2, 1), + }, + { + 'theta': T([0.512, 0.729]).view(2, 1), + 'theta_tilde': T([0.464, 0.693]).view(2, 1), + 'g_tilde': T([0.928, 0.693]).view(2, 1), + 'g_hat': T([1.024, 0.729]).view(2, 1), + }, + { + 'theta': T([0.4096, 0.6561]).view(2, 1), + 'theta_tilde': T([0.256, 0.54675]).view(2, 1), + 'g_tilde': T([0.512, 0.54675]).view(2, 1), + 'g_hat': T([0.8192, 0.6561]).view(2, 1), + }, + { + 'theta': T([0.32768, 0.59049]).view(2, 1), + 'theta_tilde': T([0.21504, 0.49572]).view(2, 1), + 'g_tilde': T([0.43008, 0.49572]).view(2, 1), + 'g_hat': T([0.65536, 0.59049]).view(2, 1), + }, + { + 'theta': T([0.262144, 0.531441]).view(2, 1), + 'theta_tilde': T([0.1905, 0.4531]).view(2, 1), + 'g_tilde': None, + 'g_hat': None, + }, + + +] + + +if __name__ == '__main__': + model = Vector(th.ones(2, 1)) + opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.0) + opt = torch_igt.NCIGT(model.parameters(), opt=opt) + + H = V(T([[2, 0], [0, 1]])) + for t in range(1, 1+6): + opt.train() + assert close(model.vector.data, + reference[t]['theta_tilde']) + opt.zero_grad() + loss = 0.5 * th.mm(model().t(), + th.mm(H, model())) + loss.backward() + assert close(model.vector.data, + reference[t]['theta_tilde']) + assert close(model.vector.grad.data, + reference[t]['g_tilde']) + + opt.step() + assert close(model.vector.grad.data, + reference[t]['g_hat']) + + assert close(model.vector.data, + reference[t+1]['theta_tilde']) + opt.eval() + assert close(model.vector.data, + reference[t+1]['theta']) + + opt.train() + assert close(model.vector.data, + reference[t+1]['theta_tilde']) + diff --git a/tests/test_nesterov_igt.py b/tests/test_nesterov_igt.py new file mode 100644 index 0000000..bd7804a --- /dev/null +++ b/tests/test_nesterov_igt.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + +import torch as th +import torch.nn as nn +from torch import Tensor as T +from torch.autograd import Variable as V + +import torch_igt + +""" +The reference values were hand-derived for this toy example, +following the algorithm from the paper. +(C.f. ./derivations/nesterov_igt.svg) +""" + + +class Vector(nn.Module): + def __init__(self, init): + super(Vector, self).__init__() + self.vector = nn.Parameter(init) + + def forward(self): + return self.vector + + +def close(x, y): + return (x - y).pow(2).sum() < 1e-8 + + +reference = [ + + { + 'transported_grad': T([2, 1]).view(2, 1), + 'true_param': T([0.7, 0.85]).view(2, 1), + 'transported_param': T([0.4, 0.7]).view(2, 1), + 'igt_velocity': T([2, 1]).view(2, 1), + 'momentum_velocity': T([-0.2, -0.1]).view(2, 1), + }, + + { + 'transported_grad': T([0.8, 0.7]).view(2, 1), + 'true_param': T([0.44, 0.6975]).view(2, 1), + 'transported_param': T([-0.08, 0.3925]).view(2, 1), + 'igt_velocity': T([1.4, 0.85]).view(2, 1), + 'momentum_velocity': T([-0.24, -0.135]).view(2, 1), + }, + { + 'transported_grad': T([-0.16, 0.3925]).view(2, 1), + 'true_param': T([0.248, 0.559125]).view(2, 1), + 'transported_param': T([-0.328, 0.144]).view(2, 1), + 'igt_velocity': T([0.88, 0.6975]).view(2, 1), + 'momentum_velocity': T([-0.208, -0.13725]).view(2, 1), + }, + { + 'transported_grad': T([-0.656, 0.144]).view(2, 1), + 'true_param': T([0.1216, 0.44094375]).view(2, 1), + 'transported_param': T([-0.384, -0.0318]).view(2, 1), + 'igt_velocity': T([0.496, 0.559125]).view(2, 1), + 'momentum_velocity': T([-0.1536, -0.1245375]).view(2, 1), + }, + +] + + +if __name__ == '__main__': + model = Vector(th.ones(2, 1)) + opt = torch_igt.MomentumIGT(model.parameters(), lr=0.1, + momentum=0.5, nesterov=True) + + H = V(T([[2, 0], [0, 1]])) + for i in range(4): + opt.train() + opt.zero_grad() + loss = 0.5 * th.mm(model().t(), + th.mm(H, model())) + loss.backward() + assert(close(model.vector.grad.data, + reference[i]['transported_grad'])) + + opt.step() + params = opt.state[model.vector] + assert(close(params['igt_velocity'], + reference[i]['igt_velocity'])) + # Adjust for Pytorch's (equivalent) way of computing momentum + assert(close(params['momentum_velocity'] * -0.1, + reference[i]['momentum_velocity'])) + assert(close(model().data, + reference[i]['transported_param'])) + + opt.eval() + assert(close(model().data, + reference[i]['true_param'])) + + opt.train() + assert(close(model().data, + reference[i]['transported_param'])) diff --git a/tests/test_wrapper_cnn.py b/tests/test_wrapper_cnn.py new file mode 100644 index 0000000..885e2aa --- /dev/null +++ b/tests/test_wrapper_cnn.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import torch as th +import torch.nn as nn +from torch import Tensor as T +from torch.autograd import Variable as V +import torch.nn.functional as F + +import torch_igt + + +class Convnet(nn.Module): + def __init__(self): + super(Convnet, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def dist(x, y): + return (x - y).pow(2).sum() + + +def close(x, y): + return dist(x, y) < 1e-8 + + +if __name__ == '__main__': + th.manual_seed(1234) + model1 = Convnet().double() + model2 = Convnet().double() + for p1, p2 in zip(model1.parameters(), model2.parameters()): + p1.data.copy_(p2.data) + + ref = torch_igt.MomentumIGT(model1.parameters(), lr=0.1, momentum=0.5) + opt = th.optim.SGD(model2.parameters(), lr=0.1, momentum=0.5) + igt = torch_igt.IGTransporter(model2.parameters(), opt) + + x = V(th.randn(3, 1, 28, 28).double(), requires_grad=False) + + for i in range(100): + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.data, p2.data) + + # Compute reference gradients + ref.train() + ref.zero_grad() + loss1 = model1.forward(x).pow(2).mean() + loss1.backward() + + # Compute wrapper gradients + igt.train() + igt.zero_grad() + loss2 = model2.forward(x).pow(2).mean() + loss2.backward() + + assert close(loss1.data, loss2.data) + + # Test identical gradients + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.grad.data, p2.grad.data) + + # Take on step + ref.step() + igt.step() + + # Test identical parameters (train) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.data, p2.data) + + # Test identical parameters (eval) + ref.eval() + igt.eval() + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert close(p1.data, p2.data) + + ref.train() + igt.train() diff --git a/torch_igt/__init__.py b/torch_igt/__init__.py new file mode 100644 index 0000000..0bc25a8 --- /dev/null +++ b/torch_igt/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from .igt import MomentumIGT +from .adam_igt import AdamIGT +from .wrapper import IGTransporter, ExpIGT, SoftExpIGT,SoftResetExpIGT, Exp, ExpIGTCont +from .non_convex_wrapper import NCIGT +from .transported_exp_ata import ITA diff --git a/torch_igt/adam_igt.py b/torch_igt/adam_igt.py new file mode 100644 index 0000000..b48d12c --- /dev/null +++ b/torch_igt/adam_igt.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 + +import math +import torch as th +from torch.optim.optimizer import Optimizer, required + + +class AdamIGT(Optimizer): + + def __init__(self, + params=required, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + delta=1.0): + defaults = { + 'delta': delta, + 'num_steps': 0, + 'train': True, + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, + } + super(AdamIGT, self).__init__(params, defaults) + + def compute_update(self, p, param_state, group): + exp_avg = param_state['exp_avg'] + exp_avg_sq = param_state['exp_avg_sq'] + beta1, beta2 = group['betas'] + lr = group['lr'] + grad = p.grad.data + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + denom = exp_avg_sq.sqrt().add_(group['eps']) + + # NOTE: the + 1 is because IGT and Adam don't count steps the same way. + bias_correction1 = 1 - beta1 ** (group['num_steps'] + 1) + bias_correction2 = 1 - beta2 ** (group['num_steps'] + 1) + step_size = lr * math.sqrt(bias_correction2) / bias_correction1 + update = -step_size * (exp_avg / denom) + return update + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + for group in self.param_groups: + delta = group['delta'] + num_steps = group['num_steps'] + gamma = (num_steps) / (num_steps + delta) + future_gamma = (num_steps + 1) / (num_steps + 1 + delta) + future_transport = future_gamma / (1.0 - future_gamma) + for p in group['params']: + + if p.grad is None: + continue + + d_p = p.grad.data + param_state = self.state[p] + + # Compute the IGT estimate + if num_steps == 0: + param_state['igt_estimate'] = th.zeros_like(d_p) + param_state['igt_estimate'].add_(d_p) + param_state['true_p'] = th.zeros_like(p.data) + param_state['true_p'].add_(p.data) + param_state['exp_avg'] = th.zeros_like(p.data) + param_state['exp_avg_sq'] = th.zeros_like(p.data) + true_p = param_state['true_p'] + else: + igt_estimate = param_state['igt_estimate'] + true_p = param_state['true_p'] + igt_estimate.mul_(gamma).add_((1.0 - gamma), d_p) + # Sets gradients to the IGT estimate + d_p.copy_(igt_estimate) + p.data.copy_(true_p) # Revert to true params + + # Take the step according to opt + update = self.compute_update(p, param_state, group) + + # Transport to the next parameter point + true_p.copy_(p.data).add_(update) + p.data.add_(1.0 + future_transport, update) + group['num_steps'] += 1 + return loss + + def train(self): + for group in self.param_groups: + if not group['train']: + for p in group['params']: + param_state = self.state[p] + true_p = param_state['true_p'] + temp_p = p.data.clone() + p.data.copy_(true_p) + true_p.copy_(temp_p) + group['train'] = True + + def eval(self): + for group in self.param_groups: + if group['train']: + for p in group['params']: + param_state = self.state[p] + true_p = param_state['true_p'] + temp_p = p.data.clone() + p.data.copy_(true_p) + true_p.copy_(temp_p) + group['train'] = False diff --git a/torch_igt/igt.py b/torch_igt/igt.py new file mode 100644 index 0000000..5dee7ce --- /dev/null +++ b/torch_igt/igt.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 + +import torch as th +from torch.optim.optimizer import Optimizer, required + + +class MomentumIGT(Optimizer): + + """ + Implementation of Implicit Gradient Transport, + and its Heavyball and Nesterov versions. + + Arguments: + + * params -- parameters to optimize + * lr=required -- learning rate + * momentum=0.0 -- momentum factor + * dampening=0.0 -- momentum dampening factor + * weight_decay=0.0 -- weight decay factor + * nesterov=False -- wheter to use nesterov or not + * delta=1.0 -- IGT's delta displacement parameter + + Example: + + # Training + opt = MomentumIGT(model.parameters(), lr=0.1, momentum=0.9) + opt.train() # Optional if opt.eval() is never used + error = loss(model(X), y) + error.backward() + opt.step() + + # Evaluation + model.eval() + opt.eval() + error = loss(model(X), y) + + Notes: + + * This implementation requires 5 copies of the model's parameters. + (igt_velocity, momentum_velocity, true_params, gradients, params) + I think it's possible to have a version with only 4 copies, + but it would sacrifice some clarity. + * Heavyball and Nesterov are implemented as in PyTorch's SGD. + """ + + def __init__(self, params, lr=required, momentum=0.0, dampening=0.0, + weight_decay=0.0, nesterov=False, delta=1.0): + + if weight_decay < 0.0: + msg = "Invalid weight_decay value: {}".format(weight_decay) + raise ValueError(msg) + + if delta <= 0.0: + raise ValueError("Invalid delta value: {}".format(delta)) + + defaults = { + 'lr': lr, + 'momentum': momentum, + 'dampening': dampening, + 'weight_decay': weight_decay, + 'nesterov': nesterov, + 'delta': delta, + 'num_steps': 0, + 'transported': True, + } + if nesterov and (momentum <= 0 or dampening != 0): + msg = "Nesterov momentum requires a momentum and zero dampening" + raise ValueError(msg) + + super(MomentumIGT, self).__init__(params, defaults) + + + def __setstate__(self, state): + super(MomentumIGT, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + lr = group['lr'] + momentum = group['momentum'] + dampening = group['dampening'] + weight_decay = group['weight_decay'] + nesterov = group['nesterov'] + delta = group['delta'] + num_steps = group['num_steps'] + + gamma = num_steps / (num_steps + delta) + future_gamma = (num_steps + 1) / (num_steps + 1 + delta) + future_transport = future_gamma / (1.0 - future_gamma) + + for p in group['params']: + + if p.grad is None: + continue + + d_p = p.grad.data # Transported gradients + param_state = self.state[p] + + # Apply weight decay + if weight_decay != 0: + d_p.add_(weight_decay, p.data) + + # Init buffers appropriately + if num_steps == 0: + + if 'igt_velocity' not in param_state: + param_state['igt_velocity'] = th.zeros_like(p.data) + if 'true_params' not in param_state: + param_state['true_params'] = p.data.clone() + if momentum != 0 and 'momentum_velocity' not in param_state: + param_state['momentum_velocity'] = th.zeros_like(p.data) + + igt_velocity = param_state['igt_velocity'] + true_params = param_state['true_params'] + + # Compute first step and initial values + igt_velocity.add_(d_p) + if momentum != 0: + param_state['momentum_velocity'].add_(igt_velocity) + + # Compute the IGT update + else: + igt_velocity = param_state['igt_velocity'] + true_params = param_state['true_params'] + + # Update IGT's velocity + igt_velocity.mul_(gamma).add_((1.0 - gamma), d_p) + + # Compute momentum if necessary + if momentum != 0: + momentum_velocity = param_state['momentum_velocity'] + momentum_velocity.mul_(momentum).add_(1.0 - dampening, + igt_velocity) + + # Update true and transported parameters + if momentum == 0: + # Update true parameters + true_params.add_(-lr, igt_velocity) + + # Set parameters to transported ones + p.data.copy_(true_params) + p.data.add_(-lr * future_transport, igt_velocity) + else: + momentum_velocity = param_state['momentum_velocity'] + if nesterov: + true_params.add_(-lr, igt_velocity) + true_params.add_(-lr * momentum, momentum_velocity) + + # Set parameters to transported ones + p.data.copy_(true_params) + p.data.add_(-lr * future_transport, igt_velocity) + p.data.add_(-lr * momentum * future_transport, + momentum_velocity) + else: + true_params.add_(-lr, momentum_velocity) + + # Set parameters to transported ones + p.data.copy_(true_params) + p.data.add_(-lr * future_transport, momentum_velocity) + + group['num_steps'] = num_steps + 1 + return loss + + def train(self): + """ + Swaps true and transported parameters. + + Useful for switching from inference to training. + """ + for group in self.param_groups: + lr = group['lr'] + momentum = group['momentum'] + nesterov = group['nesterov'] + delta = group['delta'] + num_steps = group['num_steps'] + transported = group['transported'] + + gamma = (num_steps) / (num_steps + delta) + transport = gamma / (1.0 - gamma) + + if not transported and num_steps > 0: + for p in group['params']: + # Should compute the future transported params + param_state = self.state[p] + igt_velocity = param_state['igt_velocity'] + true_params = param_state['true_params'] + p.data.copy_(true_params) + if momentum == 0: + p.data.add_(-lr * transport, igt_velocity) + else: + momentum_velocity = param_state['momentum_velocity'] + if nesterov: + p.data.add_(-lr * transport, igt_velocity) + p.data.add_(-lr * momentum * transport, + momentum_velocity) + else: + p.data.add_(-lr * transport, momentum_velocity) + group['transported'] = True + + def eval(self): + for group in self.param_groups: + if group['transported']: + for p in group['params']: + # Copy true_params to the params + p.data.copy_(self.state[p]['true_params']) + group['transported'] = False diff --git a/torch_igt/non_convex_wrapper.py b/torch_igt/non_convex_wrapper.py new file mode 100644 index 0000000..59c6115 --- /dev/null +++ b/torch_igt/non_convex_wrapper.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 + +import torch as th +from torch.optim.optimizer import Optimizer, required +from torch_igt.wrapper import IGTransporter + + +class NCIGT(IGTransporter, Optimizer): + """ Implementation of Non-Convex IGT. """ + + def __init__(self, params=required, opt=required, interval=2.0): + self.opt = opt + defaults = { + 'num_steps': 1, + 'train': True, + 'interval': interval, + } + super(IGTransporter, self).__init__(params, defaults) + + def step(self, closure=None): + for group in self.param_groups: + t = group['num_steps'] + c = group['interval'] + for p in group['params']: + + if p.grad is None: + continue + + d_p = p.grad.data + param_state = self.state[p] + + # Compute the NC-IGT estimate + if t == 1: + param_state['true_p'] = th.zeros_like(p.data) + param_state['true_p'].add_(p.data) + param_state['g_hat_1'] = th.zeros_like(d_p) + param_state['g_hat_2'] = th.zeros_like(d_p) + param_state['theta_hat_1'] = th.zeros_like(p.data) + param_state['theta_hat_1'].add_(p.data) + param_state['theta_hat_2'] = th.zeros_like(p.data) + param_state['N_t'] = 0 + else: + # Compute the gradient estimate + true_p = param_state['true_p'] + g_hat_1 = param_state['g_hat_1'] + g_hat_2 = param_state['g_hat_2'] + theta_hat_1 = param_state['theta_hat_1'] + theta_hat_2 = param_state['theta_hat_2'] + N_t = param_state['N_t'] + g_hat_1.mul_((N_t-1)/N_t).add_(1.0/N_t, d_p) + d_p.mul_(0).add_(g_hat_2).add_(c*N_t/t, g_hat_1 - g_hat_2) + + # Perform a reset + if N_t >= t / c: + g_hat_2.copy_(g_hat_1) + theta_hat_2.copy_(theta_hat_1) + g_hat_1.mul_(0) + param_state['N_t'] = 0 + + # Revert to true params for update + p.data.copy_(true_p) + + # Take the optimization step + result = self.opt.step(closure) + + # Compute and set transported params theta_tilde + for group in self.param_groups: + group['num_steps'] += 1 + t = group['num_steps'] + c = group['interval'] + for p in group['params']: + + if p.grad is None: + continue + + param_state = self.state[p] + param_state['N_t'] += 1 + N_t = param_state['N_t'] + true_p = param_state['true_p'] + theta_hat_1 = param_state['theta_hat_1'] + theta_hat_2 = param_state['theta_hat_2'] + + # First, copy true params from after the update + true_p.copy_(p.data) + + # Then compute theta_tilde + theta_tilde = p.data + theta_tilde.mul_(0).add_(-(N_t - 1), theta_hat_1) + theta_hat_1.mul_(0) + theta_hat_1.add_(1/N_t, + t/c * true_p - (t/c - N_t) * theta_hat_2) + theta_tilde.add_(N_t, theta_hat_1) + + return result diff --git a/torch_igt/transported_exp_ata.py b/torch_igt/transported_exp_ata.py new file mode 100644 index 0000000..ba43216 --- /dev/null +++ b/torch_igt/transported_exp_ata.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +import math +import torch as th +from torch.optim.optimizer import required + +from torch_igt.wrapper import IGTransporter + + +def exp_ata_weight(c, num_steps): + if num_steps < 2: + return 0 + gamma = c * num_steps / (1.0 + c * num_steps) + gamma *= 1.0 - math.sqrt((1.0 - c) / (num_steps * (num_steps + 1))) / c + return gamma + + +class ITA(IGTransporter): + + """ + In this implementation, the following variables are defined as + + * igt_estimate: the IGT buffers that get updated with every new + stochastic gradient. + * true_p: keeps track of the unshifted parameters when training, + and the model parameters are set at the shifted position. Conversely, + when calling `eval()` the parameters of the model are set to the + unshifted values, and true_p becomes the shifted ones. You can switch + back using `train()`. + """ + + def __init__(self, params=required, opt=required, delta=1.0, interval=2.0): + self.opt = opt + defaults = { + 'delta': delta, + 'c': 1.0 / interval, + 'num_steps': 0.0, + 'train': True, + } + super(IGTransporter, self).__init__(params, defaults) + + def step(self, closure=None): + for group in self.param_groups: + num_steps = group['num_steps'] + c = group['c'] + + # Compute the exponential ATA weighting + gamma = exp_ata_weight(c, num_steps) + for p in group['params']: + + if p.grad is None: + continue + + d_p = p.grad.data + param_state = self.state[p] + + # Compute the IGT estimate + if num_steps == 0: + param_state['igt_estimate'] = th.zeros_like(d_p) + param_state['true_p'] = th.zeros_like(p.data) + param_state['igt_estimate'].add_(d_p) + param_state['true_p'].add_(p.data) + true_p = param_state['true_p'] + else: + igt_estimate = param_state['igt_estimate'] + true_p = param_state['true_p'] + igt_estimate.mul_(gamma).add_((1.0 - gamma), d_p) + # Sets gradients to the IGT estimate + d_p.copy_(igt_estimate) + p.data.copy_(true_p) # Revert to true params + + # Take the step according to opt + loss = self.opt.step(closure) + + # Transport to the next parameter point + for group in self.param_groups: + num_steps = group['num_steps'] + + # Compute the next exponential ATA weighting + c = group['c'] + future_gamma = exp_ata_weight(c, num_steps + 1.0) + future_transport = future_gamma / (1.0 - future_gamma) + for p in group['params']: + true_p = self.state[p]['true_p'] + temp_p = p.data.clone() + vector_change = p.data.add(-1.0, true_p) + """ + NOTE: The numerical problem is here. + Essentially, computing vector change involves a subtraction, + while computing the update with opt.step() is a multiplication. + + Subtraction is numerically unstable and hence the observed + differences in algorithms. + Note: this mainly depends on the loss computation. If it is + stable, then using the parameter difference doesn't greatly + diverge. + """ + p.data.add_(future_transport, vector_change) + true_p.copy_(temp_p) + group['num_steps'] += 1.0 + return loss diff --git a/torch_igt/wrapper.py b/torch_igt/wrapper.py new file mode 100644 index 0000000..847a286 --- /dev/null +++ b/torch_igt/wrapper.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 + +import torch as th +from torch.optim.optimizer import Optimizer, required + + +class IGTransporter(Optimizer): + + """ + In this implementation, the following variables are defined as + + * igt_estimate: the IGT buffers that get updated with every new + stochastic gradient. + * true_p: keeps track of the unshifted parameters when training, + and the model parameters are set at the shifted position. Conversely, + when calling `eval()` the parameters of the model are set to the + unshifted values, and true_p becomes the shifted ones. You can switch + back using `train()`. + """ + + def __init__(self, params=required, opt=required, delta=1.0, interval=1.0): + self.opt = opt + defaults = { + 'delta': delta, + 'interval': 1.0 / interval, + 'num_steps': 0, + 'train': True, + } + super(IGTransporter, self).__init__(params, defaults) + + def step(self, closure=None): + for group in self.param_groups: + delta = group['delta'] + num_steps = group['num_steps'] + interval = group['interval'] + gamma = (num_steps) / (num_steps + delta) + gamma = gamma**(2.0 / interval - 1.0) + for p in group['params']: + + if p.grad is None: + continue + + d_p = p.grad.data + param_state = self.state[p] + + # Compute the IGT estimate + if num_steps == 0: + param_state['igt_estimate'] = th.zeros_like(d_p) + param_state['true_p'] = th.zeros_like(p.data) + param_state['igt_estimate'].add_(d_p) + param_state['true_p'].add_(p.data) + true_p = param_state['true_p'] + else: + igt_estimate = param_state['igt_estimate'] + true_p = param_state['true_p'] + igt_estimate.mul_(gamma).add_((1.0 - gamma), d_p) + # Sets gradients to the IGT estimate + d_p.copy_(igt_estimate) + p.data.copy_(true_p) # Revert to true params + + # Take the step according to opt + loss = self.opt.step(closure) + + # Transport to the next parameter point + for group in self.param_groups: + delta = group['delta'] + num_steps = group['num_steps'] + future_gamma = (num_steps + 1) / (num_steps + 1 + delta) + future_transport = future_gamma / (1.0 - future_gamma) + for p in group['params']: + true_p = self.state[p]['true_p'] + temp_p = p.data.clone() + vector_change = p.data.add(-1.0, true_p) + """ + NOTE: The numerical problem is here. + Essentially, computing vector change involves a subtraction, + while computing the update with opt.step() is a multiplication. + + Subtraction is numerically unstable and hence the observed + differences in algorithms. + Note: this mainly depends on the loss computation. If it is + stable, then using the parameter difference doesn't greatly + diverge. + """ + p.data.add_(future_transport, vector_change) + true_p.copy_(temp_p) + group['num_steps'] += 1 + return loss + + def train(self): + for group in self.param_groups: + if not group['train']: + for p in group['params']: + param_state = self.state[p] + true_p = param_state['true_p'] + temp_p = p.data.clone() + p.data.copy_(true_p) + true_p.copy_(temp_p) + group['train'] = True + + def eval(self): + for group in self.param_groups: + if group['train']: + for p in group['params']: + param_state = self.state[p] + true_p = param_state['true_p'] + temp_p = p.data.clone() + p.data.copy_(true_p) + true_p.copy_(temp_p) + group['train'] = False + + +class ExpIGT(IGTransporter): + + def __init__(self, params=required, opt=required, delta=1.0): + self.opt = opt + defaults = { + 'delta': delta, + 'num_steps': 0, + 'train': True, + 'num_resets': 0, + 'exp_power': 2, + } + super(IGTransporter, self).__init__(params, defaults) + + def step(self, closure=None): + result = super(ExpIGT, self).step(closure) + for group in self.param_groups: + assert group['train'], 'Called step not in train mode.' + + num_steps = group['num_steps'] + num_resets = group['num_resets'] + exp_power = group['exp_power'] + if exp_power**num_resets == num_steps: + group['num_resets'] = num_resets + 1 + group['num_steps'] = 0 + # Then we perform a reset + for p in group['params']: + param_state = self.state[p] + # First, move the true params to shifted ones + true_p = param_state['true_p'] + true_p.copy_(p.data) + # Second, zero-out the IGT buffers + param_state['igt_estimate'].mul_(0) + return result + + +class SoftExpIGT(ExpIGT): + + def step(self, closure=None): + result = super(ExpIGT, self).step(closure) + for group in self.param_groups: + assert group['train'], 'Called step not in train mode.' + + num_steps = group['num_steps'] + num_resets = group['num_resets'] + exp_power = group['exp_power'] + if exp_power**num_resets == num_steps: + group['num_resets'] = num_resets + 1 + group['num_steps'] = 1 + # Then we perform a reset + for p in group['params']: + param_state = self.state[p] + # Only, move the true params to shifted ones + true_p = param_state['true_p'] + true_p.copy_(p.data) + return result + + +class SoftResetExpIGT(ExpIGT): + + def step(self, closure=None): + result = super(ExpIGT, self).step(closure) + for group in self.param_groups: + assert group['train'], 'Called step not in train mode.' + + num_steps = group['num_steps'] + num_resets = group['num_resets'] + exp_power = group['exp_power'] + if exp_power**num_resets == num_steps: + group['num_resets'] = num_resets + 1 + group['num_steps'] = group['num_resets'] + # Then we perform a reset + for p in group['params']: + param_state = self.state[p] + # Only, move the true params to shifted ones + true_p = param_state['true_p'] + true_p.copy_(p.data) + return result + + +class ExpIGTCont(IGTransporter): + + def __init__(self, params=required, opt=required, delta=1.0): + self.opt = opt + defaults = { + 'delta': delta, + 'num_steps': 0, + 'train': True, + 'num_resets': 0, + 'exp_power': 2, + } + super(IGTransporter, self).__init__(params, defaults) + + def step(self, closure=None): + result = super(ExpIGTCont, self).step(closure) + for group in self.param_groups: + assert group['train'], 'Called step not in train mode.' + + num_steps = group['num_steps'] + num_resets = group['num_resets'] + exp_power = group['exp_power'] + if exp_power**num_resets == num_steps: + group['num_resets'] = num_resets + 1 + group['num_steps'] = 0 + # Then we perform a reset + for p in group['params']: + param_state = self.state[p] + # First, move the shifted params to true ones + true_p = param_state['true_p'] + p.data.copy_(true_p) + # Second, zero-out the IGT buffers + param_state['igt_estimate'].mul_(0) + return result + + +class Exp(Optimizer): + + def __init__(self, params=required, opt=required, delta=1): + self.opt = opt + defaults = { + 'delta': delta, + 'num_steps': 0, + 'num_resets': 0, + 'exp_power': 2, + } + super(Exp, self).__init__(params, defaults) + + def step(self, closure=None): + # Replace each gradient with exponential average + for group in self.param_groups: + num_steps = group['num_steps'] + num_resets = group['num_resets'] + exp_power = group['exp_power'] + delta = group['delta'] + gamma = (num_steps) / (num_steps + delta) + for p in group['params']: + param_state = self.state[p] + if 'exp_average' not in param_state: + # Init with stochastic gradient + param_state['exp_average'] = p.grad.data.clone() + else: + # Compute new exponential average + p.grad.data.mul_(1.0 - gamma).add_(gamma, + param_state['exp_average']) + param_state['exp_average'].copy_(p.grad.data) + group['num_steps'] += 1 + + # Take optimization step + result = self.opt.step(closure) + + # Reset buffers if necessary + for group in self.param_groups: + num_steps = group['num_steps'] + num_resets = group['num_resets'] + exp_power = group['exp_power'] + if exp_power**num_resets == num_steps: + group['num_resets'] = num_resets + 1 + group['num_steps'] = 0 + # Then we perform a reset + for p in group['params']: + param_state = self.state[p] + # Second, zero-out the averaging buffers + param_state['exp_average'].mul_(0) + return result