<a href="https://colab.research.google.com/github/sushant2076/py/blob/master/DLHC_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git clone https://github.com/VerataiLtd/snomed_graph.git

Cloning into 'snomed_graph'...
remote: Enumerating objects: 35, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 35 (delta 19), reused 24 (delta 8), pack-reused 0[K
Receiving objects: 100% (35/35), 26.40 KiB | 458.00 KiB/s, done.
Resolving deltas: 100% (19/19), done.


In [None]:
from pathlib import Path
import pandas as pd

In [None]:
def load_snomed_df(data_path):
    """
    Create a SNOMED CT concept DataFrame.

    Derived from: https://github.com/CogStack/MedCAT/blob/master/medcat/utils/preprocess_snomed.py

    Returns:
        pandas.DataFrame: SNOMED CT concept DataFrame.
    """

    def _read_file_and_subset_to_active(filename):
        with open(filename, encoding="utf-8") as f:
            entities = [[n.strip() for n in line.split("\t")] for line in f]
            df = pd.DataFrame(entities[1:], columns=entities[0])
        return df[df.active == "1"]

    active_terms = _read_file_and_subset_to_active(
        data_path / "sct2_Concept_Snapshot_INT_20230531.txt"
    )
    active_descs = _read_file_and_subset_to_active(
        data_path / "sct2_Description_Snapshot-en_INT_20230531.txt"
    )

    df = pd.merge(active_terms, active_descs, left_on=["id"], right_on=["conceptId"], how="inner")[
        ["id_x", "term", "typeId"]
    ].rename(columns={"id_x": "concept_id", "term": "concept_name", "typeId": "name_type"})

    # active description or active synonym
    df["name_type"] = df["name_type"].replace(
        ["900000000000003001", "900000000000013009"], ["P", "A"]
    )
    active_snomed_df = df[df.name_type.isin(["P", "A"])]

    active_snomed_df["hierarchy"] = active_snomed_df["concept_name"].str.extract(
        r"\((\w+\s?.?\s?\w+.?\w+.?\w+.?)\)$"
    )
    active_snomed_df = active_snomed_df[active_snomed_df.hierarchy.notnull()].reset_index(drop=True)
    return active_snomed_df

In [None]:
snomed_rf2_path = Path(
    "/content/drive/MyDrive/DLHC_Project/SnomedCT_InternationalRF2_PRODUCTION_20230531T120000Z_Challenge_Edition"
)

In [None]:
df = load_snomed_df(snomed_rf2_path / "Snapshot" / "Terminology")
df.shape[0]

364323

In [None]:
concept_type_subset = [
    "procedure",                    # top level category
    "body structure",               # top level category
    "finding",                      # top level category
    "disorder",                     # child of finding
    "morphologic abnormality",      # child of body structure
    "regime/therapy",               # child of procedure
    "cell structure",               # child of body structure
]

In [None]:

filtered_df = df[
    (df.hierarchy.isin(concept_type_subset)) &   # Filter the SNOMED data to the selected Concept Types
    ( df.name_type == "P" )                      # Preferred Terms only (i.e. one row per concept, drop synonyms)
].copy()


In [None]:
filtered_df.shape[0]

218467

In [None]:
filtered_df.hierarchy.value_counts()

disorder                   83431
procedure                  55981
finding                    35545
body structure             34914
morphologic abnormality     4969
regime/therapy              3110
cell structure               517
Name: hierarchy, dtype: int64

In [None]:
filtered_df.drop("name_type", axis="columns", inplace=True)
filtered_df.to_csv("flattened_terminology.csv")

In [None]:
training_set = pd.read_csv("/content/drive/MyDrive/DLHC_Project/snomed-ct-entity-linking-challenge-1.0.0/mimic-iv_notes_training_set.csv")
training_annot = pd.read_csv("/content/drive/MyDrive/DLHC_Project/snomed-ct-entity-linking-challenge-1.0.0/train_annotations.csv")

In [None]:

#BiLSTM-CNN
#FLERT

In [None]:
len(training_set["text"][0])

4279

In [None]:
training_set["text"][0]

' \nName:  ___                  Unit No:   ___\n \nAdmission Date:  ___              Discharge Date:   ___\n \nDate of Birth:  ___             Sex:   M\n \nService: SURGERY\n \nAllergies: \nPenicillins\n \nAttending: ___.\n \nChief Complaint:\nBiliary pancreatitis\n \nMajor Surgical or Invasive Procedure:\n___: Laparoscopic cholecystectomy\n\n \nHistory of Present Illness:\nMr. ___ is a ___ man who had severe biliary \npancreatitis resulting in pancreatic necrosis for which he was \ntreated with nasojejunal feedings and pancreatic rest.  He had \ninitially had multisystem organ failure, which improved. Mr. \n___ has a large postnecrotic pseudocyst, which has been \ndrained through a minimally invasive approach into his GI tract. \n He has some debris, but this is not currently infected. The \npatient was followed by Dr. ___ in his ___ \nclinic to discuss cholecystectomy. After discussion of all \nrisks, benefits and possible outcomes, patient was scheduled for \nelective cholecystectom

In [None]:
filtered_df

Unnamed: 0,concept_id,concept_name,hierarchy
3,104001,Excision of lesion of patella (procedure),procedure
4,106004,Structure of posterior carpal region (body str...,body structure
5,107008,Structure of fetal part of placenta (body stru...,body structure
6,108003,Entire condylar emissary vein (body structure),body structure
7,109006,Anxiety disorder of childhood OR adolescence (...,disorder
...,...,...,...
364312,971918681000119107,Chronic respiratory failure due to obstructive...,disorder
364313,972604701000119104,Acquired arteriovenous malformation of vascula...,disorder
364315,978253001000132109,Small bowel enteroscopy normal (finding),finding
364317,985355341000119101,Malignant melanoma of skin of left wrist (diso...,disorder


In [None]:
filtered_df.to_csv("tags.csv")

In [None]:
training_annot

Unnamed: 0,note_id,start,end,concept_id
0,10060142-DS-9,179,190,91936005
1,10060142-DS-9,228,248,95563007
2,10060142-DS-9,294,322,45595009
3,10060142-DS-9,390,411,95563007
4,10060142-DS-9,425,444,1835003
...,...,...,...,...
51569,19926965-DS-14,9216,9227,76948002
51570,19926965-DS-14,9257,9261,22253000
51571,19926965-DS-14,9298,9302,22253000
51572,19926965-DS-14,9318,9323,386661006


In [None]:
training_set

Unnamed: 0,note_id,text
0,10060142-DS-9,\nName: ___ Unit No: ___\...
1,10097089-DS-8,\nName: ___ Unit No: ___\...
2,10124346-DS-4,\nName: ___ Unit No: ___\n \n...
3,10302979-DS-5,\nName: ___ Unit No: ___\n...
4,10352433-DS-20,\nName: ___ Unit No: ___\...
...,...,...
199,19859532-DS-19,\nName: ___ Unit No: ___...
200,19871603-DS-14,\nName: ___ Unit No: ___\...
201,19884924-DS-14,\nName: ___ Unit No: __...
202,19895550-DS-7,\nName: ___ Unit No: ___\n \...


In [None]:
dataset_merged = pd.merge(training_annot, training_set, on='note_id', how='left')

In [None]:
dataset_merged

Unnamed: 0,note_id,start,end,concept_id,text
0,10060142-DS-9,179,190,91936005,\nName: ___ Unit No: ___\...
1,10060142-DS-9,228,248,95563007,\nName: ___ Unit No: ___\...
2,10060142-DS-9,294,322,45595009,\nName: ___ Unit No: ___\...
3,10060142-DS-9,390,411,95563007,\nName: ___ Unit No: ___\...
4,10060142-DS-9,425,444,1835003,\nName: ___ Unit No: ___\...
...,...,...,...,...,...
51569,19926965-DS-14,9216,9227,76948002,\nName: ___ Unit No: ___\n \...
51570,19926965-DS-14,9257,9261,22253000,\nName: ___ Unit No: ___\n \...
51571,19926965-DS-14,9298,9302,22253000,\nName: ___ Unit No: ___\n \...
51572,19926965-DS-14,9318,9323,386661006,\nName: ___ Unit No: ___\n \...


In [None]:
filtered_df['concept_id'] = pd.to_numeric(filtered_df['concept_id'])

In [None]:
dataset_train = pd.merge(dataset_merged, filtered_df, on="concept_id", how="left")

In [None]:
dataset_train

Unnamed: 0,note_id,start,end,concept_id,text,concept_name,hierarchy
0,10060142-DS-9,179,190,91936005,\nName: ___ Unit No: ___\...,Allergy to penicillin (finding),finding
1,10060142-DS-9,228,248,95563007,\nName: ___ Unit No: ___\...,Gallstone pancreatitis (disorder),disorder
2,10060142-DS-9,294,322,45595009,\nName: ___ Unit No: ___\...,Laparoscopic cholecystectomy (procedure),procedure
3,10060142-DS-9,390,411,95563007,\nName: ___ Unit No: ___\...,Gallstone pancreatitis (disorder),disorder
4,10060142-DS-9,425,444,1835003,\nName: ___ Unit No: ___\...,Necrosis of pancreas (disorder),disorder
...,...,...,...,...,...,...,...
51569,19926965-DS-14,9216,9227,76948002,\nName: ___ Unit No: ___\n \...,Severe pain (finding),finding
51570,19926965-DS-14,9257,9261,22253000,\nName: ___ Unit No: ___\n \...,Pain (finding),finding
51571,19926965-DS-14,9298,9302,22253000,\nName: ___ Unit No: ___\n \...,Pain (finding),finding
51572,19926965-DS-14,9318,9323,386661006,\nName: ___ Unit No: ___\n \...,Fever (finding),finding


In [None]:
for row in range(len(dataset_train["text"])):
  for j in range(len(dataset_train["text"][row])):
    if(dataset_train["text"][row][j]=='S' and dataset_train["text"][row][j+2]=='r'):
      dataset_train["text"][row]= dataset_train["text"][row][j:]
      dataset_train["start"][row]-=j
      dataset_train["end"][row]-=j
      break


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset_train["text"][row]= dataset_train["text"][row][j:]
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset_train["start"][row]-=j
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset_train["end"][row]-=j


In [None]:
dataset_train

Unnamed: 0,note_id,start,end,concept_id,text,concept_name,hierarchy
0,10060142-DS-9,31,42,91936005,Service: SURGERY\n \nAllergies: \nPenicillins\...,Allergy to penicillin (finding),finding
1,10060142-DS-9,80,100,95563007,Service: SURGERY\n \nAllergies: \nPenicillins\...,Gallstone pancreatitis (disorder),disorder
2,10060142-DS-9,146,174,45595009,Service: SURGERY\n \nAllergies: \nPenicillins\...,Laparoscopic cholecystectomy (procedure),procedure
3,10060142-DS-9,242,263,95563007,Service: SURGERY\n \nAllergies: \nPenicillins\...,Gallstone pancreatitis (disorder),disorder
4,10060142-DS-9,277,296,1835003,Service: SURGERY\n \nAllergies: \nPenicillins\...,Necrosis of pancreas (disorder),disorder
...,...,...,...,...,...,...,...
51569,19926965-DS-14,9071,9082,76948002,Service: SURGERY\n \nAllergies: \npenicillin G...,Severe pain (finding),finding
51570,19926965-DS-14,9112,9116,22253000,Service: SURGERY\n \nAllergies: \npenicillin G...,Pain (finding),finding
51571,19926965-DS-14,9153,9157,22253000,Service: SURGERY\n \nAllergies: \npenicillin G...,Pain (finding),finding
51572,19926965-DS-14,9173,9178,386661006,Service: SURGERY\n \nAllergies: \npenicillin G...,Fever (finding),finding


In [None]:
l =[]
for row in (dataset_train["note_id"]):
  l.append(row)

In [None]:
len(set(l))

204

In [None]:
dataset_train_ = dataset_train.copy()

In [None]:
annotation = []
filters = ["!", "#", "$", "%", "&", "(", ")", "/", "*", ".", ":", ";", "<", "=", ">", "?", "@", "[",
               "\\", "]", "_", "`", "{", "}", "~", "'"]

for i in range(len(dataset_train[["start", "end", "text", "concept_name"]])):
  s = dataset_train["start"][i]
  e = dataset_train["end"][i]
  t = dataset_train["text"][i]
  c = dataset_train["concept_name"][i]
  text = dataset_train["text"][i][s:e]
  # dataset_train["text"][i] = dataset_train["text"][i].replace('_', '')
  # dataset_train["text"][i] = dataset_train["text"][i].replace('=', '')
  # dataset_train["text"][i] = ' '.join(dataset_train["text"][i].split())
  text = text.replace('\n', '')
  for j in filters:
    if j in text:
      text = text.replace(j, " " + j + " ")

  # text = text.replace('_', '')
  # text = text.replace('=', '')
  # text = ' '.join(text.split())

  tup = (text, c)
  annotation.append(tup)



In [None]:
dataset_train['annotations'] = annotation

In [None]:
dataset = dataset_train[["text", "annotations"]]

In [None]:
dataset_row_merged = dataset.groupby('text')["annotations"].agg(list).reset_index()

In [None]:
dataset_row_merged["annotations"][1]

[('Codeine', 'Allergy to codeine (finding)'),
 ('adhesive tape', 'Allergy to adhesive agent (finding)'),
 ('chest pain', 'Chest pain (finding)'),
 ('fatigue', 'Fatigue (finding)'),
 ('Coronary Artery Bypass Grafting x 2',
  'Coronary artery bypass grafts x 2 (procedure)'),
 ('left internal mammary artery',
  'Structure of left internal thoracic artery (body structure)'),
 ('left anterior descending coronary artery',
  'Structure of anterior descending branch of left coronary artery (body structure)'),
 ('saphenous vein graft',
  'Aortocoronary artery bypass of one coronary artery with saphenous vein graft (procedure)'),
 ('obtuse marginal',
  'Structure of obtuse marginal branch of circumflex branch of left coronary artery (body structure)'),
 ('cardiac risk factors',
  'Assessment for risk of cardiovascular disease (procedure)'),
 ('chest pain', 'Chest pain (finding)'),
 ('angina', 'Angina (disorder)'),
 ('Stress test', 'Electrocardiogram with exercise test (procedure)'),
 ('anterosep

In [None]:
for i in range(len(dataset_row_merged)):
  dataset_row_merged["text"][i] = dataset_row_merged["text"][i].replace('\n', '')
  for j in filters:
    if j in dataset_row_merged["text"][i]:
      dataset_row_merged["text"][i] = dataset_row_merged["text"][i].replace(j, " " + j + " ")

In [None]:
dataset_row_merged["text"][1]

"Service :  CARDIOTHORACIC Allergies :  Codeine  /  adhesive tape  /  Percocet Attending :   _  _  _  .  Chief Complaint : Recurrent chest pain, fatigue Major Surgical or Invasive Procedure : Coronary Artery Bypass Grafting x 2  ( left internal mammary artery to the left anterior descending coronary artery ;  saphenous vein graft to the obtuse marginal branch )  History of Present Illness : Mrs .   _  _  _  is a  _  _  _  year old female with multiple cardiac risk factors whopresented with recurrent chest pain concerning for angina .  Stress test was notable for mild anteroseptal ischemia .  Subsequent cardiac catheterization showed 60 %  left main lesion .  Of note, she has history of SVT .  She reportedoccasional episodes of palpitations and diaphoresis sometimes associated with chest pain .  These episodes occurred several times per week .   Based upon the above findings, she was referred to Dr .   _  _  _  surgical revascularization .  Past Medical History : 1 .   Hypertension2 .  

In [None]:
dataset_row_merged.to_csv('df.csv')

In [None]:
import pandas as pd
import numpy as np

In [None]:
# dataset_row_merged = pd.read_csv("df.csv")

In [None]:
dataset_row_merged["annotations"][0]

[('Codeine', 'Allergy to codeine (finding)'),
 ('Quinine', 'Allergy to quinine (finding)'),
 ('non-small cell lung cancer', 'Non-small cell lung cancer (disorder)'),
 ('cervical mediastinoscopy', 'Cervical mediastinoscopy (procedure)'),
 ('flexible bronchoscopy', 'Flexible bronchoscopy (procedure)'),
 ('smoker', 'Smoker (finding)'),
 ('Hep C', 'Viral hepatitis type C (disorder)'),
 ('HCC', 'Liver cell carcinoma (disorder)'),
 ('liver transplant', 'Transplantation of liver (procedure)'),
 ('RUL', 'Structure of upper lobe of right lung (body structure)'),
 ('PET', 'Positron emission tomography (procedure)'),
 ('nodule', 'Nodule (morphologic abnormality)'),
 ('non-small cell carcinoma',
  'Non-small cell carcinoma (morphologic abnormality)'),
 ('asymptomatic', 'Asymptomatic (finding)'),
 ('SOB', 'Dyspnea (finding)'),
 ('cough', 'Cough (finding)'),
 ('fevers', 'Fever (finding)'),
 ('chills', 'Chill (finding)'),
 ('nightsweats', 'Night sweats (finding)'),
 ('biopsied', 'Biopsy (procedure)')

In [None]:
from tqdm import tqdm
from difflib import SequenceMatcher
import re
import pickle

In [None]:
def clean(text):
    '''
    Just a helper fuction to add a space before the punctuations for better tokenization
    '''
    filters = ["!", "#", "$", "%", "&", "(", ")", "/", "*", ".", ":", ";", "<", "=", ">", "?", "@", "[",
               "\\", "]", "_", "`", "{", "}", "~", "'"]
    for i in text:
        if i in filters:
            text = text.replace(i, " " + i)

    return text

In [None]:
def matcher(string, pattern):
    '''
    Return the start and end index of any pattern present in the text.
    '''
    match_list = []
    pattern = pattern.strip()
    seqMatch = SequenceMatcher(None, string, pattern, autojunk=False)
    match = seqMatch.find_longest_match(0, len(string), 0, len(pattern))
    if (match.size == len(pattern)):
        start = match.a
        end = match.a + match.size
        match_tup = (start, end)
        string = string.replace(pattern, "X" * len(pattern), 1)
        match_list.append(match_tup)
        # print(match_list)
    return match_list, string

In [None]:
# def mark_sentence(s, match_list):
#     '''
#     Marks all the entities in the sentence as per the BIO scheme.
#     '''
#     word_dict = []
#     for word in s.split():
#         word_dict[word] = 'O'

#     for start, end, e_type in match_list:
#         temp_str = s[start:end]
#         tmp_list = temp_str.split()
#         print(tmp_list)
#         if len(tmp_list) > 1:
#             word_dict[tmp_list[0]] = 'B-' + e_type
#             for w in tmp_list[1:]:
#                 word_dict[w] = 'I-' + e_type
#         else:
#             word_dict[temp_str] = 'B-' + e_type
#     return word_dict

In [None]:
def mark_sentence(text, match_list):
  l = text.split()
  for start, end, e_type in match_list:
    temp = text[start:end]
    # print(temp)
    temp_list = temp.split()
    # print(temp_list)
    n = len(temp_list)
    for i in range(len(l)):
      if(n==1):
        if(l[i]==temp_list[0]):
          l[i]='B-' + e_type
      else:
        flag = True
        for j in range(n):
          if (i+j)<len(l) and temp_list[j]!=l[i+j]:
            flag = False
            break
        if(flag):
          l[i] = 'B-' + e_type
          for k in range(1,n):
            if(i+k < len(l)):
              l[i+k] = 'I-' + e_type
  for i in range(len(l)):
    if((len(l[i])>1) and (l[i][0]=='B' or l[i][0]=='I') and l[i][1]=='-'):
      continue
    else:
      l[i]='O'
  return l







In [None]:
def create_data(df, filepath):
    '''
    The function responsible for the creation of data in the said format.
    '''
    df2 = pd.DataFrame(columns=['text', 'labels'])
    index = 1
    with open(filepath , 'w') as f:
        for text, annotation in zip(df.text, df.annotations):
            # text = clean(text)
            text_ = text
            # print(text)
            match_list = []
            # print(annotation)
            for i in annotation:
                # print(i)
                a, text_ = matcher(text, i[0])
                # print(a[0])
                match_list.append((a[0][0], a[0][1], i[1]))
            # print(match_list)

            d = mark_sentence(text, match_list)
            complete_string = ' '.join(d)
            data_to_append = {'text': text,
                  'labels': complete_string,}

            df2 = df2.append(pd.Series(data_to_append, name=index))
            index = index + 1



            for i in range(len(d)):
                f.writelines(text.split()[i] + ' ' + d[i] +'\n')
            f.writelines('\n')

    return df2

In [None]:
final_dataset = create_data(dataset_row_merged, "train/data.txt")

  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.append(pd.Series(data_to_append, name=index))
  df2 = df2.ap

In [None]:
final_dataset.to_csv("NER_dataset.csv")

In [None]:
df_train = dataset_row_merged[:184]
df_test = dataset_row_merged[184:199]
df_dev = dataset_row_merged[199:]

create_data(df_train, "train/train.txt")
create_data(df_test, "train/test.txt")
create_data(df_dev, "train/dev.txt")


In [None]:
!pip3 install flair

In [None]:
!pip3 install --upgrade urllib3

In [None]:
from flair.data import Corpus
from flair.datasets import ColumnCorpus

# define columns
columns = {0 : 'text', 1 : 'ner'}
# directory where the data resides
data_folder = 'train'
# initializing the corpus
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file = 'train.txt',
                              test_file = 'test.txt',
                              dev_file = 'dev.txt')

In [None]:
print(len(corpus.train))

In [None]:
print(corpus.train[0].to_tagged_string('ner'))

In [None]:
# tag to predict
tag_type = 'ner'
# make tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)

In [None]:
from flair.embeddings import WordEmbeddings, StackedEmbeddings
from flair.embeddings import TokenEmbeddings
from typing import List
embedding_types : List[TokenEmbeddings] = [
        WordEmbeddings('glove'),
        ## other embeddings
        ]
embeddings : StackedEmbeddings = StackedEmbeddings(
                                 embeddings=embedding_types)

In [None]:
from flair.models import SequenceTagger
tagger : SequenceTagger = SequenceTagger(hidden_size=256,
                                       embeddings=embeddings,
                                       tag_dictionary=tag_dictionary,
                                       tag_type=tag_type,
                                       use_crf=True)
print(tagger)

In [None]:
from flair.trainers import ModelTrainer
trainer : ModelTrainer = ModelTrainer(tagger, corpus)
trainer.train('resources/taggers/example-ner',
              learning_rate=0.0001,
              mini_batch_size=4,
              max_epochs=150)

In [None]:
from flair.data import Sentence
from flair.models import SequenceTagger
# load the trained model
model = SequenceTagger.load('/content/resources/taggers/example-ner/final-model.pt')
# create example sentence
sentence = Sentence('Service: CARDIOTHORACIC\n \nAllergies: \nCodeine / Quinine\n \nAttending: ___\n \nChief Complaint:\nnon-small cell lung cancer\n \nMajor Surgical or Invasive Procedure:\nvideo assisted cervical mediastinoscopy, flexible bronchoscopy\n\n \nHistory of Present Illness:\nMr. ___ is a ___ current ___ py smoker, hx Hep C & HCC s/p\nliver transplant ___, with 3cm RUL PET-avid nodule with\npathology showing non-small cell carcinoma, here for follow-up.\n\nHe has been asymptomatic since previous visit. No SOB, cough. No\nfevers, chills, nightsweats. He had EBUS which biopsied lesion\nshowing non-small cell carcinoma. Lymph nodes were not biopsied. \n\nPET negative for distant metastasis.\n \nPast Medical History:\nHEPATITIS C (genotype 1a, no sequelae of chronic liver disease \nat this point, stage III fibrosis on biopsy; now s/p liver \ntransplant)\nHEPATOCELLULAR CARCINOMA (s/p RFA on ___ to segment VIa \nlesion; now s/p liver transplant)')
# predict the tags
model.predict(sentence)
print(sentence.to_tagged_string())

In [None]:
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

In [None]:
label_type = 'ner'
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)
print(label_dict)

In [None]:
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )

In [None]:
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ner',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

In [None]:
trainer = ModelTrainer(tagger, corpus)

In [None]:
trainer.fine_tune('resources/taggers/sota-ner-flert',
                  learning_rate=5.0e-6,
                  mini_batch_size=1,
                  mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU
                  )