In [1]:
import torch
import torch.nn as nn

In [2]:
speaker2 = torch.rand(4, 2, 4)
mixtures = torch.rand(4, 2, 4)

In [3]:
class Sub(nn.Module):

    def __init__(self, embed_dim):
        super(Sub, self).__init__()
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, mixtures, speaker2):
        x = torch.subtract(mixtures, speaker2)
        x = self.fc(x)
        return x

sub = Sub(4)
sub(mixtures, speaker2).size()

torch.Size([4, 2, 4])

In [4]:
class Mul(nn.Module):

    def __init__(self, embed_dim):
        super(Mul, self).__init__()
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, mixtures, speaker2):
        x = torch.mul(mixtures, speaker2)
        x = self.fc(x)
        return x

mul = Mul(4)
mul(mixtures, speaker2).size()

torch.Size([4, 2, 4])

In [5]:
class Concat1(nn.Module):

    def __init__(self, embed_dim):
        super(Concat1, self).__init__()
        self.fc = nn.Linear(2*embed_dim, embed_dim)

    def forward(self, mixtures, speaker2):
        x = torch.cat((mixtures, speaker2), axis=2)
        x = self.fc(x)
        return x

concat1 = Concat1(4)
concat1(mixtures, speaker2).size()

torch.Size([4, 2, 4])

In [6]:
class Concat2(nn.Module):

    def __init__(self, embed_dim):
        super(Concat2, self).__init__()
        self.fc1 = nn.Linear(2*embed_dim, embed_dim)
        self.fc2 = nn.Linear(embed_dim, embed_dim)

    def forward(self, mixtures, speaker2):
        x = torch.cat((mixtures, speaker2), axis=2)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        return x

concat2 = Concat2(4)
concat2(mixtures, speaker2).size()

torch.Size([4, 2, 4])

In [7]:
class ShareConcat(nn.Module):

    def __init__(self, embed_dim):
        super(ShareConcat, self).__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim)
        self.fc2 = nn.Linear(2*embed_dim, embed_dim)

    def forward(self, mixtures, speaker2):
        mixtures = nn.ReLU()(self.fc1(mixtures))
        speaker2 = nn.ReLU()(self.fc1(speaker2))
        x = torch.cat((mixtures, speaker2), axis=2)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        return x

share_concat = ShareConcat(4)
share_concat(mixtures, speaker2).size()

torch.Size([4, 2, 4])

In [8]:
class SeparateConcat(nn.Module):

    def __init__(self, embed_dim):
        super(SeparateConcat, self).__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim)
        self.fc2 = nn.Linear(embed_dim, embed_dim)
        self.fc3 = nn.Linear(2*embed_dim, embed_dim)

    def forward(self, mixtures, speaker2):
        mixtures = nn.ReLU()(self.fc1(mixtures))
        speaker2 = nn.ReLU()(self.fc2(speaker2))
        x = torch.cat((mixtures, speaker2), axis=2)
        x = nn.ReLU()(x)
        x = self.fc3(x)
        return x

separate_concat = SeparateConcat(4)
separate_concat(mixtures, speaker2).size()

torch.Size([4, 2, 4])