In [None]:
from sentence_transformers import SentenceTransformer
from scipy.stats import wasserstein_distance
from tqdm import tqdm
import numpy as np
import csv

path_synthetic = "/home/abackurs/c/tabular_data/tabular_data/outputs/water_e_10_ns_2000_v_3_mtp_0.5_mcp_0.5_eps_4.0_session_dp/0/samples_csv.csv"
path_train = "/home/abackurs/c/tabular_data/tabular_data/data/water/water_train_csv.csv"

sentence_transformer = SentenceTransformer("sentence-t5-xl").eval()
def get_embeddings_model(sequences):
    batch_size = 128
    
    embeddings = []
    batch_count = len(sequences) // batch_size + (1 if len(sequences) % batch_size != 0 else 0)
    for b_i in tqdm(range(batch_count)):
        batch = sequences[b_i * batch_size:(b_i + 1) * batch_size]
        embeddings_batch = sentence_transformer.encode(batch)
        embeddings.append(embeddings_batch)

    return np.concatenate(embeddings)

def distribution(path):
    with open(path, 'r') as f:
        reader = csv.DictReader(f)
        sequences = list(reader)
    print(sequences[0])
    seq_1 = [s['title'] for s in sequences]
    seq_2 = [s['cleaned_review'] for s in sequences]
    emb_1 = get_embeddings_model(seq_1)
    emb_2 = get_embeddings_model(seq_2)
    dist = [np.dot(e_1, e_2) / (np.linalg.norm(e_1) * np.linalg.norm(e_2)) for e_1, e_2 in zip(emb_1, emb_2)]
    return np.array(dist)

dist_synthetic = distribution(path_synthetic)
dist_train = distribution(path_train)
print(wasserstein_distance(dist_synthetic, dist_train))

{'product_name': 'EcoFriendly Urban Water Bottle 750ml', 'overall_rating': '4.1', 'title': 'Awesome', 'cleaned_review': 'This Tritan bottle is strong. The size is perfect for daily use. It is entirely leak-proof and can be carried easily. At this price, quite reasonable.'}


100%|██████████| 16/16 [00:02<00:00,  7.13it/s]
100%|██████████| 16/16 [00:11<00:00,  1.40it/s]


{'product_name': 'MILTON Shine 800 Stainless Steel Water Bottle, Purple 700 ml Bottle Reviews', 'overall_rating': '3.5', 'rating': '5', 'title': 'Best in the market!', 'cleaned_review': 'nice'}


100%|██████████| 157/157 [00:25<00:00,  6.12it/s]
100%|██████████| 157/157 [01:07<00:00,  2.34it/s]

0.004825691382865279





In [3]:
dist_synthetic

array([0.6478088 , 0.6750095 , 0.89620155, ..., 0.8643617 , 0.6751026 ,
       0.67674476], shape=(1997,), dtype=float32)