In [1]:
import gc
import pathlib
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
from tqdm import tqdm
from typing import Dict, List, Tuple, NamedTuple, Callable, Any
import mylib

In [2]:
class ModelConf(NamedTuple):
    directory: str
    max_seq_length: int
    batch_size: int


class Conf(NamedTuple):
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pretrained_dir: str = "pretrained/"
    em_active: List[str] = ["paraphrase-MiniLM-L6-v2"]
    em_models: Dict[str, ModelConf] = {
        "paraphrase-MiniLM-L6-v2": ModelConf(
            directory=f"{pretrained_dir}sentence-transformers/paraphrase-MiniLM-L6-v2",
            max_seq_length=128,
            batch_size=1000,
        ),
        "all-MiniLM-L6-v2": ModelConf(
            directory=str(pathlib.Path(f"{pretrained_dir}sentence-transformers/all-MiniLM-L6-v2")),
            max_seq_length=256,
            batch_size=512,
        ),
        "all-mpnet-base-v2": ModelConf(
            directory=f"{pretrained_dir}sentence-transformers/all-mpnet-base-v2",
            max_seq_length=384,
            batch_size=128,
        ),
    }
    search_d_max: float = 1.5
    search_c: int = 50
    search_k: int = 100
    search_nlist: int = 1000
    search_index_file: str = "output/ruddit.index"
        
        
conf = Conf()
print(conf)
if conf.device.type == 'cuda':
    for i in range(torch.cuda.device_count()):
        print(f"device={i}, {torch.cuda.get_device_name(i)}")
        print('Mem Allocated:', round(torch.cuda.memory_allocated(i)/1024**3,1), 'GB')
        print('Mem Cached:   ', round(torch.cuda.memory_reserved(i)/1024**3,1), 'GB')

Conf(device=device(type='cuda'), pretrained_dir='pretrained/', em_active=['paraphrase-MiniLM-L6-v2'], em_models={'paraphrase-MiniLM-L6-v2': ModelConf(directory='pretrained/sentence-transformers/paraphrase-MiniLM-L6-v2', max_seq_length=128, batch_size=1000), 'all-MiniLM-L6-v2': ModelConf(directory='pretrained\\sentence-transformers\\all-MiniLM-L6-v2', max_seq_length=256, batch_size=512), 'all-mpnet-base-v2': ModelConf(directory='pretrained/sentence-transformers/all-mpnet-base-v2', max_seq_length=384, batch_size=128)}, search_d_max=1.5, search_c=50, search_k=100, search_nlist=1000, search_index_file='output/ruddit.index')
device=0, NVIDIA GeForce GTX 1060 6GB
Mem Allocated: 0.0 GB
Mem Cached:    0.0 GB


In [3]:
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()

In [4]:
df = pd.read_parquet("input/pre_ruddit.parquet")
bws = list(df["bws"])
ruddit_text2 = list(df["text2"])
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5710 entries, 0 to 5709
Data columns (total 7 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   label   5710 non-null   int32  
 1   bws     5710 non-null   float32
 2   worker  5710 non-null   int8   
 3   text    5710 non-null   object 
 4   text1   5710 non-null   object 
 5   text2   5710 non-null   object 
 6   text3   5710 non-null   object 
dtypes: float32(1), int32(1), int8(1), object(4)
memory usage: 228.7+ KB


In [5]:
val_df = pd.read_parquet("input/pre_val.parquet")
val_text = list(val_df["text"])
val_text2 = list(val_df["text2"])
val_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14251 entries, 0 to 14250
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    14251 non-null  object
 1   text1   14251 non-null  object
 2   text2   14251 non-null  object
 3   text3   14251 non-null  object
dtypes: object(4)
memory usage: 445.5+ KB


In [6]:
js18_df = pd.read_parquet("input/pre_js18.parquet")
js18_text2 = list(js18_df["text2"])
js18_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 186644 entries, 0 to 186643
Data columns (total 5 columns):
 #   Column  Non-Null Count   Dtype 
---  ------  --------------   ----- 
 0   worker  186644 non-null  int8  
 1   text    186644 non-null  object
 2   text1   186644 non-null  object
 3   text2   186644 non-null  object
 4   text3   186644 non-null  object
dtypes: int8(1), object(4)
memory usage: 5.9+ MB


In [7]:
%%time
sentences = ruddit_text2 + val_text2 + js18_text2
for name in conf.em_active:
    print(name)
    model = SentenceTransformer(conf.em_models[name].directory, device=conf.device)
    model.max_seq_length = conf.em_models[name].max_seq_length
    em = model.encode(sentences=sentences, 
                      batch_size=conf.em_models[name].batch_size, show_progress_bar=True, convert_to_numpy=True)
    print(f"em.shape={em.shape}")
    faiss.normalize_L2(em)
    cols = [f"{name}_{i:04d}" for i in range(em.shape[1])]
    ruddit_em = em[:len(df)]
    print(f"ruddit_em.shape={ruddit_em.shape}")
    df[cols] = ruddit_em
    df[cols] = df[cols].astype(np.float32)
    i = len(df)
    val_em = em[i:i + len(val_df)]
    print(f"val_em.shape={val_em.shape}")
    val_df[cols] = val_em
    val_df[cols] = val_df[cols].astype(np.float32)
    i += len(val_df)
    js18_em = em[i:i + len(js18_df)]
    print(f"js18_em.shape={js18_em.shape}")
    js18_df[cols] = js18_em
    js18_df[cols] = js18_df[cols].astype(np.float32)
    del model
    gc.collect()
em_cols = cols

paraphrase-MiniLM-L6-v2
[INFO|SentenceTransformer.py:60] 2022-02-06 14:01:19,255 >> Load pretrained SentenceTransformer: pretrained/sentence-transformers/paraphrase-MiniLM-L6-v2
[INFO|SentenceTransformer.py:60] 2022-02-06 14:01:19,255 >> Load pretrained SentenceTransformer: pretrained/sentence-transformers/paraphrase-MiniLM-L6-v2
[INFO|SentenceTransformer.py:60] 2022-02-06 14:01:19,255 >> Load pretrained SentenceTransformer: pretrained/sentence-transformers/paraphrase-MiniLM-L6-v2
[INFO|SentenceTransformer.py:60] 2022-02-06 14:01:19,255 >> Load pretrained SentenceTransformer: pretrained/sentence-transformers/paraphrase-MiniLM-L6-v2


Batches:   0%|          | 0/207 [00:00<?, ?it/s]

em.shape=(206605, 384)
ruddit_em.shape=(5710, 384)


  self[col] = igetitem(value, i)


val_em.shape=(14251, 384)
js18_em.shape=(186644, 384)
Wall time: 4min 38s


In [8]:
%%time
df.to_parquet("output/em_ruddit.parquet", index=False)
val_df.to_parquet("output/em_val.parquet", index=False)
del df, val_df, sentences
gc.collect()

Wall time: 824 ms


0

# Ruddit dataset
- Seed dataset with BWS label
- Generate embeddings
- Build index

In [9]:
%%time
d = ruddit_em.shape[1]
m = 8  # number of subquantizers
quantizer = faiss.IndexFlatIP(d)  # this remains the same
index = faiss.IndexIVFPQ(quantizer, d, conf.search_nlist, m, 8)
# 8 specifies that each sub-vector is encoded as 8 bits
index.verbose = True
index.train(ruddit_em)
index.add(ruddit_em)

Wall time: 921 ms


In [10]:
%%time
faiss.write_index(index, conf.search_index_file)

Wall time: 1 ms


In [11]:
%%time
index = faiss.read_index(conf.search_index_file)
print(f"ntotal={index.ntotal}, is_trained={index.is_trained}")

ntotal=5710, is_trained=True
Wall time: 26.4 ms


In [12]:
index.nprobe = 1
k = 4
distances, ids = index.search(ruddit_em[:20], k)  # sanity check
print(f"I={repr(ids)}\nD={repr(distances)}")

I=array([[   0,    5,    4, 2886],
       [   1, 4966, 3859, 1648],
       [   2,    7, 3918,   16],
       [   3,   23, 3783,  881],
       [   4,   31, 3903,   10],
       [   5,   10, 2694, 3903],
       [   6, 1818,  316,   -1],
       [   7,    2,   16, 3905],
       [   8,   12, 1693,   18],
       [   9,  563, 2454, 2451],
       [  10,    4, 2694,    5],
       [  11,   16,    7,    2],
       [  12,    8, 2890,   18],
       [  13, 2443, 5021, 5626],
       [  14,   26,   33, 3913],
       [  15, 3135, 3515, 3130],
       [  16,   11,    7,    2],
       [  17,   19,   30, 3917],
       [  18,    8,  194, 3914],
       [  19,   30, 3917, 1446]], dtype=int64)
D=array([[4.3200651e-01, 6.8623567e-01, 6.9459498e-01, 7.3324448e-01],
       [4.5601249e-01, 8.4324878e-01, 8.6190856e-01, 9.0924489e-01],
       [2.2245908e-01, 2.9274365e-01, 4.3219829e-01, 4.6194434e-01],
       [3.1575781e-01, 5.9827942e-01, 6.0513955e-01, 6.4260429e-01],
       [2.5046504e-01, 4.3624276e-01, 4.394050

# Validation dataset
- Estimate BWS label based on kNN similarity search

In [13]:
index.nprobe = conf.search_c
distances, ids = index.search(val_em, conf.search_k)  # sanity check
lim = 5
print(f"I={repr(ids[:lim])}\nD={repr(distances[:lim])}")

I=array([[2013, 5191, 5448, 2850, 4644, 5449, 2512, 4735, 5418, 2024, 3990,
        3839, 3622, 3640, 3356, 2838, 4625, 4231, 2267, 2016, 3995, 2196,
        4353, 5534, 3204, 5072, 2438, 5296, 1799, 1984, 4655, 3741, 5052,
        2121, 4551, 4693, 3079,  798, 5502, 3384, 2425, 2844, 4340, 4306,
        2379, 1476, 2336, 4531, 1484, 4529, 2009, 2123, 2333, 2354, 4506,
        5687, 2161,  342,   87, 4654, 1500, 2489, 4089,  189, 3077, 3487,
         343, 5539, 1980, 2055,  660, 3181, 2259, 1685, 3703, 4889, 4630,
        2007, 5600, 5517,  530, 2406, 5258, 3215, 4891, 2120, 3126, 2430,
        2796, 2461, 3366, 4290, 1707, 4631, 2019,  555, 5338, 4394, 4451,
        2124],
       [1993, 2603, 4865, 3462, 2602, 3463, 3591, 3461, 1985, 4123, 2482,
        2339, 2500, 2902, 2137, 3483, 2738, 3691, 2258, 2133, 2833, 2751,
        3569, 2494, 3986, 3267, 2333, 2594, 2151, 1987, 4643, 2188, 3142,
        3589, 4317, 3337, 3249, 2336, 3747, 2300, 3558, 3587, 2120, 4412,
        3413, 2596, 4

In [14]:
for i in range(ids[:lim].shape[0]):
    print(f"========\ni={i}, {val_text2[i]}")
    for j in range(lim):
        if ids[i][j] == -1:
            continue
        print(f"j={j}, {ruddit_text2[ids[i][j]]}")

i=0, " Go rot in hell you evil liittle bastard. " "
j=0, > or speaking ill of the country. Gods not sending you to hell because you said "America [sucks; be very bad or unpleasant]!"
j=1, Evil? No, you are not evil. The absolute worst thing anyone could say about you is that you are a selfish, moral blank whose lazy cynicism and sneering, ironic take on the world encapsulates everything wrong with a generation.
j=2, And may you rot in hell my good sir for no one is good enough for heaven not even you delusional folks.
j=3, that line low-key broke my faith that Americans had any semblance of shame. For God's sake he is advocating collective punishment and the wholesale slaughter of women and children whose only crime is they are RELATED TO BAD GUYS. The Yemen raid made me want to cry, we killed an eight year old girl.
j=4, [Dude; friend] would shake his baby to hell
i=1, Bamafuck has to be stopped. Please help me keep the Styx pages integ intact!
j=0, Gaunt's Ghosts -series of Warhammer

In [15]:
preds = {}
for i in tqdm(range(ids.shape[0])):
    v = 0
    sims = []
    for j in range(ids.shape[1]):
        if ids[i][j] == -1 or distances[i][j] >= conf.search_d_max:
            break
        d_norm = distances[i][j] / conf.search_d_max
        sims.append(1 - d_norm)
    sm = sum(sims)
    for j, s in enumerate(sims):
        v += s / sm * bws[ids[i][j]] 
    preds[val_text[i]] = v

100%|███████████████████████████████████████| 14251/14251 [00:07<00:00, 2034.37it/s]


In [16]:
vdf = pd.read_csv("input/validation_data.csv", engine="c", low_memory=False)
vdf.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30108 entries, 0 to 30107
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   worker      30108 non-null  int64 
 1   less_toxic  30108 non-null  object
 2   more_toxic  30108 non-null  object
dtypes: int64(1), object(2)
memory usage: 705.8+ KB


In [17]:
score = mylib.comp_metric(preds, validation_data=vdf)
print(f"Average Agreement with Annotators={score:.4f}")

Average Agreement with Annotators=0.6498


# Jigsaw 2018 dataset
- Pseudo labelling: weighted BWS score based on kNN similarity search

In [18]:
index.nprobe = conf.search_c
distances, ids = index.search(js18_em, conf.search_k)  # sanity check
lim = 5
print(f"I={repr(ids[:lim])}\nD={repr(distances[:lim])}")

I=array([[ 641, 2162,  676,  652,  927,  651,  647, 2899, 5569, 4669,  646,
        3044, 4252, 2404,  679,  452,  667,  663,  639,  732,  126, 5130,
        3788, 5692, 2394, 3806, 3986,  659,  653,  662, 5655, 1297, 5595,
        4743,  655, 1449, 3842, 1113,  497, 3683,  669,  678, 2755, 2411,
        4670,  494,  489, 1547, 2431, 3789, 2596, 3012, 2750,  668, 5384,
        3604,  658, 1303, 2904, 3457, 2063, 2372, 2444, 2455,  538, 5512,
         675, 2898,  648, 3587, 2470, 3926, 4718, 2415,  650, 4963,  666,
         487, 3591, 2975, 2422, 2151,  471,  470,  664, 2463,  640, 3024,
        4758, 3699, 4413,  453, 4383, 5308,  657, 4605, 2442,  654, 3499,
        2255],
       [3069,  431, 3063, 4337, 3326, 5634, 5122, 5509, 1019, 3961, 3187,
        5272, 5164, 2793, 3748, 2163, 3956, 5709, 3959, 3420, 4285, 3026,
        3028, 3013, 3421, 3626, 4757, 4581, 4722, 5120, 2485, 1382, 3327,
        1543, 5124, 3596, 2606, 5242, 2893, 1451, 5339, 4631, 1533,  960,
        5220, 2672, 4

In [19]:
for i in range(ids[:lim].shape[0]):
    print(f"========\ni={i}, {js18_text2[i]}")
    for j in range(lim):
        if ids[i][j] == -1:
            continue
        print(f"j={j}, {ruddit_text2[ids[i][j]]}")

i=0, Explanation Why the edits made under my username Hardcore Metallica Fan were reverted? They were not vandalisms, just closure on some GAs after I voted at New York Dolls FAC. And please do not remove the template from the talk page since I am retired now.
j=0, Just to give a different view. From a modding perspective, the redesign finally adds, or is in the process of adding, a lot of native tools that help make moderation a lot easier for new [mods; forum moderator]. You do not have to be a CSS wizard now to set up a good looking subreddit. There are native removal reasons that will hopefully end up in mobile apps instead of being a computer only feature from toolbox. Flair filtering is finally happening. Posting requirements will work properly instead of having Automod remove posts and comments. Yes, there a lot of complaints from [mods; forum moderator] about the redesign, but the improvements made by admins over the past few months have been incredible and show that they are l

In [20]:
preds = []
for i in tqdm(range(ids.shape[0])):
    v = 0
    sims = []
    for j in range(ids.shape[1]):
        if ids[i][j] == -1 or distances[i][j] >= conf.search_d_max:
            break
        d_norm = distances[i][j] / conf.search_d_max
        sims.append(1 - d_norm)
    sm = sum(sims)
    for j, s in enumerate(sims):
        v += s / sm * bws[ids[i][j]] 
    preds.append(v)

100%|█████████████████████████████████████| 186644/186644 [01:31<00:00, 2046.42it/s]


In [21]:
col = "bws"
js18_df[col] = preds
js18_df[col] = js18_df[col].astype(np.float32)
js18_df[col].describe(percentiles=percentiles)

count    186644.000000
mean         -0.025723
std           0.126421
min          -0.457159
1%           -0.290170
5%           -0.212564
10%          -0.170844
20%          -0.123635
30%          -0.091432
40%          -0.064997
50%          -0.039428
60%          -0.011678
70%           0.020655
80%           0.065970
90%           0.141055
95%           0.208229
99%           0.347912
max           0.533427
Name: bws, dtype: float64

In [22]:
cols = ["bws", "worker", "text", "text1", "text2", "text3"]
cols += em_cols
js18_df[cols].info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 186644 entries, 0 to 186643
Data columns (total 390 columns):
 #    Column                        Non-Null Count   Dtype  
---   ------                        --------------   -----  
 0    bws                           186644 non-null  float32
 1    worker                        186644 non-null  int8   
 2    text                          186644 non-null  object 
 3    text1                         186644 non-null  object 
 4    text2                         186644 non-null  object 
 5    text3                         186644 non-null  object 
 6    paraphrase-MiniLM-L6-v2_0000  186644 non-null  float32
 7    paraphrase-MiniLM-L6-v2_0001  186644 non-null  float32
 8    paraphrase-MiniLM-L6-v2_0002  186644 non-null  float32
 9    paraphrase-MiniLM-L6-v2_0003  186644 non-null  float32
 10   paraphrase-MiniLM-L6-v2_0004  186644 non-null  float32
 11   paraphrase-MiniLM-L6-v2_0005  186644 non-null  float32
 12   paraphrase-MiniLM-L6-v2_0006

In [23]:
%%time
js18_df[cols].to_parquet("output/em_js18.parquet", index=False)

Wall time: 7.97 s
