In [48]:
# load entities
import json

with open('../data/entities.json', 'r') as fin:
    entities = json.load(fin)

corpus = []
for _id, label in entities.items():
    corpus.append({"_id": _id, "title": "", "text": label, "metadata": {}})

print(len(corpus), 'entities')
print(corpus[0])

28497 entities
{'_id': 'Q1938494', 'title': '', 'text': 'Mirosław Bork', 'metadata': {}}


In [49]:
# load data from https://github.com/askplatypus/wikidata-simplequestions
data_path = '/ivi/ilps/personal/svakule/spoken_qa/'
split = 'valid'  # 'train' 'valid' ...
# path_to_questions = data_path + 'annotated_wd_data_%s.txt' % split
path_to_questions = data_path + 'sqs_valid_wikidata2018_09_11.txt'  # filtered subset

with open(path_to_questions) as fin:
    lines = fin.readlines()
    print(len(lines), 'questions in total')

1026 questions in total


In [50]:
# filter questions with entities for which we have label
dataset = 'WD18/'

corpus_path = data_path + dataset + "entities.jsonl"
query_path = data_path + dataset + "%s_original.jsonl" % split
qrels_path = data_path + dataset + "%s.tsv" % split

queries, qrels = [], []
indices = []
for i, l in enumerate(lines):
    s, p, o, q = lines[i].strip('\n').split('\t')
    if s in entities:
        indices.append(i)
        queries.append({"_id": str(i), "text": q, "metadata": {}})
        qrels.append([str(i), s, '1'])

print(len(queries), 'samples with entities recognised')
print(queries[0], qrels[0])

995 samples with entities recognised
{'_id': '0', 'text': 'where was sasha vujačić born', 'metadata': {}} ['0', 'Q318926', '1']


In [51]:
# save dataset
import jsonlines

with open(corpus_path, 'w') as out_file:
    for d in corpus:
        out_file.write(json.dumps(d))
        out_file.write("\n")
        
with open(query_path, 'w') as out_file:
    for d in queries:
        out_file.write(json.dumps(d))
        out_file.write("\n")

with open(qrels_path, 'w') as out_file:
    for qrel in qrels:
        out_file.write('\t'.join(qrel)+'\n')

In [52]:
from beir.datasets.data_loader import GenericDataLoader

corpus, queries, qrels = GenericDataLoader(
    corpus_file=corpus_path, 
    query_file=query_path, 
    qrels_file=qrels_path).load_custom()

# ASR Transcriptions

In [65]:
# # generate speech using Google API
# import time
# from gtts import gTTS

# # iterate over questions, generate speech and save to wav as a new dataset
wav_path = "/ivi/ilps/personal/svakule/spoken_qa/sqs_valid_wikidata2018_09_11/"
# # /ivi/ilps/personal/svakule/spoken_qa/gtts

# for d in queries:
#     if int(d['_id']) > 2837:
#         time.sleep(10)
#         speech = gTTS(text=d['text'], lang='en', slow=False)
#         speech.save(wav_path+d['_id']+'.mp3')

# check number of files generated
# # scp
# # for i in *.mp3; do ffmpeg -i $i "${i%.mp3}.wav"; done
# # remove mp3 files

In [68]:
# load queries and qrels
import os
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from datasets import load_dataset
import soundfile as sf

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model.to('cuda')

queries = []
for file in os.listdir(wav_path):
    i = int(file.split('.')[0]) - 1
    if i in indices:
        speech, samplerate = sf.read(wav_path+file)

        input_values = tokenizer(speech, return_tensors="pt", padding="longest").input_values
        input_values = input_values.to('cuda')
        
        logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = tokenizer.batch_decode(predicted_ids)[0].lower()
        
        queries.append({"_id": str(i), "text": transcription, "metadata": {}})

print(len(queries), 'questions transcribed')
print(queries[0])

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


995 questions transcribed
{'_id': '217', 'text': 'what is marfines van scok vix job', 'metadata': {}}


In [70]:
query_path = data_path + dataset + "%s_wav2vec2-base-960h.jsonl" % split

with open(query_path, 'w') as out_file:
    for d in queries:
        out_file.write(json.dumps(d))
        out_file.write("\n")

In [None]:
# from beir.datasets.data_loader import GenericDataLoader

# corpus, queries, qrels = GenericDataLoader(
#     corpus_file=corpus_path, 
#     query_file=query_path, 
#     qrels_file=qrels_path).load_custom()