From 17675dd142104f02c5d36a030db6052c57438eba Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Thu, 11 Mar 2021 22:17:00 -0800 Subject: [PATCH] Rewrite demucs to be TorchScript-able --- torchbenchmark/models/demucs/__init__.py | 26 +++++++---- .../models/demucs/demucs/augment.py | 9 ++-- torchbenchmark/models/demucs/demucs/model.py | 46 ++++++++++--------- torchbenchmark/models/demucs/demucs/utils.py | 12 ++--- 4 files changed, 53 insertions(+), 40 deletions(-) diff --git a/torchbenchmark/models/demucs/__init__.py b/torchbenchmark/models/demucs/__init__.py index 8d6a3512b6..4e38a8b249 100644 --- a/torchbenchmark/models/demucs/__init__.py +++ b/torchbenchmark/models/demucs/__init__.py @@ -10,6 +10,10 @@ from .demucs.utils import capture_init, center_trim from ...util.model import BenchmarkModel from torchbenchmark.tasks import OTHER +from torch import Tensor +from torch.nn.modules.container import Sequential +from torchbenchmark.models.demucs.demucs.model import Demucs +from typing import Optional, Tuple torch.manual_seed(1337) @@ -20,12 +24,12 @@ class DemucsWrapper(torch.nn.Module): - def __init__(self, model, augment): + def __init__(self, model: Demucs, augment: Sequential) -> None: super(DemucsWrapper, self).__init__() self.model = model self.augment = augment - def forward(self, streams): + def forward(self, streams: Tensor) -> Tuple[Tensor, Tensor]: sources = streams[:, 1:] sources = self.augment(sources) mix = sources.sum(dim=1) @@ -34,7 +38,7 @@ def forward(self, streams): class Model(BenchmarkModel): task = OTHER.OTHER_TASKS - def __init__(self, device=None, jit=False): + def __init__(self, device: Optional[str]=None, jit: bool=False) -> None: super().__init__() self.device = device self.jit = jit @@ -67,10 +71,13 @@ def __init__(self, device=None, jit=False): self.model = DemucsWrapper(self.model, self.augment) + if self.jit: + self.model = torch.jit.script(self.model) + def _set_mode(self, train): self.model.train(train) - def get_module(self): + def get_module(self) -> Tuple[DemucsWrapper, Tuple[Tensor]]: self.model.eval() return self.model, self.example_inputs @@ -92,8 +99,9 @@ def train(self, niter=1): if __name__ == '__main__': - m = Model(device='cuda', jit=False) - module, example_inputs = m.get_module() - module(*example_inputs) - m.train(niter=1) - m.eval(niter=1) + for jit in [True, False]: + m = Model(device='cuda', jit=jit) + module, example_inputs = m.get_module() + module(*example_inputs) + m.train(niter=1) + m.eval(niter=1) diff --git a/torchbenchmark/models/demucs/demucs/augment.py b/torchbenchmark/models/demucs/demucs/augment.py index b14a95e75e..8a9244c78c 100644 --- a/torchbenchmark/models/demucs/demucs/augment.py +++ b/torchbenchmark/models/demucs/demucs/augment.py @@ -23,7 +23,7 @@ def forward(self, wav): if not self.training: wav = wav[..., :length] else: - offsets = th.randint(self.shift, [batch, sources, 1, 1], device=wav.device) + offsets = th.randint(self.shift, [batch, sources, 1, 1], device=wav.device, dtype=th.int64) offsets = offsets.expand(-1, -1, channels, -1) indexes = th.arange(length, device=wav.device) wav = wav.gather(3, indexes + offsets) @@ -37,7 +37,7 @@ class FlipChannels(nn.Module): def forward(self, wav): batch, sources, channels, time = wav.size() if self.training and wav.size(2) == 2: - left = th.randint(2, (batch, sources, 1, 1), device=wav.device) + left = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.int64) left = left.expand(-1, -1, -1, time) right = 1 - left wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) @@ -77,7 +77,10 @@ def forward(self, wav): device = wav.device if self.training: - group_size = self.group_size or batch + if self.group_size is not None: + group_size = self.group_size + else: + group_size = batch if batch % group_size != 0: raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}") groups = batch // group_size diff --git a/torchbenchmark/models/demucs/demucs/model.py b/torchbenchmark/models/demucs/demucs/model.py index 9aff668674..aedc9f5a5e 100644 --- a/torchbenchmark/models/demucs/demucs/model.py +++ b/torchbenchmark/models/demucs/demucs/model.py @@ -7,13 +7,15 @@ import math import torch as th -from torch import nn +from torch import Tensor, nn from .utils import capture_init, center_trim +from torch.nn.modules.conv import Conv1d, ConvTranspose1d +from typing import Union class BLSTM(nn.Module): - def __init__(self, dim, layers=1): + def __init__(self, dim: int, layers: int=1) -> None: super().__init__() self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) self.lstm.flatten_parameters() @@ -27,7 +29,7 @@ def forward(self, x): return x -def rescale_conv(conv, reference): +def rescale_conv(conv: Union[Conv1d, ConvTranspose1d], reference: float) -> None: std = conv.weight.std().detach() scale = (std / reference)**0.5 conv.weight.data /= scale @@ -35,13 +37,13 @@ def rescale_conv(conv, reference): conv.bias.data /= scale -def rescale_module(module, reference): +def rescale_module(module: th.nn.Module, reference: float) -> None: for sub in module.modules(): if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): rescale_conv(sub, reference) -def upsample(x, stride): +def upsample(x, stride: int): """ Linear upsampling, the output will be `stride` times longer. """ @@ -52,7 +54,7 @@ def upsample(x, stride): return out.reshape(batch, channels, -1) -def downsample(x, stride): +def downsample(x, stride: int): """ Downsample x by decimation. """ @@ -64,19 +66,19 @@ class Demucs(th.jit.ScriptModule): @capture_init def __init__(self, - sources=4, - audio_channels=2, - channels=64, - depth=6, - rewrite=True, - glu=True, - upsample=False, - rescale=0.1, - kernel_size=8, - stride=4, - growth=2., - lstm_layers=2, - context=3): + sources: int=4, + audio_channels: int=2, + channels: int=64, + depth: int=6, + rewrite: bool=True, + glu: bool=True, + upsample: bool=False, + rescale: float=0.1, + kernel_size: int=8, + stride: int=4, + growth: float=2., + lstm_layers: int=2, + context: int=3) -> None: """ Args: sources (int): number of sources to separate @@ -194,7 +196,7 @@ def valid_length(self, length): return int(length) - def forward(self, mix): + def forward(self, mix: Tensor) -> Tensor: x = mix saved = [x] for encode in self.encoder: @@ -202,7 +204,7 @@ def forward(self, mix): saved.append(x) if self.upsample: x = downsample(x, self.stride) - if self.lstm: + if self.lstm is not None: x = self.lstm(x) for decode in self.decoder: if self.upsample: @@ -210,7 +212,7 @@ def forward(self, mix): skip = center_trim(saved.pop(-1), x) x = x + skip x = decode(x) - if self.final: + if self.final is not None: skip = center_trim(saved.pop(-1), x) x = th.cat([x, skip], dim=1) x = self.final(x) diff --git a/torchbenchmark/models/demucs/demucs/utils.py b/torchbenchmark/models/demucs/demucs/utils.py index 0c95636957..82050bfa49 100644 --- a/torchbenchmark/models/demucs/demucs/utils.py +++ b/torchbenchmark/models/demucs/demucs/utils.py @@ -16,19 +16,19 @@ import torch as th import tqdm -from torch import distributed +from torch import Tensor, distributed from torch.nn import functional as F +from typing import Callable, Any -def center_trim(tensor, reference): +def center_trim(tensor: Tensor, reference: Tensor) -> Tensor: """ Center trim `tensor` with respect to `reference`, along the last dimension. `reference` can also be a number, representing the length to trim to. If the size difference != 0 mod 2, the extra sample is removed on the right side. """ - if hasattr(reference, "size"): - reference = reference.size(-1) - delta = tensor.size(-1) - reference + reference_val: int = reference.size(-1) + delta = tensor.size(-1) - reference_val if delta < 0: raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") if delta: @@ -177,7 +177,7 @@ def save_model(model, path): th.save((klass, args, kwargs, state), save_to) -def capture_init(init): +def capture_init(init: Callable) -> Callable: @functools.wraps(init) def __init__(self, *args, **kwargs): self._init_args_kwargs = (args, kwargs)