In [73]:
import os
import pandas as pd
import numpy as np
import json
import re
import nltk
from nltk.tokenize import sent_tokenize 
from transformers import BertTokenizer, AutoTokenizer
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import transformers
from tqdm import tqdm
import glob

import tensorflow_hub as hub
import tensorflow as tf

from datetime import datetime

import pickle

import warnings
warnings.filterwarnings('ignore')

# Config

In [74]:
platform = 'Azure'
model_name = 'epoch_2_model_azure_roberta_base_cleaned_xtra_label.bin'

if platform == 'Azure':
    bert_path = '/home/thanish/transformer_models/bert_base_uncased'
    test_path = '../test/*'
    model_path = '../output/'
elif platform == 'Kaggle':
    bert_path = '../input/bertlargeuncasedpytorch'
    test_path = '/kaggle/input/coleridgeinitiative-show-us-the-data/test/*'
    model_path = '../input/coleridgemodels/'
else:
    bert_path = 'C:/Users/thanisb/Documents/transformer_models/bert_base_uncased/'
    test_path = '../test/*'
    model_path = '../output/'
    
config = {'MAX_LEN':512,
          'tokenizer': AutoTokenizer.from_pretrained('roberta-base' , do_lower_case=True),
          'batch_size':20,
          'Epoch': 3,
          'test_path':test_path, 
          'device': 'cuda' if torch.cuda.is_available() else 'cpu',
          'model_path':model_path,
          'model_name':model_name
         }

In [75]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower())

In [76]:
def data_joining(data_dict_id):
    '''
    This function is to join all the text data from different sections in the json to a single
    text file. 
    '''
    data_length = len(data_dict_id)

    #     temp = [clean_text(data_dict_id[i]['text']) for i in range(data_length)]
    temp = [data_dict_id[i]['text'] for i in range(data_length)]
    temp = '. '.join(temp)
    
    return temp

# Reading the test dataset

In [77]:
def read_test_json(test_data_folder):
    '''
    This function reads all the json input files and return a dictionary containing the id as the key
    and all the contents of the json as values
    '''

    test_text_data = {}
    total_files = len(glob.glob(test_data_folder))
    
    for i, test_json_loc in enumerate(glob.glob(test_data_folder)):
        filename = test_json_loc.split("/")[-1][:-5]

        with open(test_json_loc, 'r') as f:
            test_text_data[filename] = json.load(f)

        if (i%1000) == 0:
            print(f"Completed {i}/{total_files}")

    print("All files read")
    return test_text_data

In [78]:
test_data_dict = read_test_json(test_data_folder="../test/*")

Completed 0/4
All files read


# Reading the saved model file

In [80]:
# initializing the model
model = transformers.RobertaForTokenClassification.from_pretrained('roberta-base',  num_labels = 3)
model = nn.DataParallel(model)

# Reading the trained checkpoint model
trained_model_name = config['model_path'] + config['model_name']
print("Trained model checkpoint:", trained_model_name)
checkpoint = torch.load(trained_model_name, map_location = config['device'])
print("Checkpoint loaded")

# Matching the trained checkpoint model to the initialized model
model.load_state_dict(checkpoint)
print("Model loaded with all keys matching with the checkpoint")

model = model.to(config['device'])
model = nn.DataParallel(model)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able

Trained model checkpoint: ../output/epoch_2_model_azure_roberta_base_cleaned_xtra_label.bin
Checkpoint loaded
Model loaded with all keys matching with the checkpoint


In [82]:
# Prediction
def prediction_fn(tokenized_sub_sentence):

    tkns = tokenized_sub_sentence
    indexed_tokens = config['tokenizer'].convert_tokens_to_ids(tkns)
    segments_ids = [0] * len(indexed_tokens)

    tokens_tensor = torch.tensor([indexed_tokens]).to(config['device'])
    segments_tensors = torch.tensor([segments_ids]).to(config['device'])
    
    model.eval()
    with torch.no_grad():
        logit = model(tokens_tensor, 
                      token_type_ids=None,
                      attention_mask=segments_tensors)

        logit_new = logit[0].argmax(2).detach().cpu().numpy().tolist()
        prediction = logit_new[0]

#         print(tkns)
#         print(logit_new)
#         print(prediction)
        
        kword = ''
        kword_list = []

        for k, j in enumerate(prediction):
            if (len(prediction)>1):

                if (j!=0) & (k==0):
                    #if it's the first word in the first position
                    #print('At begin first word')
                    begin = tkns[k]
                    kword = begin

                elif (j!=0) & (k>=1) & (prediction[k-1]==0):
                    #begin word is in the middle of the sentence
                    begin = tkns[k]
                    previous = tkns[k-1]

                    if not begin.startswith('Ġ'):
                        kword = previous + begin[:]
                    else:
                        kword = begin

                    if k == (len(prediction) - 1):
                        #print('begin and end word is the last word of the sentence')
                        kword_list.append(kword.rstrip().lstrip().replace('Ġ', ''))

                elif (j!=0) & (k>=1) & (prediction[k-1]!=0):
                    # intermediate word of the same keyword
                    inter = tkns[k]

                    if not inter.startswith('Ġ'):
                        kword = kword + "" + inter[:]
                    else:
                        kword = kword + " " + inter


                    if k == (len(prediction) - 1):
                        #print('begin and end')
                        kword_list.append(kword.rstrip().lstrip().replace('Ġ', ''))

                elif (j==0) & (k>=1) & (prediction[k-1] !=0):
                    # End of a keywords but not end of sentence.
                    kword_list.append(kword.rstrip().lstrip().replace('Ġ', ''))
                    kword = ''
                    inter = ''
            else:
                if (j!=0):
                    begin = tkns[k]
                    kword = begin
                    kword_list.append(kword.rstrip().lstrip().replace('Ġ', ''))
#         print(kword_list)
#         print("")
    return kword_list


In [83]:
def long_sent_split(text):
    sent_split = text.split(" ")

    start = 0
    end = len(sent_split)
    max_length = 64

    final_sent_split = []
    for i in range(start, end, max_length):
        temp = sent_split[i: (i + max_length)]
        final_sent_split.append(" ".join(i for i in temp))
    return final_sent_split

In [84]:
def get_predictions(data_dict):
    
    results = {}

    for i, Id in enumerate(data_dict.keys()):
        current_id_predictions = []
        
        print(Id)
        sentences = data_joining(data_dict[Id])
        sentence_tokens = sent_tokenize(sentences)
        
        for sub_sentence in sentence_tokens:
            cleaned_sub_sentence = clean_text(sub_sentence)
        
            # Tokenize the sentence
            tokenized_sub_sentence = config['tokenizer'].tokenize(" " + cleaned_sub_sentence)
            
            if len(tokenized_sub_sentence) == 0:
                # If the tokenized sentence are empty
                sub_sentence_prediction_kword_list = []
                
            elif len(tokenized_sub_sentence) <= 512:
                # If the tokenized sentence are less than 512
                sub_sentence_prediction_kword_list = prediction_fn(tokenized_sub_sentence)

            else:
                # If the tokenized sentence are >512 which is long sentences
                long_sent_kword_list = []
                
                tokenized_sub_sentence_tok_split = long_sent_split(text = tokenized_sub_sentence)
                for i, sent_tok in enumerate(tokenized_sub_sentence_tok_split):
                    if len(sent) != 0:
                        kword_list = prediction_fn(sent_tok)
                        long_sent_kword_list.append(kword_list)
                flat_long_sent_kword = [item for sublist in long_sent_kword_list for item in sublist]
                sub_sentence_prediction_kword_list = flat_long_sent_kword
                            
            if len(sub_sentence_prediction_kword_list) !=0:
                current_id_predictions = current_id_predictions + sub_sentence_prediction_kword_list

        results[Id] = list(set(current_id_predictions))
                
    print("All predictions completed")
    
    return results

In [87]:
def remove_few_word_prediction(prediction_dict):
    final_result = {}
    for ID in prediction_dict.keys():
        temp = []

        for pred in prediction_dict[ID]:
            pred_split = pred.split(" ")
            condition1 = len(pred_split)<=2
            condition2 = 'adni' not in pred
            condition3 = 'cccsl' not in pred
            condition4 = 'ibtracs' not in pred
            condition5 = 'slosh model' not in pred
            
            if condition1 & condition2 & condition3 & condition4 & condition5:
                pass
            else:
                temp.append(pred)
        final_result[ID] = temp
        
    return final_result

In [1]:
results_1 = get_predictions(data_dict = test_data_dict)

In [89]:
# Roberta - epoch 0
res = remove_few_word_prediction(prediction_dict=results_1)
for i in res.keys():
    print(len(res[i]))
res

5
38
34
16


{'2100032a-7c33-4bff-97ef-690822c43466': ['heart and aging research in genomic epidemiology',
  'helsinki birth cohort',
  'ssgac cognition study',
  'adni',
  'affymetrix genome wide snp'],
 '2f392438-e215-4169-bebf-21ac4ff253e1': ['education secondary education and higher',
  'center for education statistics institute',
  'university diploma of specialized higher studies',
  'ppps exchange rate data',
  'international standard classification of education',
  'current population survey',
  'national education systems ines survey',
  'end of upper secondary education',
  'certificates of higher education',
  'progress in international reading literacy study',
  'international data base',
  'pisa data collection effort',
  'education combined field of study',
  'program for international student assessment',
  'online education database',
  'integrated postsecondary education data system',
  'oecd national accounts database',
  'high school for upper secondary school',
  'development in

In [90]:
prediction_df = pd.DataFrame({'Id':list(res.keys()),
                              'PredictionString':list(res.values())})
prediction_df = prediction_df.explode('PredictionString')
prediction_df

Unnamed: 0,Id,PredictionString
0,2100032a-7c33-4bff-97ef-690822c43466,heart and aging research in genomic epidemiology
0,2100032a-7c33-4bff-97ef-690822c43466,helsinki birth cohort
0,2100032a-7c33-4bff-97ef-690822c43466,ssgac cognition study
0,2100032a-7c33-4bff-97ef-690822c43466,adni
0,2100032a-7c33-4bff-97ef-690822c43466,affymetrix genome wide snp
...,...,...
3,8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60,rural urban continuum codes
3,8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60,current retail trends
3,8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60,agriculture and food
3,8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60,food access research atlas
