# Entity Resolution project @ Wavestone
## Fodors-Zagats Restaurant Matching

> *Datasets information from [here](Datasets.md) \
> Description to do but only take the raw data because the not raw data was already pre-processed*

> **Tristan PERROT**


## Import libraries


In [1]:
import os

import torch
import pickle

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [2]:
while 'model' not in os.listdir():
    os.chdir('..')

In [3]:
MODEL_NAME = 'bert-base-uncased'
DATA_NAME = 'fodors-zagats'
COMPUTER = 'gpu4.enst.fr:0'
DATA_DIR = os.path.join('data', DATA_NAME)

## Load data

In [4]:
import numpy as np

from model.BertModel import BertModel
from model.utils import deserialize_entities, load_data

In [5]:
X_train_ids, y_train, X_valid_ids, y_valid, X_test_ids, y_test = load_data(DATA_DIR, comp_er=False)
with open(os.path.join(DATA_DIR, '1_serialized.pkl'), 'rb') as f:
    table_a_serialized = pickle.load(f)
with open(os.path.join(DATA_DIR, '2_serialized.pkl'), 'rb') as f:
    table_b_serialized = pickle.load(f)

In [6]:
X_train = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_train_ids]
X_valid = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_valid_ids]
X_test = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_test_ids]

In [7]:
# Display the first 5 samples of the training set
for i in range(5):
    print(f'Sample {i}:')
    print(X_train[i])
    print(f'Label: {y_train[i]}')
    print()

Sample 0:
[COL] name [VAL] 'georgia grille' [COL] addr [VAL] '2290 peachtree rd.  peachtree square shopping center' [COL] city [VAL] atlanta [COL] phone [VAL] 404/352-3517 [COL] type [VAL] american [SEP] [COL] name [VAL] 'union square cafe' [COL] addr [VAL] '21 e. 16th st.' [COL] city [VAL] 'new york city' [COL] phone [VAL] 212-243-4020 [COL] type [VAL] 'american (new)'
Label: 0

Sample 1:
[COL] name [VAL] 'georgia grille' [COL] addr [VAL] '2290 peachtree rd.  peachtree square shopping center' [COL] city [VAL] atlanta [COL] phone [VAL] 404/352-3517 [COL] type [VAL] american [SEP] [COL] name [VAL] 'gotham bar & grill' [COL] addr [VAL] '12 e. 12th st.' [COL] city [VAL] 'new york city' [COL] phone [VAL] 212-620-4020 [COL] type [VAL] 'american (new)'
Label: 1

Sample 2:
[COL] name [VAL] 'four seasons grill room' [COL] addr [VAL] '99 e. 52nd st.' [COL] city [VAL] 'new york' [COL] phone [VAL] 212/754-9494 [COL] type [VAL] american [SEP] [COL] name [VAL] delectables [COL] addr [VAL] '1 margar

### RoBERTa Base
- Architecture: Transformer-based model
- Parameters: ~125 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [22]:
MODEL_NAME = 'roberta-base'

In [23]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: roberta-base
Study name: Fodors-Zagats/roberta-base@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.41batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.20batch/s]


Train loss: 0.6839
Val loss: 0.6630
              precision    recall  f1-score   support

           0       0.62      1.00      0.76        16
           1       1.00      0.38      0.55        16

    accuracy                           0.69        32
   macro avg       0.81      0.69      0.65        32
weighted avg       0.81      0.69      0.65        32

Epoch 2/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.54batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.15batch/s]


Train loss: 0.5746
Val loss: 0.4454
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 3/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.53batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.58batch/s]


Train loss: 0.3127
Val loss: 0.1251
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 4/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.52batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.17batch/s]


Train loss: 0.1151
Val loss: 0.0153
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 5/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.53batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.27batch/s]


Train loss: 0.0278
Val loss: 0.0057
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 6/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.53batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.28batch/s]


Train loss: 0.0351
Val loss: 0.0030
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 7/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.52batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.20batch/s]


Train loss: 0.0090
Val loss: 0.0092
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 8/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.53batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.22batch/s]


Train loss: 0.0066
Val loss: 0.0019
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 9/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.53batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.34batch/s]


Train loss: 0.0030
Val loss: 0.1142
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 10/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.51batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.15batch/s]

Train loss: 0.0082
Val loss: 0.0010
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Training completed
Best weights restored
Mean time per epoch: 3.79





In [24]:
y_pred = model.evaluate(test_loader)

Testing: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.12batch/s]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        18
           1       1.00      1.00      1.00        18

    accuracy                           1.00        36
   macro avg       1.00      1.00      1.00        36
weighted avg       1.00      1.00      1.00        36

Test loss: 0.00093





In [25]:
# Exemple of prediction
print(f'Prediction: {y_pred[0]}')
print(f'Label: {y_test[0]}')
print(f'Sample: {X_test[0]}')

Prediction: 1
Label: 1
Sample: <col>name<val>'le montrachet' <col>addr<val>'3000 w. paradise rd.' <col>city<val>'las vegas' <col>phone<val>702/732-5111 <col>type<val>continental <col>class<val>69<sep><col>name<val>'le montrachet bistro' <col>addr<val>'3000 paradise rd.' <col>city<val>'las vegas' <col>phone<val>702-732-5651 <col>type<val>'french bistro' <col>class<val>69


In [26]:
e1_df, e2_df = deserialize_entities(X_test[np.nonzero(y_test)[0][0]])
print('Entity 1:')
display(e1_df)
print('Entity 2:')
display(e2_df)
print(f'Label: {y_test[np.nonzero(y_test)[0][0]]}')
print(f'Prediction: {y_pred[np.nonzero(y_test)[0][0]]}')

Entity 1:


column,name,addr,city,phone,type,class
0,'le montrachet','3000 w. paradise rd.','las vegas',702/732-5111,continental,69


Entity 2:


column,name,addr,city,phone,type,class
0,'le montrachet bistro','3000 paradise rd.','las vegas',702-732-5651,'french bistro',69


Label: 1
Prediction: 1


### DistilRoBERTa
- Parameters: ~82 million
- Layers: 6 Transformer layers (half of RoBERTa-base)
- Hidden Size: 768 (same as RoBERTa-base)
- Attention Heads: 12 (same as RoBERTa-base)
- It’s 60% faster than RoBERTa base.
- Has 40% fewer parameters while retaining over 95% of the performance on most tasks.

In [27]:
MODEL_NAME = 'distilroberta-base'

In [28]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: distilroberta-base
Study name: Fodors-Zagats/distilroberta-base@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.77batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.64batch/s]


Train loss: 0.6894
Val loss: 0.6678
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 2/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.74batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.64batch/s]


Train loss: 0.6450
Val loss: 0.5565
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 3/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.72batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.42batch/s]


Train loss: 0.4771
Val loss: 0.2963
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 4/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.69batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.44batch/s]


Train loss: 0.2384
Val loss: 0.1643
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 5/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.74batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.37batch/s]


Train loss: 0.0598
Val loss: 0.0148
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 6/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.67batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.47batch/s]


Train loss: 0.0212
Val loss: 0.0338
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 7/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.73batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.51batch/s]


Train loss: 0.0083
Val loss: 0.1792
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 8/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.73batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.61batch/s]


Train loss: 0.0042
Val loss: 0.0021
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 9/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.70batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.61batch/s]


Train loss: 0.0021
Val loss: 0.0008
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 10/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.68batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.47batch/s]

Train loss: 0.0021
Val loss: 0.0007
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Training completed
Best weights restored
Mean time per epoch: 2.25





In [29]:
y_pred_disti = model.evaluate(test_loader)

Testing: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.75batch/s]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        18
           1       1.00      1.00      1.00        18

    accuracy                           1.00        36
   macro avg       1.00      1.00      1.00        36
weighted avg       1.00      1.00      1.00        36

Test loss: 0.00064





### BERT Base
- Parameters: ~110 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [30]:
MODEL_NAME = 'bert-base-uncased'

In [31]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: bert-base-uncased
Study name: Fodors-Zagats/bert-base-uncased@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.48batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.04batch/s]


Train loss: 0.7211
Val loss: 0.6102
              precision    recall  f1-score   support

           0       0.89      1.00      0.94        16
           1       1.00      0.88      0.93        16

    accuracy                           0.94        32
   macro avg       0.94      0.94      0.94        32
weighted avg       0.94      0.94      0.94        32

Epoch 2/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.48batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.95batch/s]


Train loss: 0.5918
Val loss: 0.4691
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 3/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.47batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.90batch/s]


Train loss: 0.4013
Val loss: 0.2965
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 4/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.46batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.07batch/s]


Train loss: 0.2312
Val loss: 0.1970
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 5/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.49batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.71batch/s]


Train loss: 0.1263
Val loss: 0.1686
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 6/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.48batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.75batch/s]


Train loss: 0.0955
Val loss: 0.1397
              precision    recall  f1-score   support

           0       0.94      1.00      0.97        16
           1       1.00      0.94      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 7/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.46batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.72batch/s]


Train loss: 0.0577
Val loss: 0.2328
              precision    recall  f1-score   support

           0       0.89      1.00      0.94        16
           1       1.00      0.88      0.93        16

    accuracy                           0.94        32
   macro avg       0.94      0.94      0.94        32
weighted avg       0.94      0.94      0.94        32

Epoch 8/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.47batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.66batch/s]


Train loss: 0.0431
Val loss: 0.2551
              precision    recall  f1-score   support

           0       0.89      1.00      0.94        16
           1       1.00      0.88      0.93        16

    accuracy                           0.94        32
   macro avg       0.94      0.94      0.94        32
weighted avg       0.94      0.94      0.94        32

Epoch 9/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.39batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.89batch/s]

Train loss: 0.0265
Val loss: 0.2512
              precision    recall  f1-score   support

           0       0.89      1.00      0.94        16
           1       1.00      0.88      0.93        16

    accuracy                           0.94        32
   macro avg       0.94      0.94      0.94        32
weighted avg       0.94      0.94      0.94        32

Early stopping triggered
Best weights restored
Mean time per epoch: 3.96





In [32]:
y_pred_bert = model.evaluate(test_loader)

Testing: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.43batch/s]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        18
           1       1.00      1.00      1.00        18

    accuracy                           1.00        36
   macro avg       1.00      1.00      1.00        36
weighted avg       1.00      1.00      1.00        36

Test loss: 0.02167





### Electra
- Parameters: ~110 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [33]:
MODEL_NAME = 'google/electra-base-discriminator'

In [34]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: google/electra-base-discriminator
Study name: Fodors-Zagats/google/electra-base-discriminator@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.25batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.82batch/s]


Train loss: 0.6532
Val loss: 0.5538
              precision    recall  f1-score   support

           0       1.00      0.88      0.93        16
           1       0.89      1.00      0.94        16

    accuracy                           0.94        32
   macro avg       0.94      0.94      0.94        32
weighted avg       0.94      0.94      0.94        32

Epoch 2/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.27batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.74batch/s]


Train loss: 0.4855
Val loss: 0.3889
              precision    recall  f1-score   support

           0       1.00      0.94      0.97        16
           1       0.94      1.00      0.97        16

    accuracy                           0.97        32
   macro avg       0.97      0.97      0.97        32
weighted avg       0.97      0.97      0.97        32

Epoch 3/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.29batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.87batch/s]


Train loss: 0.3274
Val loss: 0.2367
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 4/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.31batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.95batch/s]


Train loss: 0.2136
Val loss: 0.1476
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 5/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.31batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.98batch/s]


Train loss: 0.1386
Val loss: 0.0938
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 6/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.32batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.96batch/s]


Train loss: 0.0933
Val loss: 0.0580
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 7/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.31batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.95batch/s]


Train loss: 0.0610
Val loss: 0.0394
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 8/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.32batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.99batch/s]


Train loss: 0.0446
Val loss: 0.0279
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 9/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.31batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.88batch/s]


Train loss: 0.0323
Val loss: 0.0211
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Epoch 10/10


Training: 100%|███████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.33batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.96batch/s]

Train loss: 0.0258
Val loss: 0.0165
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00        16

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

Training completed
Best weights restored
Mean time per epoch: 4.39





In [35]:
y_pred_electra = model.evaluate(test_loader)

Testing: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.53batch/s]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        18
           1       1.00      1.00      1.00        18

    accuracy                           1.00        36
   macro avg       1.00      1.00      1.00        36
weighted avg       1.00      1.00      1.00        36

Test loss: 0.01513



