In [91]:
%load_ext autoreload
%autoreload 2

from src.datasets import CustomDirAudioDataset
from torch.utils.data import DataLoader
from src.collate_fn.collate import collate_fn
from pathlib import Path
from torch import nn
import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [92]:
cfg = {"preprocessing": {"sr": 16000}}

In [93]:
ds = CustomDirAudioDataset("/Users/arturgimranov/CS/fourth_year/dla_course/ss/train_mixtures", cfg, None)

In [95]:
ds[0]

{'ref_audio': tensor([[ 8.5449e-04,  1.7395e-03,  2.3193e-03,  ..., -3.6621e-04,
          -2.7466e-04, -3.0518e-05]]),
 'ref_duration': 9.41,
 'ref_length': 150560,
 'ref_path': '/Users/arturgimranov/CS/fourth_year/dla_course/ss/train_mixtures/refs/4719_1638_004020_0-ref.wav',
 'mix_audio': tensor([[ 0.0013,  0.0007,  0.0011,  ..., -0.0010,  0.0060,  0.0132]]),
 'mix_duration': 3.0,
 'mix_length': 48000,
 'mix_path': '/Users/arturgimranov/CS/fourth_year/dla_course/ss/train_mixtures/mix/4719_1638_004020_0-mixed.wav',
 'target_audio': tensor([[0.0017, 0.0009, 0.0015,  ..., 0.0126, 0.0128, 0.0141]]),
 'target_duration': 3.0,
 'target_length': 48000,
 'target_path': '/Users/arturgimranov/CS/fourth_year/dla_course/ss/train_mixtures/targets/4719_1638_004020_0-target.wav',
 'ref_speaker_id': 4719,
 'ref_target': 0}

In [7]:
dl = DataLoader(
        ds,
        batch_size=4,
        collate_fn=collate_fn,
        shuffle=True,
        num_workers=2,
        drop_last=True,
    )

In [8]:
batch = next(iter(dl))

In [7]:
batch["mix_audio"].shape

torch.Size([4, 48000])

In [8]:
mix_audio = batch["mix_audio"]

# SpeechEncoder

In [24]:
L1 = 40
L2 = 160
L3 = 320
N = 256
speaker_dim = 256
n_classes = 10

In [10]:
class SpeechEncoder(nn.Module):
    def __init__(self, L1, L2, L3, N):
        super().__init__()
        self.L1 = L1
        self.L2 = L2
        self.L3 = L3
        self.N = N

        self.short_encoder = nn.Conv1d(1, N, L1, L1 // 2)
        self.middle_encoder = nn.Sequential(
            nn.ConstantPad1d((0, (L2 - L1)), 0), nn.Conv1d(1, N, L2, L1 // 2)
        )
        self.long_encoder = nn.Sequential(
            nn.ConstantPad1d((0, (L3 - L1)), 0), nn.Conv1d(1, N, L3, L1 // 2)
        )

    def forward(self, x, return_tuple=False):
        x1 = self.short_encoder(x)
        x2 = self.middle_encoder(x)
        x3 = self.long_encoder(x)

        assert x1.shape == x2.shape == x3.shape
        if return_tuple:
            return torch.cat([x1, x2, x3], dim=1), (x1, x2, x3)
        return torch.cat([x1, x2, x3], dim=1)

    def _length_after(self, length_before):
        return (
            torch.div(length_before - self.L1, self.L1 // 2, rounding_mode="floor")
            + 1
        )


In [11]:
class ChannelLayerNorm(nn.LayerNorm):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        x = x.transpose(-1, -2)
        x = super().forward(x)
        x = x.transpose(-1, -2)
        return x

In [12]:
class ResNetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv1d(dim, dim, 1),
            nn.BatchNorm1d(dim),
            nn.PReLU(),
            nn.Conv1d(dim, dim, 1),
            nn.BatchNorm1d(dim),
        )

        self.head = nn.Sequential(
            nn.PReLU(),
            nn.MaxPool1d(3)
        )


    def forward(self, x):
        return self.head(x + self.body(x))


In [38]:
class SpeakerEncoder(nn.Module):
    def __init__(
        self, N, speaker_dim, n_classes, speech_encoder, num_resnet_blocks=3
    ):
        super().__init__()
        self.speech_encoder = speech_encoder
        self.channel_layer_norm = ChannelLayerNorm(3 * N)
        self.conv1 = nn.Conv1d(3 * N, speaker_dim, 1)

        self.num_resnet_blocks = num_resnet_blocks
        self.resnet_blocks = nn.Sequential(
            *[ResNetBlock(speaker_dim) for _ in range(num_resnet_blocks)]
        )
        self.conv2 = nn.Conv1d(speaker_dim, speaker_dim, 1)
        self.linear = nn.Linear(speaker_dim, n_classes)

    def _length_after_resnet(self, length_before):
        return torch.div(
            length_before, (3**self.num_resnet_blocks), rounding_mode="floor"
        )

    def forward(self, x, x_lengths):
        x = self.conv1(self.channel_layer_norm(self.speech_encoder(x)))
        x = self.conv2(self.resnet_blocks(x))
        x = x.sum(dim=-1) / self._length_after_resnet(
            self.speech_encoder._length_after(x_lengths)
        )[:, None]
        logits = self.linear(x)
        return x, logits


In [81]:
class TCNBase(nn.Module):
    def __init__(self, in_channels, speaker_channels, out_channels, dilation, kernel_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_channels + speaker_channels, out_channels, 1),
            nn.PReLU(),
            nn.GroupNorm(1, out_channels),
            nn.Conv1d(
                out_channels,
                out_channels,
                kernel_size,
                dilation=dilation,
                padding="same",
            ),
            nn.PReLU(),
            nn.GroupNorm(1, out_channels),
            nn.Conv1d(out_channels, in_channels, 1),
        )

    def forward(self, x):
        return x + self.net(x)

class TCNBlock(TCNBase):
    def __init__(self, in_channels, out_channels, dilation, kernel_size):
        super().__init__(in_channels, 0, out_channels, dilation, kernel_size)

class FirstTCNBlock(TCNBase):
    def __init__(self, in_channels, speaker_dim, out_channels, dilation, kernel_size):
        super().__init__(in_channels, speaker_dim, out_channels, dilation, kernel_size)

    def forward(self, x, speaker_embedding):
        time_length = x.shape[-1]
        speaker_embedding = torch.unsqueeze(speaker_embedding, -1)
        speaker_embedding = speaker_embedding.repeat(1, 1, time_length)
        return x + self.net(torch.cat([x, speaker_embedding], dim=1))


class StackedTCN(nn.Module):
    def __init__(
        self, in_channels, speaker_dim, out_channels, kernel_size, num_blocks
    ):  
        super().__init__()
        self.first_block = FirstTCNBlock(
                in_channels, speaker_dim, out_channels, 1, kernel_size
            )

        self.rest_blocks = nn.Sequential(
            *[
                TCNBlock(out_channels, out_channels, 2 ** i, kernel_size)
                for i in range(1, num_blocks)
            ]
        )

    def forward(self, x, speaker_embedding):
        x = self.first_block(x, speaker_embedding)
        print(f'{x.shape=}')
        return self.rest_blocks(x)


In [16]:
speech_encoder = SpeechEncoder(L1, L2, L3, N)

In [22]:
mix_encode, (y1, y2, y3) = speech_encoder(mix_audio.unsqueeze(1), return_tuple=True)

In [25]:
ln = ChannelLayerNorm(3 * N)
conv1 = nn.Conv1d(3 * N, speaker_dim, 1)
mix_encode = conv1(ln(mix_encode))
mix_encode.shape

torch.Size([4, 256, 2399])

In [41]:
ref_audio = batch["ref_audio"]
# ref_audio = speech_encoder(ref_audio.unsqueeze(1))
# ref_audio.shape

In [39]:
speaker_encoder = SpeakerEncoder(N, speaker_dim, n_classes, speech_encoder)

In [43]:
ref_lengths = torch.tensor([duration * 16000 for duration in batch["ref_duration"]])

In [44]:
ref_audio, logits = speaker_encoder(ref_audio.unsqueeze(1), ref_lengths)

In [78]:
tcn = StackedTCN(N, speaker_dim, 256, 3, 3)

In [79]:
ref_audio = tcn(mix_encode, ref_audio)

speaker_embedding.shape=torch.Size([4, 256, 2399])
x.shape=torch.Size([4, 256, 2399])


In [80]:
ref_audio.shape

torch.Size([4, 256, 2399])

In [82]:
short_decoder = nn.ConvTranspose1d(N, 1, L1, L1 // 2)

In [83]:
short_decoder(ref_audio).shape

torch.Size([4, 1, 48000])

In [106]:
a = torch.randn((3, 4, 100))
c = nn.Conv1d(4, 6, 5, 2)

In [107]:
c1 = nn.Conv1d(6, 4, 5, 2)

In [108]:
print(a.shape)
print(c(a).shape)
print(c1(c(a)).shape)

torch.Size([3, 4, 100])
torch.Size([3, 6, 48])
torch.Size([3, 4, 22])


In [109]:
class SpeakerExtractor(nn.Module):
    def __init__(
        self,
        N,
        speech_encoder,
        speaker_dim,
        tcn_block_dim,
        tcn_kernel_size,
        tcn_num_blocks,
        tnc_num_stacks, 
    ):
        super().__init__()
        self.speech_encoder = speech_encoder
        self.channel_layer_norm = ChannelLayerNorm(3 * N)
        self.conv1 = nn.Conv1d(3 * N, speaker_dim, 1)
        self.tcn = nn.Sequential(
            *[
                StackedTCN(
                    speaker_dim,
                    speaker_dim,
                    tcn_block_dim,
                    tcn_kernel_size,
                    tcn_num_blocks,
                )
                for _ in range(tnc_num_stacks)
            ]
        )
        self.masks_head = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv1d(speaker_dim, N, 1),
                    nn.ReLU(),
                )
            ] for i in range(3)
        )

    def forward(self, mix, speaker_embedding):
        mix = self.conv1(self.channel_layer_norm(self.speech_encoder(mix)))
        mix_embedding = self.tcn(mix, speaker_embedding)
        masks = [mask_head(mix_embedding) for mask_head in self.masks_head]
        return [mask * mix for mask in masks]

In [111]:
class SpeechDecoder(nn.Module):
    def __init__(self, L1, L2, L3, N):
        super().__init__()
        self.short_decoder = nn.ConvTranspose1d(N, 1, L1, L1 // 2)
        self.middle_decoder = nn.Sequential(
            nn.ConvTranspose1d(N, 1, L2, L2 // 2),
        )
        self.long_decoder = nn.Sequential(
            nn.ConvTranspose1d(N, 1, L3, L3 // 2),
        )

    def forward(self, y1, y2, y3):
        y1 = self.short_decoder(y1)
        y2 = self.middle_decoder(y2)
        y3 = self.long_decoder(y3)

        return y1, y2, y3
        

In [112]:
class SpExPlus(nn.Module):
    def __init__(
        self,
        speaker_dim,
        tcn_block_dim,
        tcn_kernel_size,
        tcn_num_blocks,
        tnc_num_stacks,
        L1,
        L2,
        L3,
        N,
    ):
        super().__init__()
        self.speech_encoder = SpeechEncoder(L1, L2, L3, N)
        self.speaker_encoder = SpeakerEncoder(N, speaker_dim, n_classes, speech_encoder)
        self.speaker_extractor = SpeakerExtractor(
            N,
            speech_encoder,
            speaker_dim,
            tcn_block_dim,
            tcn_kernel_size,
            tcn_num_blocks,
            tnc_num_stacks,
        )
        self.speech_decoder = SpeechDecoder(L1, L2, L3, N)

    def forward(self, mix, ref, ref_lengths):
        mix_length = mix.shape[-1]
        mix_encode, (y1, y2, y3) = self.speech_encoder(mix.unsqueeze(1), return_tuple=True)
        ref_encode = self.speech_encoder(ref.unsqueeze(1))
        ref_audio, logits = self.speaker_encoder(ref_encode, ref_lengths)
        masks = self.speaker_extractor(mix_encode, ref_audio)

        for mask, y in zip(masks, [y1, y2, y3]):
            mask = torch.unsqueeze(mask, 1)
            mask = mask.repeat(1, N, 1)
            y = y * mask

        short, middle, long = self.speech_decoder(y1, y2, y3)

        return {
            "short": short[:, :, :mix_length],
            "middle": middle[:, :, :mix_length],
            "long": long[:, :, :mix_length],
        }


In [301]:
speach_encoder = SpeechEncoder(L1, L2, L3, N)

In [302]:
x = speach_encoder(mix_audio.unsqueeze(1))

In [303]:
tcn_block = TCNBlock(3 * N, 3 * N, 2, 3)

In [304]:
tcn_block(x).shape

torch.Size([4, 768, 2399])

In [305]:
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return x + y

In [276]:
s = nn.Sequential(MyModule(), nn.Linear(5, 10))

In [281]:
s(torch.randn((3, 4, 5)), torch.randn((3, 4, 5)))

TypeError: forward() takes 2 positional arguments but 3 were given

In [268]:
class Test:
    def __init__(self, a, *args, **kwargs):
        print(f'{a=}')
        print(f'{args=}')
        print(f'{kwargs=}')

In [271]:
a = Test(10, 20, 30, b=2)

a=10
args=(20, 30)
kwargs={'b': 2}


In [231]:
x = torch.randn(3, 4, 10)
aux = torch.randn(3, 4)

In [232]:
T = x.shape[-1]
aux = torch.unsqueeze(aux, -1)
aux = aux.repeat(1,1,T)

In [239]:
aux[:, :, 0]

tensor([[-1.4661, -0.8354, -1.2423, -0.2026],
        [-0.1847, -0.1209,  1.6513, -0.1893],
        [-0.2205,  0.3992,  0.1901,  2.6755]])

In [240]:
aux[:, :, 1]

tensor([[-1.4661, -0.8354, -1.2423, -0.2026],
        [-0.1847, -0.1209,  1.6513, -0.1893],
        [-0.2205,  0.3992,  0.1901,  2.6755]])

In [233]:
y = torch.cat([x, aux], 1)

In [235]:
y.shape

torch.Size([3, 8, 10])

In [259]:
gln = nn.GroupNorm(1, 8)

In [261]:
gln(y).std((1, 2))

tensor([1.0063, 1.0063, 1.0063], grad_fn=<StdBackward0>)

In [245]:
gln.mean()

AttributeError: 'LayerNorm' object has no attribute 'mean'

In [32]:
from src.model import SpExPlus

In [33]:
model = SpExPlus(
    speaker_dim=256,
    tcn_block_dim=256,
    tcn_kernel_size=3,
    tcn_num_blocks=3,
    tcn_num_stacks=3,
    L1=40,
    L2=160,
    L3=320,
    N=256,
    n_classes=10,
)

In [34]:
mix_audio = batch["mix_audio"]
ref_audio = batch["ref_audio"]
ref_lengths = torch.tensor([duration * 16000 for duration in batch["ref_duration"]])

In [35]:
a = model(mix_audio, ref_audio, ref_lengths)

x.shape=torch.Size([4, 256, 393])
x.shape=torch.Size([4, 256])


In [21]:
for key, value in a.items():
    print(key, value.shape)

short torch.Size([4, 1, 48000])
middle torch.Size([4, 1, 48000])
long torch.Size([4, 1, 48000])
logits torch.Size([4, 10])


In [13]:
ref_audio.shape

torch.Size([4, 212480])

In [65]:
def si_sdr(est, target, eps=1e-8):
    l2norm = lambda x, keepdim=False: torch.norm(x, dim=-1, keepdim=keepdim)
    alpha = (target * est).sum() / l2norm(target, True)**2
    print(f'{alpha.shape=}')
    return 20 * torch.log10(l2norm(alpha * target) / (l2norm(alpha * target - est) + eps) + eps)

In [66]:
a["short"].shape, mix_audio.shape

(torch.Size([4, 48000]), torch.Size([4, 48000]))

In [71]:
si_sdr(a["short"] - a["short"].mean(dim=-1, keepdim=True), mix_audio - mix_audio.mean(dim=-1, keepdim=True))

alpha.shape=torch.Size([4, 1])


tensor([-35.6926, -35.7528, -34.7097, -34.3358], grad_fn=<MulBackward0>)

In [72]:
def sisdr(x, s, eps=1e-8):
    """
    Arguments:
    x: separated signal, N x S tensor
    s: reference signal, N x S tensor
    Return:
    sisdr: N tensor
    """

    def l2norm(mat, keepdim=False):
        return torch.norm(mat, dim=-1, keepdim=keepdim)

    x_zm = x - torch.mean(x, dim=-1, keepdim=True)
    s_zm = s - torch.mean(s, dim=-1, keepdim=True)
    t = torch.sum(
        x_zm * s_zm, dim=-1,
        keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
    return 20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))

In [82]:
sisdr(a["short"], mix_audio)

tensor([-42.4574, -36.3025, -35.8277, -40.6908], grad_fn=<MulBackward0>)

In [89]:
from functools import partial

def si_sdr(estimated, target, eps=1e-8):
    l2norm = partial(torch.linalg.norm, dim=-1)
    alpha = (target * estimated).sum(dim=-1, keepdim=True) / l2norm(target, keepdim=True) ** 2
    return 20 * torch.log10(l2norm(alpha * target) / (l2norm(alpha * target - estimated) + eps) + eps)

In [90]:
si_sdr(a["short"] - a["short"].mean(-1, keepdim=True), mix_audio - mix_audio.mean(-1, keepdim=True))

tensor([-42.4574, -36.3025, -35.8277, -40.6908], grad_fn=<MulBackward0>)