# SDoH prediction with multi-modal BERT


## Setup

We'll need [the Transformers library](https://huggingface.co/transformers/) by Hugging Face:

In [30]:

import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
import torch

import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap

from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from tqdm.notebook import tqdm

#updated imports for multimodal input 

from dataclasses import dataclass, field
import json
import logging
import os
from typing import Optional

from transformers import (
    AutoTokenizer,
    AutoConfig,
    Trainer,
    EvalPrediction,
    set_seed
)
from transformers.training_args import TrainingArguments

from multimodal_transformers.data import load_data_from_folder
from multimodal_transformers.model import TabularConfig
from multimodal_transformers.model import BertWithTabular#AutoModelWithTabular

#done

%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

rcParams['figure.figsize'] = 12, 8

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Data Exploration


In [31]:
#set data path
va_path = "/home/gridsan/stynan/sdoh_model/sdoh_dataset_processed.xlsx" 

df = pd.read_excel(va_path)

In [32]:
#make list of texts and labels 
df = df[df["edu_level_composite"] != -1] #remove all rows that have missing values or unknowns in edu level 
df.edu_level_composite -= 1
df = df[df["open_response"].notna()]
df.fillna(-1, inplace = True)

df = df.iloc[1:] #removing the first row since it is just not necessary

In [33]:
df.shape

(6834, 107)

In [34]:
PRE_TRAINED_MODEL_NAME = '/home/gridsan/stynan/sdoh_model/bert-base-uncased'

In [35]:
@dataclass
class ModelArguments:
  """
  Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
  """

  model_name_or_path: str = field(
      metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
  )
  config_name: Optional[str] = field(
      default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
  )
  tokenizer_name: Optional[str] = field(
      default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
  )
  cache_dir: Optional[str] = field(
      default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
  )

    
@dataclass
class MultimodalDataTrainingArguments:
  """
  Arguments pertaining to how we combine tabular features
  Using `HfArgumentParser` we can turn this class
  into argparse arguments to be able to specify them on
  the command line.
  """

  data_path: str = field(metadata={
                            'help': 'the path to the csv file containing the dataset'
                        })
  column_info_path: str = field(
      default=None,
      metadata={
          'help': 'the path to the json file detailing which columns are text, categorical, numerical, and the label'
  })

  column_info: dict = field(
      default=None,
      metadata={
          'help': 'a dict referencing the text, categorical, numerical, and label columns'
                  'its keys are text_cols, num_cols, cat_cols, and label_col'
  })

  categorical_encode_type: str = field(default='ohe',
                                        metadata={
                                            'help': 'sklearn encoder to use for categorical data',
                                            'choices': ['ohe', 'binary', 'label', 'none']
                                        })

  numerical_transformer_method: str = field(default='yeo_johnson',
                                            metadata={
                                                'help': 'sklearn numerical transformer to preprocess numerical data',
                                                'choices': ['yeo_johnson', 'box_cox', 'quantile_normal', 'none']
                                            })
  task: str = field(default="classification",
                    metadata={
                        "help": "The downstream training task",
                        "choices": ["classification", "regression"]
                    })

  mlp_division: int = field(default=4,
                            metadata={
                                'help': 'the ratio of the number of '
                                        'hidden dims in a current layer to the next MLP layer'
                            })
  combine_feat_method: str = field(default='individual_mlps_on_cat_and_numerical_feats_then_concat',
                                    metadata={
                                        'help': 'method to combine categorical and numerical features, '
                                                'see README for all the method'
                                    })
  mlp_dropout: float = field(default=0.1,
                              metadata={
                                'help': 'dropout ratio used for MLP layers'
                              })
  numerical_bn: bool = field(default=True,
                              metadata={
                                  'help': 'whether to use batchnorm on numerical features'
                              })
  use_simple_classifier: str = field(default=True,
                                      metadata={
                                          'help': 'whether to use single layer or MLP as final classifier'
                                      })
  mlp_act: str = field(default='relu',
                        metadata={
                            'help': 'the activation function to use for finetuning layers',
                            'choices': ['relu', 'prelu', 'sigmoid', 'tanh', 'linear']
                        })
  gating_beta: float = field(default=0.2,
                              metadata={
                                  'help': "the beta hyperparameters used for gating tabular data "
                                          "see https://www.aclweb.org/anthology/2020.acl-main.214.pdf"
                              })

  def __post_init__(self):
      assert self.column_info != self.column_info_path
      if self.column_info is None and self.column_info_path:
          with open(self.column_info_path, 'r') as f:
              self.column_info = json.load(f)

The tokenizer is doing most of the heavy lifting for us. We also return the review texts, so it'll be easier to evaluate the predictions from our model. Let's split the data:

In [44]:
text_cols = ["open_response"]

cat_cols = ['site_x',
            'gs_text34_x',
            'g1_05', 
            'gs_comorbid1',
            'gs_comorbid2', 
            'g1_06m', 
            'g1_06y', 
            'g1_07a', 
            'g2_01',
            'g4_02',
            'g4_03a',
            'g4_05',
            'g4_08',
            'g5_05',
            'a1_01_2','a1_01_3','a1_01_4','a1_01_5','a1_01_6','a1_01_7','a1_01_8','a1_01_9','a1_01_10','a1_01_11','a1_01_12','a1_01_13','a1_01_14',
            'a3_10',
            'a3_17',
            'a3_18',
            'a4_01',
            'a4_02_1',
            'a4_02_2',
            'a4_02_3',
            'a4_02_4',
            'a4_02_5a',
             'a4_02_5a',
           'a4_02_5b',
            'a4_03',
            'a4_04',
            'a4_05',
           'a4_05',
            'a4_06',
            'a5_02',
            'a5_03',
            'a6_01',
            'a6_02_1',
            'a6_02_2',
            'a6_02_3',
            'a6_02_4',
            'a6_02_5',
            'a6_02_6',
            'a6_02_7',
            'a6_02_8',
            'a6_02_9',
            'a6_02_10',
            'a6_02_11',	
            'a6_02_12a',
            'a6_02_13',	
            'a6_02_14',	
            'a6_02_15',
            'a6_04',
            'a6_05',
            'a6_09',
            'a6_10'
           ]
numerical_cols = ['age_years', 'total_#_hospitals', 'Interview Attemps', 'g4_04', 'g4_07', 'a2_01','a3_11','a5_04' ]

column_info_dict = {
    'text_cols': text_cols,
    'num_cols': numerical_cols,
    'cat_cols': cat_cols,
    'label_col': "edu_level_composite",
    'label_list': [1,2,3,4]
}


model_args = ModelArguments(
    model_name_or_path='/home/gridsan/stynan/sdoh_model/bert-base-uncased'
)

data_args = MultimodalDataTrainingArguments(
    data_path='/home/gridsan/stynan/sdoh_model/',
    combine_feat_method='gating_on_cat_and_num_feats_then_sum',
    column_info=column_info_dict,
    task='classification'
)

training_args = TrainingArguments(
    output_dir="./logs/model_name",
    logging_dir="./logs/runs",
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    do_predict = True,
    per_device_train_batch_size=16,
    num_train_epochs=20,
    evaluate_during_training=True,
    logging_steps=25,
    eval_steps=250
)

set_seed(training_args.seed)

In [45]:
#tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

tokenizer_path_or_name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
print('Specified tokenizer: ', tokenizer_path_or_name)
tokenizer = BertTokenizer.from_pretrained(
    tokenizer_path_or_name,
    cache_dir=model_args.cache_dir,
    model_max_length=512
)

Specified tokenizer:  /home/gridsan/stynan/sdoh_model/bert-base-uncased


In [46]:
df_train, df_test = train_test_split(df, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df, test_size=0.5, random_state=RANDOM_SEED)

In [47]:
df_train.to_csv('train.csv')
df_val.to_csv('val.csv')
df_test.to_csv('test.csv')

In [48]:
# Get Datasets
train_dataset, val_dataset, test_dataset = load_data_from_folder(
    data_args.data_path,
    data_args.column_info['text_cols'],
    tokenizer,
    label_col=data_args.column_info['label_col'],
    label_list=data_args.column_info['label_list'],
    categorical_cols=data_args.column_info['cat_cols'],
    numerical_cols=data_args.column_info['num_cols'],
    sep_text_token_str=tokenizer.sep_token,
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[num_cols] = df[num_cols].fillna(df[num_cols].median())
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[num_cols] = df[num_cols].fillna(df[num_cols].median())
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[num_cols] = df[num_cols].fillna(df[num_cols].median())
A value is trying to be set on 

In [50]:
config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
num_labels = 4
tabular_config = TabularConfig(num_labels=num_labels,
                               cat_feat_dim=train_dataset.cat_feats.shape[1],
                               numerical_feat_dim=train_dataset.numerical_feats.shape[1],
                               **vars(data_args))
config.tabular_config = tabular_config

## Classification with BERT and Hugging Face

In [51]:

model = BertWithTabular.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir
    )

model = model.to(device)

Some weights of the model checkpoint at /home/gridsan/stynan/sdoh_model/bert-base-uncased were not used when initializing BertWithTabular: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertWithTabular from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertWithTabular from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertWithTabular were not initialized from the model checkpoint at /home/gridsan/stynan/sd

We can use all of this knowledge to create a classifier that uses the BERT model:

In [52]:
from scipy.special import softmax
from sklearn.metrics import (
    auc,
    precision_recall_curve,
    roc_auc_score,
    f1_score,
    confusion_matrix,
    matthews_corrcoef,
)

def calc_classification_metrics(p: EvalPrediction):
  pred_labels = np.argmax(p.predictions, axis=1)
  pred_scores = softmax(p.predictions, axis=1)[:, 1]
  labels = p.label_ids
  if len(np.unique(labels)) == 2:  # binary classification
      roc_auc_pred_score = roc_auc_score(labels, pred_scores)
      precisions, recalls, thresholds = precision_recall_curve(labels,
                                                                pred_scores)
      fscore = (2 * precisions * recalls) / (precisions + recalls)
      fscore[np.isnan(fscore)] = 0
      ix = np.argmax(fscore)
      threshold = thresholds[ix].item()
      pr_auc = auc(recalls, precisions)
      tn, fp, fn, tp = confusion_matrix(labels, pred_labels, labels=[0, 1]).ravel()
      result = {'roc_auc': roc_auc_pred_score,
                'threshold': threshold,
                'pr_auc': pr_auc,
                'recall': recalls[ix].item(),
                'precision': precisions[ix].item(), 'f1': fscore[ix].item(),
                'tn': tn.item(), 'fp': fp.item(), 'fn': fn.item(), 'tp': tp.item()
                }
  else:
      acc = (pred_labels == labels).mean()
      f1 = f1_score(y_true=labels, y_pred=pred_labels, average = 'weighted')
      result = {
          'precidsion'
          "acc": acc,
          "f1": f1,
          "acc_and_f1": (acc + f1) / 2,
          "mcc": matthews_corrcoef(labels, pred_labels)
      }

  return result

### Training

In [53]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    #test_dataset = test_dataset,
    compute_metrics=calc_classification_metrics,
)

In [54]:
%%time
trainer.train()

HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))



{'loss': 1.292910614013672, 'learning_rate': 4.967616580310881e-05, 'epoch': 0.12953367875647667, 'step': 25}
{'loss': 1.2415359497070313, 'learning_rate': 4.9352331606217614e-05, 'epoch': 0.25906735751295334, 'step': 50}
{'loss': 1.1841990661621093, 'learning_rate': 4.902849740932643e-05, 'epoch': 0.38860103626943004, 'step': 75}
{'loss': 1.212017822265625, 'learning_rate': 4.870466321243523e-05, 'epoch': 0.5181347150259067, 'step': 100}
{'loss': 1.1837298583984375, 'learning_rate': 4.8380829015544046e-05, 'epoch': 0.6476683937823834, 'step': 125}
{'loss': 1.1828997802734376, 'learning_rate': 4.805699481865285e-05, 'epoch': 0.7772020725388601, 'step': 150}
{'loss': 1.1289813232421875, 'learning_rate': 4.773316062176166e-05, 'epoch': 0.9067357512953368, 'step': 175}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 1.0936590576171874, 'learning_rate': 4.740932642487047e-05, 'epoch': 1.0362694300518134, 'step': 200}
{'loss': 1.1113958740234375, 'learning_rate': 4.708549222797928e-05, 'epoch': 1.16580310880829, 'step': 225}
{'loss': 1.155003662109375, 'learning_rate': 4.676165803108808e-05, 'epoch': 1.2953367875647668, 'step': 250}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 1.0588400294845621, 'eval_acc': 0.5127304653204565, 'eval_f1': 0.41148551066325006, 'eval_acc_and_f1': 0.46210798799185326, 'eval_mcc': 0.23629790787361776, 'epoch': 1.2953367875647668, 'step': 250}
{'loss': 1.072750244140625, 'learning_rate': 4.643782383419689e-05, 'epoch': 1.4248704663212435, 'step': 275}
{'loss': 1.103773193359375, 'learning_rate': 4.61139896373057e-05, 'epoch': 1.5544041450777202, 'step': 300}
{'loss': 1.06382568359375, 'learning_rate': 4.5790155440414514e-05, 'epoch': 1.6839378238341969, 'step': 325}
{'loss': 1.092694091796875, 'learning_rate': 4.546632124352332e-05, 'epoch': 1.8134715025906736, 'step': 350}
{'loss': 1.105372314453125, 'learning_rate': 4.5142487046632126e-05, 'epoch': 1.9430051813471503, 'step': 375}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 1.012935791015625, 'learning_rate': 4.481865284974093e-05, 'epoch': 2.0725388601036268, 'step': 400}
{'loss': 1.046302490234375, 'learning_rate': 4.4494818652849745e-05, 'epoch': 2.2020725388601035, 'step': 425}
{'loss': 1.030330810546875, 'learning_rate': 4.417098445595855e-05, 'epoch': 2.33160621761658, 'step': 450}
{'loss': 0.976514892578125, 'learning_rate': 4.384715025906736e-05, 'epoch': 2.461139896373057, 'step': 475}
{'loss': 0.97410400390625, 'learning_rate': 4.352331606217617e-05, 'epoch': 2.5906735751295336, 'step': 500}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.8771921090295449, 'eval_acc': 0.6379865378987416, 'eval_f1': 0.611303937536368, 'eval_acc_and_f1': 0.6246452377175549, 'eval_mcc': 0.4702629571091348, 'epoch': 2.5906735751295336, 'step': 500}
{'loss': 1.0104248046875, 'learning_rate': 4.3199481865284976e-05, 'epoch': 2.7202072538860103, 'step': 525}
{'loss': 0.9827783203125, 'learning_rate': 4.287564766839379e-05, 'epoch': 2.849740932642487, 'step': 550}
{'loss': 0.9934716796875, 'learning_rate': 4.2551813471502595e-05, 'epoch': 2.9792746113989637, 'step': 575}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.92469482421875, 'learning_rate': 4.22279792746114e-05, 'epoch': 3.1088082901554404, 'step': 600}
{'loss': 0.8411572265625, 'learning_rate': 4.190414507772021e-05, 'epoch': 3.238341968911917, 'step': 625}
{'loss': 0.8701708984375, 'learning_rate': 4.158031088082901e-05, 'epoch': 3.3678756476683938, 'step': 650}
{'loss': 0.8166259765625, 'learning_rate': 4.1256476683937825e-05, 'epoch': 3.4974093264248705, 'step': 675}
{'loss': 0.8501904296875, 'learning_rate': 4.093264248704664e-05, 'epoch': 3.626943005181347, 'step': 700}
{'loss': 0.83544189453125, 'learning_rate': 4.0608808290155444e-05, 'epoch': 3.756476683937824, 'step': 725}
{'loss': 0.82861328125, 'learning_rate': 4.028497409326425e-05, 'epoch': 3.8860103626943006, 'step': 750}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.62451481780777, 'eval_acc': 0.7632426104770267, 'eval_f1': 0.7610234148388905, 'eval_acc_and_f1': 0.7621330126579586, 'eval_mcc': 0.6558430528014217, 'epoch': 3.8860103626943006, 'step': 750}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.81122314453125, 'learning_rate': 3.9961139896373056e-05, 'epoch': 4.015544041450777, 'step': 775}
{'loss': 0.63302001953125, 'learning_rate': 3.963730569948187e-05, 'epoch': 4.1450777202072535, 'step': 800}
{'loss': 0.58846435546875, 'learning_rate': 3.9313471502590675e-05, 'epoch': 4.274611398963731, 'step': 825}
{'loss': 0.6130322265625, 'learning_rate': 3.898963730569948e-05, 'epoch': 4.404145077720207, 'step': 850}
{'loss': 0.58425048828125, 'learning_rate': 3.8665803108808294e-05, 'epoch': 4.533678756476684, 'step': 875}
{'loss': 0.6055517578125, 'learning_rate': 3.83419689119171e-05, 'epoch': 4.66321243523316, 'step': 900}
{'loss': 0.63435546875, 'learning_rate': 3.801813471502591e-05, 'epoch': 4.7927461139896375, 'step': 925}
{'loss': 0.58943115234375, 'learning_rate': 3.769430051813472e-05, 'epoch': 4.922279792746114, 'step': 950}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.540458984375, 'learning_rate': 3.7370466321243525e-05, 'epoch': 5.051813471502591, 'step': 975}
{'loss': 0.4073486328125, 'learning_rate': 3.704663212435233e-05, 'epoch': 5.181347150259067, 'step': 1000}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.2952756170016616, 'eval_acc': 0.8964003511852502, 'eval_f1': 0.8949642226962585, 'eval_acc_and_f1': 0.8956822869407544, 'eval_mcc': 0.8505040014809454, 'epoch': 5.181347150259067, 'step': 1000}
{'loss': 0.38732666015625, 'learning_rate': 3.6722797927461137e-05, 'epoch': 5.310880829015544, 'step': 1025}
{'loss': 0.34810302734375, 'learning_rate': 3.639896373056995e-05, 'epoch': 5.4404145077720205, 'step': 1050}
{'loss': 0.43593994140625, 'learning_rate': 3.6075129533678755e-05, 'epoch': 5.569948186528498, 'step': 1075}
{'loss': 0.39218505859375, 'learning_rate': 3.575129533678757e-05, 'epoch': 5.699481865284974, 'step': 1100}
{'loss': 0.46605712890625, 'learning_rate': 3.5427461139896374e-05, 'epoch': 5.829015544041451, 'step': 1125}
{'loss': 0.42658935546875, 'learning_rate': 3.510362694300519e-05, 'epoch': 5.958549222797927, 'step': 1150}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.37089111328125, 'learning_rate': 3.477979274611399e-05, 'epoch': 6.0880829015544045, 'step': 1175}
{'loss': 0.26534912109375, 'learning_rate': 3.44559585492228e-05, 'epoch': 6.217616580310881, 'step': 1200}
{'loss': 0.2774462890625, 'learning_rate': 3.4132124352331605e-05, 'epoch': 6.347150259067358, 'step': 1225}
{'loss': 0.3357666015625, 'learning_rate': 3.380829015544041e-05, 'epoch': 6.476683937823834, 'step': 1250}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.1991715617815863, 'eval_acc': 0.9250804799531753, 'eval_f1': 0.9252608285043796, 'eval_acc_and_f1': 0.9251706542287774, 'eval_mcc': 0.8915693770637129, 'epoch': 6.476683937823834, 'step': 1250}
{'loss': 0.23931396484375, 'learning_rate': 3.3484455958549224e-05, 'epoch': 6.606217616580311, 'step': 1275}
{'loss': 0.317763671875, 'learning_rate': 3.3160621761658036e-05, 'epoch': 6.7357512953367875, 'step': 1300}
{'loss': 0.3760107421875, 'learning_rate': 3.283678756476684e-05, 'epoch': 6.865284974093264, 'step': 1325}
{'loss': 0.298056640625, 'learning_rate': 3.251295336787565e-05, 'epoch': 6.994818652849741, 'step': 1350}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.2389697265625, 'learning_rate': 3.2189119170984454e-05, 'epoch': 7.124352331606218, 'step': 1375}
{'loss': 0.222236328125, 'learning_rate': 3.186528497409327e-05, 'epoch': 7.253886010362694, 'step': 1400}
{'loss': 0.1974755859375, 'learning_rate': 3.154145077720207e-05, 'epoch': 7.383419689119171, 'step': 1425}
{'loss': 0.2406298828125, 'learning_rate': 3.121761658031088e-05, 'epoch': 7.512953367875648, 'step': 1450}
{'loss': 0.2478125, 'learning_rate': 3.089378238341969e-05, 'epoch': 7.642487046632124, 'step': 1475}
{'loss': 0.207392578125, 'learning_rate': 3.05699481865285e-05, 'epoch': 7.772020725388601, 'step': 1500}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.15611581033586863, 'eval_acc': 0.9388352355867721, 'eval_f1': 0.9383050274376387, 'eval_acc_and_f1': 0.9385701315122054, 'eval_mcc': 0.9121355580776338, 'epoch': 7.772020725388601, 'step': 1500}
{'loss': 0.244248046875, 'learning_rate': 3.024611398963731e-05, 'epoch': 7.901554404145077, 'step': 1525}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.2060009765625, 'learning_rate': 2.9922279792746117e-05, 'epoch': 8.031088082901555, 'step': 1550}
{'loss': 0.1735498046875, 'learning_rate': 2.9598445595854923e-05, 'epoch': 8.160621761658032, 'step': 1575}
{'loss': 0.1785888671875, 'learning_rate': 2.9274611398963732e-05, 'epoch': 8.290155440414507, 'step': 1600}
{'loss': 0.1903662109375, 'learning_rate': 2.8950777202072538e-05, 'epoch': 8.419689119170984, 'step': 1625}
{'loss': 0.17611328125, 'learning_rate': 2.862694300518135e-05, 'epoch': 8.549222797927461, 'step': 1650}
{'loss': 0.2047607421875, 'learning_rate': 2.8303108808290157e-05, 'epoch': 8.678756476683938, 'step': 1675}
{'loss': 0.203388671875, 'learning_rate': 2.7979274611398963e-05, 'epoch': 8.808290155440414, 'step': 1700}
{'loss': 0.1852587890625, 'learning_rate': 2.7655440414507772e-05, 'epoch': 8.937823834196891, 'step': 1725}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.1826220703125, 'learning_rate': 2.7331606217616585e-05, 'epoch': 9.067357512953368, 'step': 1750}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.14628202642374372, 'eval_acc': 0.9461515949663447, 'eval_f1': 0.9460366321192708, 'eval_acc_and_f1': 0.9460941135428078, 'eval_mcc': 0.9222917008029957, 'epoch': 9.067357512953368, 'step': 1750}
{'loss': 0.2185888671875, 'learning_rate': 2.700777202072539e-05, 'epoch': 9.196891191709845, 'step': 1775}
{'loss': 0.1586962890625, 'learning_rate': 2.6683937823834197e-05, 'epoch': 9.32642487046632, 'step': 1800}
{'loss': 0.1746826171875, 'learning_rate': 2.6360103626943007e-05, 'epoch': 9.455958549222798, 'step': 1825}
{'loss': 0.1923779296875, 'learning_rate': 2.6036269430051813e-05, 'epoch': 9.585492227979275, 'step': 1850}
{'loss': 0.1542626953125, 'learning_rate': 2.5712435233160625e-05, 'epoch': 9.715025906735752, 'step': 1875}
{'loss': 0.1759033203125, 'learning_rate': 2.538860103626943e-05, 'epoch': 9.844559585492227, 'step': 1900}
{'loss': 0.1715771484375, 'learning_rate': 2.506476683937824e-05, 'epoch': 9.974093264248705, 'step': 1925}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.178447265625, 'learning_rate': 2.4740932642487047e-05, 'epoch': 10.103626943005182, 'step': 1950}
{'loss': 0.1615087890625, 'learning_rate': 2.4417098445595856e-05, 'epoch': 10.233160621761659, 'step': 1975}
{'loss': 0.1514697265625, 'learning_rate': 2.4093264248704665e-05, 'epoch': 10.362694300518134, 'step': 2000}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.1184456596213137, 'eval_acc': 0.95171202809482, 'eval_f1': 0.9515342907527305, 'eval_acc_and_f1': 0.9516231594237753, 'eval_mcc': 0.9303215421005616, 'epoch': 10.362694300518134, 'step': 2000}
{'loss': 0.16134765625, 'learning_rate': 2.3769430051813475e-05, 'epoch': 10.492227979274611, 'step': 2025}
{'loss': 0.1583349609375, 'learning_rate': 2.344559585492228e-05, 'epoch': 10.621761658031089, 'step': 2050}
{'loss': 0.1565869140625, 'learning_rate': 2.3121761658031087e-05, 'epoch': 10.751295336787564, 'step': 2075}
{'loss': 0.1833642578125, 'learning_rate': 2.27979274611399e-05, 'epoch': 10.880829015544041, 'step': 2100}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.1531298828125, 'learning_rate': 2.2474093264248706e-05, 'epoch': 11.010362694300518, 'step': 2125}
{'loss': 0.1168212890625, 'learning_rate': 2.2150259067357515e-05, 'epoch': 11.139896373056995, 'step': 2150}
{'loss': 0.1133447265625, 'learning_rate': 2.182642487046632e-05, 'epoch': 11.26943005181347, 'step': 2175}
{'loss': 0.1337353515625, 'learning_rate': 2.150259067357513e-05, 'epoch': 11.398963730569948, 'step': 2200}
{'loss': 0.1373095703125, 'learning_rate': 2.117875647668394e-05, 'epoch': 11.528497409326425, 'step': 2225}
{'loss': 0.1865966796875, 'learning_rate': 2.0854922279792746e-05, 'epoch': 11.658031088082902, 'step': 2250}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))

{'loss': 0.12515625, 'learning_rate': 1.89119170984456e-05, 'epoch': 12.435233160621761, 'step': 2400}
{'loss': 0.127138671875, 'learning_rate': 1.8588082901554405e-05, 'epoch': 12.564766839378239, 'step': 2425}
{'loss': 0.1128271484375, 'learning_rate': 1.8264248704663214e-05, 'epoch': 12.694300518134716, 'step': 2450}
{'loss': 0.1319970703125, 'learning_rate': 1.794041450777202e-05, 'epoch': 12.823834196891191, 'step': 2475}
{'loss': 0.1269775390625, 'learning_rate': 1.761658031088083e-05, 'epoch': 12.953367875647668, 'step': 2500}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.10421038287908999, 'eval_acc': 0.9628328943517706, 'eval_f1': 0.9626922412593568, 'eval_acc_and_f1': 0.9627625678055637, 'eval_mcc': 0.94653790295311, 'epoch': 12.953367875647668, 'step': 2500}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.13748046875, 'learning_rate': 1.729274611398964e-05, 'epoch': 13.082901554404145, 'step': 2525}
{'loss': 0.1096630859375, 'learning_rate': 1.6968911917098445e-05, 'epoch': 13.212435233160623, 'step': 2550}
{'loss': 0.131689453125, 'learning_rate': 1.6645077720207254e-05, 'epoch': 13.341968911917098, 'step': 2575}
{'loss': 0.1512109375, 'learning_rate': 1.6321243523316064e-05, 'epoch': 13.471502590673575, 'step': 2600}
{'loss': 0.1172021484375, 'learning_rate': 1.5997409326424873e-05, 'epoch': 13.601036269430052, 'step': 2625}
{'loss': 0.118916015625, 'learning_rate': 1.567357512953368e-05, 'epoch': 13.73056994818653, 'step': 2650}
{'loss': 0.09666015625, 'learning_rate': 1.534974093264249e-05, 'epoch': 13.860103626943005, 'step': 2675}
{'loss': 0.1298583984375, 'learning_rate': 1.5025906735751296e-05, 'epoch': 13.989637305699482, 'step': 2700}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.1103759765625, 'learning_rate': 1.4702072538860104e-05, 'epoch': 14.119170984455959, 'step': 2725}
{'loss': 0.1145458984375, 'learning_rate': 1.4378238341968913e-05, 'epoch': 14.248704663212436, 'step': 2750}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.09437631462726012, 'eval_acc': 0.9625402399765877, 'eval_f1': 0.9623828666178491, 'eval_acc_and_f1': 0.9624615532972184, 'eval_mcc': 0.946245155252108, 'epoch': 14.248704663212436, 'step': 2750}
{'loss': 0.112802734375, 'learning_rate': 1.4054404145077721e-05, 'epoch': 14.378238341968911, 'step': 2775}
{'loss': 0.1399853515625, 'learning_rate': 1.3730569948186529e-05, 'epoch': 14.507772020725389, 'step': 2800}
{'loss': 0.120517578125, 'learning_rate': 1.3406735751295338e-05, 'epoch': 14.637305699481866, 'step': 2825}
{'loss': 0.1113818359375, 'learning_rate': 1.3082901554404146e-05, 'epoch': 14.766839378238341, 'step': 2850}
{'loss': 0.1295458984375, 'learning_rate': 1.2759067357512955e-05, 'epoch': 14.896373056994818, 'step': 2875}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.1136376953125, 'learning_rate': 1.2435233160621763e-05, 'epoch': 15.025906735751295, 'step': 2900}
{'loss': 0.131552734375, 'learning_rate': 1.211139896373057e-05, 'epoch': 15.155440414507773, 'step': 2925}
{'loss': 0.1014404296875, 'learning_rate': 1.1787564766839378e-05, 'epoch': 15.284974093264248, 'step': 2950}
{'loss': 0.086689453125, 'learning_rate': 1.1463730569948188e-05, 'epoch': 15.414507772020725, 'step': 2975}
{'loss': 0.10732421875, 'learning_rate': 1.1139896373056995e-05, 'epoch': 15.544041450777202, 'step': 3000}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



{'loss': 0.13017578125, 'learning_rate': 8.22538860103627e-06, 'epoch': 16.70984455958549, 'step': 3225}
{'loss': 0.090078125, 'learning_rate': 7.901554404145079e-06, 'epoch': 16.83937823834197, 'step': 3250}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.08367527191676276, 'eval_acc': 0.9692712906057945, 'eval_f1': 0.9691987107319959, 'eval_acc_and_f1': 0.9692350006688952, 'eval_mcc': 0.9557588926538138, 'epoch': 16.83937823834197, 'step': 3250}
{'loss': 0.107685546875, 'learning_rate': 7.577720207253887e-06, 'epoch': 16.968911917098445, 'step': 3275}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.095205078125, 'learning_rate': 7.253886010362694e-06, 'epoch': 17.098445595854923, 'step': 3300}
{'loss': 0.091728515625, 'learning_rate': 6.930051813471503e-06, 'epoch': 17.2279792746114, 'step': 3325}
{'loss': 0.0918603515625, 'learning_rate': 6.6062176165803115e-06, 'epoch': 17.357512953367877, 'step': 3350}
{'loss': 0.1035107421875, 'learning_rate': 6.282383419689119e-06, 'epoch': 17.487046632124354, 'step': 3375}
{'loss': 0.101865234375, 'learning_rate': 5.958549222797928e-06, 'epoch': 17.616580310880828, 'step': 3400}
{'loss': 0.079228515625, 'learning_rate': 5.634715025906736e-06, 'epoch': 17.746113989637305, 'step': 3425}
{'loss': 0.106396484375, 'learning_rate': 5.310880829015545e-06, 'epoch': 17.875647668393782, 'step': 3450}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.101064453125, 'learning_rate': 4.9870466321243525e-06, 'epoch': 18.00518134715026, 'step': 3475}
{'loss': 0.0868603515625, 'learning_rate': 4.663212435233161e-06, 'epoch': 18.134715025906736, 'step': 3500}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.08178338511537064, 'eval_acc': 0.971027216856892, 'eval_f1': 0.9710071521734951, 'eval_acc_and_f1': 0.9710171845151936, 'eval_mcc': 0.958109425351346, 'epoch': 18.134715025906736, 'step': 3500}
{'loss': 0.1038818359375, 'learning_rate': 4.33937823834197e-06, 'epoch': 18.264248704663213, 'step': 3525}
{'loss': 0.0922705078125, 'learning_rate': 4.015544041450777e-06, 'epoch': 18.39378238341969, 'step': 3550}
{'loss': 0.09341796875, 'learning_rate': 3.6917098445595854e-06, 'epoch': 18.523316062176164, 'step': 3575}
{'loss': 0.082939453125, 'learning_rate': 3.367875647668394e-06, 'epoch': 18.65284974093264, 'step': 3600}
{'loss': 0.10423828125, 'learning_rate': 3.044041450777202e-06, 'epoch': 18.78238341968912, 'step': 3625}
{'loss': 0.08978515625, 'learning_rate': 2.7202072538860106e-06, 'epoch': 18.911917098445596, 'step': 3650}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=193.0), HTML(value='')))

{'loss': 0.1018359375, 'learning_rate': 2.3963730569948187e-06, 'epoch': 19.041450777202073, 'step': 3675}
{'loss': 0.060986328125, 'learning_rate': 2.0725388601036273e-06, 'epoch': 19.17098445595855, 'step': 3700}
{'loss': 0.0896337890625, 'learning_rate': 1.7487046632124352e-06, 'epoch': 19.300518134715027, 'step': 3725}
{'loss': 0.090400390625, 'learning_rate': 1.4248704663212437e-06, 'epoch': 19.430051813471504, 'step': 3750}


HBox(children=(HTML(value='Evaluation'), FloatProgress(value=0.0, max=214.0), HTML(value='')))


{'eval_loss': 0.07844798672205136, 'eval_acc': 0.9721978343576236, 'eval_f1': 0.9721372771557468, 'eval_acc_and_f1': 0.9721675557566852, 'eval_mcc': 0.9598868526662633, 'epoch': 19.430051813471504, 'step': 3750}
{'loss': 0.0834716796875, 'learning_rate': 1.1010362694300518e-06, 'epoch': 19.559585492227978, 'step': 3775}
{'loss': 0.08548828125, 'learning_rate': 7.772020725388602e-07, 'epoch': 19.689119170984455, 'step': 3800}
{'loss': 0.1129638671875, 'learning_rate': 4.533678756476684e-07, 'epoch': 19.818652849740932, 'step': 3825}
{'loss': 0.0868701171875, 'learning_rate': 1.295336787564767e-07, 'epoch': 19.94818652849741, 'step': 3850}


CPU times: user 1h 10min 6s, sys: 26min 24s, total: 1h 36min 31s
Wall time: 57min 15s


TrainOutput(global_step=3860, training_loss=0.3605542118685233)

In [55]:
torch.save(model.state_dict(), 'best_multimodal_model_state.bin')
model.eval()

BertWithTabular(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

## Evaluation

So how good is our model? Let's start by calculating the accuracy on the test data:

In [57]:
%%time
output = trainer.predict(test_dataset)

HBox(children=(HTML(value='Prediction'), FloatProgress(value=0.0, max=214.0), HTML(value='')))




CPU times: user 29 s, sys: 8.34 s, total: 37.4 s
Wall time: 34.8 s


In [61]:
test_dataset.labels

array([2, 0, 2, ..., 0, 0, 2])

In [59]:
best_predictions = output.predictions

In [63]:
best_predictions

array([[ 4.2196665 , -4.417337  ,  0.21525875,  1.3477846 ],
       [ 5.4384613 ,  1.0589525 , -3.375457  , -3.4238255 ],
       [ 2.7913105 , -4.1859922 ,  5.3319745 , -3.644537  ],
       ...,
       [ 8.368958  , -2.649276  , -2.1181278 , -2.9907756 ],
       [ 8.224091  , -3.5231922 , -1.4661903 , -2.4355128 ],
       [ 2.11276   , -2.4767573 ,  2.5345497 , -1.3108507 ]],
      dtype=float32)

In [64]:
print(classification_report(test_dataset.labels, np.argmax(output.predictions, axis=1)))

              precision    recall  f1-score   support

           0       0.85      0.92      0.88      1501
           1       0.86      0.86      0.86       760
           2       0.85      0.76      0.81       693
           3       0.88      0.82      0.85       463

    accuracy                           0.86      3417
   macro avg       0.86      0.84      0.85      3417
weighted avg       0.86      0.86      0.86      3417



This is similar to the evaluation function, except that we're storing the text of the reviews and the predicted probabilities (by applying the softmax on the model outputs):

## References

- [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)
- [L11 Language Models - Alec Radford (OpenAI)](https://www.youtube.com/watch?v=BnpB3GrpsfM)
- [The Illustrated BERT, ELMo, and co.](https://jalammar.github.io/illustrated-bert/)
- [BERT Fine-Tuning Tutorial with PyTorch](https://mccormickml.com/2019/07/22/BERT-fine-tuning/)
- [How to Fine-Tune BERT for Text Classification?](https://arxiv.org/pdf/1905.05583.pdf)
- [Huggingface Transformers](https://huggingface.co/transformers/)
- [BERT Explained: State of the art language model for NLP](https://towardsdatascience.com/bert-explained-state-of-the-art-language-model-for-nlp-f8b21a9b6270)