In [2]:
cd ..

/home/bvandelft/Projects/Audio/clearaudio


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchaudio
from clearaudio.utils.wavenet_utils import (
    overlap_and_add_samples,
    cut_track_stack,
    linear_pcm,
    generate_waveform
)

In [63]:
def cut_track_stack(audio_input_tensor, window_length=8192*2**4, overlap=0.5):
    """
    Cuts a given track in overlapping windows and stacks them along a new axis
    :param track_path: path to .wav track to apply the function on
    :param window_length: number of samples per window (scalar int)
    :param overlap: ratio of overlapping samples for consecutive samples (scalar int in [0, 1))
    :return: processed track as a numpy array with dimension [window_number, 1, window_length], sampling frequency
    """
    track = audio_input_tensor
    # Get number of windows and prepare empty array
    window_number = compute_window_number(track_length=track.size(-1), window_length=window_length,overlap=overlap)
    print(window_number)
    bsz = track.size(0)
    #cut_track = np.zeros((window_number, window_length))
    cut_track = torch.zeros((bsz,window_number, window_length))
    # Cut the tracks in smaller windows
    for music_index in range(bsz):
        for i in range(window_number):
            window_start = int(i * (1 - overlap) * window_length)
            window = track[:,window_start: window_start + window_length]

            # Check if last window needs padding
            if window.size(-1) != window_length:
                padding = window_length - window.size(1)
                window = torch.cat([window, torch.zeros([bsz,padding])], axis = 1)
            cut_track[music_index,i] = window
    return cut_track


def compute_window_number(track_length: int, window_length: int = 8192*2**5, overlap: float = 0.5):
    """
    Computes the number of overlapping window for a specific track.
    :param track_length: total number of samples in the track (scalar int).
    :param window_length: number of samples per window (scalar int).
    :param overlap: ratio of overlapping samples for consecutive samples (scalar int in [0, 1))
    :return: number of windows in the track
    """
    num = track_length - window_length
    den = window_length * (1 - overlap)
    return int(num // den + 2)


# def overlap_and_add_samples(clean_tracks, equalized_tracks, generated_tracks, overlap, window_length, use_windowing=True):
#     """
#     Re-construct a full sample from its sub-parts using the OLA algorithm.
#     :param batch: input signal previously split in overlapping windows torch tensor of shape [B, 1, WINDOW_LENGTH].
#     :return: reconstructed sample (torch tensor).
#     """
#     assert(clean_tracks.size() == generated_tracks.size())
#     # Compute the size of the full sample

#     bsz, window_number, single_sample_size = equalized_tracks.size()
#     full_sample_size = int(single_sample_size * (1 + (window_number - 1) * (1 - overlap)))

#     # Initialize the full sample
    
#     clean_full = torch.zeros((bsz,full_sample_size))
#     equalized_full = torch.zeros((bsz,full_sample_size))
#     generated_full = torch.zeros((bsz,full_sample_size))
#     # print(clean_full.size())
#     if use_windowing:
#         hanning = torch.from_numpy(np.hanning(window_length))

#     for batch in range(bsz):
#         for window_index in range(window_number):
#             window_start = int(window_index * (1 - overlap) * window_length)
#             window_end = window_start + window_length
           
#             clean_sample= clean_tracks[batch,window_index].squeeze()
#             equalized_sample = equalized_tracks[batch,window_index].squeeze()
#             generated_sample = generated_tracks[batch,window_index].squeeze()

#             if use_windowing:               
#                 generated_sample *= hanning
#                 clean_sample *= hanning
#                 equalized_sample *= hanning
#             clean_full[batch,window_start: window_end] += clean_sample
#             equalized_full[batch,window_start: window_end] += equalized_sample
#             generated_full[batch,window_start: window_end] += generated_sample
            
#         return clean_full, equalized_full,generated_full


def overlap_and_add_samples(samples_tensor: torch.Tensor, overlap: float, window_length: int, use_windowing: bool = True) -> torch.Tensor:
    """
    Re-construct a full sample from its sub-parts using the OLA algorithm.
    :param batch: input signal previously split in overlapping windows torch tensor of shape [B, 1, WINDOW_LENGTH].
    :return: reconstructed sample (torch tensor).
    """
    # Compute the size of the full sample
    bsz, window_number, single_sample_size = samples_tensor.size()
    full_sample_size = int(single_sample_size * (1 + (window_number - 1) * (1 - overlap)))

    # Initialize the full sample
    full = torch.zeros((bsz, full_sample_size))

    if use_windowing:
        hanning = torch.from_numpy(np.hanning(window_length))

    for batch in range(bsz):
        for window_index in range(window_number):
            window_start = int(window_index * (1 - overlap) * window_length)
            window_end = window_start + window_length
           
            sample = samples_tensor[batch, window_index].squeeze()
            if use_windowing:               
                sample *= hanning

            full[batch,window_start: window_end] += sample
        return full


In [37]:
sample

tensor([[ 0.0000,  0.0000,  0.0000,  ..., -0.2199, -0.1982, -0.2037]])

In [39]:
sample, sr = torchaudio.load('samples/hq/starwars.wav')

In [68]:
cts = cut_track_stack(sample[:, :2**8+3], window_length=2**2)

129


In [69]:
overlap_and_add_samples(cts, window_length=2**2, overlap=0.5).size()

torch.Size([1, 260])

In [None]:
number_window = int

In [24]:
2**5

32

In [23]:
2**8


256

In [35]:
cts.size()

torch.Size([1, 7, 4])

In [36]:
overlap_and_add_samples(cts,window_length=2**2, overlap=0.5).size()

torch.Size([1, 16])

In [20]:
sample[:, :2**8]

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [2]:
class ResConv1DBlock(nn.Module):
    def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0):
        super().__init__()
        padding = dilation
        self.model = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(n_in, n_state, 3, 1, padding='same', dilation=dilation),
            nn.ReLU(),
            nn.Conv1d(n_state, n_in, 1, 1, 0),
        )
        if zero_out:
            out = self.model[-1]
            nn.init.zeros_(out.weight)
            nn.init.zeros_(out.bias)
        self.res_scale = res_scale

    def forward(self, x):
        y = x + self.res_scale * self.model(x)       
        return y

class Resnet1D(nn.Module):
    def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False):
        super().__init__()
        def _get_depth(depth):
            if dilation_cycle is None:
                return depth
            else:
                return depth % dilation_cycle
        blocks = [ResConv1DBlock(n_in, int(m_conv * n_in),
                                 dilation=dilation_growth_rate ** _get_depth(depth),
                                 zero_out=zero_out,
                                 res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth))
                  for depth in range(n_depth)]

        self.model = nn.Sequential(*blocks)

    def forward(self, x):
        y = self.model(x)
        print("ResNet : ", y.size())
        return y #self.model(x)

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.block = nn.Sequential(
            GroupNorm(in_channels),
            Swish(),
            nn.Conv1d(in_channels, out_channels, 3, 1, 1),
            GroupNorm(out_channels),
            Swish(),
            nn.Conv1d(out_channels, out_channels, 3, 1, 1)
        )
        if in_channels != out_channels:
            self.channel_up = nn.Conv1d(in_channels, out_channels, 1, 1, 0)

    def forward(self, x):
        if self.in_channels != self.out_channels:
            return self.block(x) + self.channel_up(x)
        else:
            return x + self.block(x)


class UpSampleBlock(nn.Module):
    def __init__(self, channels_in, channels_out, kernel_size = 3, stride = 2, pad = 1, mode="pixel_shuffle"):
        super(UpSampleBlock, self).__init__()
        self.scale_factor = stride
        self.mode = mode
        if self.mode == 'interpolate':
            self.conv = nn.Conv1d(channels_in, channels_out, kernel_size, padding = 'same')
        elif self.mode == 'transpose_conv':
            self.conv = nn.ConvTranspose1d(channels_in, channels_out, kernel_size, stride, pad)
        elif self.mode == "pixel_shuffle":
            self.conv = nn.Conv1d(channels_in, channels_in*self.scale_factor, kernel_size, padding = 'same')
            self.shuffle = nn.PixelShuffle(self.scale_factor)
        else:
            raise NotImplementedError()

    def forward(self, x):
        print(self.mode)
        if self.mode == 'interpolate':
            x = F.interpolate(x, scale_factor=self.scale_factor)
            return self.conv(x)
        elif self.mode == "pixel_shuffle":
            zeros_size = x.size(1)*self.scale_factor**2-x.size(1)*self.scale_factor ### Fill the channels with enough 0 to make a 2-D pixel shuffling
            zeros = torch.zeros(x.size(0), zeros_size, x.size(2)).detach()
            x_prime = torch.cat((self.conv(x), zeros), dim=1).unsqueeze(-1)
            y = self.shuffle(x_prime)[:,:,::self.scale_factor,:].contiguous().view(1,x.size(1),-1)
            return y
        elif self.mode == "transpose_conv":
            return self.conv(x)
        else:
            raise NotImplementedError()


class DownSampleBlock(nn.Module):
    def __init__(self, channels, kernel_size = 3, stride_t = 2, pad_t = 0):
        super(DownSampleBlock, self).__init__()
        self.conv = nn.Conv1d(channels, channels, kernel_size, stride_t, pad_t)        
    def forward(self, x):
#         pad = (0, 1, 0, 1)
#         x = F.pad(x, pad, mode="constant", value=0)
        print(x.size())
        return self.conv(x)


class NonLocalBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = GroupNorm(in_channels)
        self.q = torch.nn.Conv1d(in_channels, in_channels, 1, 1, 0)
        self.k = torch.nn.Conv1d(in_channels, in_channels, 1, 1, 0)
        self.v = torch.nn.Conv1d(in_channels, in_channels, 1, 1, 0)
        self.proj_out = torch.nn.Conv1d(in_channels, in_channels, 1, 1, 0)

    def forward(self, x):
        h_ = self.norm(x)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, h, w = q.shape

        q = q.reshape(b, c, h * w)
        q = q.permute(0, 2, 1)
        k = k.reshape(b, c, h * w)
        v = v.reshape(b, c, h * w)

        attn = torch.bmm(q, k)
        attn = attn * (int(c) ** (-0.5))
        attn = F.softmax(attn, dim=2)

        attn = attn.permute(0, 2, 1)
        A = torch.bmm(v, attn)
        A = A.reshape(b, c, h, w)

        A = self.proj_out(A)

        return x + A


class GroupNorm(nn.Module):
    def __init__(self, in_channels):
        super(GroupNorm, self).__init__()
        self.gn = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

    def forward(self, x):
        return self.gn(x)


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.in_channels = 1
        self.width = 128
        self.output_channels = 64
        self.depth = 2
        self.m_conv = 1.0
        self.dilation_growth_rate = 3
        self.dilation_cycle = None
        self.zero_out = False
        self.res_scale = False
        blocks = []
        stride_t = 2
        blocks.append(nn.Conv1d(self.in_channels, self.width, 31, 1, padding='same'))
        for i in range(3):
            kernel_size, pad_t = stride_t * 2, stride_t // 2
            block = nn.Sequential(
                    DownSampleBlock(self.width, kernel_size, stride_t, pad_t),
                    Resnet1D(self.width, self.depth, self.m_conv, self.dilation_growth_rate, self.dilation_cycle, self.zero_out, self.res_scale),
                )
            blocks.append(block)
        block = nn.Conv1d(self.width, self.output_channels, 31, 1, padding='same')
        blocks.append(block)

        self.model = nn.Sequential(*blocks)

    def forward(self, x):
        return self.model(x)


class Decoder(nn.Module):
    def __init__(self, kernel_size = None, output_classes = None):
        super(Decoder, self).__init__()
        self.first_kernel_size = 32
        self.in_channels = 64
        self.width = 128
        self.stride_t = 2
        self.depth = 2
        self.m_conv = 1.0
        self.dilation_growth_rate = 3
        self.dilation_cycle = None
        self.zero_out = False
        self.res_scale = False

        blocks = []
        if kernel_size is None:
            self.kernel_size, pad_t = self.stride_t * 2, self.stride_t // 2
        else:
            self.kernel_size = kernel_size
        
        block = nn.Conv1d(self.in_channels, self.width, self.first_kernel_size, 1, padding='same')
        blocks.append(block)
        for i in range(3):
            block = nn.Sequential(
                    Resnet1D(self.width, self.depth, self.m_conv, self.dilation_growth_rate, self.dilation_cycle, zero_out=self.zero_out, res_scale=self.res_scale),
                    UpSampleBlock(self.width, self.width, self.kernel_size, self.stride_t, pad_t)
                )
            blocks.append(block)
        
        blocks.append(GroupNorm(self.width))
        blocks.append(Swish())
        if output_classes is None:
            self.output_classes = 1
        else:
            self.output_classes = output_classes
        blocks.append(nn.Conv1d(self.width, self.output_classes, kernel_size=64, stride=1, padding=1))
        self.model = nn.Sequential(*blocks)

    def forward(self, x):
        return self.model(x)

In [4]:
class BottleneckBlock(nn.Module):
    def __init__(self, k_bins, emb_width, mu):
        super().__init__()
        self.k_bins = k_bins
        self.emb_width = emb_width
        self.mu = mu
        self.reset_k()
        self.threshold = 1.0

    def reset_k(self):
        self.init = False
        self.k_sum = None
        self.k_elem = None
        self.register_buffer('k', torch.zeros(self.k_bins, self.emb_width))

    def _tile(self, x):
        d, ew = x.shape
        if d < self.k_bins:
            n_repeats = (self.k_bins + d - 1) // d
            std = 0.01 / np.sqrt(ew)
            x = x.repeat(n_repeats, 1)
            x = x + torch.randn_like(x) * std
        return x

    def init_k(self, x):
        mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
        self.init = True
        # init k_w using random vectors from x
        y = self._tile(x)
        _k_rand = y[torch.randperm(y.shape[0])][:k_bins]
#         dist.broadcast(_k_rand, 0)
        self.k = _k_rand
        assert self.k.shape == (k_bins, emb_width)
        self.k_sum = self.k
        self.k_elem = torch.ones(k_bins, device=self.k.device)

    def restore_k(self, num_tokens=None, threshold=1.0):
        mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
        self.init = True
        assert self.k.shape == (k_bins, emb_width)
        self.k_sum = self.k.clone()
        self.k_elem = torch.ones(k_bins, device=self.k.device)
        if num_tokens is not None:
            expected_usage = num_tokens / k_bins
            self.k_elem.data.mul_(expected_usage)
            self.k_sum.data.mul_(expected_usage)
        self.threshold = threshold

    def update_k(self, x, x_l):
        mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
        with torch.no_grad():
            # Calculate new centres
            x_l_onehot = torch.zeros(k_bins, x.shape[0], device=x.device)  # k_bins, N * L
            x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1)
            print(f"xl : {x_l}, {x_l.size()}")
            print(f"x_l_onehot : {x_l_onehot},{x_l_onehot.sum(0)} {x_l_onehot.size()}")
            _k_sum = torch.matmul(x_l_onehot, x)  # k_bins, w
            _k_elem = x_l_onehot.sum(dim=-1)  # k_bins
            y = self._tile(x)
            _k_rand = y[torch.randperm(y.shape[0])][:k_bins]

#             dist.broadcast(_k_rand, 0)
#             dist.all_reduce(_k_sum)
#             dist.all_reduce(_k_elem)

            # Update centres
            old_k = self.k
            self.k_sum = mu * self.k_sum + (1. - mu) * _k_sum  # w, k_bins
            print(f"k_sum : {self.k_sum}, {self.k_sum.size()}")
            self.k_elem = mu * self.k_elem + (1. - mu) * _k_elem  # k_bins
            print(f"k_elem : {self.k_elem}, {self.k_elem.size()}")
            usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float()
            print(f"usage 1st : {usage}, {usage.size()}")
            self.k = usage * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) \
                     + (1 - usage) * _k_rand
            _k_prob = _k_elem / torch.sum(_k_elem)  # x_l_onehot.mean(dim=-1)  # prob of each bin
            entropy = -torch.sum(_k_prob * torch.log(_k_prob + 1e-8))  # entropy ie how diverse
            used_curr = (_k_elem >= self.threshold).sum()
            usage = torch.sum(usage)
            print(f"usage 2nd : {usage}, {usage.size()}")
            dk = torch.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape))
        return dict(entropy=entropy,
                    used_curr=used_curr,
                    usage=usage,
                    dk=dk)

    def preprocess(self, x):
        # NCT -> NTC -> [NT, C]
        x = x.permute(0, 2, 1).contiguous()
        x = x.view(-1, x.shape[-1])  # x_en = (N * L, w), k_j = (w, k_bins)

        if x.shape[-1] == self.emb_width:
            prenorm = torch.norm(x - torch.mean(x)) / np.sqrt(np.prod(x.shape))
#         elif x.shape[-1] == 2 * self.emb_width:
#             x1, x2 = x[...,:self.emb_width], x[...,self.emb_width:]
#             prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape)))

#             # Normalise
#             x = x1 + x2
        else:
            assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}"
        return x, prenorm

    def postprocess(self, x_l, x_d, x_shape):
        # [NT, C] -> NTC -> NCT
        N, T = x_shape
        x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
        x_l = x_l.view(N, T)
        return x_l, x_d

    def quantise(self, x):
        # Calculate latent code x_l
        k_w = self.k.t()
        print("x",x.size())
        print("k_w",k_w.size())
        print("x**2", torch.matmul(x, k_w).size())
        distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
                                                                                            keepdim=True)  # (N * L, b)
        print("Distance", distance, distance.size())
        min_distance, x_l = torch.min(distance, dim=-1)
        fit = torch.mean(min_distance)
        return x_l, fit

    def dequantise(self, x_l):
        x = F.embedding(x_l, self.k)
        return x

    def encode(self, x):
        N, width, T = x.shape

        # Preprocess.
        x, prenorm = self.preprocess(x)

        # Quantise
        x_l, fit = self.quantise(x)

        # Postprocess.
        x_l = x_l.view(N, T)
        return x_l

    def decode(self, x_l):
        N, T = x_l.shape
        width = self.emb_width

        # Dequantise
        x_d = self.dequantise(x_l)

        # Postprocess
        x_d = x_d.view(N, T, width).permute(0, 2, 1).contiguous()
        return x_d

    def forward(self, x, update_k=True):
        N, width, T = x.shape
        print("begin",x.shape)
        # Preprocess
        x, prenorm = self.preprocess(x)
        print("preprocessed", x.shape)
        # Init k if not inited
        if update_k and not self.init:
            self.init_k(x)

        # Quantise and dequantise through bottleneck
        x_l, fit = self.quantise(x)
        print("fit", fit)
        print('x_l', x_l, x_l.size())
        x_d = self.dequantise(x_l)
        print('x_d', x_d, x_d.size())
        # Update embeddings
        if update_k:
            update_metrics = self.update_k(x, x_l)
        else:
            update_metrics = {}

        # Loss
        commit_loss = torch.norm(x_d.detach() - x) ** 2 / np.prod(x.shape)

        # Passthrough
        x_d = x + (x_d - x).detach()

        # Postprocess
        x_l, x_d = self.postprocess(x_l, x_d, (N,T))
        print('x_l', x_l, x_l.size())
        print('x_d', x_d, x_d.size())
        return x_l, x_d, commit_loss, dict(fit=fit,
                                           pn=prenorm,
                                           **update_metrics)

In [5]:
class Bottleneck(nn.Module):
    def __init__(self, l_bins = 1024, emb_width=64, mu=0.99):
        super().__init__()
        self.l_bins = l_bins #cfg.trainer.codebook.nb_bins
        self.emb_channels = emb_width #cfg.trainer.vqvae.emb_channels
        self.mu = mu
        self.model = BottleneckBlock(self.l_bins, self.emb_channels, self.mu)
    
    def encode(self, xs):
        zs = self.model(xs)
        return zs

    def decode(self, zs):
        xs_quantised = self.model.decode(zs)
        return xs_quantised

    def forward(self, xs):
        zs, x_quantised, commit_loss, metric = self.model(xs, update_k=self.training)
        return zs, x_quantised, commit_loss, metric


In [6]:
encoder = Encoder()

In [7]:
decoder = Decoder()

In [8]:
codebook = Bottleneck()

In [2]:

xs = encoder(sample[0][:,:1024].unsqueeze(0))

NameError: name 'encoder' is not defined

In [14]:
metric

{'fit': tensor(5.9056e-05, grad_fn=<MeanBackward0>),
 'pn': tensor(0.0528, grad_fn=<DivBackward0>),
 'entropy': tensor(1.6020),
 'used_curr': tensor(37),
 'usage': tensor(37.),
 'dk': tensor(0.0171)}

In [10]:
zs, x_quantised, commit_loss, metric = codebook(xs)

begin torch.Size([1, 64, 128])
preprocessed torch.Size([128, 64])
x torch.Size([128, 64])
k_w torch.Size([64, 1024])
x**2 torch.Size([128, 1024])
Distance tensor([[7.9988e-02, 8.0459e-02, 7.8982e-02,  ..., 7.8980e-02, 8.0792e-02,
         2.2465e-01],
        [7.4519e-02, 7.5183e-02, 7.3220e-02,  ..., 7.3568e-02, 7.5496e-02,
         2.1687e-01],
        [6.3151e-02, 6.3156e-02, 6.1709e-02,  ..., 6.2160e-02, 6.3529e-02,
         2.1286e-01],
        ...,
        [8.7827e-02, 8.9004e-02, 8.6687e-02,  ..., 8.8687e-02, 8.8546e-02,
         9.4470e-03],
        [9.2914e-02, 9.4100e-02, 9.1937e-02,  ..., 9.3783e-02, 9.3338e-02,
         6.3155e-03],
        [1.1348e-01, 1.1484e-01, 1.1233e-01,  ..., 1.1458e-01, 1.1405e-01,
         1.3255e-04]], grad_fn=<AddBackward0>) torch.Size([128, 1024])
fit tensor(5.9056e-05, grad_fn=<MeanBackward0>)
x_l tensor([ 892,  717,  575,  619,  246,  517,  431,  554,   98,  715,  391,  666,
         743,   88,  660,  923,    1,  953,  123,  123,  123,  123,  

begin torch.Size([1, 64, 128])
preprocessed torch.Size([128, 64])
x torch.Size([128, 64])
k_w torch.Size([64, 1024])
x**2 torch.Size([128, 1024])
Distance tensor([[0.0983, 0.0974, 0.0962,  ..., 0.0625, 0.0993, 0.0633],
        [0.0868, 0.0862, 0.0851,  ..., 0.0494, 0.0880, 0.0542],
        [0.0838, 0.0835, 0.0824,  ..., 0.0412, 0.0852, 0.0511],
        ...,
        [0.1011, 0.1011, 0.1015,  ..., 0.1361, 0.1014, 0.1195],
        [0.1163, 0.1162, 0.1163,  ..., 0.1507, 0.1165, 0.1331],
        [0.1328, 0.1324, 0.1326,  ..., 0.1637, 0.1330, 0.1425]],
       grad_fn=<AddBackward0>) torch.Size([128, 1024])
fit tensor(5.6990e-05, grad_fn=<MeanBackward0>)
x_l tensor([ 949,  655,  174,  731,  945,  405,  351,  660,  907, 1021,  524,  409,
         441,   22,  821,  733,  785,  594,  563,  563,  563,  563,  563,  563,
         563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
         563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
         563,  563, 

(tensor([[ 949,  655,  174,  731,  945,  405,  351,  660,  907, 1021,  524,  409,
           441,   22,  821,  733,  785,  594,  563,  563,  563,  563,  563,  563,
           563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
           563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
           563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
           563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
           563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
           563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
           563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,  563,
           563,  563,  941,  782,  288,   75,  452,  688,  825,  491,  569,  948,
           116,  391,  858,  150,  685,  938,  743, 1010]]),
 tensor([[[-0.0453, -0.0447, -0.0489,  ...,  0.0710,  0.0568,  0.0404],
          [-0.0150,  0.0003,  0.0227,  ...,  0.

In [56]:
decoder(xs)

ResNet :  torch.Size([1, 128, 128])
pixel_shuffle
torch.Size([1, 256, 128])
torch.Size([1, 512, 128, 1])
torch.Size([1, 128, 256])
ResNet :  torch.Size([1, 128, 256])
pixel_shuffle
torch.Size([1, 256, 256])
torch.Size([1, 512, 256, 1])
torch.Size([1, 128, 512])
ResNet :  torch.Size([1, 128, 512])
pixel_shuffle
torch.Size([1, 256, 512])
torch.Size([1, 512, 512, 1])
torch.Size([1, 128, 1024])


tensor([[[-0.1304, -0.0941, -0.0625, -0.4133, -0.0776, -0.2118,  0.0662,
          -0.3879, -0.0948, -0.1593, -0.0539, -0.4336, -0.0723, -0.2110,
           0.0654, -0.3920, -0.0945, -0.1633, -0.0449, -0.4312, -0.0704,
          -0.2152,  0.0656, -0.3896, -0.0881, -0.1687, -0.0452, -0.4329,
          -0.0705, -0.2144,  0.0677, -0.3904, -0.0889, -0.1738, -0.0457,
          -0.4284, -0.0735, -0.2157,  0.0720, -0.3879, -0.0937, -0.1796,
          -0.0448, -0.4319, -0.0718, -0.2156,  0.0647, -0.3867, -0.0943,
          -0.1785, -0.0481, -0.4319, -0.0700, -0.2198,  0.0660, -0.3883,
          -0.0964, -0.1785, -0.0522, -0.4341, -0.0704, -0.2225,  0.0664,
          -0.3935, -0.0967, -0.1836, -0.0549, -0.4337, -0.0701, -0.2239,
           0.0693, -0.3930, -0.0957, -0.1903, -0.0557, -0.4344, -0.0742,
          -0.2224,  0.0694, -0.3911, -0.0995, -0.1891, -0.0485, -0.4358,
          -0.0744, -0.2180,  0.0699, -0.3934, -0.1016, -0.1894, -0.0517,
          -0.4408, -0.0739, -0.2161,  0.0706, -0.39

In [59]:
xs.size()

torch.Size([1, 64, 128])

<built-in method size of Tensor object at 0x2b9012e4bd10>
ResNet :  torch.Size([1, 128, 497])
<built-in method size of Tensor object at 0x2b9012e4bd70>


RuntimeError: Given groups=1, weight of size [128, 128, 4], expected input[1, 64, 468] to have 128 channels, but got 64 channels instead