In [1]:
import torch
from pathlib import Path 
import pandas as pd
from tqdm import tqdm

from nltk.tokenize import sent_tokenize
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.models.sonar_text import (
    load_sonar_text_decoder_model,
    load_sonar_text_encoder_model,
    load_sonar_tokenizer,
)

In [2]:
device = torch.device("cuda")
t2enc = load_sonar_text_encoder_model("text_sonar_basic_encoder", device=device).eval()
text_tokenizer = load_sonar_tokenizer("text_sonar_basic_encoder")

embedder = TextToEmbeddingModelPipeline(t2enc, text_tokenizer, device=device)

Using the cached checkpoint of text_sonar_basic_encoder. Set `force` to `True` to download again.


control_symbols=['__ace_Arab__', '__ace_Latn__', '__acm_Arab__', '__acq_Arab__', '__aeb_Arab__', '__afr_Latn__', '__ajp_Arab__', '__aka_Latn__', '__amh_Ethi__', '__apc_Arab__', '__arb_Arab__', '__ars_Arab__', '__ary_Arab__', '__arz_Arab__', '__asm_Beng__', '__ast_Latn__', '__awa_Deva__', '__ayr_Latn__', '__azb_Arab__', '__azj_Latn__', '__bak_Cyrl__', '__bam_Latn__', '__ban_Latn__', '__bel_Cyrl__', '__bem_Latn__', '__ben_Beng__', '__bho_Deva__', '__bjn_Arab__', '__bjn_Latn__', '__bod_Tibt__', '__bos_Latn__', '__bug_Latn__', '__bul_Cyrl__', '__cat_Latn__', '__ceb_Latn__', '__ces_Latn__', '__cjk_Latn__', '__ckb_Arab__', '__crh_Latn__', '__cym_Latn__', '__dan_Latn__', '__deu_Latn__', '__dik_Latn__', '__dyu_Latn__', '__dzo_Tibt__', '__ell_Grek__', '__eng_Latn__', '__epo_Latn__', '__est_Latn__', '__eus_Latn__', '__ewe_Latn__', '__fao_Latn__', '__pes_Arab__', '__fij_Latn__', '__fin_Latn__', '__fon_Latn__', '__fra_Latn__', '__fur_Latn__', '__fuv_Latn__', '__gla_Latn__', '__gle_Latn__', '__glg_

Using the cached tokenizer of text_sonar_basic_encoder. Set `force` to `True` to download again.


In [4]:
filelist = sorted(list(Path("/data/dw-2019/").glob("**/*.jsonl")))
filelist[:5], len(filelist)

([PosixPath('/data/dw-2019/de/2001_10.jsonl'),
  PosixPath('/data/dw-2019/de/2001_11.jsonl'),
  PosixPath('/data/dw-2019/de/2001_12.jsonl'),
  PosixPath('/data/dw-2019/de/2001_9.jsonl'),
  PosixPath('/data/dw-2019/de/2002_1.jsonl')],
 1027)

In [7]:
languages = {"en": "english", "de":"german", "es":"spanish", "fr": "french", "pt":"portuguese"}
sonar_lang = {"en": "eng_Latn", "de":"deu_Latn", "es":"spa_Latn", "fr": "fra_Latn", "pt":"por_Latn"} # https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200

for file in tqdm(filelist, total=len(filelist)):
    df = pd.read_json(file, lines=True)
    output = {
        "sentences": [],
        "language": [],
        "embeddings": [],
        "labels": [],
        "keywords": [],
        "labels": []
    }
    for idx in tqdm(range(len(df["sourceItemMainText"])), leave=False):
        d = df.iloc[idx]
        lang = d['sourceItemLangeCodeGuess']
        sent = sent_tokenize(d['sourceItemMainText'], language=languages[lang])
        if len(sent)==0:
            continue
        with torch.no_grad():
            embeddings = 0 # embedder.predict(sent, source_lang=sonar_lang[lang], batch_size=64)
        labs = [0]*(len(sent)-1) + [1]
        output['sentences'].append(sent)
        output['language'].append(lang)
        output["embeddings"].append(embeddings)
        output["labels"].extend(labs)
        output["keywords"].append(d['sourceItemKeywords'])
        out_path = Path("/disk1/data/preproc_sonar/dw_2019/")/file.parent.stem
        if len(sent)<=3:
            break
    # output['embeddings'] = torch.cat(output['embeddings'])
    # output['labels'] = torch.tensor(output['labels'])
    # out_path.mkdir(parents=True, exist_ok=True)
    # break
    # torch.save(output, out_path/f"{file.stem}.pt")

  7%|▋         | 72/1027 [00:16<03:32,  4.50it/s]


In [8]:
output

{'sentences': [['Seit Monaten schon läuft die Protest-Maschinerie auf Hochtouren.',
   'Mehr als 1000 Diskussionsveranstaltungen hat das Netzwerk Attac seit vergangenem Herbst in Deutschland organisiert, um gegen den G8-Gipfel in Heiligendamm zu mobilisieren.',
   'Auf weiteren Treffen haben sich die Aktivisten auch praktisch auf die Großdemonstration am Samstag (2.6.07) in Rostock vorbereitet: In so genannten Bastelworkshops stellten die Mitglieder Transparente und Großpuppen her, in Aktionstrainings übten sie friedliche Sitzblockaden.',
   '"Es geht nur noch um die G8"\n"Seit dem Jahreswechsel geht es bei uns fast nur noch um die G8", sagt die Attac-Sprecherin Frauke Distelrath.',
   'Der G8-Gipfel, so beschloss es die Vollversammlung im vergangenen Herbst, sollte im ersten Halbjahr 2007 den Schwerpunkt der Arbeit bilden.',
   'Nebenbei wirkt die Mobilisierung wie eine Rekrutierungskampagne: Seit Anfang des Jahres verzeichnet Attac rund 100 Neueintritte monatlich, im Mai waren es sog

In [30]:
torch.tensor(output['embeddings']).shape

ValueError: only one element tensors can be converted to Python scalars

In [26]:
for emb in output['embeddings']:
    print(emb.shape)

torch.Size([25, 1024])
torch.Size([23, 1024])
torch.Size([30, 1024])
torch.Size([21, 1024])
torch.Size([32, 1024])
torch.Size([13, 1024])
torch.Size([42, 1024])
torch.Size([20, 1024])
torch.Size([21, 1024])
torch.Size([15, 1024])
torch.Size([21, 1024])
torch.Size([13, 1024])
torch.Size([20, 1024])
torch.Size([18, 1024])
torch.Size([28, 1024])
torch.Size([31, 1024])
torch.Size([31, 1024])
torch.Size([37, 1024])
torch.Size([26, 1024])
torch.Size([32, 1024])
torch.Size([25, 1024])
torch.Size([20, 1024])
torch.Size([35, 1024])
torch.Size([15, 1024])
torch.Size([19, 1024])
torch.Size([27, 1024])
torch.Size([35, 1024])
torch.Size([22, 1024])
torch.Size([20, 1024])
torch.Size([24, 1024])
torch.Size([34, 1024])
torch.Size([15, 1024])
torch.Size([28, 1024])
torch.Size([17, 1024])
torch.Size([39, 1024])
torch.Size([47, 1024])
torch.Size([38, 1024])
torch.Size([47, 1024])
torch.Size([55, 1024])
torch.Size([39, 1024])
torch.Size([20, 1024])
torch.Size([28, 1024])
torch.Size([32, 1024])
torch.Size(