In [39]:
import torch
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import imageio as iio
import matplotlib.pyplot as plt
from pmqd.torch import PMQD
from typing import Tuple
import itertools
import torchaudio.transforms as T
from pyramids import LaplacianPyramid
import torch.nn.functional as F
import conv as conv_utils
# from torchvision import datasets
# from torchvision.transforms import ToTensor

Load audio data.

In [28]:
dataset = PMQD(root="/Users/up20938/Coding/datasets/pmqd", download=False)
sample_rate = PMQD.SAMPLE_RATE

Define Laplacian Pyramids and global variables.

In [60]:
LAPLACIAN_FILTER = np.array([[0.0025, 0.0125, 0.0200, 0.0125, 0.0025],
                             [0.0125, 0.0625, 0.1000, 0.0625, 0.0125],
                             [0.0200, 0.1000, 0.1600, 0.1000, 0.0200],
                             [0.0125, 0.0625, 0.1000, 0.0625, 0.0125],
                             [0.0025, 0.0125, 0.0200, 0.0125, 0.0025]],
                            dtype=np.float32)
filt = torch.from_numpy(np.reshape(np.tile(LAPLACIAN_FILTER, (1, 1, 1)), (1, 1, 5, 5)))

In [None]:
FILTER_SIZE = 5
NUM_LAYERS = 6

In [446]:
class LaplacianPyramid(nn.Module):
    def __init__(self, k, dims=3, filt_size=3, filt=None, trainable=False):
        super(LaplacianPyramid, self).__init__()
        if filt is None:
            filt = np.reshape(np.tile(LAPLACIAN_FILTER, (dims, 1, 1)),
                              (dims, 1, 5, 5))
        self.k = k
        self.trainable = trainable
        self.dims = dims
        self.filt_size = filt_size
        self.filt = nn.Parameter(torch.Tensor(filt), requires_grad=False)
        self.dn_filts, self.sigmas = self.DN_filters()

    """
    Spectrogram Filters
    """

    def DN_filters(self):
        sigmas = torch.tensor([1.19, 1.2, 1.22, 1.08, 0.92, 0.83]).reshape(6,1,1,1)
        dn_filts = torch.zeros(self.k,1,1,(self.filt_size**2)-1)
        dn_filts = nn.Parameter(dn_filts, requires_grad=self.trainable)
        return dn_filts, sigmas

    def pyramid(self, im):
        abso = []
        J = im
        pyr = []
        for i in range(0, self.k):
            J_padding_amount = conv_utils.pad([J.size(2), J.size(3)],
                                              self.filt.size(3), stride=2)
            I = F.conv2d(F.pad(J, J_padding_amount, mode='reflect'), self.filt,
                         stride=2, padding=0, groups=self.dims)  # downsample and convolve
            I_up = F.interpolate(I, size=[J.size(2), J.size(3)],
                                 align_corners=True, mode='bilinear')  # upsample
            I_padding_amount = conv_utils.pad([I_up.size(2), I_up.size(3)],
                                              self.filt.size(3), stride=1)
            I_up_conv = F.conv2d(F.pad(I_up, I_padding_amount, mode='reflect'),  # convolve
                                 self.filt, stride=1, padding=0,
                                 groups=self.dims)
            out = J - I_up_conv  # z_k
            absolute = torch.abs(out)
            abso.append(absolute)
            out_padding_amount = conv_utils.pad(
                [out.size(2), out.size(3)], self.filt_size, stride=1)
            Lw =  np.insert(self.dn_filts[i],len(self.dn_filts[i])//2,0).reshape(1,1,self.filt_size,self.filt_size)
            out_conv = F.conv2d(
                F.pad(torch.abs(out), out_padding_amount, mode='reflect'),
               Lw,
                stride=1,
                groups=self.dims)
            out_norm = out / (self.sigmas[i]+out_conv) # y_k
            # print(out_norm.min(), out_norm.max())
            pyr.append(out_norm)

            J = I
        return (abso, pyr)

    def compare(self, x1, x2):
        y1 = self.pyramid(x1)
        y2 = self.pyramid(x2)
        total = []
        # Calculate difference in perceptual space (Tensors are stored
        # strangley to avoid needing to pad tensors)
        for z1, z2 in zip(y1, y2):
            diff = (z1 - z2) ** 2
            sqrt = torch.sqrt(torch.mean(diff, (1, 2, 3)))
            total.append(sqrt)
        return torch.norm(torch.stack(total), 0.6)

    def forward(self, neighbourhood, sigma):
        return (self.dn_filts.expand(-1, -1, neighbourhood.shape[1], -1) * neighbourhood).sum(-1).unsqueeze(-1) + sigma

In [400]:
def get_abs_value(im, k):
    sigmas = np.zeros((k))
    J = im
    for i in range(0, k):
        J_padding_amount = conv_utils.pad([J.size(2), J.size(3)], filt.size(3), stride=2)
        I = F.conv2d(F.pad(J, J_padding_amount, mode='reflect'), filt, stride=2, padding=0, groups=1)  # downsample and convolve
        I_up = F.interpolate(I, size=[J.size(2), J.size(3)], align_corners=True, mode='bilinear')  # upsample
        I_padding_amount = conv_utils.pad([I_up.size(2), I_up.size(3)], filt.size(3), stride=1)
        I_up_conv = F.conv2d(F.pad(I_up, I_padding_amount, mode='reflect'), filt, stride=1, padding=0, groups=1)  # convolve
        out = J - I_up_conv  # z_k
        abs_out = np.abs(out)
        sigmas[i] = abs_out.mean()
        J = I
    return sigmas

Define spectrogram parameters

In [381]:
reduced_sample_rate = 16050
window_size = 2048
hop_size = 64
num_mels = 512

In [382]:
transforms = nn.Sequential(
    T.Resample(orig_freq=sample_rate, new_freq=reduced_sample_rate),
    T.MelSpectrogram(
        n_mels=num_mels,
        n_fft=window_size,
        win_length=window_size,
        hop_length=hop_size,
        power=1,
        center=False,
        sample_rate=reduced_sample_rate,
        f_min=0,
        f_max=reduced_sample_rate / 2,
        window_fn=torch.hann_window
    ),
)



Define torch dataset for getting divisive normalisation sigma values

In [383]:
class SpectrogramDatasetSigmas(Dataset):

    def __init__(self, dataset, transforms):
        self.data = dataset
        self.transforms = transforms

    def __len__(self):
        return 195 # 195 images per distortion type
        # return len(self.data)

    def __getitem__(self, idx):
        # convert from audio to spectrogram
        audio = self.data[idx]["audio"]
        mono = audio.mean(axis=0)
        image = self.transforms(mono).unsqueeze(0).unsqueeze(0)

        # extract sigmas
        sigmas = get_abs_value(image, 6)

        return sigmas

In [384]:
batch_size = 1

# Create data loaders.
sigma_data = SpectrogramDatasetSigmas(dataset,transforms=transforms)
sigma_dataloader = DataLoader(sigma_data, batch_size=batch_size)

sigmas = np.zeros((195,6))
for batch, s in enumerate(sigma_dataloader):
    sigmas[batch] = s
print(sigmas.mean(axis=0))

[1.19 1.2  1.22 1.08 0.92 0.83]


Define functions to load pixel values and their neighbourhoods from spectrograms.

In [464]:
def get_pixel_neighbourhoods(image, filt_size):

    # image size
    h = image.shape[0]
    w = image.shape[1]

    # filter size MUST BE ODD
    fh = filt_size
    fw = filt_size

    # distance from edge of filter to centre of filter
    zh = ((fh+1)//2)-1
    zw = ((fw+1)//2)-1

    # border for non-padded images
    y=h-(2*zh)
    x=w-(2*zw)

    # pixel indices for all possible filter positions
    yy,xx=np.meshgrid(np.arange(y),np.arange(x))
    p = ((yy+zh)*(w)) + zw + (xx)
    p = p.T.flatten().reshape(-1,1)

    # neighbourhood indices around all possible pixels
    nh = np.tile(np.arange(-zh,fh-zh).reshape(-1,1),fw)*h
    nw = np.tile(np.arange(-zw,fw-zw).reshape(-1,1),fh).T
    h2 = (nh+nw).flatten()
    d = ((fw*fh)-1)//2
    mask = np.hstack((np.arange(d),np.arange(d)+d+1))
    hh = np.array([h2[mask]])
    n = p+hh

    # get pixel values from indices
    pixels = image.flatten()[p]
    neighbourhoods=image.flatten()[n]

    return neighbourhoods.unsqueeze(0), pixels.unsqueeze(0)

In [474]:
class SpectrogramDataset(Dataset):

    def __init__(self, dataset, transforms):
        self.data = dataset
        self.transforms = transforms

    def __len__(self):
        return 195 # 195 images per distortion type
        # return len(self.data)

    def __getitem__(self, idx):
        # convert from audio to spectrogram
        audio = self.data[idx]["audio"]
        mono = audio.mean(axis=0)
        image = self.transforms(mono)

        # pass therough Laplacian Pyramid
        ndims = NUM_LAYERS
        pmd = LaplacianPyramid(ndims,filt_size=FILTER_SIZE,dims=1)
        abso, pyr = pmd.pyramid(image.unsqueeze(0).unsqueeze(0).float())


        n_list = []
        p_list = []
        s_list = []

        for i in range(ndims):
            image = abso[i].squeeze(0).squeeze(0)

            # extract pixels and neighbourhoods
            neighbourhoods, pixels = get_pixel_neighbourhoods(image,filt_size=FILTER_SIZE)
            sigma = pmd.sigmas[i]
            n_list.append(neighbourhoods)
            p_list.append(pixels)
            s_list.append(sigma)

        return n_list, p_list, s_list

In [476]:
batch_size = 1

# Create data loaders.
train_data = SpectrogramDataset(dataset,transforms=transforms)
train_dataloader = DataLoader(train_data, batch_size=batch_size)
# test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y, s in train_dataloader:
    print(f"Shape of X [N, C, H, W]: {X[0].shape}")
    print(f"Shape of y: {y[0].shape} {y[0].dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([1, 1, 518160, 24])
Shape of y: torch.Size([1, 1, 518160, 1]) torch.float32


Define torch training procedure

In [477]:
device='cpu'
print(f"Using {device} device")

# model = Filter(5,5,batch_size=batch_size).to(device)
model = LaplacianPyramid(k=NUM_LAYERS,filt_size=FILTER_SIZE,trainable=True)
print(model)

Using cpu device
LaplacianPyramid()


In [478]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [479]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # prev_filter = torch.zeros((batch_size,1,8),requires_grad=True)
    model.train()
    for batch, (X, y, s) in enumerate(dataloader):
        # X, y = X.to(device), y.to(device)
        # print(X,y)

        losses = torch.tensor(0).float()

        # Compute prediction error
        for i in range(len(X)):
            pred = model(X[i],s[i])

            # loss is the difference between the absolute value and the dot product of the absolute value and weights plus sigma
            loss = loss_fn(pred[i].unsqueeze(0), y[i]).sum()
            losses += loss

        loss = losses/len(X)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for param in model.parameters():
            param.data.clamp_(0)

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1)  # * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            for i in range(len(X)):
                print(f'filter {i}: \n{model.dn_filts[i].data}')


In [480]:
torch.set_printoptions(precision=4, threshold=None, edgeitems=None, linewidth=200, profile=None, sci_mode=False)
epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    # test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 7.830526  [    1/  195]
filter 0: 
tensor([[[0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
          0.0010]]])
filter 1: 
tensor([[[0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
          0.0010]]])
filter 2: 
tensor([[[0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
          0.0010]]])
filter 3: 
tensor([[[0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
          0.0010]]])
filter 4: 
tensor([[[0.001

In [485]:
torch.set_printoptions(precision = 8)

In [486]:
model.dn_filts.data

tensor([[[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 8.15232396e-02,
           3.54153365e-01, 3.54284495e-01, 8.20538402e-02, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
           0.00000000e+00, 0.00000000e+00]]],


        [[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
           4.17492151e-01, 4.18497950e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
           0.00000000e+00, 0.00000000e+00]]],


        [[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0

In [None]:
"""
dn_filts = torch.tensor([[[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               8.15232396e-02, 3.54153365e-01, 0.00000000e+00, 3.54284495e-01, 8.20538402e-02,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],

            [[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 4.17492151e-01, 0.00000000e+00, 4.18497950e-01, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],

            [[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 4.03308600e-01, 0.00000000e+00, 4.04937327e-01, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],

            [[[0.00000000e+00, 0.00000000e+00, 1.80272770e-03, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 3.42956990e-01, 0.00000000e+00, 3.44613850e-01, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],

            [[[0.00000000e+00, 0.00000000e+00, 7.62450695e-02, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               4.72654792e-05, 1.49145573e-01, 0.00000000e+00, 1.53620824e-01, 3.48444119e-05,
               0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 8.36522579e-02, 0.00000000e+00, 0.00000000e+00]]],

            [[[0.00000000e+00, 0.00000000e+00, 5.48025630e-02, 0.00000000e+00, 0.00000000e+00,
               0.00000000e+00, 8.64099056e-05, 1.37867581e-04, 0.00000000e+00, 6.32352458e-05,
               3.57608515e-04, 1.15041323e-02, 0.00000000e+00, 1.57447513e-02, 4.30504908e-04,
               5.56926061e-05, 5.07514887e-05, 1.43533718e-04, 6.86229760e-06, 0.00000000e+00,
               0.00000000e+00, 0.00000000e+00, 1.87745407e-01, 0.00000000e+00, 0.00000000e+00]]]])
"""