Skip to content

Commit

Permalink
Merge branch 'main' into pytorch-travis
Browse files Browse the repository at this point in the history
  • Loading branch information
turian committed Sep 11, 2022
2 parents 79cf952 + 42229fa commit 6ee96f8
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 14 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
repos:
- repo: https://github.com/kynan/nbstripout
rev: 0.3.9
rev: 0.6.0
hooks:
- id: nbstripout
- repo: https://github.com/mwouts/jupytext
rev: v1.11.2
rev: v1.14.1
hooks:
- id: jupytext
args: [--sync, --pipe, black]
additional_dependencies:
- black==21.5b0 # Matches hook
- black==22.6.0 # Matches hook
- repo: https://github.com/psf/black
rev: 21.5b0
rev: 22.6.0
hooks:
- id: black
language_version: python3
2 changes: 1 addition & 1 deletion examples/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def display(*args, **kwargs):
ADSR,
VCA,
ControlRateUpsample,
FmVCO,
MonophonicKeyboard,
Noise,
SineVCO,
FmVCO,
)
from torchsynth.parameter import ModuleParameterRange

Expand Down
7 changes: 3 additions & 4 deletions examples/simplesynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@
# +
from typing import Optional

import torch
import IPython.display as ipd
import torch

from torchsynth.config import BASE_REPRODUCIBLE_BATCH_SIZE, SynthConfig
from torchsynth.module import (
ADSR,
VCA,
ControlRateUpsample,
MonophonicKeyboard,
SquareSawVCO,
VCA,
)
from torchsynth.synth import AbstractSynth
from torchsynth.config import SynthConfig, BASE_REPRODUCIBLE_BATCH_SIZE


# -

Expand Down
21 changes: 21 additions & 0 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Tests for torch signals
"""

from copy import deepcopy

import torch

from torchsynth.signal import Signal


class TestSignal:
"""
Tests for Signal
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

def test_deepcopy(self):
signal = torch.zeros(1, 1).as_subclass(Signal)
print(deepcopy(signal))
3 changes: 2 additions & 1 deletion tests/test_synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Tests for torch synths
"""

import os
import json
import os

import pytest
import torch.nn
from torch import tensor
Expand Down
4 changes: 2 additions & 2 deletions torchsynth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchsynth.__info__ import ( # noqa: F401
from torchsynth.__info__ import __author_email__ # noqa: F401
from torchsynth.__info__ import (
__author__,
__author_email__,
__copyright__,
__docs__,
__homepage__,
Expand Down
1 change: 0 additions & 1 deletion torchsynth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch


#: This batch size is a nice trade-off between speed and memory consumption. On
#: a typical GPU this consumes ~2.3GB of memory for the default Voice.
#: Learn more about `batch processing <../performance/batch-processing.html>`_.
Expand Down
2 changes: 1 addition & 1 deletion torchsynth/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor
from torch import Tensor as T
from torch import tensor

import torchsynth.util as util
from torchsynth.config import BASE_REPRODUCIBLE_BATCH_SIZE, SynthConfig
Expand Down
12 changes: 12 additions & 0 deletions torchsynth/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,15 @@ def batch_size(self):
def num_samples(self):
assert self.ndim == 2
return self.shape[1]

def new_empty(self, *args, **kwargs):
# noqa: E501
"""
Implement
[torch.Tensor.new_empty](https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html)
so that
[`deepcopy`](https://docs.python.org/3/library/copy.html#copy.deepcopy)
can be run on Signal objects.
"""

return super().new_empty(*args, **kwargs).as_subclass(self.__class__)

0 comments on commit 6ee96f8

Please sign in to comment.