# Prerequisites

- Host OS: Ubuntu 20.04 lts
- Using Docker Image 'mltooling/ml-workspace-gpu' (docker pull mltooling/ml-workspace-gpu)
- Single Nvidia GPU (RTX 3080)

# 0. Import libraries

In [1]:
import bert_ensemble_functions
from datasets import load_dataset

import transformers
import datasets
import huggingface_hub
import pyarrow
import torch
print(transformers.__version__)
print(datasets.__version__)
print(huggingface_hub.__version__)
print(pyarrow.__version__)
print(torch.__version__)

4.22.1
2.4.0
0.9.1
9.0.0
1.9.0+cu111


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    device_count = torch.cuda.device_count()
    print("device_count: {}".format(device_count))
    for device_num in range(device_count):
        print("device {} capability {}".format(
            device_num,
            torch.cuda.get_device_capability(device_num)))
        print("device {} name {}".format(
            device_num, 
            torch.cuda.get_device_name(device_num)))
else:
    device = torch.device("cpu")
    print("no cuda device")

device_count: 1
device 0 capability (8, 6)
device 0 name NVIDIA GeForce RTX 3080


In [3]:
#### The number of CPU cores
!grep -c processor /proc/cpuinfo

20


In [4]:
#### GPU information
!nvidia-smi

Tue Nov 22 01:54:39 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   48C    P8    30W / 370W |    253MiB / 12288MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

# 1. Customize Train Strategy

In [5]:
num_cpus = 16
num_gpus = 1
seed = 1234
model_name = "xlm-roberta-base" # bert-base-multilingual-cased, klue/roberta-base, bert-base-cased, etc.

text_column = 'examination'
label_column = 'label'
id_column = 'id'
custom_dir = f"sev_{text_column}_ensemble"

train_proportion = 0.5 # train set : eval set = 5 : 5
do_hpo = True

# If you want to search best hyperparameters using ray tune, parameters below should be set
n_trials = 5
std = 0.1
patience = 5

# 2. Import Data

2 files are needed (`{data_name}_train.csv` and `{data_name}_test.csv`) in your data directory (in this case, `data_splited/`).

In [6]:
data_name = "cardiovascular_sev_dataset" 

dataset = load_dataset('csv', data_files={'train': f'../data_split/{data_name}_train.csv',
                                          'test': f'../data_split/{data_name}_test.csv'})
dataset

Using custom data configuration default-9aa18915b5f32f1a
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'past_history', 'treatment_effect', 'examination', 'label'],
        num_rows: 3756
    })
    test: Dataset({
        features: ['id', 'past_history', 'treatment_effect', 'examination', 'label'],
        num_rows: 940
    })
})

# 3. Data Preprocessing

In [7]:
train_dataset, eval_dataset, test_dataset = bert_ensemble_functions.preprocessing(dataset = dataset,
                                                                                  text_column = text_column, 
                                                                                  label_column = label_column,
                                                                                  id_column = id_column,
                                                                                  model_name = model_name,
                                                                                  train_proportion = train_proportion,
                                                                                  seed = seed,
                                                                                  custom_tokenizer_dir = custom_dir)

Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-3692d4dc669ef6ad.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-febaf5912a5cdc9f.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-b8479d14096fbd79.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-f95228b09d8de56b.arrow


Removing rows with missing value...
Done. (1/4)
Removing special characters...
Done. (2/4)
Tokenining the text column...


Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-5890f6c3a9630c71.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-c4bbd04bce038625.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-fa668232e196b556.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/csv/default-9aa18915b5f32f1a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-fa668232e196b556.arrow


Done. (3/4)
Spliting train-evaluation-test set...
Done. (4/4)


# 4. Modeling

In [8]:
trainer = bert_ensemble_functions.modeling(train_dataset=train_dataset,
                                           eval_dataset=eval_dataset,
                                           model_name='xlm-roberta-base',
                                           num_gpus=num_gpus,
                                           num_cpus=num_cpus,
                                           seed=seed,
                                           output_dir='./output',
                                           logging_dir="./logs",
                                           do_hpo=do_hpo,
                                           std = std,
                                           n_trials = n_trials,
                                           patience = patience,
                                           hpo_result_dir = "./hpo-results",
                                           hpo_result_dir_subfolder_name = 'tune_transformer_pbt'
                                          )

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--xlm-roberta-base/snapshots/42f548f32366559214515ec137cdd16002968bf6/config.json
Model config XLMRobertaConfig {
  "_name_or_path": "xlm-roberta-base",
  "architectures": [
    "XLMRobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "xlm-roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_past": true,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.22.1",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 250002
}

loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--xlm-roberta-base/snapshots/42f

== Status ==
Current time: 2022-11-22 01:55:07 (running for 00:00:00.18)
Memory usage on this node: 9.7/31.1 GiB
PopulationBasedTraining: 0 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PENDING, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------|
| _objective_b0544_00000 | RUNNING  | 172.17.0.3:3937125 |  0.186633 | 2.75091e-05 |              8 |           17 |
| _objective_b0544_00001 | PENDING  |                    |  0.287442 | 4.50373e-05 |              8 |  

== Status ==
Current time: 2022-11-22 01:55:39 (running for 00:00:32.31)
Memory usage on this node: 14.3/31.1 GiB
PopulationBasedTraining: 0 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PENDING, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------|
| _objective_b0544_00000 | RUNNING  | 172.17.0.3:3937125 |  0.186633 | 2.75091e-05 |              8 |           17 |
| _objective_b0544_00001 | PENDING  |                    |  0.287442 | 4.50373e-05 |              8 | 

== Status ==
Current time: 2022-11-22 01:56:08 (running for 00:01:01.74)
Memory usage on this node: 14.7/31.1 GiB
PopulationBasedTraining: 0 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (1 PAUSED, 3 PENDING, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+------

== Status ==
Current time: 2022-11-22 01:56:28 (running for 00:01:21.75)
Memory usage on this node: 14.7/31.1 GiB
PopulationBasedTraining: 0 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (1 PAUSED, 3 PENDING, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+------

== Status ==
Current time: 2022-11-22 01:56:49 (running for 00:01:43.07)
Memory usage on this node: 15.0/31.1 GiB
PopulationBasedTraining: 1 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (2 PAUSED, 2 PENDING, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+------

== Status ==
Current time: 2022-11-22 01:57:09 (running for 00:02:03.08)
Memory usage on this node: 15.0/31.1 GiB
PopulationBasedTraining: 1 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (2 PAUSED, 2 PENDING, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+------

== Status ==
Current time: 2022-11-22 01:57:31 (running for 00:02:24.38)
Memory usage on this node: 14.9/31.1 GiB
PopulationBasedTraining: 1 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (3 PAUSED, 1 PENDING, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+------

== Status ==
Current time: 2022-11-22 01:57:51 (running for 00:02:44.39)
Memory usage on this node: 14.9/31.1 GiB
PopulationBasedTraining: 1 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (3 PAUSED, 1 PENDING, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+------

== Status ==
Current time: 2022-11-22 01:58:12 (running for 00:03:05.78)
Memory usage on this node: 15.1/31.1 GiB
PopulationBasedTraining: 2 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 01:58:32 (running for 00:03:25.79)
Memory usage on this node: 15.1/31.1 GiB
PopulationBasedTraining: 2 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 01:58:54 (running for 00:03:47.32)
Memory usage on this node: 15.2/31.1 GiB
PopulationBasedTraining: 3 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 01:59:14 (running for 00:04:07.33)
Memory usage on this node: 15.2/31.1 GiB
PopulationBasedTraining: 3 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 01:59:37 (running for 00:04:30.47)
Memory usage on this node: 15.2/31.1 GiB
PopulationBasedTraining: 3 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 01:59:57 (running for 00:04:50.48)
Memory usage on this node: 15.2/31.1 GiB
PopulationBasedTraining: 3 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:00:22 (running for 00:05:15.24)
Memory usage on this node: 15.3/31.1 GiB
PopulationBasedTraining: 4 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:00:42 (running for 00:05:35.25)
Memory usage on this node: 15.3/31.1 GiB
PopulationBasedTraining: 4 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:01:05 (running for 00:05:58.11)
Memory usage on this node: 15.2/31.1 GiB
PopulationBasedTraining: 4 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:01:25 (running for 00:06:18.14)
Memory usage on this node: 15.2/31.1 GiB
PopulationBasedTraining: 4 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:01:47 (running for 00:06:40.10)
Memory usage on this node: 15.3/31.1 GiB
PopulationBasedTraining: 5 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:02:07 (running for 00:07:00.11)
Memory usage on this node: 15.3/31.1 GiB
PopulationBasedTraining: 5 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:02:29 (running for 00:07:22.80)
Memory usage on this node: 15.3/31.1 GiB
PopulationBasedTraining: 5 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

Result for _objective_b0544_00002:
  date: 2022-11-22_02-02-48
  done: false
  episodes_total: 0
  epoch: 9.09
  eval_accuracy: 0.7804878048780488
  eval_f1: 0.646288209606987
  eval_loss: 0.5444447994232178
  eval_objective: 1.4267760144850357
  eval_runtime: 3.6839
  eval_samples_per_second: 100.166
  eval_steps_per_second: 12.758
  experiment_id: 623a39e59d5341e4a134f6e6cb379927
  hostname: 3481a8a2ae33
  iterations_since_restore: 2
  node_ip: 172.17.0.3
  objective: 2.8535520289700713
  pid: 3937125
  time_since_restore: 81.53227806091309
  time_this_iter_s: 38.84005117416382
  time_total_s: 122.77449893951416
  timestamp: 1669082568
  timesteps_since_restore: 0
  timesteps_total: 0
  training_iteration: 2
  trial_id: b0544_00002
  warmup_time: 0.0016698837280273438
  
== Status ==
Current time: 2022-11-22 02:02:53 (running for 00:07:46.71)
Memory usage on this node: 15.5/31.1 GiB
PopulationBasedTraining: 5 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/

== Status ==
Current time: 2022-11-22 02:03:13 (running for 00:08:06.73)
Memory usage on this node: 15.5/31.1 GiB
PopulationBasedTraining: 5 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

Result for _objective_b0544_00003:
  date: 2022-11-22_02-03-29
  done: false
  episodes_total: 0
  epoch: 4.52
  eval_accuracy: 0.6476964769647696
  eval_f1: 0.59375
  eval_loss: 0.6592798829078674
  eval_objective: 1.2414464769647697
  eval_runtime: 3.6228
  eval_samples_per_second: 101.854
  eval_steps_per_second: 12.973
  experiment_id: 623a39e59d5341e4a134f6e6cb379927
  hostname: 3481a8a2ae33
  iterations_since_restore: 1
  node_ip: 172.17.0.3
  objective: 2.4828929539295395
  pid: 3937125
  time_since_restore: 41.26776194572449
  time_this_iter_s: 41.26776194572449
  time_total_s: 82.49791860580444
  timestamp: 1669082609
  timesteps_since_restore: 0
  timesteps_total: 0
  training_iteration: 1
  trial_id: b0544_00003
  warmup_time: 0.0016698837280273438
  
== Status ==
Current time: 2022-11-22 02:03:34 (running for 00:08:27.99)
Memory usage on this node: 15.3/31.1 GiB
PopulationBasedTraining: 5 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB h

== Status ==
Current time: 2022-11-22 02:03:54 (running for 00:08:48.00)
Memory usage on this node: 15.3/31.1 GiB
PopulationBasedTraining: 5 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:04:19 (running for 00:09:12.39)
Memory usage on this node: 15.5/31.1 GiB
PopulationBasedTraining: 6 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:04:39 (running for 00:09:32.40)
Memory usage on this node: 15.5/31.1 GiB
PopulationBasedTraining: 6 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:05:01 (running for 00:09:54.89)
Memory usage on this node: 15.4/31.1 GiB
PopulationBasedTraining: 6 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:05:21 (running for 00:10:15.00)
Memory usage on this node: 15.4/31.1 GiB
PopulationBasedTraining: 6 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:05:46 (running for 00:10:39.59)
Memory usage on this node: 15.6/31.1 GiB
PopulationBasedTraining: 7 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:06:06 (running for 00:10:59.60)
Memory usage on this node: 15.6/31.1 GiB
PopulationBasedTraining: 7 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:06:28 (running for 00:11:22.02)
Memory usage on this node: 15.6/31.1 GiB
PopulationBasedTraining: 7 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:06:48 (running for 00:11:42.03)
Memory usage on this node: 15.6/31.1 GiB
PopulationBasedTraining: 7 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:07:09 (running for 00:12:02.40)
Memory usage on this node: 15.4/31.1 GiB
PopulationBasedTraining: 7 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:07:29 (running for 00:12:22.41)
Memory usage on this node: 15.4/31.1 GiB
PopulationBasedTraining: 7 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:07:53 (running for 00:12:46.86)
Memory usage on this node: 15.4/31.1 GiB
PopulationBasedTraining: 8 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

== Status ==
Current time: 2022-11-22 02:08:13 (running for 00:13:06.87)
Memory usage on this node: 15.4/31.1 GiB
PopulationBasedTraining: 8 checkpoints, 0 perturbs
Resources requested: 16.0/16 CPUs, 1.0/1 GPUs, 0.0/14.38 GiB heap, 0.0/7.19 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /workspace/syc/BERT_XGB_ensemble_classification_binary/hpo-results/tune_transformer_pbt
Number of trials: 5/5 (4 PAUSED, 1 RUNNING)
+------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------+------------------+-------------+---------+----------------------+
| Trial name             | status   | loc                |   w_decay |          lr |   train_bs/gpu |   num_epochs |   eval_f1 |   eval_accuracy |   eval_objective |   eval_loss |   epoch |   training_iteration |
|------------------------+----------+--------------------+-----------+-------------+----------------+--------------+-----------+-----------------

2022-11-22 02:08:16,201	INFO tune.py:758 -- Total run time: 789.40 seconds (789.07 seconds for the tuning loop).


In [9]:
trainer.train()

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--xlm-roberta-base/snapshots/42f548f32366559214515ec137cdd16002968bf6/config.json
Model config XLMRobertaConfig {
  "_name_or_path": "xlm-roberta-base",
  "architectures": [
    "XLMRobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "xlm-roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_past": true,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.22.1",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 250002
}

loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--xlm-roberta-base/snapshots/42f

Step,Training Loss,Validation Loss,Accuracy,F1,Objective
50,No log,0.530988,0.680217,0.628931,1.309148
100,No log,0.602372,0.842818,0.72381,1.566628


The following columns in the evaluation set don't have a corresponding argument in `XLMRobertaForSequenceClassification.forward` and have been ignored: text. If text are not expected by `XLMRobertaForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 369
  Batch size = 8
  nn.utils.clip_grad_norm_(
The following columns in the evaluation set don't have a corresponding argument in `XLMRobertaForSequenceClassification.forward` and have been ignored: text. If text are not expected by `XLMRobertaForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 369
  Batch size = 8


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=132, training_loss=0.46376572233257873, metrics={'train_runtime': 100.3509, 'train_samples_per_second': 44.006, 'train_steps_per_second': 1.315, 'total_flos': 1157688643584000.0, 'train_loss': 0.46376572233257873, 'epoch': 11.96})

In [10]:
# save the pretrained model
trainer.model.save_pretrained(custom_dir)

Configuration saved in sev_examination_ensemble/config.json
Model weights saved in sev_examination_ensemble/pytorch_model.bin


In [11]:
df = bert_ensemble_functions.evaluation(trainer = trainer,
                                        eval_dataset = eval_dataset,
                                        text_column_name = text_column
                                        )
df

The following columns in the test set don't have a corresponding argument in `XLMRobertaForSequenceClassification.forward` and have been ignored: text. If text are not expected by `XLMRobertaForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 369
  Batch size = 8


Unnamed: 0,id,text,examination_pred_0,examination_pred_1,label,pred
0,44539,Coronary angiography 20110608 Rt CFA c 7Fr...,3.027344,-2.996094,0,0
1,51148,Brain CT angio122Recent infarction in right M...,-1.655273,1.915039,1,1
2,7150,LCA c JL 54 Diffuse ecc 30 LN of pLAD Discret...,3.164062,-3.042969,0,0
3,55731,chest CT외부영상20120730Smooth interlobular septal...,3.134766,-3.021484,0,0
4,8181,Cangio 12126PCI at dRCA035 GW Terumo014 GW ...,-1.128906,1.189453,1,1
...,...,...,...,...,...,...
364,11335,Coronary Angiography 20130131CAD3VD LMSuccessf...,2.761719,-2.722656,0,0
365,7004,CAGPuncture site RtCFA c 7Fr sheath suture ...,2.931641,-2.789062,0,0
366,8325,MRI Lt foot 201253Extensive infarction is not...,0.149536,-0.189819,0,0
367,15366,Brain MRIAcute infarcts in the right anterior ...,-0.205322,0.240479,0,1


In [12]:
# Check the classification result of each XLM-RoBERTa Model 
from sklearn.metrics import (
    confusion_matrix, 
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score
)

print("Confusion Matrix")
print(confusion_matrix(df.label, df.pred))
print("-------------------------")
accuracy = accuracy_score(df.label, df.pred)
f1 = f1_score(df.label, df.pred)
recall = recall_score(df.label, df.pred)
precision = precision_score(df.label, df.pred)
print(f"Accuracy: {accuracy}")
print(f"F1 score: {f1}")
print(f"Recall: {recall}")
print(f"Precision: {precision}")

Confusion Matrix
[[231  25]
 [ 29  84]]
-------------------------
Accuracy: 0.8536585365853658
F1 score: 0.7567567567567568
Recall: 0.7433628318584071
Precision: 0.7706422018348624


In [13]:
df.to_csv(f'./{text_column}_bert_result_df.csv', index=False)