# Jigsaw Toxicity Inference with FNet
## Table of Contents
* [1. Configuration](#1.)
* [2. Setup](#2.)
* [3. Tokenzier](#3.)
* [4. FNet Model](#4.)
* [5. Submission](#5.)

<font color="red" size="3">If you found it useful and would like to back me up, just upvote.</font>

This is inference notebook. For training notebook, visit [here](https://www.kaggle.com/lonnieqin/jigsaw-toxicity-prediction-with-fnet).

<a id="1."></a>
## 1. Configuration

In [None]:
class Config:
    vocab_size = 15000 # Vocabulary Size
    sequence_length = 100 # Length of sequence
    batch_size = 1024
    embed_dim = 256
    latent_dim = 256
config = Config()

<a id="2."></a>
## 2. Setup

In [None]:
import pandas as pd
import tensorflow as tf
import pathlib
import random
import string
import re
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import os
import sklearn
from sklearn.model_selection import train_test_split
from nltk.tokenize import TweetTokenizer 
from nltk.stem.porter import PorterStemmer
from nltk.stem import WordNetLemmatizer
from scipy.stats import rankdata
import json
import sys
sys.setrecursionlimit(100000)

<a id="3."></a>
## 3. Tokenzier

In [None]:
class Tokenizer:
    
    stopwords = set(["a", "about", "above", "after", "again", "against", "all", "am", "an", "and", "any", "are", "as", "at", "be", "because", "been", "before", "being", "below", "between", "both", "but", "by", "could", "did", "do", "does", "doing", "down", "during", "each", "few", "for", "from", "further", "had", "has", "have", "having", "he", "he'd", "he'll", "he's", "her", "here", "here's", "hers", "herself", "him", "himself", "his", "how", "how's", "i", "i'd", "i'll", "i'm", "i've", "if", "in", "into", "is", "it", "it's", "its", "itself", "let's", "me", "more", "most", "my", "myself", "nor", "of", "on", "once", "only", "or", "other", "ought", "our", "ours", "ourselves", "out", "over", "own", "same", "she", "she'd", "she'll", "she's", "should", "so", "some", "such", "than", "that", "that's", "the", "their", "theirs", "them", "themselves", "then", "there", "there's", "these", "they", "they'd", "they'll", "they're", "they've", "this", "those", "through", "to", "too", "under", "until", "up", "very", "was", "we", "we'd", "we'll", "we're", "we've", "were", "what", "what's", "when", "when's", "where", "where's", "which", "while", "who", "who's", "whom", "why", "why's", "with", "would", "you", "you'd", "you'll", "you're", "you've", "your", "yours", "yourself", "yourselves" ])
    
    tweet_tokenizer = TweetTokenizer() 
    
    stemmer = PorterStemmer()
    
    lemmatizer = WordNetLemmatizer()
    
    def __init__(self, vocab_size = None, oov_token = None, bos_token = None, eos_token = None, max_length = 10000):
        self.vocab_size = vocab_size
        self.oov_token = oov_token
        self.max_length = max_length
        self.bos_token = bos_token
        self.eos_token = eos_token
    
    @staticmethod
    def preprocess_string(text):
        # Convert sentences to lowercase.
        text = text.lower()
        # Remove puntuations, but ? and ! are usually enmotional so I won't remove it.
        text = re.sub(r'[\n| |.|\"|,|:|\(|\)|#|\'|\{|\}|\*|\/|\$|\—|~|;|=|\[｜\]|\-]+', " ", text)
        # Remove Digits
        text = re.sub("[0-9]+", " ", text)
        text = re.sub("[ ]+", " ", text)
        text = text.strip(" ")
        # Convert sentences to tokens
        items = Tokenizer.tweet_tokenizer.tokenize(text)
        # Remove stop words
        new_items = []
        for item in items:
            if item not in Tokenizer.stopwords:
                new_item = Tokenizer.lemmatizer.lemmatize(item)
                new_item = Tokenizer.stemmer.stem(new_item)
                new_items.append(new_item)
        return new_items
        
    def fit_transform(self, texts):
        current_index = 1
        word_index = {self.oov_token: current_index}
        if self.bos_token != None:
            current_index += 1
            word_index[self.bos_token] = current_index
        if self.eos_token != None:
            current_index += 1
            word_index[self.eos_token] = current_index
        word_count = {}
        for i in range(len(texts)):
            text = texts[i]
            for item in text:
                if item in word_count:
                    word_count[item] += 1
                else:
                    word_count[item] = 1
        word_count_df = pd.DataFrame({"key": word_count.keys(), "count": word_count.values()})
        word_count_df.sort_values(by="count", ascending=False, inplace=True)
        vocab = list(word_index.keys())
        vocab += list(word_count_df["key"][0: self.vocab_size - len(word_index)])
        vocab = set(vocab)
        self.vocab = vocab
        
        sentences = []
        offset = 1 if self.eos_token != None else 0
        for i in range(len(texts)):
            text = texts[i]
            sentence = []
            if self.bos_token != None:
                sentence.append(word_index[self.bos_token])
            for item in text:
                if item in self.vocab:
                    if item in word_index:
                        sentence.append(word_index[item])
                    else:
                        current_index += 1
                        word_index[item] = current_index
                        sentence.append(word_index[item])
                else:
                    sentence.append(word_index[self.oov_token])
            if len(sentence) <= self.max_length - offset:
                if self.eos_token != None:
                    sentence.append(word_index[self.eos_token])
                sentence += [0] * (self.max_length - len(sentence))
            elif len(sentence) > self.max_length - offset:
                sentence = sentence[:self.max_length - offset]
                if self.eos_token != None:
                    sentence.append(word_index[self.eos_token])
            sentences.append(sentence)
        self.word_index = word_index
        self.index_word = dict({word_index[key]: key for key in word_index.keys()})
        return sentences
    
    def save(self, path):
        dic = {
            "vocab_size": self.vocab_size,
            "oov_token": self.oov_token,
            "max_length":  self.max_length,
            "vocab": list(self.vocab),
            "index_word": self.index_word,
            "word_index": self.word_index
        }
        
        if self.bos_token is not None:
            dic["bos_token"] = self.bos_token
            
        if self.eos_token is not None:
            dic["eos_token"] = self.eos_token
            
        res = json.dumps(dic)
        
        with open(path, "w+") as f:
            f.write(res)
            
    def load(self, path):
        with open(path, "r") as f:
            dic = json.load(f)
        self.vocab_size = dic["vocab_size"]
        self.oov_token = dic["oov_token"]
        self.max_length = dic["max_length"]
        self.vocab = set(dic["vocab"])
        self.index_word = dic["index_word"]
        self.word_index = dic["word_index"]
        if "bos_token" in dic:
            self.bos_token = dic["bos_token"]
        if "eos_token" in dic:
            self.eos_token = dic["eos_token"]
            
    def transform(self, texts):
        sentences = []
        offset = 1 if self.eos_token != None else 0
        for i in range(len(texts)):
            text = texts[i]
            sentence = []
            if self.bos_token != None:
                sentence.append(self.word_index[self.bos_token])
            for item in text:
                if item in self.vocab:
                    sentence.append(self.word_index[item])
                else:
                    sentence.append(self.word_index[self.oov_token])
            if len(sentence) == self.max_length - offset:
                if self.eos_token != None:
                    sentence.append(self.word_index[self.eos_token])
            elif len(sentence) < self.max_length - offset:
                if self.eos_token != None:
                    sentence.append(self.word_index[self.eos_token])
                sentence += [0] * (self.max_length - len(sentence))
            elif len(sentence) > self.max_length - offset:
                sentence = sentence[:self.max_length - offset]
                if self.eos_token != None:
                    sentence.append(self.word_index[self.eos_token])
            sentences.append(sentence)
        return sentences

In [None]:
tokenizer = Tokenizer()
tokenizer.load("../input/jigsaw-toxicity-fnet/tokenizer.json")

<a id="4."></a>
## 4. FNet Model

### 4.1 FNet Encoder

In [None]:
class FNetEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, dropout_rate=0.1, **kwargs):
        super(FNetEncoder, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, inputs):
        # Casting the inputs to complex64
        inp_complex = tf.cast(inputs, tf.complex64)
        # Projecting the inputs to the frequency domain using FFT2D and
        # extracting the real part of the output
        fft = tf.math.real(tf.signal.fft2d(inp_complex))
        proj_input = self.layernorm_1(inputs + fft)
        proj_output = self.dense_proj(proj_input)
       
        layer_norm = self.layernorm_2(proj_input + proj_output)
        output = self.dropout(layer_norm)
        return output

<a id="4.2"></a>
### 4.2 Positional Embedding

In [None]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super(PositionalEmbedding, self).__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)


<a id="4."></a>
### 4.3 FNet Classification Model

In [None]:
def get_fnet_classifier(config):
    inputs = keras.Input(shape=(config.sequence_length), dtype="int64", name="encoder_inputs")
    x = PositionalEmbedding(config.sequence_length, config.vocab_size, config.embed_dim)(inputs)
    x = FNetEncoder(config.embed_dim, config.latent_dim)(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.3)(x)
    for i in range(3):
        x = layers.Dense(100, activation="relu")(x)
        x = layers.Dropout(0.3)(x)
    output = layers.Dense(1, activation="sigmoid")(x)
    fnet = keras.Model(inputs, output, name="fnet")
    return fnet

In [None]:
fnet = get_fnet_classifier(config)

In [None]:
fnet.load_weights("../input/jigsaw-toxicity-fnet/model_latest.h5")

In [None]:
fnet.summary()

Let's visualize the Model.

In [None]:
keras.utils.plot_model(fnet, show_shapes=True)

<a id="5."></a>
## 5. Submission

In [None]:
test = pd.read_csv("/kaggle/input/jigsaw-toxic-severity-rating/comments_to_score.csv")
sample_submission = pd.read_csv("/kaggle/input/jigsaw-toxic-severity-rating/sample_submission.csv")
test["text_preprocess"] = test["text"].apply(Tokenizer.preprocess_string)
test_sequences = tokenizer.transform(list(test["text_preprocess"]))
print(test_sequences[0])
test_ds = tf.data.Dataset.from_tensor_slices((test_sequences)).batch(config.batch_size).prefetch(1)
score = fnet.predict(test_ds).reshape(-1)
sample_submission["score"] = rankdata(score, method='ordinal')
sample_submission.to_csv("submission.csv", index=False)
sample_submission.head()


<a id="6."></a>
## 6. References
- [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824v3)
- [Attention Is All You Need](https://arxiv.org/abs/1706.03762v5)
- [Text Generation using FNet](https://keras.io/examples/nlp/text_generation_fnet/)
- [English-Spanish Translation: FNet](https://www.kaggle.com/lonnieqin/english-spanish-translation-fnet)