# Entity Resolution project @ Wavestone
## Walmart-Amazon Products Matching

> *Datasets information from [here](Datasets.md) \
> Description to do but only take the raw data because the not raw data was already pre-processed*

> **Tristan PERROT**


## Import libraries


In [13]:
import os

import torch
import pickle

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [14]:
while 'model' not in os.listdir():
    os.chdir('..')

In [15]:
MODEL_NAME = 'bert-base-uncased'
DATA_NAME = 'walmart-amazon'
COMPUTER = 'gpu6.enst.fr:0'
DATA_DIR = os.path.join('data', DATA_NAME)

## Load data

In [16]:
import numpy as np

from model.BertModel import BertModel
from model.utils import deserialize_entities, load_data

In [17]:
X_train_ids, y_train, X_valid_ids, y_valid, X_test_ids, y_test = load_data(DATA_DIR, comp_er=False)
with open(os.path.join(DATA_DIR, '1_serialized.pkl'), 'rb') as f:
    table_a_serialized = pickle.load(f)
with open(os.path.join(DATA_DIR, '2_serialized.pkl'), 'rb') as f:
    table_b_serialized = pickle.load(f)

In [18]:
X_train = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_train_ids]
X_valid = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_valid_ids]
X_test = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_test_ids]

In [19]:
# Display the first 5 samples of the training set
for i in range(5):
    print(f'Sample {i}:')
    print(X_train[i])
    print(f'Label: {y_train[i]}')
    print()

Sample 0:
[COL] brand [VAL] Creative Labs [COL] groupname [VAL] Electronics - General [COL] title [VAL] Creative Labs Sound Blaster Tactic3D Alpha [COL] price [VAL] 59.99 [COL] shelfdescr [VAL] Dual mode analog or USB THX TruStudio Pro technology VoiceFX technology [COL] shortdescr [VAL] Experience the evolution of 3D gaming audio with the Sound Blaster Tactic3D Alpha gaming headset. Acoustically optimized 40mm drivers combine with THX TruStudio Pro technology to transform everyday audio into a mind-blowing cinematic experience. Touchscreen software and Tactic Profiles allow you to save and share your favorite settings including THX TruStudio Pro Surround for immersive 360-degree headphone surround and VoiceFX voice morphing technology. The dual mode design allows you to connect via standard minijacks or use the Dual Mode adapter to connect to your PC or Mac s USB port for the full THX TruStudio Pro experience. [COL] longdescr [VAL] Creative Labs Sound Blaster Tactic3D Alpha Dual mode 

### RoBERTa Base
- Architecture: Transformer-based model
- Parameters: ~125 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [24]:
MODEL_NAME = 'roberta-base'

In [25]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: roberta-base
Study name: Walmart-Amazon/roberta-base@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:32<00:00,  1.55batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.37batch/s]


Train loss: 0.6963
Val loss: 0.6887
              precision    recall  f1-score   support

           0       0.75      0.52      0.61       169
           1       0.63      0.82      0.71       169

    accuracy                           0.67       338
   macro avg       0.69      0.67      0.66       338
weighted avg       0.69      0.67      0.66       338

Epoch 2/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:32<00:00,  1.56batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.43batch/s]


Train loss: 0.5500
Val loss: 0.3507
              precision    recall  f1-score   support

           0       0.92      0.75      0.83       169
           1       0.79      0.93      0.86       169

    accuracy                           0.84       338
   macro avg       0.86      0.84      0.84       338
weighted avg       0.86      0.84      0.84       338

Epoch 3/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:32<00:00,  1.55batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.39batch/s]


Train loss: 0.2897
Val loss: 0.2949
              precision    recall  f1-score   support

           0       0.90      0.79      0.84       169
           1       0.81      0.92      0.86       169

    accuracy                           0.85       338
   macro avg       0.86      0.85      0.85       338
weighted avg       0.86      0.85      0.85       338

Epoch 4/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:32<00:00,  1.55batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.38batch/s]


Train loss: 0.2477
Val loss: 0.2920
              precision    recall  f1-score   support

           0       0.93      0.77      0.84       169
           1       0.80      0.94      0.87       169

    accuracy                           0.86       338
   macro avg       0.87      0.86      0.85       338
weighted avg       0.87      0.86      0.85       338

Epoch 5/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.35batch/s]


Train loss: 0.2417
Val loss: 0.3369
              precision    recall  f1-score   support

           0       0.86      0.81      0.84       169
           1       0.82      0.87      0.84       169

    accuracy                           0.84       338
   macro avg       0.84      0.84      0.84       338
weighted avg       0.84      0.84      0.84       338

Epoch 6/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.33batch/s]


Train loss: 0.2273
Val loss: 0.2789
              precision    recall  f1-score   support

           0       0.94      0.78      0.85       169
           1       0.81      0.95      0.87       169

    accuracy                           0.86       338
   macro avg       0.87      0.86      0.86       338
weighted avg       0.87      0.86      0.86       338

Epoch 7/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.24batch/s]


Train loss: 0.2200
Val loss: 0.2975
              precision    recall  f1-score   support

           0       0.89      0.82      0.85       169
           1       0.83      0.89      0.86       169

    accuracy                           0.86       338
   macro avg       0.86      0.86      0.86       338
weighted avg       0.86      0.86      0.86       338

Epoch 8/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.53batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.07batch/s]


Train loss: 0.2258
Val loss: 0.2760
              precision    recall  f1-score   support

           0       0.92      0.79      0.85       169
           1       0.81      0.93      0.87       169

    accuracy                           0.86       338
   macro avg       0.87      0.86      0.86       338
weighted avg       0.87      0.86      0.86       338

Epoch 9/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.52batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.30batch/s]


Train loss: 0.2185
Val loss: 0.2716
              precision    recall  f1-score   support

           0       0.90      0.79      0.84       169
           1       0.81      0.92      0.86       169

    accuracy                           0.85       338
   macro avg       0.86      0.85      0.85       338
weighted avg       0.86      0.85      0.85       338

Epoch 10/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.12batch/s]

Train loss: 0.2238
Val loss: 0.2983
              precision    recall  f1-score   support

           0       0.88      0.80      0.84       169
           1       0.82      0.89      0.86       169

    accuracy                           0.85       338
   macro avg       0.85      0.85      0.85       338
weighted avg       0.85      0.85      0.85       338

Training completed
Best weights restored
Mean time per epoch: 35.62





In [26]:
y_pred = model.evaluate(test_loader)

Testing: 100%|██████████████████████████████████████████████████████████████| 12/12 [00:02<00:00,  4.11batch/s]

              precision    recall  f1-score   support

           0       0.86      0.84      0.85       180
           1       0.84      0.87      0.85       180

    accuracy                           0.85       360
   macro avg       0.85      0.85      0.85       360
weighted avg       0.85      0.85      0.85       360

Test loss: 0.22031





In [27]:
# Exemple of prediction
print(f'Prediction: {y_pred[0]}')
print(f'Label: {y_test[0]}')
print(f'Sample: {X_test[0]}')

Prediction: 0
Label: 0
Sample: <col>brand<val>Asus <col>groupname<val>Electronics - General <col>title<val>Asus F1A75-M Pro AMD A75 Micro ATX Motherboard <col>price<val>112.0 <col>shelfdescr<val>Dual intelligent processors 2 with DIGI VRM- digital power design UEFI BIOS EZ Mode Auto Tuning <col>shortdescr<val>Enjoy precise digital power design and excellent graphics performance with the Asus F1A75-M Pro AMD A75 Micro ATX Motherboard. The AI Suite II provides one-stop access to innovative ASUS features. <col>longdescr<val>Asus F1A75-M Pro AMD A75 Micro ATX Motherboard Dual intelligent processors 2 with DIGI VRM- digital power design UEFI BIOS EZ Mode Auto Tuning Native USB 3.0 SATA 6Gb s support AI Suite II Quad USB 3.0 SATA 6Gb s Support Multi-GPU CrossFireX support <col>orig_shelfdescr<val><li>Dual intelligent processors 2 with DIGI+ VRM- digital power design<li>UEFI BIOS EZ Mode<li>Auto Tuning <col>orig_shortdescr<val>Enjoy precise digital power design and excellent graphics performa

In [28]:
e1_df, e2_df = deserialize_entities(X_test[np.nonzero(y_test)[0][0]])
print('Entity 1:')
display(e1_df)
print('Entity 2:')
display(e2_df)
print(f'Label: {y_test[np.nonzero(y_test)[0][0]]}')
print(f'Prediction: {y_pred[np.nonzero(y_test)[0][0]]}')

Entity 1:


column,brand,groupname,title,price,shelfdescr,shortdescr,longdescr,orig_shelfdescr,orig_shortdescr,orig_longdescr,modelno,shipweight,dimensions
0,HP,MP3 Accessories,HP C7973W Ultrium 800 GB WORM Data Cartridge,29.42,Storage Capacity 400 GB Native 800 GB Compress...,Ultrium 800GB WORM Data Cartridges offer 800GB...,Storage Capacity 400 GB Native 800 GB Compress...,"<ul><li>Storage Capacity: 400 GB Native, 800 G...",Ultrium 800GB WORM Data Cartridges offer 800GB...,"<ul><li>Storage Capacity: 400 GB Native, 800 G...",C7973W,0.58,4.5 x 4.05 x 0.9 inches


Entity 2:


column,brand,modelno,category1,pcategory1,category2,pcategory2,title,listprice,price,prodfeatures,techdetails,proddescrshort,proddescrlong,dimensions,itemweight,shipweight,orig_prodfeatures,orig_techdetails
0,HP,C7973W,Data Cartridges,Blank Media,,,HP LTO Ultrium WORM x 1 - 800 GB - storage med...,58.16,28.54,,Genuine HP LTO Ultrium Tape New Product,,Data Cartridge - LTO Ultrium LTO-3 - 400 GB Na...,4.5 x 4.2 x 1.0 inches,3.2 ounces,1 pounds,,<ul><li>Genuine HP LTO / Ultrium Tape</li> <li...


Label: 1
Prediction: 1


### DistilRoBERTa
- Parameters: ~82 million
- Layers: 6 Transformer layers (half of RoBERTa-base)
- Hidden Size: 768 (same as RoBERTa-base)
- Attention Heads: 12 (same as RoBERTa-base)
- It’s 60% faster than RoBERTa base.
- Has 40% fewer parameters while retaining over 95% of the performance on most tasks.

In [29]:
MODEL_NAME = 'distilroberta-base'

In [30]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: distilroberta-base
Study name: Walmart-Amazon/distilroberta-base@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:17<00:00,  2.99batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  6.57batch/s]


Train loss: 0.6882
Val loss: 0.7176
              precision    recall  f1-score   support

           0       0.50      0.87      0.63       169
           1       0.49      0.12      0.20       169

    accuracy                           0.50       338
   macro avg       0.49      0.50      0.42       338
weighted avg       0.49      0.50      0.42       338

Epoch 2/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:16<00:00,  3.01batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  7.11batch/s]


Train loss: 0.5634
Val loss: 0.3389
              precision    recall  f1-score   support

           0       0.79      0.87      0.83       169
           1       0.86      0.77      0.81       169

    accuracy                           0.82       338
   macro avg       0.82      0.82      0.82       338
weighted avg       0.82      0.82      0.82       338

Epoch 3/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:16<00:00,  3.02batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  7.37batch/s]


Train loss: 0.2694
Val loss: 0.2953
              precision    recall  f1-score   support

           0       0.79      0.93      0.85       169
           1       0.92      0.75      0.82       169

    accuracy                           0.84       338
   macro avg       0.85      0.84      0.84       338
weighted avg       0.85      0.84      0.84       338

Epoch 4/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:16<00:00,  3.02batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  7.28batch/s]


Train loss: 0.2545
Val loss: 0.2908
              precision    recall  f1-score   support

           0       0.90      0.76      0.82       169
           1       0.79      0.91      0.85       169

    accuracy                           0.84       338
   macro avg       0.84      0.84      0.84       338
weighted avg       0.84      0.84      0.84       338

Epoch 5/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:16<00:00,  3.02batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  6.92batch/s]


Train loss: 0.2326
Val loss: 0.2744
              precision    recall  f1-score   support

           0       0.82      0.83      0.83       169
           1       0.83      0.82      0.82       169

    accuracy                           0.83       338
   macro avg       0.83      0.83      0.83       338
weighted avg       0.83      0.83      0.83       338

Epoch 6/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:16<00:00,  3.02batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  7.35batch/s]


Train loss: 0.2367
Val loss: 0.2934
              precision    recall  f1-score   support

           0       0.96      0.72      0.82       169
           1       0.78      0.97      0.86       169

    accuracy                           0.85       338
   macro avg       0.87      0.85      0.84       338
weighted avg       0.87      0.85      0.84       338

Epoch 7/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:16<00:00,  3.02batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  7.27batch/s]


Train loss: 0.2242
Val loss: 0.3029
              precision    recall  f1-score   support

           0       0.79      0.89      0.84       169
           1       0.88      0.76      0.82       169

    accuracy                           0.83       338
   macro avg       0.83      0.83      0.83       338
weighted avg       0.83      0.83      0.83       338

Epoch 8/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:16<00:00,  3.01batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  7.25batch/s]

Train loss: 0.2171
Val loss: 0.3092
              precision    recall  f1-score   support

           0       0.89      0.78      0.83       169
           1       0.80      0.91      0.85       169

    accuracy                           0.84       338
   macro avg       0.85      0.84      0.84       338
weighted avg       0.85      0.84      0.84       338

Early stopping triggered
Best weights restored
Mean time per epoch: 18.49





In [31]:
y_pred_disti = model.evaluate(test_loader)

Testing: 100%|██████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  7.48batch/s]


              precision    recall  f1-score   support

           0       0.90      0.78      0.84       180
           1       0.81      0.92      0.86       180

    accuracy                           0.85       360
   macro avg       0.86      0.85      0.85       360
weighted avg       0.86      0.85      0.85       360

Test loss: 0.21455


### BERT Base
- Parameters: ~110 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [32]:
MODEL_NAME = 'bert-base-uncased'

In [33]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: bert-base-uncased
Study name: Walmart-Amazon/bert-base-uncased@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.28batch/s]


Train loss: 0.6581
Val loss: 0.5018
              precision    recall  f1-score   support

           0       0.84      0.64      0.73       169
           1       0.71      0.88      0.79       169

    accuracy                           0.76       338
   macro avg       0.78      0.76      0.76       338
weighted avg       0.78      0.76      0.76       338

Epoch 2/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:03<00:00,  3.56batch/s]


Train loss: 0.4212
Val loss: 0.3032
              precision    recall  f1-score   support

           0       0.93      0.75      0.83       169
           1       0.79      0.95      0.86       169

    accuracy                           0.85       338
   macro avg       0.86      0.85      0.84       338
weighted avg       0.86      0.85      0.84       338

Epoch 3/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.53batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.24batch/s]


Train loss: 0.2617
Val loss: 0.2683
              precision    recall  f1-score   support

           0       0.77      0.98      0.86       169
           1       0.97      0.71      0.82       169

    accuracy                           0.84       338
   macro avg       0.87      0.84      0.84       338
weighted avg       0.87      0.84      0.84       338

Epoch 4/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.18batch/s]


Train loss: 0.2336
Val loss: 0.2511
              precision    recall  f1-score   support

           0       0.87      0.84      0.86       169
           1       0.85      0.88      0.86       169

    accuracy                           0.86       338
   macro avg       0.86      0.86      0.86       338
weighted avg       0.86      0.86      0.86       338

Epoch 5/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.20batch/s]


Train loss: 0.2506
Val loss: 0.2558
              precision    recall  f1-score   support

           0       0.99      0.73      0.84       169
           1       0.79      0.99      0.88       169

    accuracy                           0.86       338
   macro avg       0.89      0.86      0.86       338
weighted avg       0.89      0.86      0.86       338

Epoch 6/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.03batch/s]


Train loss: 0.2181
Val loss: 0.2566
              precision    recall  f1-score   support

           0       0.85      0.83      0.84       169
           1       0.83      0.86      0.85       169

    accuracy                           0.84       338
   macro avg       0.84      0.84      0.84       338
weighted avg       0.84      0.84      0.84       338

Epoch 7/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:33<00:00,  1.54batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.19batch/s]


Train loss: 0.2072
Val loss: 0.2693
              precision    recall  f1-score   support

           0       0.98      0.73      0.84       169
           1       0.79      0.98      0.87       169

    accuracy                           0.86       338
   macro avg       0.88      0.86      0.86       338
weighted avg       0.88      0.86      0.86       338

Early stopping triggered
Best weights restored
Mean time per epoch: 35.86


In [34]:
y_pred_bert = model.evaluate(test_loader)

Testing: 100%|██████████████████████████████████████████████████████████████| 12/12 [00:02<00:00,  4.22batch/s]

              precision    recall  f1-score   support

           0       0.97      0.77      0.86       180
           1       0.81      0.98      0.88       180

    accuracy                           0.87       360
   macro avg       0.89      0.87      0.87       360
weighted avg       0.89      0.87      0.87       360

Test loss: 0.21891





### Electra
- Parameters: ~110 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [22]:
MODEL_NAME = 'google/electra-base-discriminator'

In [23]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: google/electra-base-discriminator
Study name: Walmart-Amazon/google/electra-base-discriminator@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.33batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.77batch/s]


Train loss: 0.6891
Val loss: 0.6943
              precision    recall  f1-score   support

           0       0.57      0.02      0.05       169
           1       0.50      0.98      0.66       169

    accuracy                           0.50       338
   macro avg       0.54      0.50      0.35       338
weighted avg       0.54      0.50      0.35       338

Epoch 2/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.73batch/s]


Train loss: 0.6373
Val loss: 0.4913
              precision    recall  f1-score   support

           0       0.79      0.75      0.77       169
           1       0.76      0.80      0.78       169

    accuracy                           0.78       338
   macro avg       0.78      0.78      0.78       338
weighted avg       0.78      0.78      0.78       338

Epoch 3/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.73batch/s]


Train loss: 0.4026
Val loss: 0.3481
              precision    recall  f1-score   support

           0       0.75      0.99      0.85       169
           1       0.99      0.66      0.79       169

    accuracy                           0.83       338
   macro avg       0.87      0.83      0.82       338
weighted avg       0.87      0.83      0.82       338

Epoch 4/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:37<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.76batch/s]


Train loss: 0.2823
Val loss: 0.2078
              precision    recall  f1-score   support

           0       0.89      0.82      0.86       169
           1       0.84      0.90      0.87       169

    accuracy                           0.86       338
   macro avg       0.86      0.86      0.86       338
weighted avg       0.86      0.86      0.86       338

Epoch 5/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.72batch/s]


Train loss: 0.2461
Val loss: 0.1979
              precision    recall  f1-score   support

           0       0.78      1.00      0.88       169
           1       1.00      0.72      0.84       169

    accuracy                           0.86       338
   macro avg       0.89      0.86      0.86       338
weighted avg       0.89      0.86      0.86       338

Epoch 6/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.74batch/s]


Train loss: 0.2518
Val loss: 0.2059
              precision    recall  f1-score   support

           0       0.82      0.92      0.87       169
           1       0.91      0.80      0.85       169

    accuracy                           0.86       338
   macro avg       0.86      0.86      0.86       338
weighted avg       0.86      0.86      0.86       338

Epoch 7/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.72batch/s]


Train loss: 0.2313
Val loss: 0.2012
              precision    recall  f1-score   support

           0       0.80      0.96      0.87       169
           1       0.95      0.76      0.84       169

    accuracy                           0.86       338
   macro avg       0.87      0.86      0.86       338
weighted avg       0.87      0.86      0.86       338

Epoch 8/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.78batch/s]


Train loss: 0.2357
Val loss: 0.1936
              precision    recall  f1-score   support

           0       0.90      0.80      0.85       169
           1       0.82      0.91      0.86       169

    accuracy                           0.86       338
   macro avg       0.86      0.86      0.85       338
weighted avg       0.86      0.86      0.85       338

Epoch 9/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  3.70batch/s]


Train loss: 0.2315
Val loss: 0.2014
              precision    recall  f1-score   support

           0       0.80      0.96      0.87       169
           1       0.96      0.76      0.84       169

    accuracy                           0.86       338
   macro avg       0.88      0.86      0.86       338
weighted avg       0.88      0.86      0.86       338

Epoch 10/10


Training: 100%|█████████████████████████████████████████████████████████████| 51/51 [00:38<00:00,  1.34batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 11/11 [00:03<00:00,  3.61batch/s]

Train loss: 0.2400
Val loss: 0.1903
              precision    recall  f1-score   support

           0       0.79      0.98      0.87       169
           1       0.97      0.73      0.84       169

    accuracy                           0.86       338
   macro avg       0.88      0.86      0.85       338
weighted avg       0.88      0.86      0.85       338

Training completed
Best weights restored
Mean time per epoch: 41.01





In [24]:
y_pred_electra = model.evaluate(test_loader)

Testing: 100%|██████████████████████████████████████████████████████████████| 12/12 [00:03<00:00,  3.77batch/s]

              precision    recall  f1-score   support

           0       0.78      0.97      0.87       180
           1       0.96      0.73      0.83       180

    accuracy                           0.85       360
   macro avg       0.87      0.85      0.85       360
weighted avg       0.87      0.85      0.85       360

Test loss: 0.25164



