TODO:
- Adjust fuzz ratio to catch all labels

In [1]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import re
import seaborn as sns
from tqdm import tqdm
from fuzzywuzzy import fuzz

train_example_paths = glob.glob('data/train/*.json')
test_example_paths = glob.glob('data/test/*.json')

train_example_names = [fn.split('.')[0] for fn in os.listdir('data/train')]
test_example_names = [fn.split('.')[0] for fn in os.listdir('data/test')]

metadata = pd.read_csv('data/train.csv')
metadata_train = metadata.loc[metadata.Id.isin(train_example_names)]
metadata_test = metadata.loc[metadata.Id.isin(test_example_names)]

metadata = pd.read_csv('data/train.csv')
metadata_train = metadata.loc[metadata.Id.isin(train_example_names)]
metadata_test = metadata.loc[metadata.Id.isin(test_example_names)]

In [2]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower())

def remove_punc(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt))

def get_doc_id(doc_path):
    return os.path.split(train_example_names[0])[-1].split('.')[0]

def load_train_example(i: int):
    doc_path = train_example_paths[i]
    with open(doc_path) as f:
        data = json.load(f)
    return {'doc': data, 'meta': metadata.loc[metadata.Id == get_doc_id(doc_path)]}

def load_train_example_by_name(name):
    doc_path = os.path.join('data/train', name + '.json')
    with open(doc_path) as f:
        data = json.load(f)
    return data

def delete_file(filename):
    if os.path.exists(filename):
        os.remove(filename)

## Split Data

In [3]:
import random

docIdx = train_example_names.copy()
random.seed(42)
random.shuffle(docIdx)

train_ratio = 0.85
n_train = int(len(docIdx) * train_ratio)
n_val = len(docIdx) - n_train

train_idx = docIdx[:n_train]
val_idx = docIdx[n_train:]

print(f'train size: {len(train_idx)}')
print(f'val size: {len(val_idx)}')

train size: 12168
val size: 2148


## Generate Dataset and Features

In [4]:
_RE_COMBINE_WHITESPACE = re.compile(r"\s+")

def preprocess_tokenize_doc(doc_json):
    doc_text = ' '.join([remove_punc(sec['text']) for sec in doc_json])
    doc_text = make_single_whitespace(doc_text)
    
    doc_tokens = doc_text.split(' ')
    return doc_tokens

def indices(lst, element):
    result = [i for i, token in enumerate(lst) if element in token]
    return result

def make_single_whitespace(text):
    return _RE_COMBINE_WHITESPACE.sub(" ", text).strip()

## Create dataframe for tokens and targets

In [5]:
def get_doc(doc_id, reduce_tokens = False):
    doc_labels = list(metadata_train.loc[metadata_train.Id == doc_id, 'dataset_label'].values)
    doc_labels = [make_single_whitespace(remove_punc(l.strip())).lower() for l in doc_labels]

    doc = load_train_example_by_name(doc_id)
    doc_tokens = preprocess_tokenize_doc(doc)
    doc_tokens_lower = [t.lower() for t in doc_tokens]
    
    # Targets for dataset names will be 1
    target_arr = np.zeros(len(doc_tokens) ,dtype = 'uint8')

    # Keep n tokens before and after targets
    keep_df = pd.Series(np.zeros(len(doc_tokens), dtype = 'bool'))

    for l in doc_labels:
        n_label_tokens = len(l.split(' '))
        doc_tokens_joined = [' '.join(doc_tokens_lower[i:i+n_label_tokens]) for i in range(len(doc_tokens_lower) - n_label_tokens + 1)]
        
        occurrences = indices(doc_tokens_joined, l)

        assert len(occurrences) != 0, f'Label {l} not found in doc {doc_id}'
        for o in occurrences:
            keep_start = max(0, o - 250)
            keep_end = min(o + 250 + n_label_tokens, len(doc_tokens))
            keep_df[keep_start: keep_end] = True
            for i in range(n_label_tokens):
                target_arr[o + i] = 1

    doc_df = pd.DataFrame()
    doc_df['TOKEN'] = doc_tokens
    doc_df['TARGET'] = target_arr
    doc_df['TARGET'] = doc_df['TARGET'].astype('str')
    if reduce_tokens:
        doc_df = doc_df.loc[keep_df]

    return doc_df

## Create Generators

In [6]:
# Load dataset names
with open('data/dataset_names.txt', 'r', encoding="utf-8") as f:
    us_dataset_names = f.readlines()
    us_dataset_names = [n for n in us_dataset_names if len(n) > 25]
    us_dataset_names = [make_single_whitespace(remove_punc(n)) for n in us_dataset_names]

In [7]:
import random

n_examples = len(train_idx)
n_generator_repeat = 8

def replace_target(x, lst):
    if x.TARGET.iloc[0] == '0':
        # if not a dataset name, do not augment
        lst.append(x)
    else:
        random_name_tokens = random.choice(us_dataset_names).split(' ')

        new_x = pd.DataFrame()
        # Replace tokens
        new_x['TOKEN'] = random_name_tokens
        new_x['TARGET'] = '1'
        lst.append(new_x)

def train_generator():
    i_repeat = 0
    while i_repeat < n_generator_repeat:
        i_repeat += 1
        print(f'X_train generator repeat: {i_repeat}')
        for doc_id in tqdm(train_idx[:n_examples]):
            doc_df = get_doc(doc_id, reduce_tokens = True).reset_index()
            if (i_repeat > 1):
                df_pieces = []
                # Do augmentation

                # Replace target tokens
                gb = doc_df.groupby((doc_df.TARGET.shift() != doc_df.TARGET).cumsum())
                for name, group in gb:
                    replace_target(group, df_pieces)

                doc_df = pd.concat(df_pieces, ignore_index= True, axis = 0)

            # Convert to string
            lines = []
            for i, r in doc_df.iterrows():
                lines.append(r['TOKEN'] + ' ' + r['TARGET'])
            
            del doc_df
            yield lines

def val_generator():
    for doc_id in val_idx:
        doc_df = get_doc(doc_id, reduce_tokens = False).reset_index()

        # No augmentation
        lines = []
        for i, r in doc_df.iterrows():
            lines.append(r['TOKEN'] + ' ' + r['TARGET'])

        del doc_df
        yield lines

## Create Dataset

## Training

In [8]:
train_gen = train_generator()

In [9]:
lines = []

while True:
    try:
        lines.extend(next(train_gen))
        lines.append('')
    except Exception as e:
        print(e)
        break

  0%|          | 0/12168 [00:00<?, ?it/s]X_train generator repeat: 1
100%|██████████| 12168/12168 [1:22:39<00:00,  2.45it/s]
  0%|          | 0/12168 [00:00<?, ?it/s]X_train generator repeat: 2
  6%|▌         | 717/12168 [04:51<1:17:30,  2.46it/s]


KeyboardInterrupt: 

In [12]:
with open('data/train_data.txt', 'w', encoding="utf-8") as f:
    for l in lines:
        try:
            f.write(l + '\n')
        except Exception as e:
            print(e)

## Val

In [None]:
import gc

del lines
gc.collect()
val_gen = val_generator()

In [None]:
lines = []

while True:
    try:
        lines.extend(next(val_gen))
        lines.append('')
    except Exception as e:
        print(e)
        break

In [None]:
with open('data/val_data.txt', 'w', encoding="utf-8") as f:
    for l in lines:
        try:
            f.write(l + '\n')
        except Exception as e:
            print(e)