In [None]:
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
import sys
sys.setrecursionlimit(1000000)

In [None]:
### https://www.kaggle.com/columbia2131/jigsaw-cv-strategy-by-union-find
class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x == y:
            return
        if self.parents[x] > self.parents[y]:
            x, y = y, x
        self.parents[x] += self.parents[y]
        self.parents[y] = x


def get_group_unionfind(train: pd.DataFrame):
    unique_text = set(train['less_toxic']) | set(train['more_toxic'])
    text2num = {text: i for i, text in enumerate(unique_text)}
    num2text = {num: text for text, num in text2num.items()}
    train['num_less_toxic'] = train['less_toxic'].map(text2num)
    train['num_more_toxic'] = train['more_toxic'].map(text2num)

    uf = UnionFind(len(unique_text))
    for seq1, seq2 in train[['num_less_toxic', 'num_more_toxic']].to_numpy():
        uf.union(seq1, seq2)

    text2group = {num2text[i]: uf.find(i) for i in range(len(unique_text))}
    train['group'] = train['less_toxic'].map(text2group)
    train = train.drop(columns=['num_less_toxic', 'num_more_toxic'])
    return train

In [None]:
train = pd.read_csv("../input/jigsaw-toxic-severity-rating/validation_data.csv")
train = get_group_unionfind(train)

In [None]:
class Node:
    def __init__(self):
        self.child = []
        
worker = []
less_toxic = []
more_toxic = []
group_new = []

groups = np.unique(train.group.values)

for group in groups:
    tmp_df = train[train.group == group]
    tmp_texts = np.unique(np.concatenate([tmp_df.less_toxic.values, tmp_df.more_toxic.values]))
    tmp_dict = dict()
    Tree = [Node() for _ in range(len(tmp_texts))]
    
    def append_children(parent, child, group, memo):
        memo.add(child)
        for child_child in Tree[child].child:
            if child_child in memo:
                continue
            worker.append(-1)
            less_toxic.append(tmp_texts[child_child])
            more_toxic.append(tmp_texts[parent])
            group_new.append(group)
            append_children(parent, child_child, group, memo)
            
    for i, text in enumerate(tmp_texts):
        tmp_dict[text] = i
    for row in tmp_df.itertuples():
        Tree[tmp_dict[row.more_toxic]].child.append(tmp_dict[row.less_toxic])
    for i, t in enumerate(Tree):
        memo = set()
        for child in t.child:
            append_children(i, child, group, memo)
            
train_new = pd.DataFrame({
    "worker": worker,
    "less_toxic": less_toxic,
    "more_toxic": more_toxic,
    "group": group_new
})
train_new = train_new[train_new.less_toxic != train_new.more_toxic]
display(train_new)

In [None]:
train = pd.concat([train, train_new])
train = train[~train[["less_toxic", "more_toxic"]].duplicated()].reset_index(drop=True)

In [None]:
### https://www.kaggle.com/columbia2131/jigsaw-cv-strategy-by-union-find
group_kfold = GroupKFold(n_splits=5)
for fold, (trn_idx, val_idx) in enumerate(group_kfold.split(train, train, train['group'])): 
    train.loc[val_idx , "fold"] = fold

train["fold"] = train["fold"].astype(int)
train.to_csv('train_noleak.csv', index=False)
display(train)