# Entity Resolution project @ Wavestone
## Test generation for the ER project presented at Wavestone the 04/11/2024

> *Datasets information from [here](https://data.dws.informatik.uni-mannheim.de/benchmarkmatchingtasks/index.html)*

> **Tristan PERROT**

In [65]:
import os

import pandas as pd
import torch
import pickle

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda:1


In [66]:
import numpy as np
import torch.nn as nn
from sentence_transformers import CrossEncoder, InputExample
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader
from torch.nn import BCEWithLogitsLoss

In [67]:
while 'model' not in os.listdir():
    os.chdir('..')

from model.utils import load_data
from model.BertModel import BertModel

In [68]:
COMPUTER = 'gpu4.enst.fr:' + str(device.index)
final_table_exports = ""

In [69]:
SET_1 = 'walmart'
SET_2 = 'amazon'
DATA_NAME = f'{SET_1}-{SET_2}'
DATA_DIR = os.path.join('data', DATA_NAME)

In [70]:
!python3 -u import-data.py --dataset walmart-amazon -ca-rm 'id,upc' -cb-rm 'url,asin'

Downloaded dataset: walmart-amazon 

Loaded dataset: walmart-amazon 

Table A:
   subject_id        id  ...  shipweight                dimensions
0           1  14249992  ...        2.00                       NaN
1           2  10928662  ...        0.95  6.75 x 5.75 x 5.5 inches
2           3  11961447  ...        0.05                       NaN
3           4  13044637  ...         NaN                       NaN
4           5  13214131  ...        5.25   72.5 x 2.5 x 1.5 inches

[5 rows x 17 columns] 

Table B:
   subject_id  ...                                          groupname
0           1  ...  Headphone Accessories Accessories Supplies Equ...
1           2  ...             Inkjet Printer Ink Printer Ink Toner  
2           3  ...                Computers Accessories Electronics  
3           4  ...                Mice Keyboards Mice Input Devices  
4           5  ...                Mice Keyboards Mice Input Devices  

[5 rows x 23 columns] 

Pairs train:
   source_id  target_id  ma

In [71]:
X_train_ids, y_train, X_valid_ids, y_valid, X_test_ids, y_test = load_data(DATA_DIR, comp_er=True)
with open(os.path.join(DATA_DIR, 'table_a.pkl'), 'rb') as f:
    table_a_serialized = pickle.load(f)
with open(os.path.join(DATA_DIR, 'table_b.pkl'), 'rb') as f:
    table_b_serialized = pickle.load(f)

In [72]:
table_a_serialized[:5]

["[COL] brand [VAL] Draper [COL] groupname [VAL] Electronics - General [COL] title [VAL] Draper Infrared Remote Transmitter [COL] price [VAL] 58.45 [COL] shelfdescr [VAL] Infrared transmitter. 3-button operation for instant access to up  down and stop functions. Fully compatible with learnable IR master control systems. Receiver sold seperately plugs into the Draper low-voltage control unit LVC-III sold separately . [COL] shortdescr [VAL]  [COL] longdescr [VAL] DR1143Infrared transmitter. 3-button operation for instant access to up  down and stop functions. Fully compatible with learnable IR master control systems. Receiver sold seperately plugs into the Draper low-voltage control unit LVC-III sold separately . [COL] imageurl [VAL] http://i.walmartimages.com/i/mp/00/64/10/92/16/0064109216245_P255045_300X300.jpg [COL] orig_shelfdescr [VAL] Infrared transmitter. 3-button operation for instant access to ''up'', ''down'', and ''stop'' functions. Fully compatible with ''learnable'' IR maste

In [73]:
all_true_matches = set()
for i in range(len(X_train_ids)):
    if y_train[i] == 1:
        all_true_matches.add((X_train_ids[i][0], X_train_ids[i][1]))
for i in range(len(X_valid_ids)):
    if y_valid[i] == 1:
        all_true_matches.add((X_valid_ids[i][0], X_valid_ids[i][1]))
for i in range(len(X_test_ids)):
    if y_test[i] == 1:
        all_true_matches.add((X_test_ids[i][0], X_test_ids[i][1]))

In [74]:
X_train = [(table_a_serialized[int(x[0])], table_b_serialized[int(x[1])]) for x in X_train_ids]
X_valid = [(table_a_serialized[int(x[0])], table_b_serialized[int(x[1])]) for x in X_valid_ids]
X_test = [(table_a_serialized[int(x[0])], table_b_serialized[int(x[1])]) for x in X_test_ids]
X_train[:5]

[("[COL] brand [VAL] LaCie [COL] groupname [VAL] Electronics - General [COL] title [VAL] LaCie Portable 8x DVD-RW Drive with LightScribe [COL] price [VAL] 57.6 [COL] shelfdescr [VAL] LightScribe labeling software USB bus-powered Includes USB cable [COL] shortdescr [VAL] With the LaCie Portable 8x DVD- RW Drive with LightScribe you can easily create your own music video and photo CDs or DVDs wherever you go. It can conveniently record to double layer DVD-R RW DVD R RW and CD-R RW and it s USB bus powered so you only need one cable to start burning CDs and DVDs. An easy-to-use interface allows you to piece together sophisticated full-featured DVD movies in a matter of minutes complete with menus thumbnails and backgrounds for a totally customized DVD. It is equipped with LightScribe an innovative technology that allows you to burn silkscreen-quality labels directly onto CDs and DVDs with a laser instead of a printer. [COL] longdescr [VAL] LaCie Portable 8x DVD-RW Drive with LightScribe W

In [75]:
X_train = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_train_ids]
X_valid = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_valid_ids]
X_test = [table_a_serialized[i[0]] + ' [SEP] ' + table_b_serialized[i[1]] for i in X_test_ids]

In [76]:
MODEL_NAME = 'distilroberta-base'

In [77]:
model = BertModel(model_name=MODEL_NAME, study_name=DATA_NAME + "/" + MODEL_NAME + "@" + COMPUTER, device=device)
train_loader, val_loader, test_loader = model.prepare_data(X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, num_workers=4)
model.fit(train_loader, val_loader, epochs=10, lr=2e-5, weight_decay=0.01, early_stopping=True, patience=3)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized
Device: cuda:1
Model: distilroberta-base
Study name: walmart-amazon/distilroberta-base@gpu4.enst.fr:1
Epoch 1/10


Training: 100%|█████████████| 341/341 [01:48<00:00,  3.14batch/s]
Validation: 100%|█████████████| 98/98 [00:10<00:00,  9.33batch/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Train loss: 0.2726
Val loss: 0.2634
              precision    recall  f1-score   support

           0       0.93      1.00      0.96      2899
           1       0.00      0.00      0.00       232

    accuracy                           0.93      3131
   macro avg       0.46      0.50      0.48      3131
weighted avg       0.86      0.93      0.89      3131

Epoch 2/10


Training: 100%|█████████████| 341/341 [01:50<00:00,  3.09batch/s]
Validation: 100%|█████████████| 98/98 [00:10<00:00,  9.27batch/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Train loss: 0.2585
Val loss: 0.2181
              precision    recall  f1-score   support

           0       0.93      1.00      0.96      2899
           1       0.00      0.00      0.00       232

    accuracy                           0.93      3131
   macro avg       0.46      0.50      0.48      3131
weighted avg       0.86      0.93      0.89      3131

Epoch 3/10


Training: 100%|█████████████| 341/341 [01:50<00:00,  3.09batch/s]
Validation: 100%|█████████████| 98/98 [00:10<00:00,  9.25batch/s]


Train loss: 0.1919
Val loss: 0.1950
              precision    recall  f1-score   support

           0       0.94      1.00      0.97      2899
           1       0.89      0.25      0.40       232

    accuracy                           0.94      3131
   macro avg       0.92      0.63      0.68      3131
weighted avg       0.94      0.94      0.93      3131

Epoch 4/10


Training: 100%|█████████████| 341/341 [01:51<00:00,  3.06batch/s]
Validation: 100%|█████████████| 98/98 [00:10<00:00,  8.93batch/s]


Train loss: 0.1546
Val loss: 0.1760
              precision    recall  f1-score   support

           0       0.95      0.99      0.97      2899
           1       0.82      0.36      0.50       232

    accuracy                           0.95      3131
   macro avg       0.88      0.68      0.74      3131
weighted avg       0.94      0.95      0.94      3131

Epoch 5/10


Training: 100%|█████████████| 341/341 [01:52<00:00,  3.04batch/s]
Validation: 100%|█████████████| 98/98 [00:11<00:00,  8.81batch/s]


Train loss: 0.1382
Val loss: 0.1818
              precision    recall  f1-score   support

           0       0.96      0.98      0.97      2899
           1       0.66      0.43      0.52       232

    accuracy                           0.94      3131
   macro avg       0.81      0.71      0.74      3131
weighted avg       0.93      0.94      0.94      3131

Epoch 6/10


Training: 100%|█████████████| 341/341 [01:51<00:00,  3.05batch/s]
Validation: 100%|█████████████| 98/98 [00:11<00:00,  8.84batch/s]


Train loss: 0.1286
Val loss: 0.2004
              precision    recall  f1-score   support

           0       0.95      0.99      0.97      2899
           1       0.80      0.38      0.52       232

    accuracy                           0.95      3131
   macro avg       0.88      0.69      0.75      3131
weighted avg       0.94      0.95      0.94      3131

Epoch 7/10


Training: 100%|█████████████| 341/341 [01:51<00:00,  3.05batch/s]
Validation: 100%|█████████████| 98/98 [00:11<00:00,  8.81batch/s]

Train loss: 0.1256
Val loss: 0.1763
              precision    recall  f1-score   support

           0       0.96      0.99      0.97      2899
           1       0.77      0.42      0.55       232

    accuracy                           0.95      3131
   macro avg       0.86      0.71      0.76      3131
weighted avg       0.94      0.95      0.94      3131

Early stopping triggered
Best weights restored
Mean time per epoch: 121.62





In [78]:
y_pred = model.evaluate(test_loader)


Testing: 100%|████████████████| 49/49 [00:05<00:00,  8.44batch/s]

              precision    recall  f1-score   support

           0       0.95      0.99      0.97      1429
           1       0.78      0.39      0.52       114

    accuracy                           0.95      1543
   macro avg       0.86      0.69      0.75      1543
weighted avg       0.94      0.95      0.94      1543

Test loss: 0.19140



