In [1]:
import datasets
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


### 1. load and take a look the dataset

In [2]:
wfu_dataset = datasets.load_dataset('wfudata', trust_remote_code=True)

In [3]:
wfu_dataset['train']

Dataset({
    features: ['filename', 'text', 'label', 'phi'],
    num_rows: 3028
})

In [4]:
#wfu_dataset['train']['phi']['offsets'] # this won't work

In [5]:
wfu_dataset['train']['phi'][0]['offsets'] # this works

[[156, 160],
 [163, 167],
 [369, 372],
 [835, 839],
 [1102, 1105],
 [1187, 1191],
 [2784, 2804],
 [2809, 2826],
 [2857, 2877],
 [2883, 2891],
 [2892, 2896]]

In [6]:
wfu_dataset['train']['filename'][:10]

['00000.xml',
 '00001.xml',
 '00002.xml',
 '00003.xml',
 '00004.xml',
 '00005.xml',
 '00006.xml',
 '00007.xml',
 '00008.xml',
 '00009.xml']

In [7]:
wfu_dataset['train']['text'][:10]

['Care Navigation - Patient Outreach    Type of Communication: Telephonic    Patient Assessment  Since last encounter, patient reports the following updates: 2/14 @ 1525 HN called pt on home number - family answered the phone and states pt is not available to call her on her mobile number listed in the chart. HN called pt on mobile number - no answer and HN LMTCB.     3/5 Pt repots she is doing okay. States her breathing is fine but she is &quot;out of my pump inhaler&quot;. She reports this is her &quot;rescue inhaler&quot;. HN informed her that she should have a refill on file and she states she does not think she does. HN called patient&apos;s CVS pharmacy to inquire if she has a refill on file. Staff states she does. Pt is aware. Pt states her breathing is okay. Denied chest pain. HN reviewed cardiology office notes fron 2/28 - pt states &quot;I already had a work up from GI. He told me I had ulcers but that is all&quot;. HN advised her to contact PCP of GI if she has more episodes

In [8]:
wfu_dataset['train'].features['label'].names

['AGE',
 'DATE',
 'EMAIL',
 'HOSPITAL',
 'IDNUM',
 'INITIALS',
 'IPADDRESS',
 'LOCATION',
 'NAME',
 'OTHER',
 'PHONE',
 'URL',
 'NORMAL']

In [9]:
example = wfu_dataset['train'][101]
example

{'filename': '00101.xml',
 'text': '\n\nACCESSION NUMBER:  S15-20324\nRECEIVED: 8/17/2015\nORDERING PHYSICIAN:  DANIEL JEFFREY KIRSE , MD\nPATIENT NAME:  KINDLON, TYLER JOSEPH\nSURGICAL PATHOLOGY REPORT\n\nFINAL PATHOLOGIC DIAGNOSIS\nGROSS EXAMINATION AND DIAGNOSIS\n\nOROPHARYNX, BILATERAL PALATINE TONSILS, TONSILLECTOMY:\n     No gross abnormality (gross only)\n\n\nI have personally reviewed the slides and/or other related\nmaterials referenced, and have edited the report as part of my\npathologic assessment and final interpretation.\n\nElectronically Signed Out By:   S. S. O&apos;Neill, M.D., PhD,\nPathology   8/19/2015 17:16:00\n\nsso/oia\n\nSpecimen(s) Received\n Tonsil and/or adenoids Gross only\n\n\nClinical History\nSleep-disordered breathing.\n\n\n\n\nGross Description\n\nA.  Received labeled &quot;tonsils for gross only&quot; is a specimen\nreceived for gross examination only consisting of two tonsils, 5\ng and 6 g and 3.2 x 2.4 x 1.7 cm and 3.2 x 2.6 x 2 cm,\nrespectively.  T

In [10]:
# looks fine
example['text'], example['phi'], \
    example['text'][example['phi']['offsets'][0][0]:example['phi']['offsets'][0][1]], \
    example['text'][example['phi']['offsets'][2][0]:example['phi']['offsets'][2][1]]

('\n\nACCESSION NUMBER:  S15-20324\nRECEIVED: 8/17/2015\nORDERING PHYSICIAN:  DANIEL JEFFREY KIRSE , MD\nPATIENT NAME:  KINDLON, TYLER JOSEPH\nSURGICAL PATHOLOGY REPORT\n\nFINAL PATHOLOGIC DIAGNOSIS\nGROSS EXAMINATION AND DIAGNOSIS\n\nOROPHARYNX, BILATERAL PALATINE TONSILS, TONSILLECTOMY:\n     No gross abnormality (gross only)\n\n\nI have personally reviewed the slides and/or other related\nmaterials referenced, and have edited the report as part of my\npathologic assessment and final interpretation.\n\nElectronically Signed Out By:   S. S. O&apos;Neill, M.D., PhD,\nPathology   8/19/2015 17:16:00\n\nsso/oia\n\nSpecimen(s) Received\n Tonsil and/or adenoids Gross only\n\n\nClinical History\nSleep-disordered breathing.\n\n\n\n\nGross Description\n\nA.  Received labeled &quot;tonsils for gross only&quot; is a specimen\nreceived for gross examination only consisting of two tonsils, 5\ng and 6 g and 3.2 x 2.4 x 1.7 cm and 3.2 x 2.6 x 2 cm,\nrespectively.  The outer surfaces are tan-pink bos

In [11]:
from collections import Counter

types_stat = []
for x in wfu_dataset['train']['phi']:
    types_stat += x['types']
Counter(types_stat)

Counter({'DATE': 14549,
         'NAME': 9092,
         'IDNUM': 1123,
         'HOSPITAL': 1079,
         'AGE': 1058,
         'LOCATION': 550,
         'PHONE': 150,
         'INITIALS': 112,
         'OTHER': 24,
         'EMAIL': 12,
         'URL': 8,
         'IPADDRESS': 1})

In [12]:
types_stat = []
for x in wfu_dataset['test']['phi']:
    types_stat += x['types']
Counter(types_stat)

Counter({'DATE': 3329,
         'NAME': 2242,
         'IDNUM': 317,
         'AGE': 263,
         'HOSPITAL': 225,
         'LOCATION': 196,
         'INITIALS': 36,
         'PHONE': 24,
         'OTHER': 5,
         'EMAIL': 4,
         'URL': 3})

### 2. check the tokenized version

In [13]:
from transformers import AutoTokenizer
from utils import PreProcess, compute_metrics

tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")

In [14]:
print(example['text'][:50])
tmp = tokenizer(example['text'][:50], max_length=10, truncation=True, return_overflowing_tokens=True, stride=2)
tmp



ACCESSION NUMBER:  S15-20324
RECEIVED: 8/17/2015


{'input_ids': [[2, 9989, 2529, 30, 4062, 1015, 17, 16725, 4561, 3], [2, 16725, 4561, 4004, 30, 28, 19, 2752, 19, 3], [2, 2752, 19, 4299, 3]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], 'overflow_to_sample_mapping': [0, 0, 0]}

In [15]:
type(tmp)

transformers.tokenization_utils_base.BatchEncoding

In [16]:
preprocess = PreProcess(tokenizer, wfu_dataset['train'].features['label'].str2int, max_length=128, stride=10)

wfu_dataset_tokenized = wfu_dataset.map(preprocess, batched=True, batch_size=5,
                                        remove_columns=wfu_dataset['train'].column_names)

In [17]:
tokenizer.decode(wfu_dataset_tokenized['train'][0]['input_ids'])

'[CLS] care navigation - patient outreach type of communication : telephonic patient assessment since last encounter, patient reports the following updates : 2 / 14 @ 1525 hn called pt on home number - family answered the phone and states pt is not available to call her on her mobile number listed in the chart. hn called pt on mobile number - no answer and hn lmtcb. 3 / 5 pt repots she is doing okay. states her breathing is fine but she is & quot ; out of my pump inhaler & quot ;. she reports this is her & quot ; rescue inhaler & quot ;. [SEP]'

In [18]:
wfu_dataset_tokenized['train'][0].keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'overflow_to_sample_mapping', 'labels'])

In [19]:
for x, y in zip(wfu_dataset_tokenized['train'][0]['input_ids'], wfu_dataset_tokenized['train'][0]['labels']):
    print(x, tokenizer.decode(x), y)

2 [CLS] -100
2859 care 12
17109 navigation 12
17 - 12
2774 patient 12
23465 outreach 12
2601 type 12
1927 of 12
6396 communication 12
30 : 12
10304 tele 12
2102 ##ph 12
3490 ##onic 12
2774 patient 12
4076 assessment 12
3328 since 12
4691 last 12
14535 encounter 12
16 , 12
2774 patient 12
4710 reports 12
1920 the 12
2894 following 12
26404 updates 12
30 : 12
22 2 1
19 / 1
2607 14 1
36 @ 12
14039 152 1
1015 ##5 1
9888 hn 12
6043 called 12
4401 pt 12
1990 on 12
4502 home 12
2529 number 12
17 - 12
3416 family 12
16329 answered 12
1920 the 12
12647 phone 12
1930 and 12
4376 states 12
4401 pt 12
1977 is 12
2084 not 12
3322 available 12
1942 to 12
4567 call 12
3656 her 12
1990 on 12
3656 her 12
8785 mobile 12
2529 number 12
7663 listed 12
1922 in 12
1920 the 12
12583 chart 12
18 . 12
9888 hn 12
6043 called 12
4401 pt 12
1990 on 12
8785 mobile 12
2529 number 12
17 - 12
2239 no 12
11192 answer 12
1930 and 12
9888 hn 12
7735 lm 12
6916 ##tc 12
1014 ##b 12
18 . 12
23 3 1
19 / 1
25 5 1
4401 pt 12


In [30]:
tokens_stat = []
for x in wfu_dataset_tokenized['train']['labels']:
    tokens_stat += [i for i in x if i != -100]
Counter(tokens_stat)

Counter({12: 1326875,
         1: 84507,
         8: 44520,
         3: 5742,
         4: 5159,
         7: 3763,
         0: 1513,
         10: 1046,
         5: 245,
         9: 110,
         2: 97,
         11: 29,
         6: 3})

In [21]:
from sklearn.utils.class_weight import compute_class_weight

In [22]:
class_weights = compute_class_weight('balanced', classes=np.array(range(wfu_dataset['train'].features['label'].num_classes-1)) , y=tokens_stat)
class_weights = class_weights.tolist()
class_weights.append(1e-4)
class_weights = np.array(class_weights)
class_weights

array([8.08184622e+00, 1.44696100e-01, 1.26060137e+02, 2.12954255e+00,
       2.37019448e+00, 4.99095238e+01, 4.07594444e+03, 3.24949065e+00,
       2.74659329e-01, 1.11162121e+02, 1.16900892e+01, 4.21649425e+02,
       1.00000000e-04])

### 3. some initial training

In [23]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [24]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

class_labels = wfu_dataset['train'].features['label'].names
device = 'cuda:0'

model = AutoModelForTokenClassification.from_pretrained(
    "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
    num_labels=len(class_labels)).to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
import sys
# del sys.modules['mytrainer']
from mytrainer import MyTrainer

In [26]:
from datetime import datetime

In [27]:
datetime.now().strftime("%m_%d_%Y_%H_%M_%S")

'07_30_2024_17_50_39'

In [28]:
training_args = TrainingArguments(
    output_dir="testing_model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="epoch",
    # load_best_model_at_end=True,
    push_to_hub=False,
    logging_dir='huggingface_logs',
    log_level='info',
    logging_steps=50
)

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=wfu_dataset_tokenized["train"],
#     eval_dataset=wfu_dataset_tokenized["test"],
#     tokenizer=tokenizer,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics
# )

trainer = MyTrainer(
    class_weights = class_weights.astype(np.float32),
    log_file = f'logs_{datetime.now().strftime("%m_%d_%Y_%H_%M_%S")}.txt',
    model=model,
    args=training_args,
    train_dataset=wfu_dataset_tokenized["train"],
    eval_dataset=wfu_dataset_tokenized["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

The following columns in the training set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: overflow_to_sample_mapping. If overflow_to_sample_mapping are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 13,260
  Num Epochs = 2
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 1,658
  Number of trainable parameters = 108,901,645
  attn_output = torch.nn.functional.scaled_dot_product_attention(
  3%|▎         | 50/1658 [00:24<12:22,  2.17it/s]

inside log: {'loss': 1.6654, 'grad_norm': 28.366779327392578, 'learning_rate': 1.9396863691194212e-05} 50
{'loss': 1.6654, 'grad_norm': 28.366779327392578, 'learning_rate': 1.9396863691194212e-05, 'epoch': 0.06}


  6%|▌         | 100/1658 [00:47<12:00,  2.16it/s]

inside log: {'loss': 0.6261, 'grad_norm': 0.7614365816116333, 'learning_rate': 1.879372738238842e-05} 100
{'loss': 0.6261, 'grad_norm': 0.7614365816116333, 'learning_rate': 1.879372738238842e-05, 'epoch': 0.12}


  9%|▉         | 150/1658 [01:10<11:40,  2.15it/s]

inside log: {'loss': 0.5824, 'grad_norm': 3.824090003967285, 'learning_rate': 1.819059107358263e-05} 150
{'loss': 0.5824, 'grad_norm': 3.824090003967285, 'learning_rate': 1.819059107358263e-05, 'epoch': 0.18}


 12%|█▏        | 200/1658 [01:33<11:12,  2.17it/s]The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: overflow_to_sample_mapping. If overflow_to_sample_mapping are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3279
  Batch size = 16


inside log: {'loss': 0.5639, 'grad_norm': 0.7461574077606201, 'learning_rate': 1.758745476477684e-05} 200
{'loss': 0.5639, 'grad_norm': 0.7461574077606201, 'learning_rate': 1.758745476477684e-05, 'epoch': 0.24}




<class 'numpy.ndarray'>


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 12%|█▏        | 200/1658 [02:03<11:12,  2.17it/s]

inside log: {'eval_loss': 0.5026214122772217, 'eval_f1': 0.1477926675303802, 'eval_report': '              precision    recall  f1-score   support\n\n         AGE      0.003     0.995     0.007       390\n        DATE      0.387     0.997     0.558     19018\n       EMAIL      0.000     0.000     0.000        32\n    HOSPITAL      0.009     0.955     0.019      1066\n       IDNUM      0.094     0.995     0.171      1529\n    INITIALS      0.065     0.675     0.119        83\n   IPADDRESS      0.000     0.000     0.000         0\n    LOCATION      0.747     0.627     0.682      1266\n        NAME      0.198     0.993     0.330     11005\n       OTHER      0.000     0.000     0.000        23\n       PHONE      0.009     0.924     0.017       158\n         URL      0.000     0.000     0.000         6\n      NORMAL      1.000     0.009     0.018    328966\n\n    accuracy                          0.101    363542\n   macro avg      0.193     0.552     0.148    363542\nweighted avg      0.934

 15%|█▌        | 250/1658 [02:26<10:49,  2.17it/s]  

inside log: {'loss': 0.4854, 'grad_norm': 0.13080379366874695, 'learning_rate': 1.698431845597105e-05} 250
{'loss': 0.4854, 'grad_norm': 0.13080379366874695, 'learning_rate': 1.698431845597105e-05, 'epoch': 0.3}


 18%|█▊        | 300/1658 [02:49<10:31,  2.15it/s]

inside log: {'loss': 0.5352, 'grad_norm': 2.6872591972351074, 'learning_rate': 1.638118214716526e-05} 300
{'loss': 0.5352, 'grad_norm': 2.6872591972351074, 'learning_rate': 1.638118214716526e-05, 'epoch': 0.36}


 19%|█▊        | 308/1658 [02:53<10:25,  2.16it/s]

KeyboardInterrupt: 

In [None]:
from matplotlib import pyplot as plt

In [None]:
xx = [390, 19018, 32, 1066, 1529, 83, 1266, 11005, 158]
yy = [0.907, 0.994, 0.270, 0.879, 0.984, 0.839, 0.920, 0.987, 0.936]
plt.plot(xx, yy, 'o')