Skip to content

Commit

Permalink
Add text style transfer (#6)
Browse files Browse the repository at this point in the history
* initial commit

* linting
  • Loading branch information
swapnull7 authored Dec 9, 2019
1 parent adca7fa commit 6055075
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions texar/torch/utils/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
Utility functions related to variables.
"""

from typing import Any, List, Tuple, Union
from typing import List, Tuple, Union
import torch.nn as nn

from texar.torch.module_base import ModuleBase

__all__ = [
"add_variable",
Expand All @@ -24,8 +27,8 @@


def add_variable(
variable: Union[List[Any], Tuple[Any]],
var_list: List[Any]):
variable: Union[List[nn.Parameter], Tuple[nn.Parameter]],
var_list: List[nn.Parameter]):
r"""Adds variable to a given list.
Args:
Expand All @@ -40,7 +43,7 @@ def add_variable(


def collect_trainable_variables(
modules: Union[Any, List[Any]]
modules: Union[ModuleBase, List[ModuleBase]]
):
r"""Collects all trainable variables of modules.
Expand All @@ -57,7 +60,7 @@ def collect_trainable_variables(
if not isinstance(modules, (list, tuple)):
modules = [modules]

var_list: List[Any] = []
var_list: List[nn.Parameter] = []
for mod in modules:
add_variable(mod.trainable_variables, var_list)

Expand Down

0 comments on commit 6055075

Please sign in to comment.