In [None]:
!mkdir TRAIN_DATA
#!pip install matplotlib scikit-learn pandas scipy setuptools wheel spacy[cuda110,transformers,lookups] ipython && pip install jupyter --upgrade
import pandas as pd
import numpy as np
from IPython.display import display
import re

from itertools import chain
from sklearn.model_selection import train_test_split

import spacy
from spacy import displacy
from spacy.tokens import DocBin
import json
from tqdm import tqdm

def divide_chunks(l, n):
    # looping till length l
    for i in range(0, len(l), n): 
        yield l[i:i + n]


def process_feature_text(text):
    text = re.sub('I-year', '1-year', text)
    text = re.sub('-OR-', " or ", text)
    text = re.sub('-', ' ', text)
    return text


def clean_spaces(txt):
    txt = re.sub('\n', ' ', txt)
    txt = re.sub('\t', ' ', txt)
    txt = re.sub('\r', ' ', txt)
#     txt = re.sub(r'\s+', ' ', txt)
    return txt

class spacy_prep:
    def __init__(self, feature_desc, location_desc, note_corpus):
        self.feature_desc = feature_desc
        self.location_desc = location_desc
        self.note_corpus = note_corpus
        self.nlp = spacy.blank('en')
        
    def start_prep(self):
        location_dict = {}
        rels = []
        feature_keys = self.feature_desc[:, 1]
        
        for key in feature_keys:
            location_dict[key] = []
        
        for entry in self.location_desc:
            for feat in self.feature_desc:
                if entry[0] == feat[0]:
                    if entry[1] !='[]':
                        stripper = entry[1][0:len(entry[1])-1]
                        stripper = stripper[1:]
                        #stripper = re.sub(' ', '-', stripper)
                        #stripper = re.sub("'", "", stripper)
                        #stripper = re.split(";", stripper)
                        
                        #stripper = re.split("t", stripper)
                        stripper = re.split(",", stripper)
                        stripper = [[int(s) for s in re.findall(r'\b\d+\b', sentry)] for sentry in stripper]
                        #print(stripper)

                        #stripper = re.split("," , stripper)
                        #stripper = [re.sub(",","", ent) for ent in stripper]
                        #stripper = [re.split(";", entity) for entity in stripper]
                        #stripper = [[int(e) for e in entu] for entu in stripper]
                        indices = []
                        for given_list in stripper:
                            for list_entry in given_list:
                                #print(list_entry)
                                
                                indices.append(list_entry)
                        indices = list(divide_chunks(indices, 2)) #Paired chunks
                        #print(indices)
                        note_num = entry[2]
                        #print(note_num)
                        #indices = np.split(indices, 2)
                        
                        #print(indices)
                        location_dict[feat[1]].append(tuple([note_num, indices]))
                        
        for feat in self.feature_desc:
            for (note_num, indexes) in location_dict[feat[1]]:
                for entry in indexes:
                    start = entry[0]
                    stop = entry[1]
                    my_note = self.note_corpus.loc[note_num]
                    my_note = process_feature_text(my_note)
                    my_note = clean_spaces(my_note)
                    rels.append([my_note, [start, stop], feat[1]])
        return rels
    
    def training_prep(self):
        preproccd_data = self.start_prep()
        collective_dict = {'TRAINING_DATA': [], 
                           'VALIDATION_DATA': []}
        
        
        for note in self.note_corpus.values:
            entities = []
            for entry in preproccd_data:
                
                if entry[0] == note:
                    #print("yes")
                    start = entry[1][0]
                    stop = entry[1][1]
                    key = entry[2]
                    entities.append((start, stop, key))
                            
            results = [note, {"entities": entities}]
            if results[1]['entities'] == []:
                del results[1]
                del results[0]
                
            #print(results)
            collective_dict['TRAINING_DATA'].append(results)
            
        collective_dict['TRAINING_DATA'] = [x for x in collective_dict['TRAINING_DATA'] if x != []]
        
        collective_dict['TRAINING_DATA'], collective_dict['VALIDATION_DATA'] = train_test_split(collective_dict['TRAINING_DATA'] 
                                                                                                , test_size=0.2, random_state=42)
        json_string = json.dumps(collective_dict)
        
        with open('clin_data.json', 'w') as outfile:
            outfile.write(json_string)
            
        return collective_dict
    
    def create_training(self):
        coll_dict = self.training_prep()
        TRAIN_DATA = coll_dict['TRAINING_DATA']
        db = DocBin()
        for text, annot in tqdm(TRAIN_DATA):
            doc = self.nlp.make_doc(text)
            ents = []
    
            # create span objects
            for start, end, label in annot["entities"]:
                span = doc.char_span(start, end, label=label, alignment_mode="contract") 
    
                # skip if the character indices do not map to a valid span
                if span is None:
                    #print("start: {}, end: {}, label: {}".format(start, end, label))
                    print("Skipping entity.")
                else:
                    #print("start: {}, end: {}, label: {}".format(start, end, label))
                    ents.append(span)
                    # handle erroneous entity annotations by removing them
                    try:
                        doc.ents = ents
                    except:
                        # print("BAD SPAN:", span, "\n")
                        ents.pop()
            doc.ents = ents
    
            # pack Doc objects into DocBin
            db.add(doc)
            
        return db
    
    def create_validation(self):
        coll_dict = self.training_prep()
        VAL_DATA = coll_dict['VALIDATION_DATA']
        db = DocBin()
        for text, annot in tqdm(VAL_DATA):
            doc = self.nlp.make_doc(text)
            ents = []
    
            # create span objects
            for start, end, label in annot["entities"]:
                span = doc.char_span(start, end, label=label, alignment_mode="contract") 
    
                # skip if the character indices do not map to a valid span
                if span is None:
                    #print("start: {}, end: {}, label: {}".format(start, end, label))
                    print("Skipping entity.")
                else:
                    #print("start: {}, end: {}, label: {}".format(start, end, label))
                    ents.append(span)
                    # handle erroneous entity annotations by removing them
                    try:
                        doc.ents = ents
                    except:
                        # print("BAD SPAN:", span, "\n")
                        ents.pop()
            doc.ents = ents
    
            # pack Doc objects into DocBin
            db.add(doc)
            
        return db


In [None]:
def load_data():
    #Load raw data
    feature_frame = pd.read_csv('/kaggle/input/nbme-score-clinical-patient-notes/features.csv')
    note_frame = pd.read_csv('/kaggle/input/nbme-score-clinical-patient-notes/patient_notes.csv')
    train_frame = pd.read_csv('/kaggle/input/nbme-score-clinical-patient-notes/train.csv')
    print("Feature frame columns:\n{}\nNote frame columns:\n{}\nTrain frame columns:\n{}\n\n".format(feature_frame.columns, note_frame.columns, train_frame.columns))
    
    note_frame['pn_history'] = note_frame['pn_history'].apply(clean_spaces)
    note_frame.set_index('pn_num', inplace=True)
    note_corpus = note_frame['pn_history']
    #print(note_corpus)
    
    feature_frame = feature_frame.drop_duplicates('feature_text')
    feature_frame['feature_text'] = feature_frame['feature_text'].apply(process_feature_text)
    feature_frame['feature_text'] = feature_frame['feature_text'].apply(clean_spaces)
    
    feature_desc = feature_frame[['feature_num', 'feature_text']].values
    location_desc = train_frame[['feature_num', 'location', 'pn_num']].values
    
    prepper = spacy_prep(feature_desc, location_desc, note_corpus)
    
    TRAIN_DATA_DOC = prepper.create_training()
    TRAIN_DATA_DOC.to_disk("./TRAIN_DATA/TRAIN_DATA.spacy")
    
    VAL_DATA_DOC = prepper.create_validation()
    VAL_DATA_DOC.to_disk("./TRAIN_DATA/VAL_DATA.spacy")

In [None]:
from tqdm import tqdm
def prep_sub(notes, feature_dict, test_csv):
    print(feature_dict)
    test_info = test_csv[['case_num', 'pn_num', 'feature_num']].values
    spacy.require_gpu()
    nlp_output = spacy.load("../input/med-models/output_rob/model-best")
    entities = []
    
    for entry in tqdm(test_info):
        try:
            rel_case_num = entry[0]
            rel_note_num = entry[1]
            rel_feat_num = entry[2]
            
            if rel_feat_num == 601:
                feature_dict['Male'] = 601
            else:
                feature_dict['Male'] = 11
            
            if rel_feat_num == 602:
                feature_dict['17-year'] = 602
            else:
                feature_dict['17-year'] = 12
                
            #print(rel_feat_num)
            
            rel_note_row = np.where(notes[:, 2] == rel_note_num)
            #print(rel_note_row)
            rel_note = notes[rel_note_row][0][0]
            #print(rel_note)
            rel_doc = nlp_output(rel_note)
            
            entity_list = []
            equiv = []
            for ent in rel_doc.ents:
                span_list = []
                if feature_dict[ent.label_]== rel_feat_num:
                    equiv.append(1)
                    curr_span = str(ent.start_char)+" "+ str(ent.end_char)
                    #if rel_feat_num > 10:
                        #print(curr_span)
                    span_list.append(curr_span)
                    #print(curr_span)
                span_list = [x for x in span_list if x]
                #print(span_list)
                if span_list != []:
                    #print(span_list)
                    entry_string = str(rel_note_num).zfill(5) +'_'+ str(rel_feat_num).zfill(3)
                    entity_list.append([rel_case_num, rel_note_num, rel_feat_num, span_list])
            #entity_list = [list(x) for x in set(tuple(x) for x in entity_list)]
            #print(entity_list)
            #if rel_feat_num > 9:
                #print(entity_list)
            if equiv == []:
                spanny = []
                spanny.append('-1 -1')
                entity_list.append([rel_case_num, rel_note_num, rel_feat_num, spanny])
            entities.append(entity_list)
        
        except IndexError as e:
            rel_case_num = entry[0]
            rel_note_num = entry[1]
            rel_feat_num = entry[2]
            
            if rel_feat_num == 601:
                feature_dict['Male'] = 601
            else:
                feature_dict['Male'] = 11
            
            if rel_feat_num == 602:
                feature_dict['17-year'] = 602
            else:
                feature_dict['17-year'] = 12
                
            #print(rel_feat_num)
            
            rel_note_row = np.where(notes[:, 2] == rel_note_num)
            #print(rel_note_row)
            rel_note = notes[rel_note_row][0][0]
            rel_note = rel_note + " " + "err"
            #print(rel_note)
            rel_doc = nlp_output(rel_note)
            
            entity_list = []
            equiv = []
            for ent in rel_doc.ents:
                span_list = []
                if feature_dict[ent.label_]== rel_feat_num:
                    equiv.append(1)
                    curr_span = str(ent.start_char)+" "+ str(ent.end_char)
                    #if rel_feat_num > 10:
                        #print(curr_span)
                    span_list.append(curr_span)
                    #print(curr_span)
                span_list = [x for x in span_list if x]
                #print(span_list)
                if span_list != []:
                    #print(span_list)
                    entry_string = str(rel_note_num).zfill(5) +'_'+ str(rel_feat_num).zfill(3)
                    entity_list.append([rel_case_num, rel_note_num, rel_feat_num, span_list])
            #entity_list = [list(x) for x in set(tuple(x) for x in entity_list)]
            #print(entity_list)
            #if rel_feat_num > 9:
                #print(entity_list)
            if equiv == []:
                spanny = []
                spanny.append('-1 -1')
                entity_list.append([rel_case_num, rel_note_num, rel_feat_num, spanny])
            entities.append(entity_list)
            
        fin_ents = []
        if entities != []:
            for entity in entities:
                for ent in entity:
                    #print(ent)
                    fin_ents.append(ent)
            #for entry in fin_ents:
                #print(entry[0])
    return fin_ents
            
            
    

In [None]:
from IPython.display import display
def submission(entities):
    final_list = []
    if entities !=[]:
        for entry in entities:
            case_num = entry[0]
            note_num = entry[1]
            feat_num = entry[2]
            #span_list = entry[3]
            entry[3] = ''.join(entry[3])
            span = entry[3]
            if span == '':
                span = np.nan
            rel_id = str(note_num).zfill(5) +'_'+ str(feat_num).zfill(3)
            final_list.append([rel_id,span])
        
    subm_df = pd.DataFrame()
    ids = []
    locats = []
    if final_list != []:
        for entry in final_list:
            ids.append(entry[0])
            locats.append(entry[1])
    
    subm_df['id'] = pd.Series(ids)
    subm_df['location'] = pd.Series(locats)
    
    dup_dict = {ID: [] for ID in ids}
    dup_mask = subm_df.id.duplicated()
    dup_df = subm_df[dup_mask]
    subm_df = subm_df[~dup_mask]
    
    dup_df.reset_index()
    for ind, row in dup_df.iterrows():
        row_id = row['id']
        span = row['location']
        for key in dup_dict.keys():
            if row_id == key:
                dup_dict[key].append(str(span)+ ';')
                #dup_dict[key] = ''.join(dup_dict[key])
                
    for key in dup_dict.keys():
        dup_dict[key] = ''.join(dup_dict[key])
        dup_dict[key] = dup_dict[key][:len(dup_dict[key])-1]
        #print(dup_dict[key])
        
    dup_dict = {k:v for k,v in dup_dict.items() if v}
    subm_df['locale'] = subm_df['id'].apply(lambda x: dup_dict.get(x))
    #display(subm_df)
    subm_df['location'] = np.where(~subm_df['locale'].isnull(),subm_df['location'] + ';'+ subm_df['locale'],subm_df['location'])
    subm_df = subm_df.drop(['locale'], axis=1)
    subm_df['location'] = subm_df['location'].replace('-1 -1', np.nan)
    
    subm_df.to_csv('submission.csv', index=False)
    #print(subm_df.columns.tolist())
    
    display(subm_df)
    


In [None]:
import numpy as np
def tester():
    note_frame = pd.read_csv('/kaggle/input/nbme-score-clinical-patient-notes/patient_notes.csv')
    note_frame['pn_history'] = note_frame['pn_history'].apply(clean_spaces)
    note_corpus = note_frame[['pn_history', 'case_num', 'pn_num']]
    
    notes = note_corpus.values
    
    feature_frame = pd.read_csv('/kaggle/input/nbme-score-clinical-patient-notes/features.csv')
    feature_frame['feature_text'] = feature_frame['feature_text'].apply(process_feature_text)
    feature_frame['feature_text'] = feature_frame['feature_text'].apply(clean_spaces)
    
    feature_desc = feature_frame[['feature_num', 'feature_text', 'case_num']].values
    feature_tup = []
    for entry in feature_desc:
        feature_tup.append((entry[1], entry[0]))
        
    feature_dict = dict(feature_tup)
    feature_dict["17 year"] = 12
    feature_dict["Male"] = 11
    
    test_csv = pd.read_csv('/kaggle/input/nbme-score-clinical-patient-notes/test.csv')
    
    entities = prep_sub(notes, feature_dict, test_csv)
    submit = submission(entities)
    #print(entities[0])
'''
    model_test = notes[16][0]
    pn_num = notes[16][2]
    #print(model_test)
    
    nlp_output = spacy.load("../input/med-models/output/model-best")
    doc = nlp_output(model_test)
    displacy.render(doc, style="ent")
    entity_list = []
    
    

    for ent in doc.ents:
        #print("Label: {}, Span: {}:{}".format(feature_dict[ent.label_], ent.start_char, ent.end_char))
        #entity_list.append([feature_dict[ent.label_], case_num, ent.start_char, ent.end_char])
        feat_num = feature_dict[ent.label_]
        feat_num = str(feat_num)
        feat_num = feat_num.zfill(3)

        pat_num = str(pn_num)
        pat_num = pat_num.zfill(5)

        my_id = pat_num+"_"+feat_num
        entity_list.append([my_id, ent.start_char, ent.end_char])
        
    subm_df = pd.DataFrame()
    ids = []
    locats = []
    #subm_df['id'] = pd.Series(entity_list[:, 0])
    for entity in entity_list:
        r_id = entity[0]
        r_start = entity[1]
        r_end = entity[2]
        locale = str(r_start)+" "+ str(r_end)
        ids.append(r_id)
        locats.append(locale)
        
    subm_df['id'] = pd.Series(ids)
    subm_df['location'] = pd.Series(locats)
    
    subm_df.to_csv('submission.csv', index=False)
'''




In [None]:
#!apt update && pip install git+https://github.com/huggingface/transformers torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

In [None]:
#!pip install matplotlib scikit-learn pandas scipy setuptools wheel spacy[cuda114,transformers,lookups] ipython
#!pip download spacy_transformers
#!rm -r ./output

In [None]:
def setup():
    load_data()

In [None]:
if __name__ =="__main__":
    !pip install --no-index --no-deps ../input/mywheels3/Packages/*.whl
    #setup()
    #!python3 -m spacy init fill-config ../input/spacy-params/base_config.cfg config.cfg
    #!python3 -m spacy train config.cfg -g 0 --output ./output
    import spacy_transformers
    tester()
