From 2a852ddef6192543ca9ab1eb550d44f0041857e5 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Fri, 22 Jul 2022 16:42:02 -0700 Subject: [PATCH] Add test_make_fx_model_train example (#82011) Summary: X-link: https://github.com/pytorch/pytorch/pull/82011 Pull Request resolved: https://github.com/pytorch/functorch/pull/980 Test Plan: CI should pass Reviewed By: benoitsteiner Differential Revision: D38078694 Pulled By: mostafaelhoushi fbshipit-source-id: ccfbeb8531d49d0d503e728f997a6003c87f9eb1 --- test/test_pythonkey.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index 6727659f3..dd3a06844 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -11,6 +11,8 @@ import unittest import warnings import itertools +import torch.nn.utils._stateless as stateless +from collections.abc import Iterable from functools import partial from torch.testing._internal.common_device_type import instantiate_device_type_tests from functorch import ( @@ -74,6 +76,53 @@ def f(x): new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp)) + def test_make_fx_model_fwd_bwd(self, device): + class Foo(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x).relu() + + model = Foo() + + def f(x, params): + out = stateless.functional_call(model, params, x).sum() + out.backward() + return list(params.values()) + input = torch.randn(3, 5, requires_grad=True) + params = dict(model.named_parameters()) + fx_f = make_fx(f)(input, params) + # fx may change the order of parameters in list, so using set() to compare + self.assertEqual(set(fx_f(input, params)), set(f(input, params))) + + def test_make_fx_model_fwd_bwd_wgtupdate(self, device): + class Foo(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x).relu() + + model = Foo() + + def f(args, params, buffers): + if not isinstance(args, Iterable): + args = [args] + params_and_buffers = {**params, **buffers} + out = stateless.functional_call(model, params_and_buffers, args) + out.sum().backward() + return [p - 1e-4 * p.grad for p in params.values()] + + input = torch.randn(3, 5, requires_grad=True) + params = dict(model.named_parameters()) + buffers = dict(model.named_buffers()) + fx_f = make_fx(f)(input, params, buffers) + # fx may change the order of parameters in list, so using set() to compare + self.assertEqual(set(fx_f(input, params, buffers)), set(f(input, params, buffers))) + def test_scalar_device(self, device): def f(a, b): return a + b