## Install torchaudio
torchaudio not installed in colab

In [None]:
!pip install torchaudio==0.7.0

## Imports

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio.datasets.utils import download_url
from IPython import display

torchaudio.set_audio_backend("sox_io")

## Network

In [None]:
class FourierLayer(nn.Module):
    def __init__(self, in_features, out_features, scale):
        super().__init__()
        B = torch.randn(in_features, out_features)*scale
        self.register_buffer("B", B)
    
    def forward(self, x):
        x_proj = torch.matmul(2*math.pi*x, self.B)
        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
        return out

In [None]:
class SignalRegressor(nn.Module):
    def __init__(self, in_features, fourier_features,
                 hidden_features, hidden_layers, out_features, scale):
        super().__init__()

        self.net = []
        if fourier_features is not None:
            self.net.append(FourierLayer(in_features, fourier_features, scale))
            self.net.append(nn.Linear(2*fourier_features, hidden_features))
            self.net.append(nn.ReLU())
        else:
            self.net.append(nn.Linear(in_features, hidden_features))
            self.net.append(nn.ReLU())
        
        for i in range(hidden_layers-1):
            self.net.append(nn.Linear(hidden_features, hidden_features))
            self.net.append(nn.ReLU())

        self.net.append(nn.Linear(hidden_features, out_features))
        self.net.append(nn.Tanh())
        self.net = nn.Sequential(*self.net)

    def forward(self, x):
        out = self.net(x)
        return out

## Dataset

In [None]:
class AudioDataset(Dataset):
    def __init__(self, audio_path):
        super().__init__()
        self.audio_path = audio_path
        self.metadata = torchaudio.info(audio_path)

    def __getitem__(self, idx):
        frames, rate = torchaudio.load(self.audio_path, channels_first=False)
        times = torch.linspace(0, 1, steps=frames.shape[0])

        return times, frames
    
    def __len__(self):
        return 1

## Play Audio

In [None]:
web_url = "https://upload.wikimedia.org/wikipedia/commons/7/70/Emotional_piano.wav"
download_url(web_url, ".", "piano.wav")

In [None]:
audio_path = "piano.wav"
display.Audio(audio_path)

## Dataloader

In [None]:
audio_data = AudioDataset(audio_path)
audio_loader = DataLoader(audio_data, batch_size=1)

## Train and Evaluate

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
total_steps = 2000
summary_interval = 100

times, frames = next(iter(audio_loader))
times, frames = times.squeeze().to(device), frames.squeeze().to(device)
train_coords, train_values = times[::2].reshape(-1, 1), frames[::2]   # use every other frame for training
test_coords, test_values = times.reshape(-1, 1), frames    # use all the frames for evaluation

audio_regressor = SignalRegressor(in_features=1, fourier_features=256,
                                  hidden_features=256, hidden_layers=4,
                                  out_features=audio_data.metadata.num_channels,
                                  scale=5000).to(device)
optim = torch.optim.Adam(lr=1e-4, params=audio_regressor.parameters())

for step in range(1, total_steps+1):
    audio_regressor.train()
    optim.zero_grad()
    output = audio_regressor(train_coords)
    train_loss = F.mse_loss(output, train_values)
    train_loss.backward()
    optim.step()

    if not step % summary_interval:
        audio_regressor.eval()
        with torch.no_grad():
            prediction = audio_regressor(test_coords)
            test_loss = F.mse_loss(prediction, test_values)
            test_psnr = -10*torch.log10(test_loss)
            print(f"Step: {step}, Test PSNR: {test_psnr.item():.6f}")

## Temporal SuperResolution result
increase audio frame rate by 2x

In [None]:
super_path = "piano_super.wav"
torchaudio.save(super_path, src=prediction.cpu(),
                sample_rate=audio_data.metadata.sample_rate,
                channels_first=False)

display.Audio(super_path)