In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

### This notebook is the continuation from my Notebook on [Bert for Token Classification - Training](http://https://www.kaggle.com/thanish/bert-for-token-classification-training). 
Please do check it out for the training code

In [None]:
import os
import pandas as pd
import numpy as np
import json
import re
from nltk.tokenize import sent_tokenize 
from transformers import BertTokenizer, AutoTokenizer
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import transformers
from tqdm import tqdm
import glob

import warnings
warnings.filterwarnings('ignore')

# Config 

In [None]:
platform = 'Kaggle'
model_name = 'epoch_14_model_sage_bert_base_uncased.bin'

if platform == 'Kaggle':
    bert_path = '../input/huggingface-bert/bert-base-uncased'
    train_path = '/kaggle/input/coleridgeinitiative-show-us-the-data/train/*'
    test_path = '/kaggle/input/coleridgeinitiative-show-us-the-data/test/*'
    model_path = '../input/d/thanish/coleridgemodels/'

config = {'MAX_LEN': 128,
          'tokenizer': AutoTokenizer.from_pretrained(bert_path , do_lower_case=True),
          'batch_size': 20,
          'Epoch': 10,
          'train_path': train_path, 
          'test_path': test_path, 
          'device': 'cuda' if torch.cuda.is_available() else 'cpu',
          'bert_path': bert_path,
          'model_path': model_path,
          'model_name': model_name
         }

In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower())

In [None]:
# def data_joining(data_dict_id):
#     '''
#     This function is to join all the text data from different sections in the json to a single
#     text file. 
#     '''
#     data_length = len(data_dict_id)

#     temp = [data_dict_id[i]['text'] for i in range(data_length)]
# #     temp = [data_dict_id[i]['text'] for i in range(0, 1)]
#     temp = '. '.join(temp)
    
#     return temp


def data_joining(data_dict_id):
    '''
    This function is to join all the text data from different sections in the json to a single
    text file. 
    '''

    sent_list = []
    for i in range(len(data_dict_id)):
        text = data_dict_id[i]['text']
        text = clean_text(text).strip()
        text = re.sub(' +', ' ', text)
            
#         if len(text.split(" "))>15: #If the text is greater than 10 words.
#             temp = [text if any(word in text.lower() for word in ['data', 'study']) else '']
#             sent_list.append(temp[-1])

        if len(text.split(" "))>15: #If the text is greater than 20 words.
            sent_list.append(text)
            
    sent_list = list(set(sent_list))
    final_sentence = '. '.join(sent_list)
            
    return final_sentence

In [None]:
# %%time
# for id in test_data_dict.keys():
#     print(id)
#     print(len(data_joining(test_data_dict[id]).split(" ")))
#     print(len(data_joining_2(test_data_dict[id]).split(" ")))
    
    
# # 8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60
# # 7540
# # 7689
# # 2100032a-7c33-4bff-97ef-690822c43466
# # 3652
# # 3128
# # 2f392438-e215-4169-bebf-21ac4ff253e1
# # 28673
# # 23368
# # 3f316b38-1a24-45a9-8d8c-4e05a42257c6
# # 10671
# # 9868
# # CPU times: user 67.9 ms, sys: 3.32 ms, total: 71.2 ms
# # Wall time: 69.4 ms

# Reading the test data

In [None]:
def read_test_json(test_data_folder):
    '''
    This function reads all the json input files and return a dictionary containing the id as the key
    and all the contents of the json as values
    '''

    test_text_data = {}
    total_files = len(glob.glob(test_data_folder))
    
    for i, test_json_loc in enumerate(glob.glob(test_data_folder)):
        filename = test_json_loc.split("/")[-1][:-5]

        with open(test_json_loc, 'r') as f:
            test_text_data[filename] = json.load(f)

        if (i%1000) == 0:
            print(f"Completed {i}/{total_files}")

    print("All files read")
    return test_text_data

In [None]:
test_data_dict = read_test_json(test_data_folder=config['test_path'])

# Loading the saved model

In [None]:
# initializing the model
model = transformers.BertForTokenClassification.from_pretrained(config['bert_path'],  num_labels = 3)
model = nn.DataParallel(model)

# Reading the trained checkpoint model
trained_model_name = config['model_path'] + config['model_name']
print("Trained model checkpoint:", trained_model_name)
checkpoint = torch.load(trained_model_name, map_location = config['device'])
print("Checkpoint loaded")

# Matching the trained checkpoint model to the initialized model
model.load_state_dict(checkpoint)
print("Model loaded with all keys matching with the checkpoint")

# Prediction function

In [None]:
# Prediction
def prediction_fn(tokenized_sub_sentence):

    tkns = tokenized_sub_sentence
    indexed_tokens = config['tokenizer'].convert_tokens_to_ids(tkns)
    segments_ids = [0] * len(indexed_tokens)

    tokens_tensor = torch.tensor([indexed_tokens]).to(config['device'])
    segments_tensors = torch.tensor([segments_ids]).to(config['device'])
    
    model.eval()
    with torch.no_grad():
        logit = model(tokens_tensor, 
                      token_type_ids=None,
                      attention_mask=segments_tensors)

        logit_new = logit[0].argmax(2).detach().cpu().numpy().tolist()
        prediction = logit_new[0]

#         print(tkns)
#         print(prediction)
        
        kword = ''
        kword_list = []

        for k, j in enumerate(prediction):
            if (len(prediction)>1):

                if (j!=0) & (k==0):
                    #if it's the first word in the first position
                    #print('At begin first word')
                    begin = tkns[k]
                    kword = begin

                elif (j!=0) & (k>=1) & (prediction[k-1]==0):
                    #begin word is in the middle of the sentence
                    begin = tkns[k]
                    previous = tkns[k-1]

                    if begin.startswith('##'):
                        kword = previous + begin[2:]
                    else:
                        kword = begin

                    if k == (len(prediction) - 1):
                        #print('begin and end word is the last word of the sentence')
                        kword_list.append(kword.rstrip().lstrip())

                elif (j!=0) & (k>=1) & (prediction[k-1]!=0):
                    # intermediate word of the same keyword
                    inter = tkns[k]

                    if inter.startswith('##'):
                        kword = kword + "" + inter[2:]
                    else:
                        kword = kword + " " + inter


                    if k == (len(prediction) - 1):
                        #print('begin and end')
                        kword_list.append(kword.rstrip().lstrip())

                elif (j==0) & (k>=1) & (prediction[k-1] !=0):
                    # End of a keywords but not end of sentence.
                    kword_list.append(kword.rstrip().lstrip())
                    kword = ''
                    inter = ''
            else:
                if (j!=0):
                    begin = tkns[k]
                    kword = begin
                    kword_list.append(kword.rstrip().lstrip())
#         print(kword_list)
#         print("")
    return kword_list


In [None]:
def long_sent_split(long_tokens):
    '''
    If the token length is >the max length this function splits it into mutiple list of specified smaller max_length
    '''
    
    start = 0
    end = len(long_tokens)
    max_length = 64

    final_long_tok_split = []
    for i in range(start, end, max_length):
        temp = long_tokens[i: (i + max_length)]
        final_long_tok_split.append(temp)
    return final_long_tok_split

In [None]:
def get_predictions(data_dict):
    
    results = {}

    for i, Id in enumerate(data_dict.keys()):
        current_id_predictions = []
        
#         print(Id)
        sentences = data_joining(data_dict[Id])
        sentence_tokens = sent_tokenize(sentences)
        
        for sub_sentence in sentence_tokens:
            cleaned_sub_sentence = clean_text(sub_sentence)
        
            # Tokenize the sentence
            tokenized_sub_sentence = config['tokenizer'].tokenize(cleaned_sub_sentence)
            
            if len(tokenized_sub_sentence) == 0:
                # If the tokenized sentence are empty
                sub_sentence_prediction_kword_list = []
                
            elif len(tokenized_sub_sentence) <= 512:
                # If the tokenized sentence are less than 512
                sub_sentence_prediction_kword_list = prediction_fn(tokenized_sub_sentence)

            else:
                # If the tokenized sentence are >512 which is long sentences
                long_sent_kword_list = []
                
                tokenized_sub_sentence_tok_split = long_sent_split(tokenized_sub_sentence)
                for i, sent_tok in enumerate(tokenized_sub_sentence_tok_split):
                    if len(sent_tok) != 0:
                        kword_list = prediction_fn(sent_tok)
                        long_sent_kword_list.append(kword_list)
                flat_long_sent_kword = [item for sublist in long_sent_kword_list for item in sublist]
                sub_sentence_prediction_kword_list = flat_long_sent_kword
                            
            if len(sub_sentence_prediction_kword_list) !=0:
                current_id_predictions = current_id_predictions + sub_sentence_prediction_kword_list

        results[Id] = list(set(current_id_predictions))
                
    print("All predictions completed")
    
    return results

In [None]:
%%time
results = get_predictions(data_dict = test_data_dict)
# results

In [None]:
def remove_few_word_prediction(prediction_dict):
    final_result = {}
    for ID in prediction_dict.keys():
        temp = []

        for pred in prediction_dict[ID]:
            pred_split = pred.split(" ")
#             print(ID, pred_split)
            condition1 = len(pred_split)<=2
            condition2 = 'adni' not in pred
            condition3 = 'cccsl' not in pred
            condition4 = 'ibtracs' not in pred
            condition5 = 'slosh model' not in pred
            
            if condition1 & condition2 & condition3 & condition4 & condition5:
                pass
            else:
                temp.append(pred)
        final_result[ID] = temp
        
    return final_result

results = remove_few_word_prediction(prediction_dict=results)
results

In [None]:
sub_df = pd.DataFrame({'Id': list(results.keys()),
                       'PredictionString': list(results.values())})
sub_df.PredictionString = sub_df.PredictionString.apply(lambda x : "|".join(x))
sub_df

In [None]:
sub_df.to_csv("submission.csv", index=False)

# ------------------------------------------- Consider upvoting if you like it :) ------------------------------------------- 