In [1]:
%pwd

'/home/rydevera3/data-science/text-mining-titans/nbme-patient-notes'

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from torch.utils.data import DataLoader

from data import NBMEDataset, load_training_data
from model import NBMEModel
from utils import (
    Configuration,
    create_labels_for_scoring,
    get_character_probabilities,
    get_predictions,
    get_thresholded_sequences,
    get_score,
    training_function,
    validation_function,
)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# Load in the data
config = Configuration()
data = load_training_data(config=config)
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

# Get training and validation data
train_df = data.loc[data['fold_number'] != 4].reset_index(drop=True)
valid_df = data.loc[data['fold_number'] == 4].reset_index(drop=True)
valid_patient_notes_texts = valid_df['pn_history'].values
valid_labels = valid_df['location'].apply(create_labels_for_scoring)

# Create the datasets and data loaders
training_dataset = NBMEDataset(train_df, config)
valid_dataset = NBMEDataset(valid_df, config)

# Training loaders
train_loader = DataLoader(training_dataset, batch_size=8, shuffle=True, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False, pin_memory=True, drop_last=False)

# Get the loss and optimizers and model
model = NBMEModel(config=config)
model = model.to(device=device)

criterion = nn.BCEWithLogitsLoss(reduction='none')
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

# TODO: still need to make training function and validation
for epoch in range(15):
    training_function(
        config=config,
        train_loader=train_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=None,
        device=device,
    )

    # Get the probability outputs
    predictions, labels = validation_function(
        config, valid_loader, model, device
    )

    # Reshape the predictions and labels
    samples = len(valid_df)
    predictions = predictions.reshape((samples, config.max_length))
    labels = labels.reshape((samples, config.max_length))

    # Get character probabilities
    character_probabilities = get_character_probabilities(
        valid_patient_notes_texts,
        predictions,
        config
    )

    # Get results
    results = get_thresholded_sequences(character_probabilities)
    preds = get_predictions(results)
    score = get_score(valid_labels, preds)
    print(score)

Some weights of the model checkpoint at microsoft/deberta-v3-base were not used when initializing DebertaV2Model: ['lm_predictions.lm_head.LayerNorm.bias', 'mask_predictions.dense.weight', 'mask_predictions.dense.bias', 'mask_predictions.classifier.bias', 'lm_predictions.lm_head.dense.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'mask_predictions.LayerNorm.weight', 'lm_predictions.lm_head.dense.weight', 'mask_predictions.LayerNorm.bias', 'mask_predictions.classifier.weight', 'lm_predictions.lm_head.bias']
- This IS expected if you are initializing DebertaV2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
1430it [06:36,  3.60i

0.8210371540210479


1430it [06:37,  3.60it/s]
358it [00:53,  6.66it/s]


0.8478710364381493


1430it [06:35,  3.61it/s]
358it [00:53,  6.65it/s]


0.8576069851261484


1430it [06:36,  3.61it/s]
358it [00:53,  6.68it/s]


0.8540243210576947


1430it [06:36,  3.61it/s]
358it [00:53,  6.67it/s]


0.861051495412729


1430it [06:36,  3.60it/s]
358it [00:53,  6.65it/s]


0.8611132817587456


1430it [06:37,  3.60it/s]
358it [00:53,  6.68it/s]


0.8582820743478156


1430it [06:36,  3.61it/s]
358it [00:53,  6.67it/s]


0.8627476022017599


1430it [06:36,  3.61it/s]
358it [00:53,  6.67it/s]


0.8644190018343982


1430it [06:36,  3.61it/s]
358it [00:53,  6.67it/s]


0.8628322586480933


1430it [06:36,  3.61it/s]
358it [00:53,  6.66it/s]


0.8597984127181889


1430it [06:36,  3.61it/s]
358it [00:53,  6.67it/s]


0.8581267554035901


1430it [06:36,  3.60it/s]
358it [00:53,  6.66it/s]


0.8631076724693747


1430it [06:36,  3.61it/s]
358it [00:53,  6.66it/s]


0.8627729130933804


1430it [06:36,  3.60it/s]
358it [00:53,  6.65it/s]


0.8618494851337895


In [4]:
from data import build_nbme_labels

In [5]:
valid_df.head()

Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location,feature_text,pn_history,annotation_length,fold_number
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724],Family-history-of-MI-OR-Family-history-of-myoc...,HPI: 17yo M presents with palpitations. Patien...,1,4
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693],Family-history-of-thyroid-disorder,HPI: 17yo M presents with palpitations. Patien...,1,4
2,00016_002,0,16,2,[chest pressure],[203 217],Chest-pressure,HPI: 17yo M presents with palpitations. Patien...,1,4
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]",Intermittent-symptoms,HPI: 17yo M presents with palpitations. Patien...,2,4
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258],Lightheaded,HPI: 17yo M presents with palpitations. Patien...,1,4


In [6]:
index = 1
text = valid_df['pn_history'].values[index]
location = valid_df['location'].values[index]
annotation_length = valid_df['annotation_length'].values[index]

In [7]:
labels = build_nbme_labels(config, text, annotation_length, location)

In [8]:
encoded = config.tokenizer(
    text,
    add_special_tokens=True,
    max_length=config.max_length,
    padding='max_length',
    return_offsets_mapping=True
)

In [9]:
text

'HPI: 17yo M presents with palpitations. Patient reports 3-4 months of intermittent episodes of "heart beating/pounding out of my chest." 2 days ago during a soccer game had an episode, but this time had chest pressure and felt as if he were going to pass out (did not lose conciousness). Of note patient endorses abusing adderall, primarily to study (1-3 times per week). Before recent soccer game, took adderrall night before and morning of game. Denies shortness of breath, diaphoresis, fevers, chills, headache, fatigue, changes in sleep, changes in vision/hearing, abdominal paun, changes in bowel or urinary habits. \r\nPMHx: none\r\nRx: uses friends adderrall\r\nFHx: mom with "thyroid disease," dad with recent heart attcak\r\nAll: none\r\nImmunizations: up to date\r\nSHx: Freshmen in college. Endorses 3-4 drinks 3 nights / week (on weekends), denies tabacco, endorses trying marijuana. Sexually active with girlfriend x 1 year, uses condoms'

In [10]:
import numpy as np
np.array(encoded.tokens())[labels == 1]

array(['▁mom', '▁with', '▁"', 'thyroid', '▁disease'], dtype='<U13')

In [11]:
location

['668 693']

In [12]:
encoded['offset_mapping']

[(0, 0),
 (0, 2),
 (2, 3),
 (3, 4),
 (4, 7),
 (7, 9),
 (9, 11),
 (11, 20),
 (20, 25),
 (25, 38),
 (38, 39),
 (39, 47),
 (47, 55),
 (55, 57),
 (57, 58),
 (58, 59),
 (59, 66),
 (66, 69),
 (69, 82),
 (82, 91),
 (91, 94),
 (94, 96),
 (96, 101),
 (101, 109),
 (109, 110),
 (110, 118),
 (118, 122),
 (122, 125),
 (125, 128),
 (128, 134),
 (134, 135),
 (135, 136),
 (136, 138),
 (138, 143),
 (143, 147),
 (147, 154),
 (154, 156),
 (156, 163),
 (163, 168),
 (168, 172),
 (172, 175),
 (175, 183),
 (183, 184),
 (184, 188),
 (188, 193),
 (193, 198),
 (198, 202),
 (202, 208),
 (208, 217),
 (217, 221),
 (221, 226),
 (226, 229),
 (229, 232),
 (232, 235),
 (235, 240),
 (240, 246),
 (246, 249),
 (249, 254),
 (254, 258),
 (258, 260),
 (260, 263),
 (263, 267),
 (267, 272),
 (272, 281),
 (281, 285),
 (285, 286),
 (286, 287),
 (287, 290),
 (290, 295),
 (295, 303),
 (303, 312),
 (312, 320),
 (320, 329),
 (329, 330),
 (330, 340),
 (340, 343),
 (343, 349),
 (349, 351),
 (351, 352),
 (352, 353),
 (353, 354),
 (354