# **RADI623: Natural Language Processing**

### Assignment: Natural Language Processing
**Romen Samuel Rodis Wabina** <br>
Student, PhD Data Science in Healthcare and Clinical Informatics <br>
Clinical Epidemiology and Biostatistics, Faculty of Medicine (Ramathibodi Hospital) <br>
Mahidol University

Note: In case of Python Markdown errors, you may access the assignment through this GitHub [Link](https://github.com/rrwabina/RADI605/tree/main)

## **Medical Specialty Identification**

The problem of predicting one’s illnesses wrongly through self-diagnosis in medicine is very real. In a report by the [Telegraph](https://www.telegraph.co.uk/news/health/news/11760658/One-in-four-self-diagnose-on-the-internet-instead-of-visiting-the-doctor.html), nearly one in four self-diagnose instead of visiting the doctor. Out of those who misdiagnose, nearly half have misdiagnosed their illness wrongly [reported](https://bigthink.com/health/self-diagnosis/). While there could be multiple root causes to this problem, this could stem from a general unwillingness and inability to seek professional help.

Elevent percent of the respondents surveyed, for example, could not find an appointment in time. This means that crucial time is lost during the screening phase of a medical treatment, and early diagnosis which could have resulted in illnesses treated earlier was not achieved.

With the knowledge of which medical specialty area to focus on, a patient can receive targeted help much faster through consulting specialist doctors. To alleviate waiting times and predict which area of medical specialty to focus on, we can utilize natural language processing (NLP) to solve this task.

Given any medical transcript or patient condition, this solution would predict the medical specialty that the patient should seek help in. Ideally, given a sufficiently comprehensive transcript (and dataset), one would be able to predict exactly which illness he is suffering from.

In [87]:
import numpy  as np
import pandas as pd
import matplotlib.pyplot as plt 
import spacy
import re
import logging
import random
import os
import warnings 
warnings.filterwarnings('ignore')

from spacy.lang.en.stop_words import STOP_WORDS
from spacy.pipeline.tagger import Tagger
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from datasets import load_dataset

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchtext
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import BertModel, BertTokenizer
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import TensorDataset
from imblearn.over_sampling import RandomOverSampler
from nltk.tokenize import word_tokenize
from nltk.tokenize import sent_tokenize
from nltk.stem import WordNetLemmatizer 

import transformers
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, BertConfig, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

nlp = spacy.load('en_core_web_sm')

In [2]:
data = pd.read_csv('../data/mtsamples.csv')

num_samples = len(data)
num_medical_specialties = data['medical_specialty'].nunique()

def calculate_univariate(data):
    description_lengths   = data['description'].str.len()
    transcription_lengths = data['transcription'].str.len()

    avg_description_length = description_lengths.mean()
    min_description_length = description_lengths.min()
    max_description_length = description_lengths.max()

    avg_transcription_length = transcription_lengths.mean()
    min_transcription_length = transcription_lengths.min()
    max_transcription_length = transcription_lengths.max()

    dictionary = {}
    dictionary['description']   = [avg_description_length, min_description_length, max_description_length]
    dictionary['transcription'] = [avg_transcription_length, min_transcription_length, max_transcription_length]
    return dictionary
summary = calculate_univariate(data)

def plot_classes(data):
    specialty_counts = data['medical_specialty'].value_counts()

    plt.figure(figsize = (10, 5))
    plt.bar(specialty_counts.index, specialty_counts.values)
    plt.xlabel('Medical Specialty')
    plt.ylabel('Frequency')
    plt.title('Distribution of Medical Specialties')
    plt.xticks(rotation = 90)
    plt.show()

def plot_histogram(data):
    description_lengths   = data['description'].str.len()
    transcription_lengths = data['transcription'].str.len()

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].hist(description_lengths, bins=50, alpha=0.8)
    axs[0].set_xlabel('Description Length')
    axs[0].set_ylabel('Frequency')
    axs[0].set_title('Histogram of Description Lengths')

    axs[1].hist(transcription_lengths, bins = 50, alpha = 0.8)
    axs[1].set_xlabel('Transcription Length')
    axs[1].set_ylabel('Frequency')
    axs[1].set_title('Histogram of Transcription Lengths')
    plt.tight_layout()
    plt.show()

In [7]:
data.isnull().sum(axis = 0)

description             0
medical_specialty       0
sample_name             0
transcription           0
keywords             1068
dtype: int64

In [6]:
data['transcription'].fillna(data['description'], inplace = True)
len(data['medical_specialty'].unique())

40

In [8]:
def get_sentence_word_count(text_list):
    sent_count = 0
    word_count = 0
    vocab = {}
    for text in text_list:
        sentences=sent_tokenize(str(text).lower())
        sent_count = sent_count + len(sentences)
        for sentence in sentences:
            words = word_tokenize(sentence)
            for word in words:
                if(word in vocab.keys()):
                    vocab[word] = vocab[word] +1
                else:
                    vocab[word] = 1 
    word_count = len(vocab.keys())
    return sent_count,word_count

clinical_text_df = data[data['transcription'].notna()]
sent_count, word_count = get_sentence_word_count(clinical_text_df['transcription'].tolist())

print('Number of sentences in transcriptions column: '    + str(sent_count))
print('Number of unique words in transcriptions column: ' + str(word_count))

Number of sentences in transcriptions column: 140259
Number of unique words in transcriptions column: 35807


In [9]:
data_categories  = clinical_text_df.groupby(clinical_text_df['medical_specialty'])
filtered_data_categories = data_categories.filter(lambda x:x.shape[0] > 50)
final_data_categories = filtered_data_categories.groupby(filtered_data_categories['medical_specialty'])
sent_count, word_count = get_sentence_word_count(filtered_data_categories['transcription'].tolist())

print('Number of sentences in transcriptions column: '    + str(sent_count))
print('Number of unique words in transcriptions column: ' + str(word_count))

Number of sentences in transcriptions column: 130924
Number of unique words in transcriptions column: 35090


In [10]:
reduced_df = filtered_data_categories[['description', 'medical_specialty', 'sample_name', 'transcription', 'keywords']]
reduced_df = reduced_df.drop(reduced_df[reduced_df['transcription'].isna()].index)
reduced_df.to_csv('../data/mtsamples_modified.csv')

The <code>preprocessing</code> function takes a sentence, removes hyperlinks, performs various token-level filters (removing stop words, symbols, punctuation marks, and whitespace), lemmatizes the remaining tokens to their base forms, and returns the cleaned sentence as a string. Specifically, the code <code>token.pos_ != 'SYM' and token.pos_ != 'PUNCT' and token.pos_ != 'SPACE'</code> checks if the token's part-of-speech (POS) tag is not 'SYM' (symbol), 'PUNCT' (punctuation), or 'SPACE'. It further filters out tokens that are symbols, punctuation marks, or represent whitespace. We also appended the lowercase lemma (base form) of the token, obtained using <code>token.lemma_</code>, to the cleaned_tokens list.

In [11]:
def set_seed(seed):
    if seed:
        logging.info(f'Running in deterministic mode with seed {seed}')
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
    else:
        logging.info('Running in non-deterministic mode')
set_seed(2023)

def remove_hyperlinks(sentence):
    sentence = re.sub(
        '(@[A-Za-z0-9]+)|([^0-9A-Za-z \t])|(\w+:\/\/\S+)|^rt|http.+?"', " ", sentence)
    return sentence

def preprocessing(sentence):
    sentence = remove_hyperlinks(sentence)
    doc = nlp(sentence)
    cleaned_tokens = []
    for token in doc:
        if token.is_stop == False and \
            token.pos_ != 'SYM' and \
            token.pos_ != 'PUNCT' and token.pos_ != 'SPACE':
            cleaned_tokens.append(token.lemma_.lower().strip())
    return ' '.join(cleaned_tokens)

def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            new_labels.append(-100)
        else:
            label = labels[word_id]
            if label % 2 == 1:
                label += 1
            new_labels.append(label)
    return new_labels

def tokenize_and_align_labels(tokenizer, examples):
    tokenized_inputs = tokenizer(examples['tokens'], 
                                 truncation = True, 
                                 is_split_into_words = True)
    all_labels = examples['ner_tags']
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))
    tokenized_inputs['labels'] = new_labels
    return tokenized_inputs

def to_tokens(tokenizer, sentence):
    inputs = tokenizer(sentence)
    return tokenizer.convert_ids_to_tokens(inputs.input_ids)

def load_preprocessing(path = '../data/mtsamples_modified.csv'):
    df = pd.read_csv(path)
    df = df.iloc[:50, :]
    for i, row in df.iterrows():
        df.at[i, 'description']   = preprocessing(row['description'])
        df.at[i, 'medical_specialty'] = preprocessing(row['medical_specialty'])
        df.at[i, 'sample_name']   = preprocessing(row['sample_name'])
        df.at[i, 'transcription'] = preprocessing(row['transcription']) if not pd.isnull(row['transcription']) else np.NaN  
        df.at[i, 'keywords']      = preprocessing(row['keywords']) if not pd.isnull(row['keywords']) else np.NaN  
    return df

def split_data(df):
    shuffle = df.sample(frac = 1, random_state = 42)

    train_data,  test_data = train_test_split(shuffle,    test_size = 0.30, random_state = 42)
    train_data, valid_data = train_test_split(train_data, test_size = 0.15, random_state = 42) 

    train_data.to_csv('../data/train.csv', index = False)
    valid_data.to_csv('../data/valid.csv', index = False)
    test_data. to_csv('../data/test.csv' , index = False)

    data_files = {
        'train': '../data/train.csv',
        'valid': '../data/valid.csv',
        'test' : '../data/test.csv'}
    dataset = load_dataset('csv', data_files = data_files, streaming = True)
    return dataset 

def compute_review_length(example):
    return {'review_length': len(example['transcription'].split())}

def bert_tokenizer(df, use_special):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    input_ids, attention_masks = [], []

    if use_special:
        for index, row in df.iterrows():
            encoded_dict = tokenizer.encode_plus(
                row['description'],
                row['medical_specialty'],
                row['sample_name'],
                row['transcription'],
                row['keywords'],
                padding = 'max_length',
                truncation = True,
                return_attention_mask = True,
                return_tensors = 'pt')
            input_ids.append(encoded_dict['input_ids'])
            attention_masks.append(encoded_dict['attention_mask'])
        input_ids = torch.cat(input_ids, dim = 0)
        attention_masks = torch.cat(attention_masks, dim = 0)

    else:
        for description in df['description']:
            encoded_dict = tokenizer.encode_plus(
                description,
                add_special_tokens = True, 
                max_length = 512, 
                padding = 'max_length',
                truncation = True,
                return_attention_mask = True,
                return_tensors = 'pt')
            input_ids.append(encoded_dict['input_ids'])
            attention_masks.append(encoded_dict['attention_mask'])
        input_ids = torch.cat(input_ids, dim = 0)
        attention_masks = torch.cat(attention_masks, dim = 0)
    return input_ids, attention_masks

def process(examples):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    tokenized_inputs = tokenizer(
        examples["sentence"], truncation = True, max_length=512
    )
    return tokenized_inputs

def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:
        print(f'{item:>6}')
    print(f'______\n{sum(params):>6}')

In [12]:
df = load_preprocessing()
dataset = split_data(df)
next(iter(dataset['train']))

In [13]:
text_column = df['transcription'].astype('str')

label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(df['medical_specialty'])
labels = torch.tensor(labels)

num_classes = len(label_encoder.classes_)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
encoded_inputs = tokenizer.batch_encode_plus(
                    text_column.tolist(),
                    max_length = 512,
                    padding = 'max_length',
                    truncation = True,
                    return_attention_mask = True,
                    return_tensors = 'pt')

input_ids = encoded_inputs['input_ids']
attention_mask = encoded_inputs['attention_mask']
token_type_ids = encoded_inputs['token_type_ids']

In [61]:
class NLPDATASET(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, index):
        sequence = self.sequences[index]
        label = self.labels[index]
        return sequence, label
    
def GENERATE_DATALOADER(input_ids, attention_mask, labels, batch_size = 64, use_sampler = True):
    if use_sampler:
        oversampler = RandomOverSampler(random_state = 42)
        X = np.concatenate((input_ids, attention_mask), axis = -1)
        y = np.ravel(labels)

        X_resampled, y_resampled = oversampler.fit_resample(X, y)

        input_ids_resampled      = X_resampled[:, :input_ids.shape[1]]
        attention_mask_resampled = X_resampled[:, input_ids.shape[1]:]
        labels_resampled = y_resampled

        dataset = TensorDataset(torch.tensor(input_ids_resampled),
                                torch.tensor(attention_mask_resampled),
                                torch.tensor(labels_resampled))
        
        train_size = int(0.6 * len(dataset))
        valid_size = int(0.2 * len(dataset))
        tests_size = len(dataset) - train_size - valid_size
        train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, tests_size])
        
    else:
        dataset = TensorDataset(torch.tensor(input_ids), 
                                torch.tensor(attention_mask), 
                                torch.tensor(labels))
        
        train_size = int(0.6 * len(dataset))
        valid_size = int(0.2 * len(dataset))
        tests_size = len(dataset) - train_size - valid_size

        train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, tests_size])

    train_dataloader = DataLoader(
        train_dataset,
        sampler = RandomSampler(train_dataset),
        batch_size = batch_size)
    validation_dataloader = DataLoader(
        valid_dataset,
        sampler = SequentialSampler(valid_dataset),
        batch_size = batch_size)
    test_dataloader = DataLoader(
        test_dataset,
        sampler = SequentialSampler(test_dataset),
        batch_size = batch_size)
    return train_dataloader, validation_dataloader, test_dataloader

train_loader, valid_loader, test_loader = GENERATE_DATALOADER(input_ids, attention_mask, labels, use_sampler = True)

In [62]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, inputs):
        _, (hidden, _) = self.lstm(inputs)
        hidden = hidden.squeeze(0)  
        output = self.fc(hidden)
        return output
    
class BasicClassifier(nn.Module):
    def __init__(self, in_features, hidden_size, out_features):
        super(BasicClassifier, self).__init__()
        self.fc1 = torch.nn.Linear(in_features, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, 32)
        self.fc3 = torch.nn.Linear(32, out_features)
                
    def forward(self, inputs):
        x = F.relu(self.fc1(inputs.squeeze(1)))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        probs = F.relu(logits)
        return probs

def BERT_EMBEDDING(input_ids, attention_mask, token_type_ids):
    bert_model = BertModel.from_pretrained('bert-base-uncased')
    bert_model.eval()
    with torch.no_grad():
        outputs = bert_model(input_ids = input_ids, 
                             attention_mask = attention_mask, 
                             token_type_ids = token_type_ids)
        bert_embeddings = outputs.last_hidden_state

    batch_size = bert_embeddings.size(0)
    sequence_length = bert_embeddings.size(1)
    bert_embeddings = bert_embeddings.view(batch_size, sequence_length, -1)
    embeddings  = bert_embeddings.permute(1, 0, 2)
    return bert_model, embeddings

def LSTM_BASELINE(bert_model, embeddings):
    input_size = bert_model.config.hidden_size
    hidden_size, num_classes = 50, 20  
    lstm_model   = LSTMClassifier(input_size, hidden_size, num_classes)
    lstm_output  = lstm_model(embeddings)
    output_probs = nn.functional.softmax(lstm_output, dim = 1)
    _, predicted_labels = torch.max(output_probs, dim = 1)
    return output_probs, predicted_labels

bert_model, embeddings = BERT_EMBEDDING(input_ids, attention_mask, token_type_ids)
output_probs, predicted_labels = LSTM_BASELINE(bert_model, embeddings)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [63]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def accuracy(preds, y):
    predicted = torch.max(preds.data, 1)[1]
    batch_corr = (predicted == y).sum()
    acc = batch_corr / len(y)
    return acc

def evaluate_predictions(predictions, labels):
    predicted_labels = torch.argmax(predictions, dim=1)
    true_labels = labels.numpy()

    accuracy  = accuracy_score(true_labels, 
                               predicted_labels)
    
    precision = precision_score(true_labels, 
                               predicted_labels, 
                               average = 'weighted')
    
    recall = recall_score(true_labels, 
                          predicted_labels, 
                          average = 'weighted')
    
    f1 = f1_score(true_labels, 
                  predicted_labels, 
                  average = 'weighted')
    return {
        'Accuracy':     np.round(accuracy, 4),
        'Precision':    np.round(precision, 4),
        'Recall':       np.round(recall, 4),
        'F1-score':     np.round(f1, 4)}
        
evaluate_predictions(output_probs, labels)

{'Accuracy': 0.0, 'Precision': 0.0, 'Recall': 0.0, 'F1-score': 0.0}

In [272]:
def train(model, loader, optimizer, criterion, loader_length):
    epoch_loss = 0
    epoch_acc = 0
    model.train() 
    
    for i, (label, text) in enumerate(loader): 
        label = label.to(device) 
        text = text.to(device) 

        predictions = model(text).squeeze(1) 
        loss = criterion(predictions, label)
        acc  = accuracy(predictions, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()            
    return epoch_loss / loader_length, epoch_acc / loader_length

def evaluate(model, loader, criterion, loader_length):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()
    with torch.no_grad():
        for i, (label, text) in enumerate(loader): 
            label = label.to(device) 
            text  = text.to(device)  

            predictions = model(text).squeeze(1) 
            
            loss = criterion(predictions, label)
            acc  = accuracy(predictions, label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
    return epoch_loss / loader_length, epoch_acc / loader_length

In [29]:
def initialize_custom_vocabs(path = '../data/SNMI.csv'):
    clinical = pd.read_csv(path)
    features = ['Preferred Label','Synonyms']
    clinical = clinical[features]
    clinical = clinical['Preferred Label'].append(clinical['Synonyms'])
    clinical = clinical.dropna()
    
    vocab = clinical.str.split('\W', expand = True).stack().unique()
    vocab = filter(None, vocab)

    filepath = '../data/vocab.txt'
    with open(filepath, 'w') as file_handler:
        for item in vocab:
            file_handler.write('{}\n'.format(item))

    with open('../data/vocab.txt', 'r') as f:
        vocab_words = set(f.read().splitlines())
    return vocab_words  

vocab_words = initialize_custom_vocabs()
def is_vocab_word(token):
    return token.lower_ in vocab_words

execute = True
if execute:
    spacy.tokens.Token.set_extension('is_vocab', getter = is_vocab_word, force = True)

def medical_vocabs(text):
    doc = nlp(text)
    tagged_tokens = [(token.text, token.pos_) for token in doc]
    filtered_tokens = [(token, pos) for token, pos in tagged_tokens if token._.is_vocab]
    return filtered_tokens

for word in vocab_words:
    nlp.vocab[word]
vocab = nlp.vocab

In [39]:
def generate_index2word(vocab_words):
    word2index = {'<PAD>': 0, 
                  '<UNK>': 1}
    for vo in vocab_words:
        if word2index.get(vo) is None:
            word2index[vo] = len(word2index)
            
    index2word = {v:k for k, v in word2index.items()}
    return index2word

In [28]:
def initialize_custom_tagger(path = '../data/clinical-stopwords.txt'):
    with open(path, 'r') as f:
        stop_words = set(f.read().splitlines())
    return stop_words

stop_words = initialize_custom_tagger()
def is_stop_word(token):
    return token.lower_ in stop_words

execute = True
if execute:
    spacy.tokens.Token.set_extension('is_stop', getter = is_stop_word, force = True)

def medical_tagger(text):
    doc = nlp(text)
    for token in doc:
        if token.lower_ in stop_words:
            token.is_stop = True
        else:
            token.is_stop = False
    return doc

In [26]:
stop_tagger  = Tagger(nlp.vocab, medical_tagger)
vocab_tagger = Tagger(nlp.vocab, medical_vocabs)
excluded_tokens = {}


use_stop_tagger, use_vocab_tagger  = False, False
if use_stop_tagger:
    nlp.add_pipe('stop_tagger', config = {'component': stop_tagger}, last = True)
    excluded_tokens.add('is_stop')


if use_vocab_tagger:
    nlp.add_pipe(name = 'vocab tagger',
                 component = vocab_tagger,
                 remote = True
                )
    excluded_tokens['is_vocab'] = {False}

In [83]:
dataset = pd.read_csv('../data/mtsamples.csv')
dataset['transcription'].fillna(dataset['description'], inplace = True)
def get_training_corpus():
    for i in range(0, len(dataset), 1000):
        yield dataset[i : i + 1000]['transcription']

In [95]:
def get_training_corpus():
    for i in range(0, len(dataset), 1000):
        yield dataset[i : i + 1000]['transcription']

tokenizer = Tokenizer(models.WordPiece(unk_token = '[UNK]'))
tokenizer.normalizer = normalizers.Sequence([normalizers.NFD(), 
                                             normalizers.Lowercase(), 
                                             normalizers.StripAccents()])

tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()

special_tokens = ['[UNK]', '[PAD]', '[CLS]', '[SEP]', '[MASK]']
trainer = trainers.WordPieceTrainer(vocab_size = 25000, special_tokens = special_tokens)
tokenizer.train_from_iterator(get_training_corpus(), trainer = trainer)

cls_token_id = tokenizer.token_to_id('[CLS]')
sep_token_id = tokenizer.token_to_id('[SEP]')

tokenizer.post_processor = processors.TemplateProcessing(
    single = f'[CLS]:0 $A:0 [SEP]:0',
    pair   = f'[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1',
    special_tokens = [('[CLS]', cls_token_id), ('[SEP]', sep_token_id)])
tokenizer.decoder = decoders.WordPiece(prefix = '##')
tokenizer.save('../data/tokenizer.json')

In [103]:
class GENERATE_NEW_TOKENIZER:
    def __init__(self):
        self.tokenizer = Tokenizer(models.WordPiece(unk_token='[UNK]'))
        self.tokenizer.normalizer = normalizers.Sequence([
            normalizers.NFD(),
            normalizers.Lowercase(),
            normalizers.StripAccents()])
        
        self.tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
        self.special_tokens = ['[UNK]', '[PAD]', '[CLS]', '[SEP]', '[MASK]']
        self.trainer = trainers.WordPieceTrainer(
            vocab_size = 50000,
            special_tokens = self.special_tokens)
        self.cls_token_id  = None
        self.sep_token_id  = None

    def get_training_corpus(self, dataset):
        for i in range(0, len(dataset), 1000):
            yield dataset[i: i + 1000]['transcription']

    def train_tokenizer(self, dataset):
        self.tokenizer.train_from_iterator(
            self.get_training_corpus(dataset),
            trainer = self.trainer)
        self.cls_token_id = self.tokenizer.token_to_id('[CLS]')
        self.sep_token_id = self.tokenizer.token_to_id('[SEP]')

        self.tokenizer.post_processor = processors.TemplateProcessing(
            single = f'[CLS]:0 $A:0 [SEP]:0',
            pair   = f'[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1',
            special_tokens=[('[CLS]', self.cls_token_id), ('[SEP]', self.sep_token_id)])
        self.tokenizer.decoder = decoders.WordPiece(prefix='##')

    def save_tokenizer(self, filepath):
        self.tokenizer.save(filepath)
        
tokenizer = GENERATE_NEW_TOKENIZER()