In [1]:
# Calculate embeddings for ParlaMint data

In [1]:
import json

with open('artefacts/parlamint.json', 'r') as f:
    parlamint = json.load(f)

In [2]:
from datetime import datetime
from queue import Empty
import json
import pandas as pd
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
from pathlib import Path
import multiprocessing as mp
import torch
import threading
import time

model = None
processed_count = None
start = datetime.now()
def worker_init(gpu_id, task_queue):
    device = f"cuda:{gpu_id}"
    print('Initializing worker on', device)
    global model
    model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
    model = model.to(torch.device(device))

    while True:
        try:
            session, xml_path, embeddings_path = task_queue.get(timeout=1)
        except Empty:
            break

        process_session(session, xml_path, embeddings_path)
        with processed_count.get_lock():
            processed_count.value += 1

def get_embeddings(xml_file):
    with open(xml_file, "r", encoding="utf8") as f:
        contents = f.read()
    soup = BeautifulSoup(contents, 'xml')

    data = []
    for seg in soup.find_all('seg'):
        sentences = seg.find_all('s')
        for sentence in sentences:
            sentence_id = sentence.get('xml:id')
            sentence_text = ''
            for unit in sentence.find_all(['w', 'pc']):
                text = unit.get_text()
                if not text:
                    continue
                if unit.name == 'w':
                    sentence_text += ' ' + unit.get_text()
                elif unit.name == 'pc':
                    sentence_text += unit.get_text()
            sentence_text = sentence_text.strip()
            note_texts = [note.get_text() for note in sentence.find_all('note')]
            embeddings = model.encode([sentence_text] + note_texts)
            data.append([sentence_id, sentence_text, embeddings[0], note_texts, embeddings[1:]])

    return pd.DataFrame(data, columns=['sentence_id', 'text', 'embedding', 'notes', 'note_embeddings'])

def process_session(session, xml_path, embeddings_path):
    embeddings = get_embeddings(xml_path)
    embeddings_path.parent.mkdir(parents=True, exist_ok=True)
    embeddings.to_json(embeddings_path, lines=True, orient='records')
    with processed_count.get_lock():
        processed_count.value += 1
    return session

def print_progress(total_sessions):
    while True:
        with processed_count.get_lock():
            count = processed_count.value
        if count == 0:
            time_remaining = 'unknown'
        else:
            time_remaining = (datetime.now() - start) / count * (total_sessions - count)
            time_remaining = str(time_remaining).split('.')[0]
        print(f"Processed {count}/{total_sessions} sessions, eta:", time_remaining.ljust(20), end='\r')
        if count >= total_sessions:
            break
        time.sleep(1)

N_WORKERS = 8
N_GPUS = torch.cuda.device_count()
total_sessions = sum(len(parlamint[language]) for language in parlamint)
processed_count = mp.Value('i', 0)

task_queues = [mp.Queue() for _ in range(N_WORKERS)]
workers = []
for i in range(N_WORKERS):
    worker = mp.Process(target=worker_init, args=(i % N_GPUS, task_queues[i]))
    worker.start()
    workers.append(worker)
    time.sleep(0.1)

for i, (language, sessions) in enumerate(parlamint.items()):
    task_queue = task_queues[i % N_WORKERS]
    for session in sessions:
        xml_path = Path(session["xml_path"])
        embeddings_path = Path('artefacts/parlamint') / language / session['date'][:4] / f'{xml_path.stem}.json'
        if embeddings_path.exists() and embeddings_path.stat().st_size > 0:
            continue
        session['embeddings_path'] = str(embeddings_path)
        task_queue.put((session, xml_path, embeddings_path))

progress_thread = threading.Thread(target=print_progress, args=(total_sessions,))
progress_thread.daemon = True
progress_thread.start()

for worker in workers:
    worker.join()

progress_thread.join()

with open('artefacts/parlamint.json', 'w') as f:
    json.dump(parlamint, f, indent=2)


Initializing worker on cuda:0
Initializing worker on cuda:1
Initializing worker on cuda:0
Initializing worker on cuda:1
Initializing worker on cuda:0
Initializing worker on cuda:1
Initializing worker on cuda:0
Initializing worker on cuda:1
Processed 132/5689 sessions, eta: 13:02:42            

Process Process-2:
Process Process-3:
Process Process-1:
Process Process-5:
Traceback (most recent call last):


KeyboardInterrupt: 

Process Process-4:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/vidklopcic/anaconda3/envs/dhh23/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/vidklopcic/anaconda3/envs/dhh23/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/vidklopcic/anaconda3/envs/dhh23/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/home/vidklopcic/anaconda3/envs/dhh23/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/tmp/ipykernel_385301/3767796123.py", line 29, in worker_init
    process_session(session, xml_path, embeddings_path)
  File "/home/vidklopcic/anaconda3/envs/dhh23/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/vidklopcic/an