In [2]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

1.12.1+cu113
0.12.1+cu113
cuda


## Overview

The process of alignment looks like the following.

1. Estimate the frame-wise label probability from audio waveform
2. Generate the trellis matrix which represents the probability of
   labels aligned at time step.
3. Find the most likely path from the trellis matrix.

In this example, we use ``torchaudio``\ ’s ``Wav2Vec2`` model for
acoustic feature extraction.




In [3]:
files = ["3130303030365f3431", "3130303030365f3936", "3130313036335f313333", "3130313036335f3935", "3130313733395f313331", "3130313733395f3636", "3130323533385f3234", "3130323533385f3336", "3130323533385f3439", "3130323630395f3737", "3130323735315f30", "3130323735315f3237", "3130323735315f3434", "3130323735315f3632", "3130333034385f313333", "3130333034385f3139", "3131393237365f3535", "3132313536335f3134", "3132313536335f3238", "3132313536335f3536", "3132313536335f3730", "3132313830325f3431", "3132313830325f3534", "3132313830325f3935", "3132313931355f30", "3132313931355f3131", "3132313931375f30", "3132313931375f3233", "3132313931375f3639", "3132313931375f3831", "3132323535335f30", "3132323535335f313035", "3132323535335f3330", "3132323535335f3435", "3132323833305f3332", "3132323833305f3337", "3132383034385f3633", "3132393131325f3232", "3132393131325f3434", "3132393131325f3737", "3133383239305f3230", "3133383239305f3731", "37303234365f3231", "37303234365f3431", "37303234365f3732", "37303238305f3532", "37303631385f3631", "37303637335f3530", "37303835325f3330", "37303836315f3430", "37303836315f3530", "37303836315f3630", "37303837315f3130", "37303837315f3134", "37303837315f3137", "37303932305f3431", "37303932305f3534", "37303932305f3638", "37303936325f3438", "37303936325f3536", "37313037305f3534", "37313037305f3831", "37313635365f313035", "37313931325f3531", "37323031345f30", "37323031345f3630", "37323031375f3330", "37323031385f3535", "37323031385f3639", "37323031385f3937", "37323137385f30", "37323137385f3237", "37323233365f30", "37323233365f3132", "37323233365f3437", "37323233365f3730", "37323234355f3539", "37323331345f3133", "37323331345f3430", "37323331345f3437", "37323430365f3233", "37323430365f3436", "37323430365f3538", "37323430365f3831", "37323737365f313036", "37323737365f3135", "37323737365f3631", "37323739315f30", "37323739315f3539", "37323739315f38", "37333233335f3331", "37333233335f3633", "37333233335f3733", "37333331365f3930", "37333433375f3534", "37333433375f3638", "37333433375f3831", "37333436365f3237", "37333436365f3334", "37333437335f313032", "37333437335f3733", "37333535375f3131", "37333535375f3433", "37333535375f3634", "37333738355f3536", "37333932325f3338", "37343132365f3332", "37343132365f3430", "37343132365f38", "37343337305f3138", "37343337305f3534", "37343337305f3839", "37343438375f3535", "37343438375f3832", "37343831335f3637", "37343836395f313035", "37343836395f313232", "37343836395f3335", "37343836395f3837", "37343838385f3136", "37343838385f3635", "37343838385f3831", "37353131315f3134", "37353131315f3432", "37353131315f3536", "37353131315f3730", "37353233365f3139", "37353233365f3935", "37353539325f3235", "37353539325f3337", "37353539325f3439", "37353539325f3836", "37353630315f3537", "37363034375f313131", "37363034375f3438", "37363034375f3739", "37363131375f3235", "37363131375f3838", "37363138355f3436", "37363138355f3736", "37363435345f3530", "37363539395f3237", "37363637385f3337", "37363637385f3433", "37363637385f36", "37363834385f3131", "37363834385f3232", "37363834385f3333", "37363837375f313232", "37363837375f3335", "37363930335f3434", "37363930335f3632", "37373131315f3931", "37373239345f3530", "37373239345f3538", "37373332315f3438", "37373332315f37", "37383132355f30", "37383132355f3131", "37383132355f3639", "37383132355f3830", "37383331355f3531", "37383331355f3630", "37383331355f39", "37383337365f3835", "37383431305f3335", "37383431305f3434", "37383431305f3532", "37383439365f3138", "37383439365f3431", "37383439365f36", "37383531345f30", "37383531345f3234", "37383531345f3430", "37383539335f30", "37383539335f3237", "37383539335f3632", "37383935395f3735", "37393236365f3435", "37393237375f3137", "37393237375f3236", "37393237375f3531", "37393237375f3630", "37393331325f313232", "37393331325f313432", "37393331325f3631", "37393337345f3233", "37393337345f3437", "37393337345f3539", "37393632395f30", "37393632395f3338", "37393632395f3438", "37393632395f3637", "37393832355f3437", "37393832355f3537", "37393832355f3636", "37393832365f3736", "37393838355f313033", "37393838355f3434", "37393838355f3539", "37393838355f3734", "37393938355f3136", "37393938355f3535", "38303133375f3135", "38303133375f3239", "38303133375f37", "38303336305f30", "38303336305f3236", "38303336305f3333", "38303336305f3436", "38303339395f3133", "38303339395f3230", "38303339395f3334", "38303339395f3437", "38303635395f3639", "38303730395f30", "38303730395f313239", "38303730395f3138", "38303730395f3734", "38303838345f3431", "38303932325f30", "38303932325f313032", "38303932325f3239", "38303932325f3837", "38303934305f3136", "38303934305f3234", "38303934305f3339", "38303934305f3535", "38303936395f3238", "38303936395f3639", "38313033355f3235", "38313033355f3331", "38313330375f3531", "38313330375f3836", "38313531365f3230", "38313531365f3334", "38313531365f3430", "38313534315f3336", "38313534315f3534", "38313537315f3335", "38313537315f3437", "38313537315f3832", "38313731365f30", "38313731365f3135", "38313731365f3532", "38313731365f37", "38313831355f3334", "38313831355f3432", "38313933305f3334", "38313933305f3430", "38313933305f36", "38313934335f30", "38313934335f3139", "38313934335f38", "38313937345f3736", "38323039355f3238", "38323039355f3432", "38323039355f3639", "38323133355f3535", "38323135375f3432", "38323135375f3533", "38323135375f3734", "38323136365f313035", "39333534315f3637"]
print(files[0])

!wget https://dl-challenge.zalo.ai/lyric-alignment/public_test.zip
!unzip public_test.zip

3130303030365f3431
--2022-11-12 08:56:28--  https://dl-challenge.zalo.ai/lyric-alignment/public_test.zip
Resolving dl-challenge.zalo.ai (dl-challenge.zalo.ai)... 49.213.78.231
Connecting to dl-challenge.zalo.ai (dl-challenge.zalo.ai)|49.213.78.231|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 904253903 (862M) [application/zip]
Saving to: ‘public_test.zip’


2022-11-12 08:57:18 (17.6 MB/s) - ‘public_test.zip’ saved [904253903/904253903]

Archive:  public_test.zip
   creating: public_test/songs/
   creating: public_test/lyrics/
  inflating: public_test/songs/3130303030365f3431.wav  
  inflating: public_test/songs/3130303030365f3936.wav  
  inflating: public_test/songs/3130313036335f313333.wav  
  inflating: public_test/songs/3130313036335f3935.wav  
  inflating: public_test/songs/3130313733395f313331.wav  
  inflating: public_test/songs/3130313733395f3636.wav  
  inflating: public_test/songs/3130323533385f3234.wav  
  inflating: public_test/songs/3130323533385

In [58]:
from dataclasses import dataclass
import os
import glob
import re
import sys

def convert(text):
    """
    Convert from 'Tieng Viet co dau' thanh 'Tieng Viet khong dau'
    text: input string to be converted
    Return: string converted
    """
    output = text
    for regex, replace in patterns.items():
        output = re.sub(regex, replace, output)
        output = re.sub(regex.upper(), replace.upper(), output)
    return output

def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    # Trellis has extra diemsions for both time axis and tokens.
    # The extra dim for tokens represents <SoS> (start-of-sentence)
    # The extra dim for time axis is for simplification of the code.
    trellis = torch.empty((num_frame + 1, num_tokens + 1))
    trellis[0, 0] = 0
    trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
    trellis[0, -num_tokens:] = -float("inf")
    trellis[-num_tokens:, 0] = float("inf")

    for t in range(num_frame):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens],
        )
    return trellis

@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    # Note:
    # j and t are indices for trellis, which has extra dimensions
    # for time and tokens at the beginning.
    # When referring to time frame index `T` in trellis,
    # the corresponding index in emission is `T-1`.
    # Similarly, when referring to token index `J` in trellis,
    # the corresponding index in transcript is `J-1`.
    j = trellis.size(1) - 1
    t_start = torch.argmax(trellis[:, j]).item()

    path = []
    for t in range(t_start, 0, -1):
        # 1. Figure out if the current position was stay or change
        # Note (again):
        # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
        # Score for token staying the same from time frame J-1 to T.
        stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
        # Score for token changing from C-1 at T-1 to J at T.
        changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

        # 2. Store the path with frame-wise probability.
        prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
        # Return token index and time index in non-trellis coordinate.
        path.append(Point(j - 1, t - 1, prob))

        # 3. Update the token
        if changed > stayed:
            j -= 1
            if j == 0:
                break
    else:
        raise ValueError("Failed to align")
    return path[::-1]


# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path, transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

# Merge words
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words


In [60]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()

import json

def align_file(i):
    txt = open(f"public_test/lyrics/{files[i]}.txt").read()
    words = re.split("\\s+", txt)
    print(words)
    a = txt.upper().replace(" ","|")
    print(a)
    transcript = convert(a)
    dictionary = {c: i for i, c in enumerate(labels)}
    tokens = [dictionary[c] for c in transcript]

    # print(list(zip(transcript, tokens)))
    # print(f"public_test/songs/{files[i]}.wav")
    with torch.inference_mode():
        waveform, _ = torchaudio.load(f"public_test/songs/{files[i]}.wav")
        emissions, _ = model(waveform.to(device))
        emissions = torch.log_softmax(emissions, dim=-1)

    emission = emissions[0].cpu().detach()
    trellis = get_trellis(emission, tokens)
    path = backtrack(trellis, emission, tokens)
    # for p in path: print(p)

    segments = merge_repeats(path, transcript)
    # for seg in segments: print(seg)

    word_segments = merge_words(segments)
    # for word in word_segments: print(word)
    # NHUNG	(0.87): [   28,    87)
    # NIEM	(0.77): [   88,   107)
    # DAU	(0.83): [  110,   133)
    # KHI	(0.83): [  140,   159)
    # ANH	(0.92): [  166,   221)
    sample_rate = 44100 # bundle.sample_rate
    ratio = waveform.size(1) / (trellis.size(0) - 1)

    length_ms = (waveform.size(1) / sample_rate) * 1000
    data =[{ "s": 0, "e": length_ms, "l": []}]
    l = []
    for i in range(len(word_segments)):
        word = word_segments[i]
        x0 = int(ratio * word.start)
        x1 = int(ratio * word.end)
        # print(f"{word.label}-{words[i]} {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
        start = (x0 / sample_rate) * 1000
        end = (x1 / sample_rate) * 1000
        label = {"s": start, "e": end, "d": words[i]}
        l.append(label)
    data[0]["l"] = l
    return data

In [None]:
patterns = {
    '[àáảãạăắằẵặẳâầấậẫẩ]': 'a',
    '[đ]': 'd',
    '[èéẻẽẹêềếểễệ]': 'e',
    '[ìíỉĩị]': 'i',
    '[òóỏõọôồốổỗộơờớởỡợ]': 'o',
    '[ùúủũụưừứửữự]': 'u',
    '[ỳýỷỹỵ]': 'y',
    '[\.“,\?!%"’]': '',
    '0': 'KHONG', '1': 'MOT', '2': 'HAI', '3': 'BA', '4': 'BON', 
    '5': 'NAM', '6': 'SAU', '7': 'BAY', '8': 'TAM', '9': 'CHIN',

}

import os
import codecs

# !rm -rf submit
!mkdir -p submit
for i in range(len(files)):
    jsonfile = f"submit/{files[i]}.json"
    if not os.path.exists(jsonfile):
        print(i, files[i])
        data = align_file(i)
        file = codecs.open(jsonfile, "w", "utf-8")
        file.write(json.dumps(data, ensure_ascii=False))
        file.close()

In [None]:
!ls submit
# !cat submit/3130303030365f3431.json
!tar -czf submit.tar.gz submit