In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# [ BATCH_SIZE x NUM_TIMESTAMPS x MFCC_FEATURE_DIM ]
# need to swap to
# [ BATCH_SIZE x MFCC_FEATURE_DIM x NUM_TIMESTAMPS ]

class GradReverse(torch.autograd.Function):
    """
    Extension of grad reverse layer
    """
    @staticmethod
    def forward(self, inputs):
        return inputs

    def backward(self, grad_output):
        grad_input = grad_output.clone()
        grad_input = -grad_input
        return grad_input

class SexClassifier(nn.Module):
    def __init__(self, num_classes):
        super(SexClassifier, self).__init__()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, input, constant):
        input = GradReverse.grad_reverse(input, constant)
        logits = F.relu(self.fc1(input))
        logits = F.log_softmax(self.fc2(logits), 1)
        return logits

# Gated Linear Units
class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()
        # Custom Implementation because the Voice Conversion Cycle GAN
        # paper assumes GLU won't reduce the dimension of tensor by 2.

    def forward(self, input):
        return input * torch.sigmoid(input)

class PixelShuffle(nn.Module):
    def __init__(self, upscale_factor):
        super(PixelShuffle, self).__init__()
        # Custom Implementation because PyTorch PixelShuffle requires,
        # 4D input. Whereas, in this case we have have 3D array
        self.upscale_factor = upscale_factor

    def forward(self, input):
        n = input.shape[0]
        c_out = input.shape[1] // 2
        w_new = input.shape[2] * 2
        return input.view(n, c_out, w_new)

# define the NN architecture
class ConvAutoencoder(nn.Module):
    def __init__(self, mfcc_feature_dim):
        super(ConvAutoencoder, self).__init__()
        
        ## model parameters ##
        self.mfcc_feature_dim = mfcc_feature_dim

        ## encoder layers ##
        self.encoder=nn.Sequential(
            nn.Conv1d(in_channels=self.mfcc_feature_dim, out_channels=128, kernel_size=15, stride=1, padding=7),
            GLU(),
            nn.Conv1d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=1),
            nn.InstanceNorm1d(num_features=256, affine=True),
            GLU(),
            nn.Conv1d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2),
            nn.InstanceNorm1d(num_features=512, affine=True),
            GLU()
        )

        ## decoder layers ##
        self.decoder=nn.Sequential(
            nn.Conv1d(in_channels=512, out_channels=1024, kernel_size=5, stride=1, padding=2),
            PixelShuffle(upscale_factor=2), 
            nn.InstanceNorm1d(num_features=1024 // 2, affine=True),
            GLU(),
            nn.Conv1d(in_channels=1024 // 2, out_channels=512, kernel_size=5, stride=1, padding=2),
            PixelShuffle(upscale_factor=2), 
            nn.InstanceNorm1d(num_features=512 // 2, affine=True),
            GLU(),
            nn.Conv1d(in_channels=512 // 2, out_channels=self.mfcc_feature_dim, kernel_size=15, stride=1, padding=7),
        )

        ## Sex classifier: num_classes = 2 ##
        self.sex_classifier = SexClassifier(2)


    def forward(self, input, constant):
        ## encode ##
        input = self.encoder(input)

        ## statistics pooling ##
        mean = torch.mean(input, 2)
        std = torch.std(input, 2)
        stat_pooling = torch.cat((mean, std), 1)

        ## sex classifier ##
        sex_classifier_logits = self.sex_classifier(stat_pooling, constant)
        
        ## decode ##
        input = self.decoder(input)

        ## return reconstructed speech feature for reconstruction loss, sex classification for cross entropy loss ##
        return input, sex_classifier_logits


In [24]:
# [ BATCH_SIZE x NUM_TIMESTAMPS x MFCC_FEATURE_DIM ]
# need to swap to
# [ BATCH_SIZE x MFCC_FEATURE_DIM x NUM_TIMESTAMPS ]

input = torch.rand(40, 200, 20)
input = input.view(40, 20, 200)
model = ConvAutoencoder(20)

output, sex_logits = model(input, 1)

print(output.shape)
print(sex_logits.shape)


torch.Size([40, 20, 200])
torch.Size([40, 2])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
from tqdm import tqdm 
from tqdm import trange
from sklearn.metrics import accuracy_score

def main():
    # Instantiate the parser
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type=float, default=1e-6, help='learning rate')
    parser.add_argument('--epochs', type=int, default=20, help='number of epochs')
    parser.add_argument('--r_weight', type=float, default=0.5, help='reconstruction loss weight')
    parser.add_argument('--s_weight', type=float, default=0.5, help='sex classification loss weight')
    #parser.add_argument('--a_weight', type=float, default=0.5, help='asr loss weight')
    parser.add_argument('--feature_dim', type=int, default=20, help='input feature dim')
    parser.add_argument('--output_model_file', type=str, default="/checkpoints/model.bin", help='output path for model')
    parser.add_argument('--patience', type=int, default=2)
    
    args = parser.parse_args()

    args.reconstruction_loss = nn.L1Loss(reduction='mean')
    args.sex_classification_loss = nn.CrossEntropyLoss()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = ConvAutoencoder(args.feature_dim)

    ## train the model ##
    train(model, train_loader, dev_loader, args, device)



def train(model, train_loader, dev_loader, args, device):

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    train_losses = []
    loss_history = []
    no_improvement = 0
    for _ in trange(args.epochs, desc="Epoch"):
        model.train()
        train_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(train_loader, desc="Training iteration")):

            # TODO: NEED TO CHANGE THIS ACCORDING TO DATA LOADER
            batch = tuple(t.to(device) for t in batch)
            speech, transcription, sex_label = batch

            optimizer.zero_grad()

            # forward
            reconstructed_speech, sex_logits = model(speech)
            recon_loss = args.reconstruction_loss(reconstructed_speech, speech)
            sex_loss = args.sex_classification_loss(sex_logits, sex_label)

            # backward
            recon_loss.backward()
            sex_loss.backward()
            train_loss += args.r_weight * recon_loss.item() + args.s_weight * sex_loss.item()
            optimizer.step()

        # print out losses
        print("Loss history:", train_losses)
        print("Train loss:", train_loss/nb_tr_steps)

        dev_loss, _, _ = evaluate(model, dev_loader, device="cuda")
        print("Dev loss:", dev_loss)

        if len(loss_history) == 0 or dev_loss < min(loss_history):
            no_improvement = 0
            model_to_save = model.module if hasattr(model, 'module') else model
            torch.save(model_to_save.state_dict(), args.output_model_file)
        else:
            no_improvement += 1
        
        if no_improvement >= args.patience:
            print("No improvement on development set. Finish training.")
            break

        train_losses.append(train_loss/ len(train_loader))
        loss_history.append(dev_loss)
        

def evaluate(model, dev_loader, args, device):
    model.eval()
    
    eval_loss = 0
    nb_eval_steps = 0
    predicted_labels, correct_labels = [], []

    model.to(device)
    for step, batch in enumerate(tqdm(dev_loader, desc="Evaluation iteration")):
        batch = tuple(t.to(device) for t in batch)
        speech, transcription, sex_label = batch

        with torch.no_grad():
            reconstructed_speech, sex_logits = model(speech)

        recon_loss = args.reconstruction_loss(reconstructed_speech, speech)
        sex_loss = args.sex_classification_loss(sex_logits, sex_label)

        outputs = np.argmax(sex_logits.to('cpu'), axis=1)
        sex_label = sex_label.to('cpu').numpy()
        
        predicted_labels += list(outputs)
        correct_labels += list(sex_label)
        
        eval_loss += args.r_weight * recon_loss.item() + args.s_weight * sex_loss.item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    
    correct_labels = np.array(correct_labels)
    predicted_labels = np.array(predicted_labels)
    print("Accuracy on testset: "+str(accuracy_score(correct_labels, predicted_labels)))
        
    return eval_loss, correct_labels, predicted_labels


def test(model, test_loader, args, device):
    model.eval()

    eval_loss = 0
    nb_eval_steps = 0
    predicted_labels, correct_labels = [], []
    model.to(device)
    for step, batch in enumerate(tqdm(test_loader, desc="Testing iteration")):
        batch = tuple(t.to(device) for t in batch)
        speech, transcription, sex_label = batch

        with torch.no_grad():
            reconstructed_speech, sex_logits = model(speech)

        recon_loss = args.reconstruction_loss(reconstructed_speech, speech)

        outputs = np.argmax(sex_logits.to('cpu'), axis=1)
        sex_label = sex_label.to('cpu').numpy()
        
        predicted_labels += list(outputs)
        correct_labels += list(sex_label)
        
        eval_loss += recon_loss.item() 
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    
    correct_labels = np.array(correct_labels)
    predicted_labels = np.array(predicted_labels)
    accuracy = accuracy_score(correct_labels, predicted_labels)*100
    print("Test Accuracy: "+str(accuracy))
        
    return eval_loss, accuracy
        

    

In [4]:
import speechbrain as sb
from speechbrain import Brain

def main():
    # Instantiate the parser
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type=float, default=1e-6, help='learning rate')
    parser.add_argument('--epochs', type=int, default=20, help='number of epochs')
    parser.add_argument('--r_weight', type=float, default=0.5, help='reconstruction loss weight')
    parser.add_argument('--s_weight', type=float, default=0.5, help='sex classification loss weight')
    #parser.add_argument('--a_weight', type=float, default=0.5, help='asr loss weight')
    parser.add_argument('--feature_dim', type=int, default=20, help='input feature dim')
    parser.add_argument('--output_model_file', type=str, default="/checkpoints/model.bin", help='output path for model')
    parser.add_argument('--patience', type=int, default=2)
    
    args = parser.parse_args()
    
    args.reconstruction_loss = nn.L1Loss(reduction='mean')
    args.sex_classification_loss = nn.CrossEntropyLoss()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = ConvAutoencoder(args.feature_dim)

    brain = ConvAutoEncoderBrain({"model": model})


class ConvAutoEncoderBrain(Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the output probabilities."""
        batch = batch.to(self.device)
        wavs, wav_lens = batch.sig
        tokens_bos, _ = batch.tokens_bos

        # Add augmentation if specified
        if stage == sb.Stage.TRAIN:
            if hasattr(self.modules, "env_corrupt"):
                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
                wavs = torch.cat([wavs, wavs_noise], dim=0)
                wav_lens = torch.cat([wav_lens, wav_lens])
                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)

        # compute features
        feats = self.hparams.compute_features(wavs)

        return self.modules.model(feats)


    def compute_objectives(self, predictions, batch, stage):
        """Computes the reconstruction loss and l1 norm loss for sex classification given predictions and targets."""

        reconstructed_speech, sex_logits  = predictions

        sex_label = batch.gender
        wavs, wav_lens = batch.sig
        # compute features
        feats = self.hparams.compute_features(wavs)

        reconstruction_loss = nn.L1Loss(reduction='mean')
        sex_classification_loss = nn.CrossEntropyLoss()

        loss = (
            self.hparams.ctc_weight * loss_ctc
            + (1 - self.hparams.ctc_weight) * loss_seq
        )

        if stage != sb.Stage.TRAIN:
            current_epoch = self.hparams.epoch_counter.current
            valid_search_interval = self.hparams.valid_search_interval
            if current_epoch % valid_search_interval == 0 or (
                stage == sb.Stage.TEST
            ):
                # Decode token terms to words
                predicted_words = [
                    tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps
                ]
                target_words = [wrd.split(" ") for wrd in batch.wrd]
                self.wer_metric.append(ids, predicted_words, target_words)

            # compute the accuracy of the one-step-forward prediction
            self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
        return torch.nn.functional.l1_loss(predictions, batch[0])

    


brain = ConvAutoEncoderBrain({"model": model})