In [1]:
SAMPLE = 4  # pick sample from the dev set
CTC_DEPTH = 5  # size of the ctc matrix considered for search
NPATHS = 10 # number of longest paths on the bigram graph

# Bottom up

In [2]:
# match entity label to the ctc table
import os
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_dataset
import soundfile as sf

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")


path = "../data/dev/"

file = str(SAMPLE) + '.wav'
speech, samplerate = sf.read(path+file)
i = int(file.split('.')[0]) - 1

input_values = tokenizer(speech, return_tensors="pt", padding="longest").input_values
logits = model(input_values).logits

# find where s_tokens appear in the table
ctc_table = torch.topk(logits, k=CTC_DEPTH, dim=-1)
predicted_ids = ctc_table.indices[0]
# predicted_ids = torch.argmax(logits, dim=-1).indices

print(predicted_ids)
print(predicted_ids.shape)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
It is strongly recommended to pass the ``sampling_rate`` argument to this function.Failing to do so can result in silent errors that might be hard to debug.


tensor([[ 0,  4,  7,  9, 10],
        [ 0,  4,  7,  9, 10],
        [ 0,  4,  7,  9,  6],
        ...,
        [ 0,  4,  7, 10,  6],
        [ 0,  4,  7, 10,  9],
        [ 0,  4,  7, 10,  9]])
torch.Size([365, 5])


In [3]:
import numpy as np
predictions = np.transpose(np.array(predicted_ids))
print(predictions)
print(predictions.shape)
# print(predictions[0])
indices = predictions.flatten()
print(indices)

[[ 0  0  0 ...  0  0  0]
 [ 4  4  4 ...  4  4  4]
 [ 7  7  7 ...  7  7  7]
 [ 9  9  9 ... 10 10 10]
 [10 10  6 ...  6  9  9]]
(5, 365)
[0 0 0 ... 6 9 9]


In [4]:
# predictions = np.array([[2, 0, 0, 4, 5, 5],
#                         [4, 5, 5, 6, 6, 1]])
# generate adjacencies
def connect(predictions, t, k, n):
    edges = []
    for j in range(predictions.shape[0]):  # offset
        if predictions[j][k] != 0:
            edges.append([n*predictions.shape[1]+t, j*predictions.shape[1]+k])
        else:
            # skip to next if exists
            if k < predictions.shape[1]-1:
                edges.extend(connect(predictions, t, k+1, n))
    return edges

edges = []
for t in range(predictions.shape[1]-1):  # columns
    for n in range(predictions.shape[0]):  # rows
        if predictions[n][t] != 0:
            edges.extend(connect(predictions, t, t+1, n))
            
print(len(edges))

315267


# Top down

In [5]:
# load entities
import json

path = '../data/'

with open(path+'entities.json', 'r') as fin:
    entities = json.load(fin)
print(len(entities), 'entity labels')

28497 entity labels


In [6]:
import networkx as nx
import itertools
import difflib


def get_overlap(s1, s2):
    s = difflib.SequenceMatcher(None, s1, s2)
    pos_a, pos_b, size = s.find_longest_match(0, len(s1), 0, len(s2)) 
    return s1[pos_a:pos_a+size]


def match(edges, indices, query_str, tokenizer, n_paths=NPATHS):
    query = tokenizer.tokenizer(query_str)['input_ids']
    query = [query[i:i + 2] for i in range(0, len(query)-1, 1)]
    # filter bigrams
    bigrams = []
    for e in edges:
        bigram = [indices[e[0]], indices[e[1]]]
        if bigram in query:
            bigrams.append(e)

    # build graph
    DG = nx.DiGraph()
    DG.add_edges_from(bigrams)
    
    # find all paths
#     all_paths = []
#     for (x, y) in itertools.combinations(DG.nodes, 2):
#         for path in nx.all_simple_paths(DG, x, y):
#             all_paths.append(path)
#     # sort all paths
#     all_paths.sort(key=len, reverse=True)
    
    all_paths = [nx.dag_longest_path(DG)]

    # lookup maximum overlap between strings
    for path in all_paths[:n_paths]:
        word = ''.join(tokenizer.tokenizer.convert_ids_to_tokens([indices[i] for i in path]))
        print(word)
        overlap = get_overlap(query_str, word)
        print(overlap)
        return len(overlap) / len(query_str)
    return 0

In [11]:
# load original question
import re
from unidecode import unidecode
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

path = '../data/'
with open(path+'annotated_wd_data_valid_answerable.txt') as fin:
    lines = fin.readlines()
    l = lines[SAMPLE-1]
#         subject [tab] property [tab] object [tab] question
    s, p, o, q = l.strip('\n').split('\t')
    
    q = re.sub(chars_to_ignore_regex, '', q).lower()
    q = unidecode(q)
    q = ''.join([j for i, j in enumerate(q) if j != q[i-1]])  # remove repeated letters
    print(q)

    s_label = entities[s]
    s_label = re.sub(chars_to_ignore_regex, '', s_label).lower()
    s_label = unidecode(s_label)
    s_label = ''.join([j for i, j in enumerate(s_label) if j != s_label[i-1]])  # remove repeated letters
    print(s_label)

in what french city did antoine de fevin die 
antoine de fevin


In [15]:
# encode entity label
query_str = s_label
# query_str = 'gregor'
print(query_str)

q_words = [w for w in query_str.split() if len(w) > 1]
print(q_words)

matches = 0
for word in q_words:
    query_str = word.upper()
    matches += match(edges, indices, query_str, tokenizer)
    
print('%.2f words matched' % (matches/len(q_words)))

antoine de fevin
['antoine', 'de', 'fevin']
INTOINTOINTOINE
NTOINE
DE
DE
FEVIN
FEVIN
0.95 words matched
