# Entity Resolution project @ Wavestone
## Amazon-Google Products Matching

> *Datasets information from [here](Datasets.md) \
> Description to do but only take the raw data because the not raw data was already pre-processed*

> **Tristan PERROT**


## Import libraries


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
import pickle

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [3]:
while 'model' not in os.listdir():
    os.chdir('..')

In [4]:
MODEL_NAME = 'bert-base-uncased'
DATA_NAME = 'amazon-google'
COMPUTER = 'gpu4.enst.fr:0'
DATA_DIR = os.path.join('data', DATA_NAME)

## Load data

In [5]:
import numpy as np

from model.BertModel import BertModel
from model.utils import deserialize_entities, load_data

In [9]:
X_train_ids, y_train, X_valid_ids, y_valid, X_test_ids, y_test = load_data(DATA_DIR, comp_er=False)
with open(os.path.join(DATA_DIR, '1_serialized.pkl'), 'rb') as f:
    table_a_serialized = pickle.load(f)
with open(os.path.join(DATA_DIR, '2_serialized.pkl'), 'rb') as f:
    table_b_serialized = pickle.load(f)

In [10]:
X_train = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_train_ids]
X_valid = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_valid_ids]
X_test = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_test_ids]

In [11]:
# Display the first 5 samples of the training set
for i in range(5):
    print(f'Sample {i}:')
    print(X_train[i])
    print(f'Label: {y_train[i]}')
    print()

Sample 0:
[COL] name [VAL] quickbooks point-of-sale basic 6.0 [COL] description [VAL] quickbooks point-of-sale basic 6.0 retail management software turns any pc into a cash register that does what no ordinary cash register can do: it automatically tracks your inventory and customers while you ring up sales. it's a high-powered retail-management system that tracks inventory sales and customer information to help save you time and better serve your customers. track inventory automatically as you ring up sales. view larger. see each customer's purchase history as you ring up sales and suggest additional purchases based on past preferences. view larger. answer a few questions in the simple setup wizard import your inventory customer and vendor lists from microsoft excel or quickbooks financial software and you're ready to start ringing up sales. view larger. transfer key information directly into quickbooks financial software with one click. view larger. create customized reports that help

### RoBERTa Base
- Architecture: Transformer-based model
- Parameters: ~125 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [156]:
MODEL_NAME = 'roberta-base'

In [157]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: roberta-base
Study name: Amazon-Google/roberta-base@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:35<00:00,  1.58batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.57batch/s]


Train loss: 0.6719
Val loss: 0.5654
              precision    recall  f1-score   support

           0       0.98      0.33      0.49       200
           1       0.60      0.99      0.75       200

    accuracy                           0.66       400
   macro avg       0.79      0.66      0.62       400
weighted avg       0.79      0.66      0.62       400

Epoch 2/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:35<00:00,  1.61batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.62batch/s]


Train loss: 0.2827
Val loss: 0.1724
              precision    recall  f1-score   support

           0       0.96      0.81      0.88       200
           1       0.83      0.97      0.90       200

    accuracy                           0.89       400
   macro avg       0.90      0.89      0.89       400
weighted avg       0.90      0.89      0.89       400

Epoch 3/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:35<00:00,  1.60batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.53batch/s]


Train loss: 0.1395
Val loss: 0.2006
              precision    recall  f1-score   support

           0       0.95      0.79      0.86       200
           1       0.82      0.96      0.88       200

    accuracy                           0.87       400
   macro avg       0.88      0.87      0.87       400
weighted avg       0.88      0.87      0.87       400

Epoch 4/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:35<00:00,  1.59batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.53batch/s]


Train loss: 0.1200
Val loss: 0.1608
              precision    recall  f1-score   support

           0       0.99      0.80      0.88       200
           1       0.83      0.99      0.90       200

    accuracy                           0.90       400
   macro avg       0.91      0.90      0.89       400
weighted avg       0.91      0.90      0.89       400

Epoch 5/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.58batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.37batch/s]


Train loss: 0.1138
Val loss: 0.1565
              precision    recall  f1-score   support

           0       0.99      0.81      0.89       200
           1       0.84      0.99      0.91       200

    accuracy                           0.90       400
   macro avg       0.91      0.90      0.90       400
weighted avg       0.91      0.90      0.90       400

Epoch 6/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.52batch/s]


Train loss: 0.1116
Val loss: 0.1720
              precision    recall  f1-score   support

           0       0.86      0.97      0.91       200
           1       0.97      0.83      0.90       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.90       400
weighted avg       0.91      0.91      0.90       400

Epoch 7/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.54batch/s]


Train loss: 0.1053
Val loss: 0.1708
              precision    recall  f1-score   support

           0       0.87      0.95      0.91       200
           1       0.95      0.86      0.90       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.90       400
weighted avg       0.91      0.91      0.90       400

Epoch 8/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.51batch/s]


Train loss: 0.1008
Val loss: 0.1605
              precision    recall  f1-score   support

           0       0.97      0.82      0.89       200
           1       0.84      0.97      0.90       200

    accuracy                           0.90       400
   macro avg       0.91      0.90      0.90       400
weighted avg       0.91      0.90      0.90       400

Early stopping triggered
Best weights restored
Mean time per epoch: 38.85


In [158]:
y_pred = model.evaluate(test_loader)

Testing: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.37batch/s]

              precision    recall  f1-score   support

           0       0.98      0.92      0.95       199
           1       0.93      0.98      0.95       199

    accuracy                           0.95       398
   macro avg       0.95      0.95      0.95       398
weighted avg       0.95      0.95      0.95       398

Test loss: 0.08975





In [159]:
# Exemple of prediction
print(f'Prediction: {y_pred[0]}')
print(f'Label: {y_test[0]}')
print(f'Sample: {X_test[0]}')

Prediction: 0
Label: 0
Sample: <col>title<val>casual games pack <col>description<val>a collection of 10 engaging amusing and addicting casual games to keep you (and your pc) busy / esrb = e for everyone <col>manufacturer<val>egames <col>price<val>19.99<sep><col>title<val>apple iwork '06 family pack <col>description<val>minimum system requirements macintosh computer with 500mhz or faster powerpc g4 or g5 processor or intel core duo processor mac os x 10.3.9 or 10.4.3 or later 3gb of available disk space 256mb of ram minimum; 512mb recommended 32mb or more video ... <col>manufacturer<val> <col>price<val>99.99


In [160]:
e1_df, e2_df = deserialize_entities(X_test[np.nonzero(y_test)[0][0]])
print('Entity 1:')
display(e1_df)
print('Entity 2:')
display(e2_df)
print(f'Label: {y_test[np.nonzero(y_test)[0][0]]}')
print(f'Prediction: {y_pred[np.nonzero(y_test)[0][0]]}')

Entity 1:


column,title,description,manufacturer,price
0,corel dvd moviefactory 6.0 plus,ulead dvd moviefactory 6 plus is the award-win...,corel,79.99


Entity 2:


column,title,description,manufacturer,price
0,dvd moviefactory 6 plus,corel dvd moviefactory 6 plus windows,,77.32


Label: 1
Prediction: 1


### DistilRoBERTa
- Parameters: ~82 million
- Layers: 6 Transformer layers (half of RoBERTa-base)
- Hidden Size: 768 (same as RoBERTa-base)
- Attention Heads: 12 (same as RoBERTa-base)
- It’s 60% faster than RoBERTa base.
- Has 40% fewer parameters while retaining over 95% of the performance on most tasks.

In [161]:
MODEL_NAME = 'distilroberta-base'

In [162]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: distilroberta-base
Study name: Amazon-Google/distilroberta-base@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:18<00:00,  3.08batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.89batch/s]


Train loss: 0.6649
Val loss: 0.3951
              precision    recall  f1-score   support

           0       0.81      0.92      0.86       200
           1       0.90      0.79      0.84       200

    accuracy                           0.85       400
   macro avg       0.86      0.85      0.85       400
weighted avg       0.86      0.85      0.85       400

Epoch 2/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:18<00:00,  3.09batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.40batch/s]


Train loss: 0.2813
Val loss: 0.3417
              precision    recall  f1-score   support

           0       0.90      0.70      0.79       200
           1       0.76      0.92      0.83       200

    accuracy                           0.81       400
   macro avg       0.83      0.81      0.81       400
weighted avg       0.83      0.81      0.81       400

Epoch 3/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:18<00:00,  3.04batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.46batch/s]


Train loss: 0.1516
Val loss: 0.1779
              precision    recall  f1-score   support

           0       0.90      0.89      0.90       200
           1       0.89      0.91      0.90       200

    accuracy                           0.90       400
   macro avg       0.90      0.90      0.90       400
weighted avg       0.90      0.90      0.90       400

Epoch 4/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:18<00:00,  3.03batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.08batch/s]


Train loss: 0.1223
Val loss: 0.1512
              precision    recall  f1-score   support

           0       0.87      0.97      0.92       200
           1       0.97      0.85      0.90       200

    accuracy                           0.91       400
   macro avg       0.92      0.91      0.91       400
weighted avg       0.92      0.91      0.91       400

Epoch 5/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:18<00:00,  3.02batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.32batch/s]


Train loss: 0.1116
Val loss: 0.1758
              precision    recall  f1-score   support

           0       0.92      0.88      0.90       200
           1       0.88      0.92      0.90       200

    accuracy                           0.90       400
   macro avg       0.90      0.90      0.90       400
weighted avg       0.90      0.90      0.90       400

Epoch 6/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:18<00:00,  3.04batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.50batch/s]


Train loss: 0.1024
Val loss: 0.1740
              precision    recall  f1-score   support

           0       0.93      0.88      0.90       200
           1       0.89      0.93      0.91       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.90       400
weighted avg       0.91      0.91      0.90       400

Epoch 7/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:18<00:00,  3.04batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.21batch/s]

Train loss: 0.1140
Val loss: 0.1904
              precision    recall  f1-score   support

           0       0.90      0.90      0.90       200
           1       0.90      0.90      0.90       200

    accuracy                           0.90       400
   macro avg       0.90      0.90      0.90       400
weighted avg       0.90      0.90      0.90       400

Early stopping triggered
Best weights restored
Mean time per epoch: 20.46





In [163]:
y_pred_disti = model.evaluate(test_loader)

Testing: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.02batch/s]

              precision    recall  f1-score   support

           0       0.93      0.97      0.95       199
           1       0.97      0.92      0.95       199

    accuracy                           0.95       398
   macro avg       0.95      0.95      0.95       398
weighted avg       0.95      0.95      0.95       398

Test loss: 0.12640





### BERT Base
- Parameters: ~110 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [164]:
MODEL_NAME = 'bert-base-uncased'

In [165]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: bert-base-uncased
Study name: Amazon-Google/bert-base-uncased@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.55batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  4.23batch/s]


Train loss: 0.5158
Val loss: 0.3788
              precision    recall  f1-score   support

           0       0.94      0.73      0.82       200
           1       0.78      0.95      0.86       200

    accuracy                           0.84       400
   macro avg       0.86      0.84      0.84       400
weighted avg       0.86      0.84      0.84       400

Epoch 2/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.36batch/s]


Train loss: 0.2073
Val loss: 0.1971
              precision    recall  f1-score   support

           0       0.97      0.78      0.87       200
           1       0.82      0.98      0.89       200

    accuracy                           0.88       400
   macro avg       0.90      0.88      0.88       400
weighted avg       0.90      0.88      0.88       400

Epoch 3/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  4.27batch/s]


Train loss: 0.1435
Val loss: 0.1745
              precision    recall  f1-score   support

           0       0.92      0.89      0.90       200
           1       0.89      0.92      0.91       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.90       400
weighted avg       0.91      0.91      0.90       400

Epoch 4/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.37batch/s]


Train loss: 0.1166
Val loss: 0.1828
              precision    recall  f1-score   support

           0       0.99      0.80      0.88       200
           1       0.83      0.99      0.90       200

    accuracy                           0.90       400
   macro avg       0.91      0.90      0.89       400
weighted avg       0.91      0.90      0.89       400

Epoch 5/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.38batch/s]


Train loss: 0.1060
Val loss: 0.1716
              precision    recall  f1-score   support

           0       0.92      0.89      0.91       200
           1       0.89      0.93      0.91       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.91       400
weighted avg       0.91      0.91      0.91       400

Epoch 6/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.33batch/s]


Train loss: 0.1008
Val loss: 0.1710
              precision    recall  f1-score   support

           0       0.92      0.89      0.91       200
           1       0.89      0.93      0.91       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.91       400
weighted avg       0.91      0.91      0.91       400

Epoch 7/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.35batch/s]


Train loss: 0.0937
Val loss: 0.1809
              precision    recall  f1-score   support

           0       0.94      0.88      0.90       200
           1       0.88      0.94      0.91       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.91       400
weighted avg       0.91      0.91      0.91       400

Epoch 8/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.38batch/s]


Train loss: 0.0959
Val loss: 0.1626
              precision    recall  f1-score   support

           0       0.89      0.94      0.92       200
           1       0.94      0.89      0.91       200

    accuracy                           0.92       400
   macro avg       0.92      0.92      0.91       400
weighted avg       0.92      0.92      0.91       400

Epoch 9/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.41batch/s]


Train loss: 0.0874
Val loss: 0.1936
              precision    recall  f1-score   support

           0       0.91      0.92      0.91       200
           1       0.92      0.91      0.91       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.91       400
weighted avg       0.91      0.91      0.91       400

Epoch 10/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:36<00:00,  1.57batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  4.38batch/s]


Train loss: 0.0944
Val loss: 0.1795
              precision    recall  f1-score   support

           0       0.93      0.89      0.91       200
           1       0.89      0.94      0.91       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.91       400
weighted avg       0.91      0.91      0.91       400

Training completed
Best weights restored
Mean time per epoch: 39.31


In [166]:
y_pred_bert = model.evaluate(test_loader)

Testing: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  4.24batch/s]

              precision    recall  f1-score   support

           0       0.97      0.92      0.95       199
           1       0.93      0.97      0.95       199

    accuracy                           0.95       398
   macro avg       0.95      0.95      0.95       398
weighted avg       0.95      0.95      0.95       398

Test loss: 0.07849





### Electra
- Parameters: ~110 million
- Layers: 12 Transformer layers
- Hidden Size: 768
- Attention Heads: 12

In [21]:
MODEL_NAME = 'google/electra-base-discriminator'

In [22]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:0
Model: google/electra-base-discriminator
Study name: Amazon-Google/google/electra-base-discriminator@gpu6.enst.fr:0
Epoch 1/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:41<00:00,  1.39batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.95batch/s]


Train loss: 0.5222
Val loss: 0.2358
              precision    recall  f1-score   support

           0       0.87      0.94      0.90       200
           1       0.94      0.85      0.90       200

    accuracy                           0.90       400
   macro avg       0.90      0.90      0.90       400
weighted avg       0.90      0.90      0.90       400

Epoch 2/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:41<00:00,  1.38batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.93batch/s]


Train loss: 0.1591
Val loss: 0.1694
              precision    recall  f1-score   support

           0       0.90      0.89      0.89       200
           1       0.89      0.91      0.90       200

    accuracy                           0.90       400
   macro avg       0.90      0.90      0.89       400
weighted avg       0.90      0.90      0.89       400

Epoch 3/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:41<00:00,  1.37batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.90batch/s]


Train loss: 0.1084
Val loss: 0.1584
              precision    recall  f1-score   support

           0       0.96      0.81      0.88       200
           1       0.84      0.97      0.90       200

    accuracy                           0.89       400
   macro avg       0.90      0.89      0.89       400
weighted avg       0.90      0.89      0.89       400

Epoch 4/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:41<00:00,  1.37batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.91batch/s]


Train loss: 0.0949
Val loss: 0.1650
              precision    recall  f1-score   support

           0       0.98      0.79      0.88       200
           1       0.82      0.98      0.90       200

    accuracy                           0.89       400
   macro avg       0.90      0.89      0.89       400
weighted avg       0.90      0.89      0.89       400

Epoch 5/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:41<00:00,  1.37batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.91batch/s]


Train loss: 0.0884
Val loss: 0.1574
              precision    recall  f1-score   support

           0       0.87      0.96      0.91       200
           1       0.96      0.85      0.90       200

    accuracy                           0.91       400
   macro avg       0.91      0.91      0.91       400
weighted avg       0.91      0.91      0.91       400

Epoch 6/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:41<00:00,  1.37batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.90batch/s]


Train loss: 0.0885
Val loss: 0.1829
              precision    recall  f1-score   support

           0       0.99      0.79      0.87       200
           1       0.82      0.99      0.90       200

    accuracy                           0.89       400
   macro avg       0.90      0.89      0.89       400
weighted avg       0.90      0.89      0.89       400

Epoch 7/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:41<00:00,  1.36batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.64batch/s]


Train loss: 0.0837
Val loss: 0.1634
              precision    recall  f1-score   support

           0       0.90      0.90      0.90       200
           1       0.90      0.90      0.90       200

    accuracy                           0.90       400
   macro avg       0.90      0.90      0.90       400
weighted avg       0.90      0.90      0.90       400

Epoch 8/10


Training: 100%|█████████████████████████████████████████████████████████████| 57/57 [00:41<00:00,  1.36batch/s]
Validation: 100%|███████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.79batch/s]


Train loss: 0.1074
Val loss: 0.1576
              precision    recall  f1-score   support

           0       0.91      0.88      0.89       200
           1       0.88      0.92      0.90       200

    accuracy                           0.90       400
   macro avg       0.90      0.90      0.89       400
weighted avg       0.90      0.90      0.89       400

Early stopping triggered
Best weights restored
Mean time per epoch: 44.87


In [23]:
y_pred_electra = model.evaluate(test_loader)

Testing: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.89batch/s]

              precision    recall  f1-score   support

           0       0.95      0.96      0.96       199
           1       0.96      0.95      0.96       199

    accuracy                           0.96       398
   macro avg       0.96      0.96      0.96       398
weighted avg       0.96      0.96      0.96       398

Test loss: 0.09199



