In [56]:
from transformers import (
    Wav2Vec2Processor, Wav2Vec2Model
)
import librosa
import numpy as np
import os

import json
import pandas as pd 
from glob import glob


In [57]:
def load_alignment(path):
    with open(path, encoding="utf-8") as f:
        lines = f.readlines()

        lines = [json.loads(line.strip()) for line in lines]

    return lines

def load_id(path):
    with open(path, encoding="utf-8") as f:
        lines = f.readlines()

        lines = [line.strip() for line in lines]

    return lines

In [58]:
def load_data(data_dir):
    phone_ids = np.load(f'{data_dir}/phone_ids.npy')
    
    phone_scores = np.load(f'{data_dir}/phone_scores.npy')
    sentence_scores = np.load(f'{data_dir}/sentence_scores.npy')

    ids = load_id(f'{data_dir}/id')
    gops = np.load(f'{data_dir}/gop.npy')
    alignments = load_alignment(f'{data_dir}/alignment')
    return ids, phone_ids, phone_scores, sentence_scores, gops, alignments


In [59]:
from torch.utils.data import (
    Dataset,
    DataLoader
)
import librosa
import torch
import numpy as np

def pad_1d(inputs, max_length=None, pad_value=0.0):
    if max_length is None:
        lengths = [len(sample) for sample in inputs]
        max_length = max(lengths)

    for i in range(len(inputs)):
        if inputs[i].shape[0] < max_length:
            inputs[i] = torch.cat(
                (
                    inputs[i], 
                    pad_value * torch.ones(max_length-inputs[i].shape[0])),
                dim=0
            )
        else:
            inputs[i] = inputs[i][0:max_length]
    inputs = torch.stack(inputs, dim=0)
    return inputs

class Wav2vec2Dataset(Dataset):
    def __init__(self, audio_dir, ids, phone_ids, phone_scores, sentence_scores, gops, alignments, processor):
        self.audio_dir = audio_dir
        self.ids = ids
        self.phone_ids = phone_ids
        self.phone_scores = phone_scores
        self.sentence_scores = sentence_scores
        self.gops = gops
        self.alignments = alignments

        self.processor = processor
        self.max_utt_length = 32

    def load_audio(self, path):
        wav, sr = librosa.load(path, sr=16000)

        return wav
    
    def parse_data(self, alignment, id, phone_id, phone_score, sentence_score, gop):
        path = f'{self.audio_dir}/{id}.wav'
        audio = self.load_audio(path)
        
        align_indices = self.get_indices_from_aligments(
            alignment=alignment
        )

        audio = self.processor(
            audio, return_tensors="pt", padding="longest", sampling_rate=16000).input_values
        
        assert audio.shape[0] == 1
        audio = audio.squeeze(0)

        phone_id = torch.tensor(phone_id)
        phone_score = torch.tensor(phone_score)
        sentence_score = torch.tensor(sentence_score)
        gop = torch.tensor(gop)
        
        phone_score[phone_score != -1] /= 50
        sentence_score /= 50

        return {
            "audio": audio,
            "phone_id": phone_id,
            "phone_score": phone_score,
            "sentence_score": sentence_score,
            "gop": gop,
            "align_indices": align_indices
        } 

    def get_indices_from_aligments(self, alignment):
        index = 0
        indices = -1 * torch.ones(alignment[-1][1] + alignment[-1][2])
        for phoneme, start_frame, duration in alignment:
            end_frame = start_frame + duration
            indices[start_frame:end_frame] = index
            index += 1
        
        return indices

    def __getitem__(self, index):
        id = self.ids[index]
        phone_id = self.phone_ids[index]
        phone_score = self.phone_scores[index]
        sentence_score = self.sentence_scores[index]
        gop = self.gops[index]
        alignment = self.alignments[index]

        return self.parse_data(
            id=id,
            phone_id=phone_id,
            phone_score=phone_score,
            sentence_score=sentence_score,
            gop=gop,
            alignment=alignment
        )
    def __len__(self):
        return len(self.ids)

    def pad2indices(self, indices):
        lengths = [len(sample) for sample in indices]
        max_length = max(lengths)

        for i in range(len(indices)):
            if indices[i].shape[0] < max_length:
                padding =  -1 * torch.ones(max_length - len(indices[i]))
                indices[i] = torch.cat(
                    [
                        indices[i],
                        padding
                    ]
                )

        indices = torch.stack(indices, dim=0)
        max_indice = self.max_utt_length - 1
        for i, length in enumerate(lengths):
            max_current_index = indices[i].max().item()

            index = max_current_index
            for j in range(0, max_length):
                if indices[i][j] != -1:
                    continue
                
                if index < max_indice:
                    index += 1
                indices[i][j] = index
                    
        return indices
    
    def collate_fn(self, batch):
        phone_ids = [sample["phone_id"] for sample in batch]
        phone_scores = [sample["phone_score"] for sample in batch]
        sentence_scores = [sample["sentence_score"] for sample in batch]
        gops = [sample["gop"] for sample in batch]
        input_values = [sample["audio"] for sample in batch]
        align_indices = [sample["align_indices"] for sample in batch]

        input_values = pad_1d(input_values, pad_value=0.0)
        indices = self.pad2indices(align_indices)

        phone_ids = torch.stack(phone_ids, dim=0)
        phone_scores = torch.stack(phone_scores, dim=0)
        sentence_scores = torch.stack(sentence_scores, dim=0)
        gops = torch.stack(gops, dim=0)
        
        return {
            "indices": indices,
            "phone_ids": phone_ids,
            "phone_scores": phone_scores,
            "sentence_scores": sentence_scores,
            "gops": gops,
            "input_values": input_values
        }

audio_dir = "/data/audio_data/prep_submission_audio/10"
data_dir = "/data/codes/apa/train/exps/features/test/dev"
ids, phone_ids, phone_scores, sentence_scores, gops, alignments = load_data(data_dir)

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

dataset = Wav2vec2Dataset(
    audio_dir=audio_dir, 
    ids=ids, 
    phone_ids=phone_ids, 
    phone_scores=phone_scores, 
    sentence_scores=sentence_scores, 
    gops=gops, 
    alignments=alignments, 
    processor=processor
)

dataloader = DataLoader(dataset, batch_size=1, collate_fn=dataset.collate_fn)
for batch in dataloader:
    # for i in batch["indices"]:
    #     print(i)
    # print(batch["indices"][0])
    # print(batch)
    break

In [60]:
import torch
from torch.nn import Module

class PrepModel(Module):
    def __init__(self, model):
        super(PrepModel, self).__init__()
        self.model = model
        
        self.num_phone = 43
        self.phn_proj = torch.nn.Linear(self.num_phone, 64)
        
        self.ffw = torch.nn.Linear(
            768+64, 768
        )
        
        self.utt_head = torch.nn.Linear(
            768, 1
        )

    def get_phone_level_features(self, features, indices, device="cuda:0"):
        feature_indices = torch.arange(features.shape[1]).unsqueeze(-1).to(device)
        expanded_indices = feature_indices.expand((-1, 2)).flatten()
        features = features[:, expanded_indices]

        indices[indices==-1] = indices.max() + 1
        indices = torch.nn.functional.one_hot(
            indices.long(), num_classes=int(indices.max().item())+1).float().to(device)

        # indices = indices / indices.sum(1, keepdim=True)

        if indices.shape[1] > features.shape[1]:
            features = torch.matmul(
                indices[:, 0:features.shape[1]].transpose(1, 2), features)
        else:
            features = torch.matmul(
                indices.transpose(1, 2), features[:, 0:indices.shape[1]])
            
        return features

    def forward(self, input_values, phone_ids, indices, sentence_scores=None, device="cuda:0"):
        phn_one_hot = torch.nn.functional.one_hot(
            phone_ids.long()+1, num_classes=self.num_phone).float().to(device)
        phn_embed = self.phn_proj(phn_one_hot)
        
        features = self.model.extract_features(input_values)[0]
        
        features = self.get_phone_level_features(
            features=features,
            indices=indices
        )
        
        features = torch.cat([features, phn_embed], dim=-1)

        features = self.ffw(features)
        
        features = features.mean(dim=1)
        scores = self.utt_head(features)
        
        if sentence_scores is not None:
            loss = self.loss(pred_sentence_scores=scores, label_sentence_scores=sentence_scores)
            return loss, scores 
        
        return scores

    def loss(self, pred_sentence_scores,label_sentence_scores):
        sent_loss = torch.mean((pred_sentence_scores-label_sentence_scores)**2)
        
        return sent_loss
        

In [61]:
# wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
# processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

In [62]:
%cd /data/codes/apa/train
from src.models.wavlm_model import WavLM, WavLMConfig
pretrained_path = "/data/codes/apa/train/exps/ckpts/wavlm-base+.pt"
checkpoint = torch.load(pretrained_path)

config = WavLMConfig(checkpoint['cfg'])
wav2vec = WavLM(config).cuda()
wav2vec.load_state_dict(checkpoint['model'])

/data/codes/apa/train


<All keys matched successfully>

In [63]:
audio_dir = "/data/audio_data/prep_submission_audio/10"
data_dir = "/data/codes/apa/train/exps/features/test/dev"
ids, phone_ids, phone_scores, sentence_scores, gops, alignments = load_data(data_dir)

index = 2500


trainset = Wav2vec2Dataset(
    audio_dir=audio_dir, 
    ids=ids[0:2500], 
    phone_ids=phone_ids[0:2500], 
    phone_scores=phone_scores[0:2500], 
    sentence_scores=sentence_scores[0:2500], 
    gops=gops[0:2500], 
    alignments=alignments[0:2500], 
    processor=processor
)

train_loader = DataLoader(trainset, batch_size=8, shuffle=True, collate_fn=dataset.collate_fn)

testset = Wav2vec2Dataset(
    audio_dir=audio_dir, 
    ids=ids[2500:], 
    phone_ids=phone_ids[2500:], 
    phone_scores=phone_scores[2500:], 
    sentence_scores=sentence_scores[2500:], 
    gops=gops[2500:], 
    alignments=alignments[2500:], 
    processor=processor
)

val_loader = DataLoader(testset, batch_size=16, collate_fn=dataset.collate_fn)

In [64]:
from torch.optim import Adam

device = "cuda:0"
model = PrepModel(wav2vec).to(device)
optimizer = Adam(model.parameters(), lr=1e-5)

In [65]:
def valid_utt(predict, target):
    utt_mse = np.mean(((predict[:, 0] - target[:, 0]) ** 2).numpy())
    utt_mae = np.mean((np.abs(predict[:, 0] - target[:, 0])).numpy())
    
    utt_corr = np.corrcoef(predict[:, 0], target[:, 0])[0, 1]
    return utt_mse, utt_mae, utt_corr


In [66]:
from tqdm import tqdm

for epoch in range(10):
    train_losses, val_losses = [], []
    
    train_tqdm = tqdm(train_loader, desc="Train")
    for batch in train_tqdm:
        indices = batch["indices"].to(device)
        input_values = batch["input_values"].to(device)
        phone_ids = batch["phone_ids"].to(device)
        phone_scores = batch["phone_scores"].to(device)
        gops = batch["gops"].to(device)
        sentence_scores = batch["sentence_scores"].to(device)
        
        optimizer.zero_grad()
        try:
            loss, scores = model(
                input_values=input_values,
                indices=indices,
                sentence_scores=sentence_scores,
                phone_ids=phone_ids,
                device=device
            )
        except:
            continue
        
        loss.backward()
        
        optimizer.step()
        
        train_losses.append(loss.item())
        train_tqdm.set_postfix(loss=loss.item())
    
    pred_scores, label_scores = [], []
    val_tqdm = tqdm(val_loader, desc="Test")
    for batch in val_tqdm:
        indices = batch["indices"].to(device)
        input_values = batch["input_values"].to(device)
        phone_ids = batch["phone_ids"].to(device)
        phone_scores = batch["phone_scores"].to(device)
        gops = batch["gops"].to(device)
        sentence_scores = batch["sentence_scores"].to(device)
        
        with torch.no_grad():
            try:
                loss, scores = model(
                    input_values=input_values,
                    indices=indices,
                    sentence_scores=sentence_scores,
                    phone_ids=phone_ids,
                    device=device
                )
            except:
                continue
        
        pred_scores.append(scores)
        label_scores.append(sentence_scores.unsqueeze(-1))
        
        val_losses.append(loss.item())
        val_tqdm.set_postfix(loss=loss.item())
        
    pred_scores = torch.cat(pred_scores, dim=0)
    label_scores = torch.cat(label_scores, dim=0)
    
    utt_mse, utt_mae, utt_corr = valid_utt(
        predict=pred_scores.cpu(), 
        target=label_scores.cpu()
    )
    
    print(utt_mse, utt_mae, utt_corr)
        
    print(f'Train loss: {np.mean(train_losses)}, Test loss: {np.mean(val_losses)}')

Train: 100%|██████████| 313/313 [00:38<00:00,  8.09it/s, loss=0.281] 
Test: 100%|██████████| 124/124 [00:13<00:00,  9.06it/s, loss=0.161] 


0.1725504496183761 0.3362931657852246 0.01324251992383299
Train loss: 0.25763531581707955, Test loss: 0.17668090462202188


Train: 100%|██████████| 313/313 [00:38<00:00,  8.15it/s, loss=0.0684]
Test: 100%|██████████| 124/124 [00:13<00:00,  9.15it/s, loss=0.145] 


0.16249394062088024 0.33035708415347076 0.06391066009153079
Train loss: 0.2906006127503431, Test loss: 0.1689805659613434


Train: 100%|██████████| 313/313 [00:38<00:00,  8.14it/s, loss=0.128] 
Test: 100%|██████████| 124/124 [00:13<00:00,  9.04it/s, loss=0.145] 


0.1477346029595126 0.31560457250607604 0.027400497056666043
Train loss: 0.18689634244439496, Test loss: 0.15267970934217545


Train: 100%|██████████| 313/313 [00:38<00:00,  8.12it/s, loss=0.0693]
Test: 100%|██████████| 124/124 [00:13<00:00,  9.11it/s, loss=0.191] 


0.1759101222050824 0.33484905845375523 0.0024265110202011256
Train loss: 0.1600074216214537, Test loss: 0.17906056358457048


Train: 100%|██████████| 313/313 [00:38<00:00,  8.09it/s, loss=0.0379]
Test: 100%|██████████| 124/124 [00:13<00:00,  9.07it/s, loss=0.183] 


0.15224956957693642 0.31207981221355896 0.10363701217423818
Train loss: 0.16314415247525896, Test loss: 0.16635495582128737


Train: 100%|██████████| 313/313 [00:38<00:00,  8.16it/s, loss=0.0323]
Test: 100%|██████████| 124/124 [00:13<00:00,  9.17it/s, loss=0.157] 


0.16748336000784503 0.3383526340555351 0.04659179328559895
Train loss: 0.14695374719063717, Test loss: 0.17738075522980853


Train: 100%|██████████| 313/313 [00:38<00:00,  8.03it/s, loss=0.0844]
Test: 100%|██████████| 124/124 [00:13<00:00,  9.05it/s, loss=0.275] 


0.13880625297190777 0.31000552877658066 0.07005133854891882
Train loss: 0.14979522637334164, Test loss: 0.1444640910833847


Train: 100%|██████████| 313/313 [00:38<00:00,  8.12it/s, loss=0.0872]
Test: 100%|██████████| 124/124 [00:13<00:00,  9.06it/s, loss=0.157] 


0.12909420459042448 0.2938289573502949 0.13744630708490985
Train loss: 0.1550404574375162, Test loss: 0.1383062030038138


Train: 100%|██████████| 313/313 [00:39<00:00,  8.00it/s, loss=0.166] 
Test: 100%|██████████| 124/124 [00:13<00:00,  9.12it/s, loss=0.337] 


0.5066941077810867 0.5115063267781924 0.022814030084555802
Train loss: 0.20748309306108717, Test loss: 0.5243153287002663


Train: 100%|██████████| 313/313 [00:38<00:00,  8.11it/s, loss=0.114] 
Test: 100%|██████████| 124/124 [00:13<00:00,  9.11it/s, loss=0.154] 

0.14550136940647562 0.3142473769847608 0.02143494770187068
Train loss: 0.3372367670443356, Test loss: 0.15008283325510927





In [67]:
sentence_scores

tensor([1.3776, 1.1086, 0.6452, 1.8580, 1.4176, 1.1550, 1.1400, 1.5274, 1.8448,
        0.9392, 1.9248, 1.4496, 1.5958, 1.4460], device='cuda:0',
       dtype=torch.float64)

### test

In [68]:
def get_phone_level_features(features, indices):
    # feature_indices = torch.arange(features.shape[1]).unsqueeze(-1)
    # expanded_indices = feature_indices.expand((-1, 2)).flatten()
    # features = features[:, expanded_indices]

    indices[indices==-1] = indices.max() + 1
    indices = torch.nn.functional.one_hot(
        indices.long(), num_classes=int(indices.max().item())+1)
    indices = indices / indices.sum(1, keepdim=True)

    if indices.shape[1] > features.shape[1]:
        features = torch.matmul(
            indices[:, 0:features.shape[1]].transpose(1, 2), features)
    else:
        features = torch.matmul(
            indices.transpose(1, 2), features[:, 0:indices.shape[1]])
        
    features = features[:, :-1]
    return features

def pad2indices(indices, max_utt_length=32):
    lengths = [len(sample) for sample in indices]
    max_length = max(lengths)

    for i in range(len(indices)):
        if indices[i].shape[0] < max_length:
            padding =  -1 * torch.ones(max_length - len(indices[i]))
            indices[i] = torch.cat(
                [
                    indices[i],
                    padding
                ]
            )

    indices = torch.stack(indices, dim=0)
    max_indice = max_utt_length - 1
    for i, length in enumerate(lengths):
        max_current_index = indices[i].max().item()

        index = max_current_index
        for j in range(length, max_length):
            if index < max_indice:
                index += 1
            indices[i][j] = index
                
    return indices

In [69]:
import torch

alignment_1 = [(0, 2), (3, 5), (8, 4), (12, 4), (16, 3), (19, 4)]
features_1 = -1 * torch.ones(alignment_1[-1][0] + alignment_1[-1][1], 8)
indices_1 = -1 * torch.ones(alignment_1[-1][0] + alignment_1[-1][1])
for index, (start, duration) in enumerate(alignment_1):
    features_1[start:start + duration] = index
    indices_1[start:start + duration] = index

# get_phone_level_features(features=features_1, indices=indices_1)


In [70]:
import torch

alignment_2 = [(0, 4), (5, 6), (10, 2), (12, 4)]
features_2 = -1 * torch.ones(alignment_2[-1][0] + alignment_2[-1][1], 8)
indices_2 = -1 * torch.ones(alignment_2[-1][0] + alignment_2[-1][1])
for index, (start, duration) in enumerate(alignment_2):
    features_2[start:start + duration] = index
    indices_2[start:start + duration] = index

# get_phone_level_features(features=features_2, indices=indices_2)


In [71]:
def pad_2d(inputs, max_length=None, pad_value=0.0):
    if max_length is None:
        lengths = [len(sample) for sample in inputs]
        max_length = max(lengths)

    for i in range(len(inputs)):
        if inputs[i].shape[0] < max_length:
            inputs[i] = torch.cat(
                (
                    inputs[i], 
                    pad_value * torch.ones((max_length-inputs[i].shape[0], inputs[i].shape[1]))),
                dim=0
            )
        else:
            inputs[i] = inputs[i][0:max_length]
    inputs = torch.stack(inputs, dim=0)
    return inputs

def pad_1d(inputs, max_length=None, pad_value=0.0):
    if max_length is None:
        lengths = [len(sample) for sample in inputs]
        max_length = max(lengths)

    for i in range(len(inputs)):
        if inputs[i].shape[0] < max_length:
            inputs[i] = torch.cat(
                (
                    inputs[i], 
                    pad_value * torch.ones(max_length-inputs[i].shape[0])),
                dim=0
            )
        else:
            inputs[i] = inputs[i][0:max_length]
    inputs = torch.stack(inputs, dim=0)
    return inputs

In [72]:
indices

tensor([[17., 18., 19.,  ..., 31., 31., 31.],
        [17., 18., 19.,  ..., 31., 31., 31.],
        [17., 18., 19.,  ..., 31., 31., 31.],
        ...,
        [19., 20., 21.,  ..., 31., 31., 31.],
        [15., 16., 17.,  ..., 31., 31., 31.],
        [20., 21., 22.,  ..., 31., 31., 31.]], device='cuda:0')

In [73]:
alignments = [alignment_1, alignment_2]
features = [features_1, features_2]
indices = [indices_1, indices_2]

features = pad_2d(features)
indices = pad2indices(indices, max_utt_length=6)

In [74]:
indices

tensor([[ 0.,  0., -1.,  1.,  1.,  1.,  1.,  1.,  2.,  2.,  2.,  2.,  3.,  3.,
          3.,  3.,  4.,  4.,  4.,  5.,  5.,  5.,  5.],
        [ 0.,  0.,  0.,  0., -1.,  1.,  1.,  1.,  1.,  1.,  2.,  2.,  3.,  3.,
          3.,  3.,  4.,  5.,  5.,  5.,  5.,  5.,  5.]])

In [75]:
get_phone_level_features(features, indices)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3., 3., 3., 3.],
         [4., 4., 4., 4., 4., 4., 4., 4.],
         [5., 5., 5., 5., 5., 5., 5., 5.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3., 3., 3., 3.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]])