Knowledge Distillation using MedNLI as an intermediate dataset
1) train a BERT model on AP and evaluate it 
2) use trained model to label MedNLI data with AP classes
3) train a BERT model on the dataset from (2)
4) evaluate model from (3) on the AP test set (as in (1))

In [None]:
import json
import re
import numpy as np
from pathlib import Path

from datasets import load_dataset

import torch

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report


In [None]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()

import sys
sys.path.append('../utils/')

from utils import set_seed
set_seed(42)

lm_name = "bert-base-cased"
layer_sharing_mode = 'even-6'

OUT_PATH = '../output/{}/'.format(lm_name)

Models training/loading

In [None]:
load_checkpoint_ap = False
load_checkpoint_mednli = False 
mednli_oversampled = False
AP_BEST_DIR = './AP/best/'
MEDNLI_BEST_DIR = './MedNLI/best/'

Load AP Dataset

In [None]:
# TODO: adapt path to data files
data_files = {
    "train": './ap/train.csv',
    "dev": './ap/dev.csv',
    "test": './ap/test.csv',
}

data = load_dataset("csv", data_files=data_files, )

LABEL_ENCODER = LabelEncoder()
LABEL_ENCODER.fit(data['train']['Relation'])
LABEL_ENCODER.classes_


AP Preprocessing

In [None]:
PHI_PATTERN = re.compile(r'\[\*\*[^\]]+\*\*\]')

tokenizer = AutoTokenizer.from_pretrained(
    lm_name,
    do_lower_case=False
)
tokenizer.add_tokens(['@@PHI@@'], special_tokens=True)

def preprocess(row):   
    d = {
        'Assessment': PHI_PATTERN.sub('@@PHI@@', row['Assessment']),
        'Plan Subsection': PHI_PATTERN.sub('@@PHI@@', row['Plan Subsection']),
    }
    
    if row['Relation']:
        d['label'] = LABEL_ENCODER.transform([row['Relation']])[0]
    else:
        d['label'] = 0 # if we have no label (during test), we just use a default label of 0
    
    return d

def tokenize(examples):
    return tokenizer(
        examples['Assessment'],
        examples['Plan Subsection'],
        truncation=True,
        max_length=512
    )

In [None]:
data = data.map(preprocess)
data = data.map(tokenize, batched=True)

In [None]:
tokenizer.decode(data['train'][0]['input_ids'])

Train Model

In [None]:

model = AutoModelForSequenceClassification.from_pretrained(
    lm_name,
    num_labels=len(LABEL_ENCODER.classes_)
)
model.resize_token_embeddings(len(tokenizer))


from sklearn.metrics import precision_recall_fscore_support

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    return {
        'f1': f1,
        'precision': precision,
        'recall': recall
    }


training_args = TrainingArguments(
    output_dir=OUT_PATH,
    num_train_epochs=3, 
    learning_rate=5e-5,
    warmup_ratio=0,
    weight_decay=0,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=512,
    logging_strategy="steps",
    logging_steps=20,
    evaluation_strategy="steps",
    eval_steps=20,
    save_strategy="steps",
    save_steps=20,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    greater_is_better=True,
    save_total_limit = 2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data["train"],
    eval_dataset=data["dev"], 
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

if load_checkpoint_ap:
    loaded_model_ap = AutoModelForSequenceClassification.from_pretrained(AP_BEST_DIR,
    local_files_only=True)
    loaded_model_ap.to(device)
    trainer.model = loaded_model_ap
    model = loaded_model_ap
else: 
    trainer.train()
    trainer.save_model(AP_BEST_DIR)

Evaluation on AP

In [None]:
def predict(trainer, data, split_name: str):
    print(f'Metrics for {split_name}')

    preds = trainer.predict(data[split_name])
    logits = torch.tensor(preds.predictions)
    y_pred = np.argmax(preds.predictions, axis=-1)
    y_pred = LABEL_ENCODER.inverse_transform(y_pred)
    y_pred_proba = torch.nn.functional.softmax(logits, dim=1)
    
    
    y_true = preds.label_ids
    y_true = LABEL_ENCODER.inverse_transform(y_true)
    print(f"Evaluate {split_name}\n")
    print(classification_report(y_true, y_pred, digits=3))

In [None]:
predict(trainer, data, 'test')


Loading MedNLI

In [None]:
# TODO: adapt path to data files
mednli_data_files = {
    "train": './physionet.org/files/mednli/1.0.0/mli_train_v1.jsonl',
    "dev": './physionet.org/files/mednli/1.0.0/mli_dev_v1.jsonl',

}
mednli_data = load_dataset("json", data_files=mednli_data_files )

def mednli_preprocess(row):   
    d = {
        'Assessment': PHI_PATTERN.sub('@@PHI@@', row['sentence1']),
        'Plan Subsection': PHI_PATTERN.sub('@@PHI@@', row['sentence2']),
    }
    
 
    d['label'] = 0 # if we have no label (during test), we just use a default label of 0
    
    return d

mednli_data = mednli_data.map(mednli_preprocess)
mednli_data = mednli_data.map(tokenize, batched=True)

In [None]:
len(mednli_data['train'])

In [None]:
len(mednli_data['dev'])

In [None]:
def mednli_predict(trainer, data, split_name: str, out_path):
    print(f"Labeling MedNLI_{split_name}...\n")
    out_path = Path(out_path)
    out_path.mkdir(exist_ok=True)
    
    preds = trainer.predict(data[split_name])
    logits = torch.tensor(preds.predictions)
    y_pred = np.argmax(preds.predictions, axis=-1)
    y_pred = LABEL_ENCODER.inverse_transform(y_pred)
    y_pred_proba = torch.nn.functional.softmax(logits, dim=1)
    
    

    with open(out_path / f"y_pred_mednli{split_name}.txt", "w") as fout:
        for i in y_pred:
            fout.write(str(i) + "\n")
            
    with open(out_path / f"y_pred_proba_mednli_{split_name}.jsonl", "w") as fout:
        for ps in y_pred_proba.tolist():
            json.dump(ps, fout)
            fout.write('\n')
    
    return y_pred, y_pred_proba

# Hard/soft AP labels for training/dev sets of MedNLI
hard_labels, soft_labels = mednli_predict(trainer, mednli_data, 'train', OUT_PATH)
hard_labels_dev, soft_labels_dev = mednli_predict(trainer, mednli_data, 'dev', OUT_PATH)


In [None]:
from collections import Counter
mednli_label_cntr = Counter(hard_labels)
# 
for k, v in mednli_label_cntr.most_common(4):
    print(k + '      : ', np.round(v/len(hard_labels)*100, 2))

In [None]:
mednli_label_cntr_dev = Counter(hard_labels_dev)

for k, v in mednli_label_cntr_dev.most_common(4):
    print(k + '      : ', np.round(v/len(hard_labels_dev)*100, 2))

In [None]:
# Converting hard labels (strings) to integers
transformed_labels_train = LABEL_ENCODER.transform(hard_labels)
transformed_labels_dev = LABEL_ENCODER.transform(hard_labels_dev)


In [None]:
updated_mednli_data = mednli_data.copy()
# Creating a copy of the data and converting to integer labels
updated_mednli_data['train'] = mednli_data['train'].map(lambda example, idx: {'label': transformed_labels_train[idx]}, with_indices=True)

updated_mednli_data['dev'] = updated_mednli_data['dev'].map(lambda example, idx: {'label': transformed_labels_dev[idx]}, with_indices=True)

Oversampling

In [None]:
from imblearn.over_sampling import RandomOverSampler
from datasets import Dataset

def oversample(data, split='train'):
    df = data[split].to_pandas()
    x_cols = df.columns.tolist()
    x_cols.remove('label')
    X_dev = df[x_cols]
    y_dev = df['label']

    oversample = RandomOverSampler(sampling_strategy='all')
    X_over, y_over = oversample.fit_resample(X_dev, y_dev)
    # add labels to df 
    X_over['label'] = y_over
    
    X_over['gold_label'] = LABEL_ENCODER.inverse_transform(X_over['label'])

    data[split + '_oversampled'] = Dataset.from_pandas(X_over)

    return data 

if mednli_oversampled:
    updated_mednli_data = oversample(updated_mednli_data, split='dev')
    updated_mednli_data = oversample(updated_mednli_data, split='train')
    updated_mednli_data['dev_oversampled'].to_pandas()[['label', 'gold_label']].value_counts()

In [None]:
if mednli_oversampled:
    updated_mednli_data['train_oversampled'].to_pandas()[['label', 'gold_label']].value_counts()

In [None]:
mednli_model = AutoModelForSequenceClassification.from_pretrained(
    lm_name,
    num_labels=len(LABEL_ENCODER.classes_)
)
mednli_model.resize_token_embeddings(len(tokenizer))


In [None]:
def share_layer(layer_num, sharing_mode):
    '''
    indicates if the layer with index 'layer_num' should be shared depending on the sharing_mode
    '''
    if sharing_mode == 'none':
        return False
    else:
        num_layers = int(sharing_mode.split('-')[-1])*2
        if layer_num % 2 == 0 and layer_num < num_layers:
            return True
        return False

ap_model =  model.state_dict()


for name, p in mednli_model.named_parameters():
    if 'layer' in name:
        n = name[19:]
        first_dot = n.find('.') 
        layer_num = int(n[:first_dot])
        # every_second layer ^
        if share_layer(layer_num, layer_sharing_mode):
            #p.data = ap_model[name]
            p.data.copy_(ap_model[name].data)
            # freeze layer
            p.requires_grad = False

ap_model =  model.state_dict()
mednli_model_params = mednli_model.state_dict()

# just checknig if all is good 

for name, p in mednli_model.named_parameters():
    if 'layer' in name:
        n = name[19:]
        first_dot = n.find('.') 
        layer_num = int(n[:first_dot])
        # every_second layer 
        if share_layer(layer_num, layer_sharing_mode):
            #print(mednli_model_params[name])
            #print(p.requires_grad)
            #print(ap_model[name])
            assert(p.requires_grad == False)
        else:
            assert(p.requires_grad == True)

In [None]:
o_dir = '../output/{}/{}/bert-mednli/'.format(lm_name, layer_sharing_mode)
mednli_train_set = updated_mednli_data['train']
mednli_dev_set = updated_mednli_data['dev']

if mednli_oversampled:
    o_dir = '../output/{}/{}/bert-mednli-oversampled/'.format(lm_name, layer_sharing_mode)
    mednli_train_set = updated_mednli_data['train_oversampled']
    mednli_dev_set = updated_mednli_data['dev_oversampled']

training_args_mednli = TrainingArguments(
    output_dir= o_dir,
    num_train_epochs=1, 
    learning_rate=5e-5,
    warmup_ratio=0,
    weight_decay=0,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=512,
    logging_strategy="steps",
    logging_steps=20,
    evaluation_strategy="steps",
    eval_steps=20,
    save_strategy="steps",
    save_steps=20,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    greater_is_better=True,
    save_total_limit = 2,
)

mednli_trainer = Trainer(
    model = mednli_model,
    args = training_args_mednli,
    train_dataset = mednli_train_set,
    eval_dataset = mednli_dev_set,
    tokenizer = tokenizer,
    compute_metrics = compute_metrics,
)

if load_checkpoint_mednli:
    print('Loading checkpoint...')
    if mednli_oversampled: 
        raise ValueError
    else:
        if layer_sharing_mode == 'none':
            loaded_model_mednli = AutoModelForSequenceClassification.from_pretrained(MEDNLI_BEST_DIR + 'none/' , local_files_only=True)
        elif layer_sharing_mode == 'even-6':
            loaded_model_mednli = AutoModelForSequenceClassification.from_pretrained(MEDNLI_BEST_DIR + 'even-6/' , local_files_only=True)
        elif layer_sharing_mode == 'even-3':
            loaded_model_mednli = AutoModelForSequenceClassification.from_pretrained(MEDNLI_BEST_DIR + 'even-3/' , local_files_only=True)
        else:
            raise ValueError
        
        
    loaded_model_mednli.to(device)
    mednli_trainer.model = loaded_model_mednli
else: 
    mednli_trainer.train()
    mednli_trainer.save_model(MEDNLI_BEST_DIR + layer_sharing_mode +'/')


In [None]:
predict(mednli_trainer, data, 'test')


In [None]:
predict(trainer, data, 'test')
