-
Notifications
You must be signed in to change notification settings - Fork 64
Implement fft torchop #2141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement fft torchop #2141
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2141 +/- ##
==========================================
+ Coverage 74.25% 74.27% +0.02%
==========================================
Files 226 226
Lines 29442 29454 +12
Branches 3424 3432 +8
==========================================
+ Hits 21862 21878 +16
+ Misses 6424 6419 -5
- Partials 1156 1157 +1 ☔ View full report in Codecov by Sentry. |
The error looks like a mismatch of complex (torch) vs real representation (ONNX). Maybe explore
Or using
|
output doesn't need to be converted. We always use the real repr for complex values in onnx. The torch exporter will keep track of this information. |
For the errors |
@bmehta001 not sure why the cla bot is complaining even though you have join the org. Can you comment following the second (company) option? |
Could you test (pytorch/pytorch#119360) import torch
from torch import nn
class iRFFTModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.fft.irfft(torch.view_as_complex(x), n=960)
example_tensor = torch.rand(481, 2)
irfft_model = iRFFTModel()
irfft_model_onnx_path = "irfft_model.onnx"
out_torch = irfft_model(example_tensor)
print("torch shape:", out_torch.shape)
torch.onnx.export(
irfft_model,
(example_tensor,),
dynamo=True,
verify=True
) import torch
from torch import nn
import torchaudio
class DataCov(nn.Module):
def __init__(self):
super(DataCov, self).__init__()
self.transform = nn.Sequential(
torchaudio.transforms.MelSpectrogram(
sample_rate=48000, n_fft=1536, hop_length=768, f_min=20, f_max=20000
)
)
def forward(self, x1):
return self.transform(x1)
def load_data_cov():
module = DataCov().to(torch.float32).to("cpu")
module.eval()
return module
data_cov = load_data_cov()
x = torch.randn((1, 1, 12 * 48000), dtype=torch.float32, device="cpu")
y = data_cov(x)
input_names = ["x"]
output_names = ["output"]
torch.onnx.export(
data_cov,
(x,),
dynamo=True,
verify=True,
) import torch
import torch.nn as nn
def fftconv(u, k, D):
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
out = y + u * D.unsqueeze(-1)
return out.to(dtype=u.dtype)
class Filter(nn.Module):
def forward(self, x, k=None, bias=None):
y = fftconv(x, k, bias)
return y
filter = Filter().eval()
x_input, k_input, bias_input = torch.rand(1, 512, 1024), torch.rand(512, 1024), torch.rand(512)
export_output = torch.onnx.export(filter, (x_input, k_input, bias_input), dynamo=True, verify=True) |
@microsoft-github-policy-service agree company="Microsoft" |
I am going to take the liberty to merge this PR so that the main improvements can be made available. Please feel free to create follow ups when you get a chance to look at the pytorch examples. Thanks for the great work! |
(1, 2), | ||
(0, 1), | ||
(0, 1, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: this should be added back in a follow up
WIP
r2c = forwards, could be one-sided
c2r = backwards/inverse, never one-sided
c2c could be either forwards/backwards, never one-sided
Must respect normalization method provided - however, op.DFT calls "backwards" normalization, if 'inverse' is set to True, so need to account for normalization being done by op.DFT
When running above functions across multiple axes, need to run FFT in reverse order through op.DFT one-by-one
Currently have issues with:
#1271