Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions extension/audio/mel_spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,23 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
r"""
Args:
waveform (`torch.Tensor`): Mono waveform input, tensor of (dynamic) shape [num_samples],
where num_samples < n_samples. n_samples is 480000 for 16kHz and chunk length 30

Returns:
torch.Tensor: Output of fixed shape [1, feature_size, nb_max_frames]
[1, 80, 3000] with default options
torch.Tensor: Output of shape [1, feature_size, nb_max_frames * n_chunks]
n_chunks is the number of chunks of `sampling_rate` samples in the input waveform.
[1, 80, 3000] with default options and 1 chunk
"""
# TODO: pad up to multiples of chunk_length (currently 1 chunk of 30 sec)
n_chunks = (waveform.shape[0] - 1) // self.n_samples + 1
waveform = F.pad(
waveform,
(0, self.n_samples - waveform.shape[0] - 1),
(0, self.n_samples * n_chunks - waveform.shape[0]),
mode="constant",
value=self.padding_value,
)
# Ideally we should do:
# window = torch.hann_window(self.n_fft)
# but this is not currently supported when lowering.
# torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
window = 0.5 * (
1
- torch.cos(
Expand All @@ -145,10 +149,6 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
/ self.n_fft
)
)
# Ideally we should do instead
# window = torch.hann_window(self.n_fft)
# but this is not currently supported when lowering
# torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
stft = torch.stft(
waveform,
n_fft=self.n_fft,
Expand All @@ -157,7 +157,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
center=True,
return_complex=True,
)
magnitudes = torch.abs(stft) ** 2 # pyre-ignore[58]
magnitudes = torch.abs(stft)[..., :-1] ** 2 # pyre-ignore[58]

mel_spec = self.mel_filters @ magnitudes

Expand All @@ -173,8 +173,7 @@ def export_processor():
audio_tensor = torch.randn(480000)
chunk_tensor = audio_tensor[:93680]
with torch.no_grad():
# export. What is the min of waveforms?
dim = Dim("waveform", min=1600, max=audio_tensor.size(0))
dim = Dim("waveform", min=1600, max=audio_tensor.size(0) * 10) # 10 chunks max
ep: ExportedProgram = export(
model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True
)
Expand Down
Loading