In [None]:
import numpy as np
import re
from pathlib import Path
import random
import pandas as pd

from datasets import load_dataset

import torch

from transformers import (
    BertModel,
    AutoTokenizer,
)

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report


In [None]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()

import sys
sys.path.append('../utils')

from utils import set_seed
set_seed(42)

lm_ap = './AP/best/'
mednli_none = './MedNLI/best/none/'
mednli_even_six = './MedNLI/best/even-6/'
mednli_even_three= './MedNLI/best/even-3/'

lms = {
    'teacher' : lm_ap,
    'student-none': mednli_none,
    'student-even-6': mednli_even_six, 
    'student-even-3': mednli_even_three,
    'pretrained' : 'bert-base-cased'
}

lm_name = 'student-even-6' # 'student-even-3' # "bert-base-cased"
represntation_type = 'cls' # cls or mean
replacement_strategy = 'other_random_entity' # 'entity_type' or 'other_random_entity'
max_instances_per_class = 100
split = 'train'
num_tokens_to_replace = 2

lm_name_path = lm_name, lms[lm_name]

OUT_PATH = '../output/{}/'.format(lm_name_path[1])
AP_TRAIN_PATH = './ap/train.csv'

In [None]:
import datasets 
from datasets import disable_caching

disable_caching()

In [None]:
def get_cls_repr(model, dl):
    cls_reprs = None
    model.eval()
    for i, batch in enumerate(dl):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch, output_hidden_states=True)
            # hidden states from the last layer: 
            hidden_states_lst_lyr = outputs.hidden_states[-1]
            # cls repr:  32 x 768
            cls_repr = hidden_states_lst_lyr[:,0,:]
            if i == 0:
                cls_reprs = cls_repr.cpu().numpy()
            else:
                cls_reprs = np.append(cls_reprs, cls_repr.cpu().numpy(), axis=0)
    return cls_reprs

def get_mean_repr(model, dl):
    mean_reprs = None
    model.eval()
    for i, batch in enumerate(dl):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch, output_hidden_states=True)
            # hidden states from the last layer: 
            hidden_states_lst_lyr = outputs.hidden_states[-1] # shape 32x512x768
            attention_masks_broadcasted = batch['attention_mask'].unsqueeze(-1) # shape 32x512x1

            hidden_representations_masked = hidden_states_lst_lyr * attention_masks_broadcasted  # shape: 32x512x768
            # sum hidden representaions 
            sum_non_masked = hidden_representations_masked.sum(dim=1) # 32x768
            count_non_masked = attention_masks_broadcasted.sum(dim=1) # 32x1

            average_representations = sum_non_masked / count_non_masked # 32x768

            if i == 0:
                mean_reprs = average_representations.cpu().numpy()
            else:
                mean_reprs = np.append(mean_reprs, average_representations.cpu().numpy(), axis=0)
    return mean_reprs

Models training/loading

Load AP Dataset

In [None]:
train_df = pd.read_csv(AP_TRAIN_PATH)
train_unique = train_df.drop_duplicates(subset='Assessment', keep="first")
train_df.to_csv('train_unique.csv')

data_files = {
    "train": './train_unique.csv',

}
data = load_dataset("csv", data_files=data_files)

LABEL_ENCODER = LabelEncoder()
LABEL_ENCODER.fit(data[split]['Relation'])
LABEL_ENCODER.classes_


In [None]:
import stanza
nlp = stanza.Pipeline('en', package='mimic', processors={'ner': 'i2b2'})

# collect all entities
entities_vocab = {}

for a in data['train']['Assessment']: 
    doc = nlp(a)
    for e in doc.ents: 
        entities_vocab[e.type.lower()] = entities_vocab.get(e.type.lower(), []) + [e.text]

for p in data['train']['Plan Subsection']: 
    doc = nlp(p)
    for e in doc.ents: 
        entities_vocab[e.type.lower()] = entities_vocab.get(e.type.lower(), []) + [e.text]

for k, v in entities_vocab.items():
    entities_vocab[k] = list(set(v))

In [None]:
entities_vocab = {}

for a in data['train']['Assessment']: 
    doc = nlp(a)
    for e in doc.ents: 
        entities_vocab[e.type.lower()] = entities_vocab.get(e.type.lower(), []) + [e.text]

for p in data['train']['Plan Subsection']: 
    doc = nlp(p)
    for e in doc.ents: 
        entities_vocab[e.type.lower()] = entities_vocab.get(e.type.lower(), []) + [e.text]

for k, v in entities_vocab.items():
    entities_vocab[k] = list(set(v))

AP Preprocessing

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    lm_name_path[1],
    do_lower_case=False
)

if lm_name_path[1] == 'bert-base-cased':
    tokenizer.add_tokens(['@@PHI@@'], special_tokens=True)

def remove_entities(x, replacement_strategy = 'entity_type'):

    doc = nlp(x)
    original = x
    replaced_tokens = 0

    #vocab = list(tokenizer.get_vocab().keys())
    for e in reversed(doc.ents):
        if e.text.lower() == e.type.lower():
            continue 

        to_replace = original[e.start_char:]        

        if replacement_strategy == 'entity_type':
            tmp = to_replace.replace(e.text, e.type.lower(), 1)
            replaced_tokens += 1
        elif replacement_strategy == 'other_random_entity':    
            while True:
                replacement_token = random.choice(entities_vocab[e.type.lower()]).replace('\n', ' ')
                if replacement_token.lower() != e.text.lower():
                    break
            tmp = to_replace.replace(e.text, replacement_token, 1)
            replaced_tokens += 1
        else:
            raise ValueError

        replaced = original[:e._start_char] + tmp

        if original == replaced: 
            raise ValueError
        else:
            original = replaced 
        
        if replaced_tokens == num_tokens_to_replace:
            return original
            
    return original 

In [None]:
PHI_PATTERN = re.compile(r'\[\*\*[^\]]+\*\*\]')

def preprocess(row):   
    d = {
        'Assessment': PHI_PATTERN.sub('@@PHI@@', row['Assessment']),
        'Plan Subsection': PHI_PATTERN.sub('@@PHI@@', row['Plan Subsection']),
    }

    d['assessment w/o ents'] = remove_entities(d['Assessment'], replacement_strategy)
    d['plan w/o ents'] = remove_entities(d['Plan Subsection'], replacement_strategy)

    if row['Relation']:
        d['label'] = LABEL_ENCODER.transform([row['Relation']])[0]
    else:
        d['label'] = 0 # if we have no label (during test), we just use a default label of 0
    
    return d

def tokenize_pos(examples):
    return tokenizer(
        examples['Assessment'],
        examples['Plan Subsection'],
        truncation=True,
        max_length=512, 
        padding='max_length'
    )

def tokenize_neg(examples):
    return tokenizer( 
        examples['assessment w/o ents'],
        examples['plan w/o ents'],
        truncation=True,
        max_length=512, 
        padding='max_length'
    )

In [None]:
data.cleanup_cache_files()
data = data.map(preprocess, load_from_cache_file=False)
data_pos = data.map(tokenize_pos, batched=True, load_from_cache_file=False)
data_neg = data.map(tokenize_neg, batched=True, load_from_cache_file=False)

In [None]:
tokenizer.decode(data_pos[split][0]['input_ids'])

In [None]:
tokenizer.decode(data_neg[split][0]['input_ids'])

In [None]:
# https://huggingface.co/docs/datasets/v1.11.0/quicktour.html
dataset_pos = data_pos[split].map(lambda examples: {'labels': examples['label']}, batched=True)
dataset_pos.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])
dataloader_pos = torch.utils.data.DataLoader(dataset_pos, batch_size=32)

dataset_neg = data_neg[split].map(lambda examples: {'labels': examples['label']}, batched=True)
dataset_neg.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])
dataloader_neg = torch.utils.data.DataLoader(dataset_neg, batch_size=32)

Load Model




In [None]:
model = BertModel.from_pretrained(lm_name_path[1]).to(device)

if lm_name_path[1] == 'bert-base-cased':
    model.resize_token_embeddings(len(tokenizer))



In [None]:
def get_cls_repr(model, dl):
    cls_reprs = None
    model.eval()
    for i, batch in enumerate(dl):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch, output_hidden_states=True)
            # hidden states from the last layer: 
            hidden_states_lst_lyr = outputs.hidden_states[-1]
            # cls repr:  32 x 768
            cls_repr = hidden_states_lst_lyr[:,0,:]
            if i == 0:
                cls_reprs = cls_repr.cpu().numpy()
            else:
                cls_reprs = np.append(cls_reprs, cls_repr.cpu().numpy(), axis=0)
    return cls_reprs

def get_mean_repr(model, dl):
    mean_reprs = None
    model.eval()
    for i, batch in enumerate(dl):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch, output_hidden_states=True)
            # hidden states from the last layer: 
            hidden_states_lst_lyr = outputs.hidden_states[-1] # shape 32x512x768
            attention_masks_broadcasted = batch['attention_mask'].unsqueeze(-1) # shape 32x512x1

            hidden_representations_masked = hidden_states_lst_lyr * attention_masks_broadcasted  # shape: 32x512x768
            # sum hidden representaions 
            sum_non_masked = hidden_representations_masked.sum(dim=1) # 32x768
            count_non_masked = attention_masks_broadcasted.sum(dim=1) # 32x1

            average_representations = sum_non_masked / count_non_masked # 32x768

            if i == 0:
                mean_reprs = average_representations.cpu().numpy()
            else:
                mean_reprs = np.append(mean_reprs, average_representations.cpu().numpy(), axis=0)
    return mean_reprs

In [None]:
if represntation_type == 'cls':
    reprs_pos = get_cls_repr(model, dataloader_pos)
    reprs_neg = get_cls_repr(model, dataloader_neg)
elif represntation_type == 'mean':
    reprs_pos = get_mean_repr(model, dataloader_pos)
    reprs_neg = get_mean_repr(model, dataloader_neg)
else:
    raise ValueError

In [None]:
reprs_pos.shape

Training Set Probing

In [None]:
# combinging pos and neg data
num_instances = min(max_instances_per_class, len(reprs_pos))
X = np.concatenate((reprs_pos[:num_instances], reprs_neg[:num_instances]), axis=0)
y = [1]*num_instances + [0]*num_instances

In [None]:
from sklearn.utils import shuffle
random_state = 23
X_shuffled, y_shuffled = shuffle(X, y, random_state = random_state)

In [None]:
len(X_shuffled)

In [None]:
from sklearn.model_selection import cross_validate, StratifiedKFold, KFold
from sklearn.linear_model import LogisticRegression

folds = StratifiedKFold(n_splits=4, shuffle=True, random_state=random_state)

clf = LogisticRegression(penalty = 'l1', max_iter=500, solver='saga')

lr_scores = cross_validate(clf, X_shuffled, y_shuffled, cv=folds, scoring=['f1_macro', 'accuracy'], return_estimator = True)


In [None]:
print(lr_scores['test_f1_macro'])
mean_accuracy = np.round(np.mean(lr_scores['test_f1_macro'])*100,2)
std = np.round(np.std(lr_scores['test_f1_macro'])*100,2)
print(mean_accuracy)

In [None]:
with open('./probing_results.csv', 'a') as f:
    f.write(f'{split},{lm_name_path[0]},{replacement_strategy},{represntation_type},{num_instances},{mean_accuracy},{std}\n')