In [None]:
import os
import re
import gc
import sys
import time
import json
import random
import unicodedata
import multiprocessing
from functools import partial, lru_cache

import emoji

import numpy as np
import pandas as pd
from sklearn.externals import joblib
from tqdm import tqdm, tqdm_notebook

from nltk import TweetTokenizer
from nltk.stem import PorterStemmer, SnowballStemmer
from nltk.stem.lancaster import LancasterStemmer

import torch
from torch import nn
from torch.utils import data
from torch.nn import functional as F
from gensim.models import KeyedVectors
from keras.preprocessing.sequence import pad_sequences

class SequenceBucketCollator():
    def __init__(self, choose_length, sequence_index, length_index, label_index=None):
        self.choose_length = choose_length
        self.sequence_index = sequence_index
        self.length_index = length_index
        self.label_index = label_index
        
    def __call__(self, batch):
        batch = [torch.stack(x) for x in list(zip(*batch))]
        
        sequences = batch[self.sequence_index]
        lengths = batch[self.length_index]
        
        length = self.choose_length(lengths)
        mask = torch.arange(start=maxlen, end=0, step=-1) < length
        padded_sequences = sequences[:, mask]
        
        batch[self.sequence_index] = padded_sequences
        
        if self.label_index is not None:
            return [x for i, x in enumerate(batch) if i != self.label_index], batch[self.label_index]
    
        return batch


CUSTOM_TABLE = str.maketrans(
    {
        "\xad": None,
        "\x7f": None,
        "\ufeff": None,
        "\u200b": None,
        "\u200e": None,
        "\u202a": None,
        "\u202c": None,
        "‘": "'",
        "’": "'",
        "`": "'",
        "“": '"',
        "”": '"',
        "«": '"',
        "»": '"',
        "ɢ": "G",
        "ɪ": "I",
        "ɴ": "N",
        "ʀ": "R",
        "ʏ": "Y",
        "ʙ": "B",
        "ʜ": "H",
        "ʟ": "L",
        "ғ": "F",
        "ᴀ": "A",
        "ᴄ": "C",
        "ᴅ": "D",
        "ᴇ": "E",
        "ᴊ": "J",
        "ᴋ": "K",
        "ᴍ": "M",
        "Μ": "M",
        "ᴏ": "O",
        "ᴘ": "P",
        "ᴛ": "T",
        "ᴜ": "U",
        "ᴡ": "W",
        "ᴠ": "V",
        "ĸ": "K",
        "в": "B",
        "м": "M",
        "н": "H",
        "т": "T",
        "ѕ": "S",
        "—": "-",
        "–": "-",
    }
)

WORDS_REPLACER = [
    ("sh*t", "shit"),
    ("s**t", "shit"),
    ("f*ck", "fuck"),
    ("fu*k", "fuck"),
    ("f**k", "fuck"),
    ("f*****g", "fucking"),
    ("f***ing", "fucking"),
    ("f**king", "fucking"),
    ("p*ssy", "pussy"),
    ("p***y", "pussy"),
    ("pu**y", "pussy"),
    ("p*ss", "piss"),
    ("b*tch", "bitch"),
    ("bit*h", "bitch"),
    ("h*ll", "hell"),
    ("h**l", "hell"),
    ("cr*p", "crap"),
    ("d*mn", "damn"),
    ("stu*pid", "stupid"),
    ("st*pid", "stupid"),
    ("n*gger", "nigger"),
    ("n***ga", "nigger"),
    ("f*ggot", "faggot"),
    ("scr*w", "screw"),
    ("pr*ck", "prick"),
    ("g*d", "god"),
    ("s*x", "sex"),
    ("a*s", "ass"),
    ("a**hole", "asshole"),
    ("a***ole", "asshole"),
    ("a**", "ass"),
]

REGEX_REPLACER = [
    (re.compile(pat.replace("*", "\*"), flags=re.IGNORECASE), repl)
    for pat, repl in WORDS_REPLACER
]

RE_SPACE = re.compile(r"\s")
RE_MULTI_SPACE = re.compile(r"\s+")

NMS_TABLE = dict.fromkeys(
    i for i in range(sys.maxunicode + 1) if unicodedata.category(chr(i)) == "Mn"
)

HEBREW_TABLE = {i: "א" for i in range(0x0590, 0x05FF)}
ARABIC_TABLE = {i: "ا" for i in range(0x0600, 0x06FF)}
CHINESE_TABLE = {i: "是" for i in range(0x4E00, 0x9FFF)}
KANJI_TABLE = {i: "ッ" for i in range(0x2E80, 0x2FD5)}
HIRAGANA_TABLE = {i: "ッ" for i in range(0x3041, 0x3096)}
KATAKANA_TABLE = {i: "ッ" for i in range(0x30A0, 0x30FF)}

TABLE = dict()
TABLE.update(CUSTOM_TABLE)
TABLE.update(NMS_TABLE)
# Non-english languages
TABLE.update(CHINESE_TABLE)
TABLE.update(HEBREW_TABLE)
TABLE.update(ARABIC_TABLE)
TABLE.update(HIRAGANA_TABLE)
TABLE.update(KATAKANA_TABLE)
TABLE.update(KANJI_TABLE)


EMOJI_REGEXP = emoji.get_emoji_regexp()

UNICODE_EMOJI_MY = {
    k: f" EMJ {v.strip(':').replace('_', ' ')} "
    for k, v in emoji.UNICODE_EMOJI_ALIAS.items()
}


def my_demojize(string: str) -> str:
    def replace(match):
        return UNICODE_EMOJI_MY.get(match.group(0), match.group(0))

    return re.sub("\ufe0f", "", EMOJI_REGEXP.sub(replace, string))


def normalize(text: str) -> str:
    text = my_demojize(text)

    text = RE_SPACE.sub(" ", text)
    text = unicodedata.normalize("NFKD", text)
    text = text.translate(TABLE)
    text = RE_MULTI_SPACE.sub(" ", text).strip()

    for pattern, repl in REGEX_REPLACER:
        text = pattern.sub(repl, text)

    return text


PORTER_STEMMER = PorterStemmer()
LANCASTER_STEMMER = LancasterStemmer()
SNOWBALL_STEMMER = SnowballStemmer("english")

def word_forms(word):
    yield word
    yield word.lower()
    yield word.upper()
    yield word.capitalize()
    yield PORTER_STEMMER.stem(word)
    yield LANCASTER_STEMMER.stem(word)
    yield SNOWBALL_STEMMER.stem(word)
    
def maybe_get_embedding(word, model):
    for form in word_forms(word):
        if form in model:
            return model[form]

    word = word.strip("-'")
    for form in word_forms(word):
        if form in model:
            return model[form]

    return None


def gensim_to_embedding_matrix(word2index, path):
    model = KeyedVectors.load(path, mmap="r")
    embedding_matrix = np.zeros((max(word2index.values()) + 1, model.vector_size), dtype=np.float32)
    unknown_words = []

    for word, i in word2index.items():
        maybe_embedding = maybe_get_embedding(word, model)
        if maybe_embedding is not None:
            embedding_matrix[i] = maybe_embedding
        else:
            unknown_words.append(word)

    return embedding_matrix, unknown_words

In [None]:
%%time

test = pd.read_csv('../input/jigsaw-unintended-bias-in-toxicity-classification/test.csv')

In [None]:
# test = test.head(1000)

In [None]:
%%time
with multiprocessing.Pool(processes=2) as pool:
     text_list = pool.map(normalize, test.comment_text.tolist())

In [None]:
%%time
tknzr = TweetTokenizer(strip_handles=True, reduce_len=True)

test_word_sequences = []
word_dict = {}
word_index = 1

for doc in text_list:
    word_seq = []
    for token in tknzr.tokenize(doc):
        if token not in word_dict:
            word_dict[token] = word_index
            word_index += 1
        word_seq.append(word_dict[token])
    test_word_sequences.append(word_seq)

In [None]:
%%time

test_lengths = torch.from_numpy(np.array([len(x) for x in test_word_sequences]))
maxlen = test_lengths.max() 
print(f"Max len = {maxlen}")
maxlen = min(maxlen, 400)

x_test_padded = torch.tensor(pad_sequences(test_word_sequences, maxlen=maxlen)).long()
test_collator = SequenceBucketCollator(torch.max, sequence_index=0, length_index=1)

del text_list, test_word_sequences, tknzr
gc.collect()

In [None]:
%%time

glove_matrix, _ = gensim_to_embedding_matrix(
    word_dict,
    "../input/gensim-embeddings-dataset/glove.840B.300d.gensim",
)

crawl_matrix, _ = gensim_to_embedding_matrix(
    word_dict, 
    "../input/gensim-embeddings-dataset/crawl-300d-2M.gensim",
)

para_matrix, _ = gensim_to_embedding_matrix(
    word_dict, 
    "../input/gensim-embeddings-dataset/paragram_300_sl999.gensim",
)

w2v_matrix, _ = gensim_to_embedding_matrix(
    word_dict, 
    "../input/gensim-embeddings-dataset/GoogleNews-vectors-negative300.gensim",
)

In [None]:
%%time

def one_hot_char_embeddings(word2index, char_vectorizer):
    words = [""] * (max(word2index.values()) + 1)
    for word, i in word2index.items():
        words[i] = word

    return char_vectorizer.transform(words).toarray().astype(np.float32)

char_matrix = one_hot_char_embeddings(
    word_dict,
    joblib.load('../input/jigsaw-solution-ver-1/char_vectorizer.pkl'),
)

In [None]:
LSTM_UNITS = 128
DENSE_HIDDEN_UNITS = 6 * LSTM_UNITS

class SpatialDropout(nn.Dropout2d):
    def forward(self, x):
        x = x.unsqueeze(2)
        x = x.permute(0, 3, 2, 1)
        x = super(SpatialDropout, self).forward(x)
        x = x.permute(0, 3, 2, 1)
        x = x.squeeze(2)
        return x
    
    
class NeuralNet(nn.Module):
    def __init__(self, embedding_matrix, output_aux_sub=11):
        super(NeuralNet, self).__init__()
        embed_size = embedding_matrix.shape[1]
        
        self.embedding = nn.Embedding(embedding_matrix.shape[0], embed_size)
        self.embedding.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype=torch.float32))
        self.embedding.weight.requires_grad = False
        self.embedding_dropout = SpatialDropout(0.3)
        
        self.lstm1 = nn.LSTM(embed_size, LSTM_UNITS, bidirectional=True, batch_first=True)
        self.lstm2 = nn.LSTM(LSTM_UNITS * 2, LSTM_UNITS, bidirectional=True, batch_first=True)
    
        self.linear1 = nn.Linear(DENSE_HIDDEN_UNITS, DENSE_HIDDEN_UNITS)
        self.linear2 = nn.Linear(DENSE_HIDDEN_UNITS, DENSE_HIDDEN_UNITS)
        
        self.linear_out = nn.Linear(DENSE_HIDDEN_UNITS + 6 + output_aux_sub, 1)
        self.linear_aux_out = nn.Linear(DENSE_HIDDEN_UNITS, 6)
        self.linear_sub_out = nn.Linear(DENSE_HIDDEN_UNITS, output_aux_sub)
        
    def forward(self, x, lengths=None):
        h_embedding = self.embedding(x)
        h_embedding = self.embedding_dropout(h_embedding)
        
        h_lstm1, _ = self.lstm1(h_embedding)
        h_lstm2, _ = self.lstm2(h_lstm1)
        
        avg_pool1 = torch.mean(h_lstm1, 1)
        avg_pool2 = torch.mean(h_lstm2, 1)
        max_pool2, _ = torch.max(h_lstm2, 1)
        
        h_conc = torch.cat((avg_pool1, max_pool2, avg_pool2), 1)
        h_conc_linear1  = F.relu(self.linear1(h_conc))
        h_conc_linear2  = F.relu(self.linear2(h_conc))
        
        hidden = h_conc + h_conc_linear1 + h_conc_linear2

        aux_result = self.linear_aux_out(hidden)
        sub_result = self.linear_sub_out(hidden)
        result = self.linear_out(torch.cat((hidden, aux_result, sub_result), 1))
        out = torch.cat([result, aux_result, sub_result], 1)
        return out
    
    
def get_lstm_preds(model_name, embedding_matrix):
    model = NeuralNet(embedding_matrix)
    temp_dict = torch.load('../input/jigsaw-solution-ver-1/' + model_name)
    temp_dict['embedding.weight'] = torch.tensor(embedding_matrix)
    model.load_state_dict(temp_dict)
    model = model.cuda()
    for param in model.parameters():
        param.requires_grad=False
    model = model.eval()
    
    batch_size = 256
    test_dataset = data.TensorDataset(x_test_padded, test_lengths)
    test_loader = data.DataLoader(test_dataset, batch_size=batch_size, collate_fn=test_collator)
    
    preds = np.zeros((len(test_dataset), 18), dtype=np.float32)
    with torch.no_grad():
        for i, (x_batch) in enumerate(test_loader):
            X_1 = x_batch[0].cuda()
            y_pred = torch.sigmoid(model(X_1)).cpu().numpy()
            preds[i * batch_size:(i + 1) * batch_size] = y_pred
            
    return preds

# LSTM inference

In [None]:
%%time

lstm_models = ['Notebook_100_5.bin',
               'Notebook_100_6.bin', 
               'Notebook_100_7.bin',
               'Notebook_100_8.bin',
               'Notebook_100_9.bin', 
               'Notebook_100_10.bin', 
               'Notebook_100_11.bin',
               'Notebook_100_12.bin']

embedding_matrix = np.concatenate([glove_matrix, crawl_matrix, w2v_matrix, char_matrix], axis=1)

all_lstm_preds = []
for model_name in lstm_models:
    print(model_name)
    preds = get_lstm_preds(model_name, embedding_matrix)
    all_lstm_preds.append(preds)
    gc.collect()

In [None]:
print(embedding_matrix.dtype)

In [None]:
# lstm_models = ['Notebook_100_13.bin']

# embedding_matrix = np.concatenate([glove_matrix, crawl_matrix, para_matrix, char_matrix], axis=1)

# for model_name in lstm_models:
#     print(model_name)
#     preds = get_lstm_preds(model_name, embedding_matrix)
#     all_lstm_preds.append(preds)
#     gc.collect()

In [None]:
%%time

lstm_models = ['Notebook_100_1.bin']

embedding_matrix = np.concatenate([para_matrix, crawl_matrix, w2v_matrix, char_matrix], axis=1)

for model_name in lstm_models:
    print(model_name)
    preds = get_lstm_preds(model_name, embedding_matrix)
    all_lstm_preds.append(preds)
    
lstm_models = ['Notebook_100_2.bin']

embedding_matrix = np.concatenate([glove_matrix, crawl_matrix, w2v_matrix, char_matrix], axis=1)

for model_name in lstm_models:
    print(model_name)
    preds = get_lstm_preds(model_name, embedding_matrix)
    all_lstm_preds.append(preds)
    
lstm_models = ['Notebook_100_3.bin']

embedding_matrix = np.concatenate([glove_matrix, para_matrix, w2v_matrix, char_matrix], axis=1)

for model_name in lstm_models:
    print(model_name)
    preds = get_lstm_preds(model_name, embedding_matrix)
    all_lstm_preds.append(preds)

lstm_models = ['Notebook_100_4.bin']

embedding_matrix = np.concatenate([glove_matrix, para_matrix, crawl_matrix, char_matrix], axis=1)

for model_name in lstm_models:
    print(model_name)
    preds = get_lstm_preds(model_name, embedding_matrix)
    all_lstm_preds.append(preds)

In [None]:
def simple_magic(preds):
    return (
        preds[:, 0] + preds[:, 1] * 0.05 - preds[:, -1] * 0.05 - preds[:, 4] * 0.05
    )

def sophisticated_magic(preds):
    return (
        preds[:, 0] + preds[:, 1] * 0.05 - preds[:, -1] * 0.03 - preds[:, 4] * 0.03
    ) - preds[:, 14] * (1 - preds[:, 0]) * 0.05 - preds[:, 10] * (1 - preds[:, 0]) * 0.05

In [None]:
# from scipy import stats

all_lstm_preds = np.vstack([sophisticated_magic(x) for x in all_lstm_preds])
# preds_lstm = np.median(all_lstm_preds, axis=0)  # MEDIAN ensemble
# preds_lstm = stats.trim_mean(all_lstm_preds, 0.1, axis=0) # 10%-trimmed-mean

In [None]:
STORE = {
    'lstm': all_lstm_preds.mean(0),
}

In [None]:
del crawl_matrix, glove_matrix, para_matrix, w2v_matrix, embedding_matrix
del test_lengths, maxlen, x_test_padded, test_collator
del lstm_models, all_lstm_preds
del word_dict, word_index
gc.collect()

# BERT inference

In [None]:
from pytorch_pretrained_bert import BertTokenizer, BertConfig, BertForSequenceClassification

UNCASED_BERT_MODEL_PATH = "../input/transformer-tokenizers/bert-base-uncased/"
CASED_BERT_MODEL_PATH = "../input/transformer-tokenizers/bert-base-cased/"


def clip_to_max_len(batch):
    X, lengths = map(torch.stack, zip(*batch))
    max_len = torch.max(lengths).item()
    return X[:, :max_len]

def prepare_bert(bin_file, is_cased):
    bert_path = CASED_BERT_MODEL_PATH if is_cased else UNCASED_BERT_MODEL_PATH 
    config = BertConfig(os.path.join(bert_path, "bert_config.json"))
    model = BertForSequenceClassification(config, num_labels=18)
    model.load_state_dict(torch.load(bin_file, map_location="cpu"))
    for p in model.parameters():
        p.requires_grad = False
    model = model.eval()
    model = model.cuda()
    return model

def prepare_tokenizer(is_cased):
    if is_cased:
        return BertTokenizer.from_pretrained(CASED_BERT_MODEL_PATH, do_lower_case=False)
    else:
        return BertTokenizer.from_pretrained(UNCASED_BERT_MODEL_PATH, do_lower_case=True)
    
    
def apply_bert(model, loader):
    preds = np.zeros((len(loader.dataset), 18), dtype=np.float32)
    for i, X in enumerate(loader):
        X = X.cuda()
        p = torch.sigmoid(model(X, attention_mask=(X > 0)))
        preds[i * loader.batch_size : (i + 1) * loader.batch_size] = p.cpu().numpy()
    return preds

## UNCASED

In [None]:
%%time
tokenizer = prepare_tokenizer(is_cased=False)
MAX_LEN = 400 - 2

def convert_line(text):
    tokens_a = tokenizer.tokenize(text)[:MAX_LEN]
    one_token = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokens_a + ["[SEP]"])
    one_token += [0] * (MAX_LEN - len(tokens_a))
    return one_token

with multiprocessing.Pool(processes=2) as pool:
    sequences = pool.map(convert_line, test.comment_text)    
sequences = np.array(sequences)

lengths = np.argmax(sequences == 0, axis=1)
lengths[lengths == 0] = sequences.shape[1]

ids = lengths.argsort(kind="stable")
inverse_ids = test.id.values[ids].argsort(kind="stable")

sequences = torch.from_numpy(sequences)
lengths = torch.from_numpy(lengths)

test_dataset = data.TensorDataset(sequences, lengths)
test_loader = data.DataLoader(data.Subset(test_dataset, ids), batch_size=128, collate_fn=clip_to_max_len)

In [None]:
%%time

UNCASED_BERT_MODELS = [
    "BERT_exp_BERT_3_decay_5_epoch_1",
    "final-pipe1-raw_1",
    "final-pipe1-raw_2",
    "final-pipe1-raw_3",
    "final-pipe2-raw_1",
    "final-pipe2-raw_2",
    "final-pipe2-raw_3",
    "final-pipe4-wiki_raw_1",
    "final-pipe4-wiki_raw_2",
    "final-pipe4-wiki_raw_3",
]

for path in UNCASED_BERT_MODELS:
    model = prepare_bert(f"../input/toxic-models-zoo/{path}.bin", is_cased=False)
    preds = apply_bert(model, test_loader)
    STORE[path] = sophisticated_magic(preds)[inverse_ids]
    
    
for path in ["final-pipe2-raw_4", "BERT_exp_BERT_3_epoch_1"]:
    model = prepare_bert(f"../input/jigsaw-solution-ver-1/{path}.bin", is_cased=False)
    preds = apply_bert(model, test_loader)
    STORE[path] = sophisticated_magic(preds)[inverse_ids]

## CASED

In [None]:
%%time
tokenizer = prepare_tokenizer(is_cased=True)

def convert_line(text):
    tokens_a = tokenizer.tokenize(text)[:MAX_LEN]
    one_token = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokens_a + ["[SEP]"])
    one_token += [0] * (MAX_LEN - len(tokens_a))
    return one_token

with multiprocessing.Pool(processes=2) as pool:
    sequences = pool.map(convert_line, test.comment_text)    
sequences = np.array(sequences)

lengths = np.argmax(sequences == 0, axis=1)
lengths[lengths == 0] = sequences.shape[1]

ids = lengths.argsort(kind="stable")
inverse_ids = test.id.values[ids].argsort(kind="stable")

sequences = torch.from_numpy(sequences)
lengths = torch.from_numpy(lengths)

test_dataset = data.TensorDataset(sequences, lengths)
test_loader = data.DataLoader(data.Subset(test_dataset, ids), batch_size=128, collate_fn=clip_to_max_len)

In [None]:
%%time

CASED_BERT_MODELS = [
    "BERT_exp_BERT_3_cased_decay_epoch_1",
    "final-pipe2-cased_1",
    "final-pipe2-cased_2",
    "final-pipe2-cased_3",
#     "final-pipe2-cased_4",
#     "final-pipe2-cased_5",
#     "final-pipe2-cased_6",
    "final-pipe5-wiki_cased_1",
    "final-pipe5-wiki_cased_2",
    "final-pipe5-wiki_cased_3",
]

for path in CASED_BERT_MODELS:
    model = prepare_bert(f"../input/toxic-models-zoo/{path}.bin", is_cased=True)
    preds = apply_bert(model, test_loader)
    STORE[path] = sophisticated_magic(preds)[inverse_ids]
    
    
for path in ["final-pipe2-cased_10"]:
    model = prepare_bert(f"../input/jigsaw-solution-ver-1/{path}.bin", is_cased=True)
    preds = apply_bert(model, test_loader)
    STORE[path] = sophisticated_magic(preds)[inverse_ids]

## GPT2CNN

In [None]:
import regex as re
from io import open

@lru_cache()
def bytes_to_unicode():
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

class MonkeyPatchedGPT2Tokenizer(object):
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
        vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
        merges_file = os.path.join(pretrained_model_name_or_path, "merges.txt")

        max_len = 1024
        kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
        # Instantiate tokenizer.
        special_tokens = kwargs.pop('special_tokens', [])
        tokenizer = cls(vocab_file, merges_file, special_tokens=special_tokens, *inputs, **kwargs)
        return tokenizer

    def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
        self.max_len = max_len if max_len is not None else int(1e12)
        self.encoder = json.load(open(vocab_file))
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
        bpe_merges = [tuple(merge.split()) for merge in bpe_data]
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized  of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

        self.special_tokens = {}
        self.special_tokens_decoder = {}
        self.set_special_tokens(special_tokens)

    def __len__(self):
        return len(self.encoder) + len(self.special_tokens)

    def set_special_tokens(self, special_tokens):
        if not special_tokens:
            self.special_tokens = {}
            self.special_tokens_decoder = {}
            return
        self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
        self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def tokenize(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def convert_tokens_to_ids(self, tokens):
        ids = []
        if isinstance(tokens, str):
            if tokens in self.special_tokens:
                return self.special_tokens[tokens]
            else:
                return self.encoder.get(tokens, 0)
        for token in tokens:
            if token in self.special_tokens:
                ids.append(self.special_tokens[token])
            else:
                ids.append(self.encoder.get(token, 0))
        if len(ids) > self.max_len:
            print(
                "Token indices sequence length is longer than the specified maximum "
                " sequence length for this OpenAI GPT model ({} > {}). Running this"
                " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
            )
        return ids

    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
        """Converts a sequence of ids in BPE tokens using the vocab."""
        tokens = []
        for i in ids:
            if i in self.special_tokens_decoder:
                if not skip_special_tokens:
                    tokens.append(self.special_tokens_decoder[i])
            else:
                tokens.append(self.decoder[i])
        return tokens

    def encode(self, text):
        return self.convert_tokens_to_ids(self.tokenize(text))

    def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
        text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        if clean_up_tokenization_spaces:
            text = text.replace('<unk>', '')
            text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
                    ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
                    ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
        return text


In [None]:
from pytorch_pretrained_bert import GPT2Config, GPT2Model
from pytorch_pretrained_bert.modeling_gpt2 import GPT2PreTrainedModel

MAX_LEN = 250
GPT2_TOKENIZER = MonkeyPatchedGPT2Tokenizer.from_pretrained("../input/transformer-tokenizers/gpt2/")

def convert_line_gpt2(text):
    tokens_a = GPT2_TOKENIZER.tokenize(text)[:MAX_LEN]
    one_token = GPT2_TOKENIZER.convert_tokens_to_ids(tokens_a)
    one_token += [0] * (MAX_LEN - len(tokens_a))
    return one_token


def apply_gpt(model, loader):
    preds = np.zeros((len(loader.dataset), 18), dtype=np.float32)
    for i, X in enumerate(loader):
        p = torch.sigmoid(model(X.cuda()))
        preds[i * loader.batch_size : (i + 1) * loader.batch_size] = p.cpu().numpy()
    return preds


class GPT2CNN(GPT2PreTrainedModel):
    def __init__(self, config, num_labels):
        super().__init__(config)
        self.transformer = GPT2Model(config)  
        self.cnn1 = nn.Conv1d(768, 256, kernel_size=3, padding=1)
        self.cnn2 = nn.Conv1d(256, num_labels, kernel_size=3, padding=1)
        self.apply(self.init_weights)

    def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
        x, _ = self.transformer(input_ids, position_ids, token_type_ids, past)
        x = x.permute(0, 2, 1)
        x = F.relu(self.cnn1(x))
        x = self.cnn2(x)
        output, _ = torch.max(x, 2)
        return output

In [None]:
%%time
with multiprocessing.Pool(processes=2) as pool:
    sequences = np.array(pool.map(convert_line_gpt2, test.comment_text))

lengths = np.argmax(sequences == 0, axis=1)
lengths[lengths == 0] = sequences.shape[1]

sequences = torch.from_numpy(sequences)
lengths = torch.from_numpy(lengths)

test_dataset = data.TensorDataset(sequences, lengths)
test_loader = data.DataLoader(test_dataset, batch_size=64, collate_fn=clip_to_max_len)

In [None]:
%%time

GPT2CNN_MODELS = [
    "BERT_exp_GPT2_CNN_epoch_1",
    "BERT_exp_GPT2_CNN_seed_epoch_1",
#     "final-pipe6-gpt_wiki_1",
]

for path in GPT2CNN_MODELS:
    model = GPT2CNN(GPT2Config(), num_labels=18)
    model.load_state_dict(
        torch.load(f"../input/toxic-models-zoo/{path}.bin", map_location="cpu")
    )
    for p in model.parameters():
        p.requires_grad = False
    model = model.eval()
    model = model.cuda()
    preds = apply_gpt(model, test_loader)
    STORE[path] = sophisticated_magic(preds)

# MERGING

In [None]:
df = pd.DataFrame(STORE)

In [None]:
df.head()

In [None]:
df.corr()

In [None]:
df.corr('kendall')

In [None]:
final_preds = (
    ### LSTM
    4.5 * df["lstm"] +
    
    ### BERT
    # best 90%-model
    df["BERT_exp_BERT_3_decay_5_epoch_1"] +
    df["BERT_exp_BERT_3_cased_decay_epoch_1"] +
    
    # GPT-CNN
    2 * df["BERT_exp_GPT2_CNN_epoch_1"] + 
    2 * df["BERT_exp_GPT2_CNN_seed_epoch_1"] +
#     2 * df["final-pipe6-gpt_wiki_1"] +
    
    # fint-tune
    df["final-pipe1-raw_1"] +
    df["final-pipe1-raw_2"] +
    df["final-pipe1-raw_3"] +
    # base-uncased
    df["final-pipe2-raw_1"] +
    df["final-pipe2-raw_2"] +
    df["final-pipe2-raw_3"] +
    # cased (note: 4,5,6 don't really improve results)
    df["final-pipe2-cased_1"] +
    df["final-pipe2-cased_2"] +
    df["final-pipe2-cased_3"] + 
#     df["final-pipe2-cased_4"] +
#     df["final-pipe2-cased_5"] +
#     df["final-pipe2-cased_6"] +
    # Wiki-ft
    df["final-pipe4-wiki_raw_1"] +
    df["final-pipe4-wiki_raw_2"] +
    df["final-pipe4-wiki_raw_3"] +
    # Wiki-cased
    df["final-pipe5-wiki_cased_1"] +
    df["final-pipe5-wiki_cased_2"] +
    df["final-pipe5-wiki_cased_3"] + 
    
    df["final-pipe2-raw_4"] +
    df["final-pipe2-cased_10"] +
    df["BERT_exp_BERT_3_epoch_1"]
)

In [None]:
submission = pd.DataFrame({
    'id': test['id'],
    'prediction': final_preds,
})

In [None]:
submission.head()

In [None]:
submission.to_csv('submission.csv', index=False)

In [None]:
print(torch.cuda.max_memory_allocated(0) / 1024 / 1024)
print(torch.cuda.max_memory_cached(0) / 1024 / 1024)