Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions torchbenchmark/models/demucs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from .demucs.utils import capture_init, center_trim
from ...util.model import BenchmarkModel
from torchbenchmark.tasks import OTHER
from torch import Tensor
from torch.nn.modules.container import Sequential
from torchbenchmark.models.demucs.demucs.model import Demucs
from typing import Optional, Tuple


torch.manual_seed(1337)
Expand All @@ -20,12 +24,12 @@


class DemucsWrapper(torch.nn.Module):
def __init__(self, model, augment):
def __init__(self, model: Demucs, augment: Sequential) -> None:
super(DemucsWrapper, self).__init__()
self.model = model
self.augment = augment

def forward(self, streams):
def forward(self, streams: Tensor) -> Tuple[Tensor, Tensor]:
sources = streams[:, 1:]
sources = self.augment(sources)
mix = sources.sum(dim=1)
Expand All @@ -34,7 +38,7 @@ def forward(self, streams):

class Model(BenchmarkModel):
task = OTHER.OTHER_TASKS
def __init__(self, device=None, jit=False):
def __init__(self, device: Optional[str]=None, jit: bool=False) -> None:
super().__init__()
self.device = device
self.jit = jit
Expand Down Expand Up @@ -67,10 +71,13 @@ def __init__(self, device=None, jit=False):

self.model = DemucsWrapper(self.model, self.augment)

if self.jit:
self.model = torch.jit.script(self.model)

def _set_mode(self, train):
self.model.train(train)

def get_module(self):
def get_module(self) -> Tuple[DemucsWrapper, Tuple[Tensor]]:
self.model.eval()
return self.model, self.example_inputs

Expand All @@ -92,8 +99,9 @@ def train(self, niter=1):


if __name__ == '__main__':
m = Model(device='cuda', jit=False)
module, example_inputs = m.get_module()
module(*example_inputs)
m.train(niter=1)
m.eval(niter=1)
for jit in [True, False]:
m = Model(device='cuda', jit=jit)
module, example_inputs = m.get_module()
module(*example_inputs)
m.train(niter=1)
m.eval(niter=1)
9 changes: 6 additions & 3 deletions torchbenchmark/models/demucs/demucs/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, wav):
if not self.training:
wav = wav[..., :length]
else:
offsets = th.randint(self.shift, [batch, sources, 1, 1], device=wav.device)
offsets = th.randint(self.shift, [batch, sources, 1, 1], device=wav.device, dtype=th.int64)
offsets = offsets.expand(-1, -1, channels, -1)
indexes = th.arange(length, device=wav.device)
wav = wav.gather(3, indexes + offsets)
Expand All @@ -37,7 +37,7 @@ class FlipChannels(nn.Module):
def forward(self, wav):
batch, sources, channels, time = wav.size()
if self.training and wav.size(2) == 2:
left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
left = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.int64)
left = left.expand(-1, -1, -1, time)
right = 1 - left
wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
Expand Down Expand Up @@ -77,7 +77,10 @@ def forward(self, wav):
device = wav.device

if self.training:
group_size = self.group_size or batch
if self.group_size is not None:
group_size = self.group_size
else:
group_size = batch
if batch % group_size != 0:
raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
groups = batch // group_size
Expand Down
46 changes: 24 additions & 22 deletions torchbenchmark/models/demucs/demucs/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import math

import torch as th
from torch import nn
from torch import Tensor, nn

from .utils import capture_init, center_trim
from torch.nn.modules.conv import Conv1d, ConvTranspose1d
from typing import Union


class BLSTM(nn.Module):
def __init__(self, dim, layers=1):
def __init__(self, dim: int, layers: int=1) -> None:
super().__init__()
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.lstm.flatten_parameters()
Expand All @@ -27,21 +29,21 @@ def forward(self, x):
return x


def rescale_conv(conv, reference):
def rescale_conv(conv: Union[Conv1d, ConvTranspose1d], reference: float) -> None:
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale


def rescale_module(module, reference):
def rescale_module(module: th.nn.Module, reference: float) -> None:
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)


def upsample(x, stride):
def upsample(x, stride: int):
"""
Linear upsampling, the output will be `stride` times longer.
"""
Expand All @@ -52,7 +54,7 @@ def upsample(x, stride):
return out.reshape(batch, channels, -1)


def downsample(x, stride):
def downsample(x, stride: int):
"""
Downsample x by decimation.
"""
Expand All @@ -64,19 +66,19 @@ class Demucs(th.jit.ScriptModule):

@capture_init
def __init__(self,
sources=4,
audio_channels=2,
channels=64,
depth=6,
rewrite=True,
glu=True,
upsample=False,
rescale=0.1,
kernel_size=8,
stride=4,
growth=2.,
lstm_layers=2,
context=3):
sources: int=4,
audio_channels: int=2,
channels: int=64,
depth: int=6,
rewrite: bool=True,
glu: bool=True,
upsample: bool=False,
rescale: float=0.1,
kernel_size: int=8,
stride: int=4,
growth: float=2.,
lstm_layers: int=2,
context: int=3) -> None:
"""
Args:
sources (int): number of sources to separate
Expand Down Expand Up @@ -194,23 +196,23 @@ def valid_length(self, length):

return int(length)

def forward(self, mix):
def forward(self, mix: Tensor) -> Tensor:
x = mix
saved = [x]
for encode in self.encoder:
x = encode(x)
saved.append(x)
if self.upsample:
x = downsample(x, self.stride)
if self.lstm:
if self.lstm is not None:
x = self.lstm(x)
for decode in self.decoder:
if self.upsample:
x = upsample(x, stride=self.stride)
skip = center_trim(saved.pop(-1), x)
x = x + skip
x = decode(x)
if self.final:
if self.final is not None:
skip = center_trim(saved.pop(-1), x)
x = th.cat([x, skip], dim=1)
x = self.final(x)
Expand Down
12 changes: 6 additions & 6 deletions torchbenchmark/models/demucs/demucs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@

import torch as th
import tqdm
from torch import distributed
from torch import Tensor, distributed
from torch.nn import functional as F
from typing import Callable, Any


def center_trim(tensor, reference):
def center_trim(tensor: Tensor, reference: Tensor) -> Tensor:
"""
Center trim `tensor` with respect to `reference`, along the last dimension.
`reference` can also be a number, representing the length to trim to.
If the size difference != 0 mod 2, the extra sample is removed on the right side.
"""
if hasattr(reference, "size"):
reference = reference.size(-1)
delta = tensor.size(-1) - reference
reference_val: int = reference.size(-1)
delta = tensor.size(-1) - reference_val
if delta < 0:
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
if delta:
Expand Down Expand Up @@ -177,7 +177,7 @@ def save_model(model, path):
th.save((klass, args, kwargs, state), save_to)


def capture_init(init):
def capture_init(init: Callable) -> Callable:
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
Expand Down