In [1]:
# add autoreload
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import transformers
import torch

import os
import sys
import gc

from typing import List, Tuple, Dict, Union

In [2]:
from setfit import SetFitModel
from setfit import SetFitTrainer
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix, roc_auc_score

from transformers import TrainingArguments, Trainer
from transformers import pipeline
from transformers import DataCollatorWithPadding
from transformers import EvalPrediction

from torchinfo import summary
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import evaluate

from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer

import benedict
import random

In [3]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("There are %d GPU(s) available." % torch.cuda.device_count())
    print("We will use the GPU:", torch.cuda.get_device_name(0))
    print(f"Memory Allocated: {torch.cuda.memory_allocated()}")
    print(f"Max memory Allocated: {torch.cuda.max_memory_allocated()}") 
    print(f"Memory reserved: {torch.cuda.memory_reserved()}")
    print(f"Max memory reserved: {torch.cuda.max_memory_reserved()}")
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: NVIDIA RTX A4000
Memory Allocated: 0
Max memory Allocated: 0
Memory reserved: 0
Max memory reserved: 0


In [5]:
deabbreviate = False
filter_reports = True
Class = 'tricuspid_regurgitation' # lv_dil, rv_dil, pe, aortic_regurgitation, diastolic_dysfunction, lv_syst_func
FLAG_TERMS = ['uitslag zie medische status', 'zie status', 'zie verslag status', 'slecht echovenster', 'echo overwegen', 'ge echo',
              'geen echovenster', 'geen beoordeelbaar echo', 'geen beoordeelbare echo', 'verslag op ic']
SAVE_TERMS = ['goed', 'geen', 'normaal', 'normale']
use_multilabel = False

MULTILABELS = {'Mild': ['Mild', 'Present'], 
               'Severe': ['Severe', 'Present'],
               'Moderate': ['Moderate', 'Present'],
               'Normal': ['Normal'],
               'No label': ['No label'],
               'Present': ['Present'],
               }


In [6]:
#Add the src folder to the path
sys.path.append(os.path.abspath(os.path.join('..', 'src')))
import deabber, echo_utils

In [7]:
ABBREVIATIONS = benedict.benedict("../assets/abbreviations.yml")

In [8]:
plt.style.use('ggplot')
def plot_history(history, val=0):
    acc = history.history['accuracy']
    if val == 1:
        val_acc = history.history['val_accuracy'] # we can add a validation set in our fit function with nn
    loss = history.history['loss']
    if val == 1:
        val_loss = history.history['val_loss']
    x = range(1, len(acc) + 1)

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(x, acc, 'b', label='Training accuracy')
    if val == 1:
        plt.plot(x, val_acc, 'r', label='Validation accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.title('Accuracy')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(x, loss, 'b', label='Training loss')
    if val == 1:
        plt.plot(x, val_loss, 'r', label='Validation loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.title('Loss')
    plt.legend()

In [9]:
def compute_metrics_binomial(logits_and_labels, averaging='macro'):
  logits, labels = logits_and_labels
  predictions = np.argmax(logits, axis=-1)
  acc = np.mean(predictions == labels)
  f1 = f1_score(labels, predictions, average = averaging)
  prec = precision_score(labels, predictions, average = averaging)
  rec = recall_score(labels, predictions, average = averaging)
  return {
          'accuracy': acc, 
          'f1_score': f1,
          'precision': prec,
          'recal': rec
          }

# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    
    
    f1_macro = f1_score(y_true=y_true, y_pred=y_pred, average='macro')
    f1_weighted = f1_score(y_true=y_true, y_pred=y_pred, average='weighted')
    prec_macro = precision_score(y_true=y_true, y_pred=y_pred, average='macro')
    prec_weighted = precision_score(y_true=y_true, y_pred=y_pred, average='weighted')
    recall_macro = recall_score(y_true=y_true, y_pred=y_pred, average='macro')
    recall_weighted = recall_score(y_true=y_true, y_pred=y_pred, average='weighted')    
    
    try:
        roc_auc_weighted = roc_auc_score(y_true, probs, average = 'weighted')
        roc_auc_macro = roc_auc_score(y_true, probs, average = 'macro')
    except:
        roc_auc_weighted = None
        roc_auc_macro = None
    
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1_macro': f1_macro,
               'f1_weighted': f1_weighted,
               'prec_macro': prec_macro,
               'prec_weighted': prec_weighted,
               'recall_macro': recall_macro,
               'recall_weighted': recall_weighted,
               'roc_auc_macro': roc_auc_macro,
               'roc_auc_weighted': roc_auc_weighted,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result
     


In [10]:
def one_hot_encode(labels, num_classes):
    # labels should be a list or a 1D tensor
    one_hot = torch.zeros((len(labels), num_classes))
    rows = torch.arange(len(labels))
    one_hot[rows, labels] = 1
    return one_hot

def _mlabel_tuple_creator(x: List[int],
                          multilabels:Dict[int,List[int]],
                          num_classes: int=None)\
                          ->List[Tuple[int,...]]:
                              
    res = [(_sc for _sc in multilabels[sc]) for sc in x]
    return res

def multi_hot_encoding(x: List[int], 
                       multilabels: Union[Dict[int,List[int]], None]=None,
                       num_classes: int=None)\
                           ->torch.Tensor:    
    if multilabels is None:
        return one_hot_encode(x, num_classes=num_classes)
    else:
        return torch.Tensor(MultiLabelBinarizer(classes=range(num_classes))\
                    .fit_transform(_mlabel_tuple_creator(x,multilabels)))

# Load documents

In [11]:
tokenizer = AutoTokenizer.from_pretrained("CLTL/MedRoBERTa.nl")

In [12]:
os.chdir('T://lab_research/RES-Folder-UPOD/Echo_label/E_ResearchData/2_ResearchData')

In [13]:
labeled_documents = pd.read_json(f"./echo_doc_labels/{Class}.jsonl", lines=True)

In [14]:
labeled_documents.label.value_counts()

label
No label    2545
Normal      1801
Mild         408
Moderate     187
Severe        59
Name: count, dtype: int64

In [15]:
# Expand with label columns
Target_maps = {Label:i for i,Label in enumerate(labeled_documents['label'].unique())}

In [16]:
Target_maps

{'Normal': 0, 'No label': 1, 'Mild': 2, 'Moderate': 3, 'Severe': 4}

In [17]:
# Load the train/test hashes
test_hashes = pd.read_csv('./test_echoid.csv', sep=',')
train_hashes = pd.read_csv('./train_echoid.csv', sep=',')
print(train_hashes.columns)

print(f"Train hashes: {train_hashes.input_hash.nunique()}")
print(f"Test hashes: {test_hashes.input_hash.nunique()}")

Index(['ECHO_StudyID', 'ECHO_StudyID.1', 'input_hash', 'task_hash'], dtype='object')
Train hashes: 96026
Test hashes: 24051


In [18]:
print(Target_maps.keys())

dict_keys(['Normal', 'No label', 'Mild', 'Moderate', 'Severe'])


In [19]:
# We now make DataSets (a special HuggingFace structure)
# assuming cross-validation

DF = labeled_documents
DF.columns = ['sentence', 'labels', '_input_hash']

label2id = Target_maps
id2label = {v:k for k,v in label2id.items()}
num_labels = len(label2id)
DF['labels'] = DF['labels'].map(label2id)


if filter_reports:
    DF = DF.assign(sentence = echo_utils.report_filter(DF.sentence, 
                                            flag_terms=FLAG_TERMS, 
                                            save_terms=SAVE_TERMS)[0])
    DF = DF.loc[DF.sentence.notna()]

if deabbreviate:
    DeAbber = deabber.deabber(model_type='sbert', 
                              abbreviations=ABBREVIATIONS['nl']['echocardiogram'], 
                              min_sim=0.5, top_k=10)
    DF = DF.assign(sentence=DeAbber.deabb(DF.sentence.values, TokenRadius=3))


In [20]:
if use_multilabel:
    _multilabels = {label2id[k]: [label2id[l] for l in v]
                    for k,v in MULTILABELS.items()}
else:
    _multilabels = None

In [21]:

# TODO: make proper
DFtrain = DF.loc[DF._input_hash.isin(train_hashes.input_hash), ['sentence', 'labels']]
DFtest = DF.loc[DF._input_hash.isin(test_hashes.input_hash), ['sentence', 'labels']]

print("Train labels:")
print(DFtrain.labels.value_counts())
print("Test labels:")
print(DFtest.labels.value_counts())

TrainSet = Dataset.from_pandas(DFtrain)
TestSet = Dataset.from_pandas(DFtest)

HF_DataSet = DatasetDict(
    {'train' : TrainSet,
     'test': TestSet,
    }
)

Tokenized_DataSet = HF_DataSet.map(lambda batch: tokenizer(batch, truncation=True, 
                                                                  padding=True, 
                                                                  max_length=256),
                                  input_columns='sentence',
                                  batched=True,
                                  remove_columns=['sentence'])

Train labels:
labels
1    1959
0    1425
2     324
3     147
4      48
Name: count, dtype: int64
Test labels:
labels
1    494
0    342
2     74
3     34
4     11
Name: count, dtype: int64


Map:   0%|          | 0/3903 [00:00<?, ? examples/s]

Map:   0%|          | 0/955 [00:00<?, ? examples/s]

In [22]:
Tokenized_DataSet = (Tokenized_DataSet
                      #.map(lambda x : {"float_labels": x["labels"].to(torch.float)}, remove_columns=["labels"])
                      .map(lambda x: {"labels": 
                          multi_hot_encoding(x['labels'], 
                                             multilabels=_multilabels, 
                                             num_classes=num_labels)}, 
                           batched=True, remove_columns=['labels']))                      
                      #.rename_column("float_labels", "labels"))

Tokenized_DataSet.set_format("torch", 
                             columns=['input_ids', 'attention_mask', 'labels'])

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Map:   0%|          | 0/3903 [00:00<?, ? examples/s]

Map:   0%|          | 0/955 [00:00<?, ? examples/s]

In [23]:
# https://colab.research.google.com/drive/1aue7x525rKy6yYLqqt-5Ll96qjQvpqS7#scrollTo=1eVCRpcLUW-y
# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb
medroberta_clf = AutoModelForSequenceClassification.from_pretrained("CLTL/MedRoBERTa.nl", 
                                                num_labels=len(id2label.keys()),
                                                problem_type='multi_label_classification',
                                                id2label=id2label,
                                                label2id=label2id)


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at CLTL/MedRoBERTa.nl and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
from torchinfo import summary
summary(medroberta_clf)

Layer (type:depth-idx)                                       Param #
RobertaForSequenceClassification                             --
├─RobertaModel: 1-1                                          --
│    └─RobertaEmbeddings: 2-1                                --
│    │    └─Embedding: 3-1                                   39,936,000
│    │    └─Embedding: 3-2                                   394,752
│    │    └─Embedding: 3-3                                   768
│    │    └─LayerNorm: 3-4                                   1,536
│    │    └─Dropout: 3-5                                     --
│    └─RobertaEncoder: 2-2                                   --
│    │    └─ModuleList: 3-6                                  85,054,464
├─RobertaClassificationHead: 1-2                             --
│    └─Linear: 2-3                                           590,592
│    └─Dropout: 2-4                                          --
│    └─Linear: 2-5                                           3,845
To

## MedRoBERTa.nl -- Training setup

In [25]:
train_dir = "T://laupodteam/AIOS/Bram/data/tmp"
metric_name = 'f1_macro'
training_args = TrainingArguments(output_dir=train_dir,
                                  evaluation_strategy='epoch',
                                  save_strategy='epoch',
                                  num_train_epochs=20,
                                  learning_rate=5e-5,
                                  per_device_train_batch_size=16,
                                  weight_decay=0.01,
                                  per_device_eval_batch_size=10,
                                  load_best_model_at_end=True, 
                                  metric_for_best_model=metric_name)

In [26]:
trainer = Trainer(medroberta_clf,
                  training_args,
                  train_dataset = Tokenized_DataSet["train"],
                  eval_dataset = Tokenized_DataSet["test"],
                  tokenizer=tokenizer,
                  data_collator=data_collator,
                  compute_metrics=compute_metrics)

## MedRoBERTa.nl -- run Training

In [27]:
trainer.train()

  0%|          | 0/4880 [00:00<?, ?it/s]

  0%|          | 0/96 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.11559130251407623, 'eval_f1_macro': 0.6445338809682433, 'eval_f1_weighted': 0.902515077239281, 'eval_prec_macro': 0.6286947683208852, 'eval_prec_weighted': 0.9087971285681161, 'eval_recall_macro': 0.6661128382490611, 'eval_recall_weighted': 0.900523560209424, 'eval_roc_auc_macro': 0.9656328408714867, 'eval_roc_auc_weighted': 0.9764462306888325, 'eval_accuracy': 0.8858638743455497, 'eval_runtime': 9.9861, 'eval_samples_per_second': 95.633, 'eval_steps_per_second': 9.613, 'epoch': 1.0}


  0%|          | 0/96 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.0751369521021843, 'eval_f1_macro': 0.695415088424039, 'eval_f1_weighted': 0.9304613915276896, 'eval_prec_macro': 0.6749854875120939, 'eval_prec_weighted': 0.9305719576821819, 'eval_recall_macro': 0.7206748067429182, 'eval_recall_weighted': 0.9319371727748691, 'eval_roc_auc_macro': 0.9747869192280119, 'eval_roc_auc_weighted': 0.9868927931969388, 'eval_accuracy': 0.9298429319371728, 'eval_runtime': 9.8322, 'eval_samples_per_second': 97.13, 'eval_steps_per_second': 9.764, 'epoch': 2.0}
{'loss': 0.139, 'grad_norm': 0.6982333660125732, 'learning_rate': 4.487704918032787e-05, 'epoch': 2.05}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.07805199921131134, 'eval_f1_macro': 0.8817279877392761, 'eval_f1_weighted': 0.9485011802618889, 'eval_prec_macro': 0.868329950795917, 'eval_prec_weighted': 0.9526238119881009, 'eval_recall_macro': 0.9018917183003872, 'eval_recall_weighted': 0.9465968586387434, 'eval_roc_auc_macro': 0.9877065914569269, 'eval_roc_auc_weighted': 0.9881468250958666, 'eval_accuracy': 0.9445026178010472, 'eval_runtime': 9.7942, 'eval_samples_per_second': 97.506, 'eval_steps_per_second': 9.802, 'epoch': 3.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.08167070895433426, 'eval_f1_macro': 0.8821114311233842, 'eval_f1_weighted': 0.9423341980712105, 'eval_prec_macro': 0.8566061452796147, 'eval_prec_weighted': 0.9461625780281621, 'eval_recall_macro': 0.9139735570076126, 'eval_recall_weighted': 0.9392670157068063, 'eval_roc_auc_macro': 0.9896651653918657, 'eval_roc_auc_weighted': 0.9887690186011976, 'eval_accuracy': 0.93717277486911, 'eval_runtime': 10.0038, 'eval_samples_per_second': 95.464, 'eval_steps_per_second': 9.596, 'epoch': 4.0}
{'loss': 0.0372, 'grad_norm': 0.0470261387526989, 'learning_rate': 3.975409836065574e-05, 'epoch': 4.1}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.0804516077041626, 'eval_f1_macro': 0.9137080148707784, 'eval_f1_weighted': 0.952054027425345, 'eval_prec_macro': 0.9057554302841458, 'eval_prec_weighted': 0.9525901672951245, 'eval_recall_macro': 0.9223892601601579, 'eval_recall_weighted': 0.9518324607329843, 'eval_roc_auc_macro': 0.9878580847678418, 'eval_roc_auc_weighted': 0.9888850456705511, 'eval_accuracy': 0.9507853403141361, 'eval_runtime': 9.7866, 'eval_samples_per_second': 97.583, 'eval_steps_per_second': 9.809, 'epoch': 5.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.09145542979240417, 'eval_f1_macro': 0.9212057801877389, 'eval_f1_weighted': 0.9527494783363246, 'eval_prec_macro': 0.920944523094601, 'eval_prec_weighted': 0.9552098690300835, 'eval_recall_macro': 0.9241707511057357, 'eval_recall_weighted': 0.9518324607329843, 'eval_roc_auc_macro': 0.9810517758243463, 'eval_roc_auc_weighted': 0.9851303680136931, 'eval_accuracy': 0.9507853403141361, 'eval_runtime': 9.8209, 'eval_samples_per_second': 97.242, 'eval_steps_per_second': 9.775, 'epoch': 6.0}
{'loss': 0.017, 'grad_norm': 0.022935491055250168, 'learning_rate': 3.463114754098361e-05, 'epoch': 6.15}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.10030101984739304, 'eval_f1_macro': 0.8937880726953186, 'eval_f1_weighted': 0.9496126457926689, 'eval_prec_macro': 0.9039964854020912, 'eval_prec_weighted': 0.9552982624298355, 'eval_recall_macro': 0.8922499715069374, 'eval_recall_weighted': 0.9465968586387434, 'eval_roc_auc_macro': 0.9825153027368921, 'eval_roc_auc_weighted': 0.9838928899608809, 'eval_accuracy': 0.9465968586387434, 'eval_runtime': 9.8837, 'eval_samples_per_second': 96.624, 'eval_steps_per_second': 9.713, 'epoch': 7.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.09036996215581894, 'eval_f1_macro': 0.9135127252425879, 'eval_f1_weighted': 0.956162130929461, 'eval_prec_macro': 0.9266951044767104, 'eval_prec_weighted': 0.959938372964265, 'eval_recall_macro': 0.9047257263975531, 'eval_recall_weighted': 0.9539267015706806, 'eval_roc_auc_macro': 0.9803616566843413, 'eval_roc_auc_weighted': 0.9880170835461857, 'eval_accuracy': 0.9539267015706806, 'eval_runtime': 9.9878, 'eval_samples_per_second': 95.617, 'eval_steps_per_second': 9.612, 'epoch': 8.0}
{'loss': 0.0106, 'grad_norm': 0.015553292818367481, 'learning_rate': 2.9508196721311478e-05, 'epoch': 8.2}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.10160262137651443, 'eval_f1_macro': 0.8928097147860807, 'eval_f1_weighted': 0.951122088275916, 'eval_prec_macro': 0.9048928477398995, 'eval_prec_weighted': 0.9551325719754301, 'eval_recall_macro': 0.8872223194204618, 'eval_recall_weighted': 0.9486910994764398, 'eval_roc_auc_macro': 0.9844294702320283, 'eval_roc_auc_weighted': 0.9868429202265468, 'eval_accuracy': 0.9476439790575916, 'eval_runtime': 9.9786, 'eval_samples_per_second': 95.705, 'eval_steps_per_second': 9.621, 'epoch': 9.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.10038677603006363, 'eval_f1_macro': 0.8927920086051045, 'eval_f1_weighted': 0.950495532836159, 'eval_prec_macro': 0.895197975269242, 'eval_prec_weighted': 0.9515659546517277, 'eval_recall_macro': 0.8977453454233639, 'eval_recall_weighted': 0.9507853403141361, 'eval_roc_auc_macro': 0.9846480249609104, 'eval_roc_auc_weighted': 0.9875068584170793, 'eval_accuracy': 0.9486910994764398, 'eval_runtime': 9.8038, 'eval_samples_per_second': 97.411, 'eval_steps_per_second': 9.792, 'epoch': 10.0}
{'loss': 0.0036, 'grad_norm': 0.006827797740697861, 'learning_rate': 2.4385245901639343e-05, 'epoch': 10.25}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.103363998234272, 'eval_f1_macro': 0.9219978338652453, 'eval_f1_weighted': 0.9559652489161323, 'eval_prec_macro': 0.9182716827701807, 'eval_prec_weighted': 0.9578353835308244, 'eval_recall_macro': 0.9275032333855864, 'eval_recall_weighted': 0.9549738219895288, 'eval_roc_auc_macro': 0.9811040304847657, 'eval_roc_auc_weighted': 0.987684447252445, 'eval_accuracy': 0.9549738219895288, 'eval_runtime': 9.967, 'eval_samples_per_second': 95.817, 'eval_steps_per_second': 9.632, 'epoch': 11.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.10690831393003464, 'eval_f1_macro': 0.9089686960765277, 'eval_f1_weighted': 0.9548926490103973, 'eval_prec_macro': 0.9129521744085171, 'eval_prec_weighted': 0.9558606372981961, 'eval_recall_macro': 0.9076083661222981, 'eval_recall_weighted': 0.9549738219895288, 'eval_roc_auc_macro': 0.9788404448032028, 'eval_roc_auc_weighted': 0.9881071259834188, 'eval_accuracy': 0.9539267015706806, 'eval_runtime': 9.9866, 'eval_samples_per_second': 95.628, 'eval_steps_per_second': 9.613, 'epoch': 12.0}
{'loss': 0.0021, 'grad_norm': 0.005365756805986166, 'learning_rate': 1.9262295081967212e-05, 'epoch': 12.3}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.10377667099237442, 'eval_f1_macro': 0.9230239856591413, 'eval_f1_weighted': 0.9559245180690369, 'eval_prec_macro': 0.9207606248240923, 'eval_prec_weighted': 0.9575649145401792, 'eval_recall_macro': 0.9271433593415018, 'eval_recall_weighted': 0.9549738219895288, 'eval_roc_auc_macro': 0.9821784416721464, 'eval_roc_auc_weighted': 0.9892353944373723, 'eval_accuracy': 0.9549738219895288, 'eval_runtime': 9.9532, 'eval_samples_per_second': 95.949, 'eval_steps_per_second': 9.645, 'epoch': 13.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.10007030516862869, 'eval_f1_macro': 0.9272345257645862, 'eval_f1_weighted': 0.9594803565562718, 'eval_prec_macro': 0.9254604905309132, 'eval_prec_weighted': 0.962018137170412, 'eval_recall_macro': 0.9320773955448878, 'eval_recall_weighted': 0.9581151832460733, 'eval_roc_auc_macro': 0.9811226917768833, 'eval_roc_auc_weighted': 0.9892526140738286, 'eval_accuracy': 0.9581151832460733, 'eval_runtime': 9.6855, 'eval_samples_per_second': 98.601, 'eval_steps_per_second': 9.912, 'epoch': 14.0}
{'loss': 0.0013, 'grad_norm': 0.004257210064679384, 'learning_rate': 1.4139344262295081e-05, 'epoch': 14.34}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.1043650433421135, 'eval_f1_macro': 0.9267789730300187, 'eval_f1_weighted': 0.9580345807554805, 'eval_prec_macro': 0.9222097166104056, 'eval_prec_weighted': 0.9601403987535763, 'eval_recall_macro': 0.9339703816484002, 'eval_recall_weighted': 0.9570680628272251, 'eval_roc_auc_macro': 0.9798190615534554, 'eval_roc_auc_weighted': 0.9880342458098413, 'eval_accuracy': 0.9570680628272251, 'eval_runtime': 9.9644, 'eval_samples_per_second': 95.841, 'eval_steps_per_second': 9.634, 'epoch': 15.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.1093260645866394, 'eval_f1_macro': 0.9239719424642185, 'eval_f1_weighted': 0.9576369433080657, 'eval_prec_macro': 0.9188488104610443, 'eval_prec_weighted': 0.959176369437233, 'eval_recall_macro': 0.9307907314099264, 'eval_recall_weighted': 0.9570680628272251, 'eval_roc_auc_macro': 0.9791083569738348, 'eval_roc_auc_weighted': 0.987577114583819, 'eval_accuracy': 0.9570680628272251, 'eval_runtime': 9.907, 'eval_samples_per_second': 96.397, 'eval_steps_per_second': 9.69, 'epoch': 16.0}
{'loss': 0.0006, 'grad_norm': 0.002441456774249673, 'learning_rate': 9.016393442622952e-06, 'epoch': 16.39}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.11132070422172546, 'eval_f1_macro': 0.9279845741393284, 'eval_f1_weighted': 0.9581686028386087, 'eval_prec_macro': 0.9217842067458516, 'eval_prec_weighted': 0.9594082284612236, 'eval_recall_macro': 0.9366730843511029, 'eval_recall_weighted': 0.9581151832460733, 'eval_roc_auc_macro': 0.9793986197023508, 'eval_roc_auc_weighted': 0.9875711658205673, 'eval_accuracy': 0.9581151832460733, 'eval_runtime': 9.8939, 'eval_samples_per_second': 96.524, 'eval_steps_per_second': 9.703, 'epoch': 17.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.11131041496992111, 'eval_f1_macro': 0.9282592837378612, 'eval_f1_weighted': 0.9586604912296987, 'eval_prec_macro': 0.9223086305799034, 'eval_prec_weighted': 0.960347249148374, 'eval_recall_macro': 0.9366730843511029, 'eval_recall_weighted': 0.9581151832460733, 'eval_roc_auc_macro': 0.979460855388302, 'eval_roc_auc_weighted': 0.9876584894319457, 'eval_accuracy': 0.9581151832460733, 'eval_runtime': 9.8614, 'eval_samples_per_second': 96.842, 'eval_steps_per_second': 9.735, 'epoch': 18.0}
{'loss': 0.0006, 'grad_norm': 0.001976240426301956, 'learning_rate': 3.89344262295082e-06, 'epoch': 18.44}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.11180665343999863, 'eval_f1_macro': 0.9294571918677139, 'eval_f1_weighted': 0.9592051184798585, 'eval_prec_macro': 0.9245175175060814, 'eval_prec_weighted': 0.9604787806308706, 'eval_recall_macro': 0.9370779426506981, 'eval_recall_weighted': 0.9591623036649215, 'eval_roc_auc_macro': 0.9789140112961008, 'eval_roc_auc_weighted': 0.9874016642343878, 'eval_accuracy': 0.9591623036649215, 'eval_runtime': 9.7982, 'eval_samples_per_second': 97.467, 'eval_steps_per_second': 9.798, 'epoch': 19.0}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.11193442344665527, 'eval_f1_macro': 0.9292575632108511, 'eval_f1_weighted': 0.9586888014301712, 'eval_prec_macro': 0.924105222016, 'eval_prec_weighted': 0.9594124247560006, 'eval_recall_macro': 0.9370779426506981, 'eval_recall_weighted': 0.9591623036649215, 'eval_roc_auc_macro': 0.9788362594832911, 'eval_roc_auc_weighted': 0.9874004260366831, 'eval_accuracy': 0.9581151832460733, 'eval_runtime': 9.833, 'eval_samples_per_second': 97.122, 'eval_steps_per_second': 9.763, 'epoch': 20.0}
{'train_runtime': 2709.4314, 'train_samples_per_second': 28.81, 'train_steps_per_second': 1.801, 'train_loss': 0.021756304021863664, 'epoch': 20.0}


TrainOutput(global_step=4880, training_loss=0.021756304021863664, metrics={'train_runtime': 2709.4314, 'train_samples_per_second': 28.81, 'train_steps_per_second': 1.801, 'total_flos': 1.026950110036992e+16, 'train_loss': 0.021756304021863664, 'epoch': 20.0})

In [28]:
trainer.eval_dataset = Tokenized_DataSet["test"]
trainer.evaluate()

  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.11180665343999863,
 'eval_f1_macro': 0.9294571918677139,
 'eval_f1_weighted': 0.9592051184798585,
 'eval_prec_macro': 0.9245175175060814,
 'eval_prec_weighted': 0.9604787806308706,
 'eval_recall_macro': 0.9370779426506981,
 'eval_recall_weighted': 0.9591623036649215,
 'eval_roc_auc_macro': 0.9789140112961008,
 'eval_roc_auc_weighted': 0.9874016642343878,
 'eval_accuracy': 0.9591623036649215,
 'eval_runtime': 10.2565,
 'eval_samples_per_second': 93.112,
 'eval_steps_per_second': 9.36,
 'epoch': 20.0}

In [29]:
#medroberta_clf_pipe = pipeline('text-classification', model=medroberta_clf, tokenizer=tokenizer)

In [30]:
# get current directory of .py file, i.e. NOT os.getcwd()
#os.path.dirname(os.path.realpath(__file__))