In [35]:
from pyspark.sql import SparkSession
import pyspark
from functools import partial
from pathlib import Path
from html import unescape
import logging as log
import json
import math



log.basicConfig(level=log.INFO)
SCHEMA = 'lang STRING, ds_name STRING, src STRING, eng STRING'

def row_to_tsv(row):
    return f'{row.lang}\t{row.ds_name}\t{row.src}\t{row.eng}'


In [5]:
spark = SparkSession \
    .builder \
    .appName("Raw data cleanup on PySpark") \
    .config("spark.driver.memory", "120g") \
    .getOrCreate()
# I have a lot of mem, so 120g -- use all
spark

In [4]:
spark.stop()

In [6]:
raw_file = 'merged/train.raw.tsv'
raw_df = spark.read.csv(raw_file, sep='\t', schema=SCHEMA)\
    .filter("ds_name != 'wiki_matrix_v1-tam_eng' and ds_name != 'wiki_matrix_v1-tel_eng'")
# there was an error processing wiki_matrix_v1; so exclude them


# Deduplicate

In [36]:
dedup_file = 'merged/train.raw.dedup.tsv'
raw_df.drop_duplicates(['src', 'eng'])\
    .rdd.map(row_to_tsv)\
    .saveAsTextFile(dedup_file)

# Exclude Test Pairs


In [None]:
tests = []
for eng in Path('merged/tests').glob("*.eng"):
    src_name = eng.name.replace(".eng", "").split("-")[-1].split("_")[0]
    src = eng.with_suffix("."+src_name)
    assert src.exists()
    tests.append((src, eng))
    
log.info(f"Found {len(tests_tok)} held out sets")

# length ratios -- mean and std dev -- to spot od ones

Length ratios are characters because we haven't tokenized the data

In [14]:
def map_char_lenghts(rec):
    return (rec.lang, 
                (len(rec.src.strip()) if rec.src else 0,
                len(rec.eng.strip()) if rec.eng else 0))
def map_word_lenghts(rec):
    return (rec.lang, 
                (len(rec.src.strip().split()) if rec.src else 0,
                len(rec.eng.strip().split()) if rec.eng else 0))


len_rdd = raw_df.rdd.map(map_word_lenghts).filter(lambda r: all(r[1]))\
    .persist(pyspark.StorageLevel.MEMORY_AND_DISK)
    # try to keep it in memory if there is room, or spill it to disk when there isn't! cool

In [15]:
def key_count(rdd):
    return (rdd.mapValues(lambda v: 1)
        .reduceByKey(lambda a, b: a + b)
        .collectAsMap())
    
def compute_key_mean(rdd, precision=4):
    """Compute mean per key"""
    means = (rdd.mapValues(lambda v: (v, 1)) # (sum, count)
        .reduceByKey(lambda a, b: (a[0]+b[0], a[1]+b[1]))   # merge: (s1, c1) + (s2, c2)
        .mapValues(lambda v: v[0]/v[1]) # mean=sum/count
        .collectAsMap())
    return {k: round(v, precision) for k, v in means.items()}

len_ratio_rdd = len_rdd.map(lambda r: (r[0], r[1][0]/r[1][1]))  # lang, src_len/tgt_len
len_ratio_mean = compute_key_mean(len_ratio_rdd)
print(list(len_ratio_mean.items())[:5])

def deviation(rec):
    lang, ratio = rec
    return lang, (len_ratio_mean[lang] - ratio)**2

len_ratio_std_sq = compute_key_mean(len_ratio_rdd.map(deviation))
len_ratio_std = {k: round(math.sqrt(v), 4) for k,v  in len_ratio_std_sq.items()}

print(list(len_ratio_std.items())[:5])

counts = key_count(len_ratio_rdd)

[('oar', 0.7738), ('orm', 0.9523), ('twi', 1.1435), ('kik', 1.026), ('kaz', 0.954)]
[('oar', 0.2412), ('orm', 0.2886), ('twi', 0.3254), ('kik', 0.3085), ('kaz', 0.4442)]


In [17]:

#len_stats_file = 'lang-char-len-stats.tsv'
len_stats_file = 'lang-word-len-stats-raw.tsv'
all_langs = set(len_ratio_std.keys()) | set(len_ratio_mean.keys())
assert all_langs == set(len_ratio_std.keys())
assert all_langs == set(len_ratio_mean.keys())
all_langs = list(sorted(all_langs, key=counts.get, reverse=True))

all_stats = []
with open(len_stats_file, 'w') as wrt:
    wrt.write(f"Lang\tMean\tSTD\tSentences\n")
    for lang in all_langs:
        mu = round(len_ratio_mean[lang], 4)
        std = round(len_ratio_std[lang], 4)        
        wrt.write(f"{lang}\t{mu}\t{std}\t{counts[lang]:,}\n")
        all_stats.append((lang, mu, std, counts[lang]))
log.info(f"Wrote {len_stats_file}")

INFO:root:Wrote lang-word-len-stats-raw.tsv


In [22]:
len([x for x in all_stats if x[-1] > 10_000])

350

In [34]:

for lang, mu, std, _ in all_stats[:350]:
    if mu < 2/3:
        print('Fewer words:', lang, 'mean:',  mu)

for lang, mu, std, _ in all_stats[:350]:        
    if mu > 3/2:
        print('Many words:', lang, 'mean:',  mu)

for lang, mu, std, _ in all_stats[:350]:
    if std > 2:
        print('High variance:', lang, 'std:', std)

Fewer words: zho mean: 0.1856
Fewer words: tha mean: 0.5345
Fewer words: jpn mean: 0.1361
Fewer words: mal mean: 0.6509
Fewer words: mya mean: 0.6278
Fewer words: kan mean: 0.5852
Fewer words: kal mean: 0.6631
Fewer words: cmn mean: 0.1765
Fewer words: syr mean: 0.6349
Fewer words: chr mean: 0.6648
Fewer words: shi mean: 0.607
Fewer words: tlh mean: 0.5933
Fewer words: iku mean: 0.6186
Many words: hrv mean: 1.622
Many words: tah mean: 1.569
Many words: ton mean: 1.686
Many words: yap mean: 1.5157
Many words: tvl mean: 1.5704
Many words: hne mean: 1.6882
Many words: jsl mean: 2.2187
Many words: kvk mean: 1.5521
Many words: ada mean: 1.5286
Many words: ksw mean: 1.516
Many words: lao mean: 1.8509
Many words: csl mean: 1.7552
Many words: quc mean: 1.582
Many words: chq mean: 1.8047
Many words: gbi mean: 1.7155
Many words: ang mean: 1.7058
Many words: tss mean: 1.8399
Many words: kac mean: 1.5439
High variance: nld std: 2.1228
High variance: ell std: 2.2567
High variance: slk std: 2.0236
H

In [None]:
import re

class MyFilter:

    def __init__(self):
        self.min_len = 1
        self.max_len = 120
        self.len_ratio = 5
        self.max_word_len = 30
        self.remove_urls = True
        #self.stats = stats 

    
    def __call__(self, rec):
        lang, provenance, src, eng = rec
        if not src or not eng or not src.strip() or not eng.strip():
            return 'EMPTY'
        src, tgt = src.strip(), eng.strip()

        if src == eng:
            return 'COPY'

        src_toks = src.split()
        eng_toks = eng.split()
        
        if len(src_toks) < self.min_len or len(eng_toks) < self.min_len:
            return 'MIN_LEN'

        if len(src_toks) > self.max_len or len(eng_toks) > self.max_len:
            return 'MAX_LEN'

        if not (1/self.len_ratio <= len(src_toks)/len(eng_toks) <= self.len_ratio):
            return 'LEN_RATIO'
        
        if any(len(t) > self.max_word_len for t in src_toks)\
            or any(len(t) > self.max_word_len for t in eng_toks):
            return 'MAX_WORD_LEN'
        
        if 'http' in src or 'http' in eng:
            return 'HTTP'
        
        return None # no reason to drop

my_filter = MyFilter()


TODO: 

- [ ] Xml unescapoe 
- [ ] drop URLs 
- [ ] drop abnormal lengths
- [ ] drop empty 
- [ ] drop copy
- [ ] drop if 