In [None]:
import os
os.chdir("/home1/wangtianshu/universal-blocker")
from pathlib import Path
import pandas as pd

In [None]:
import torch
from datasets.load import Dataset, load_dataset
from datasets.features import Array2D, Sequence, Value, Features
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("./models/roberta-base/")
model = AutoModel.from_pretrained("./models/roberta-base/")

In [None]:
@torch.no_grad()
def preprocess(batch: dict[list], index_col: str = "id", tokenizer=tokenizer, model=model):
    columns = [c for c in batch.keys() if index_col not in c]
    batch_size = len(next(iter(batch.values())))

    records = []
    for i in range(batch_size):
        records.append([(c, batch[c][i]) for c in columns])
    
    texts = [
        " ".join(str(t[1]).lower() for t in record if t[1] is not None)
        for record in records
    ]
    features = tokenizer(
        texts,
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    inputs_embeds = model.embeddings.word_embeddings(features["input_ids"])
    
    return {"inputs_embeds": inputs_embeds.numpy()}


data_dir = "./data/blocking/cora"
table_paths = sorted(Path(data_dir).glob("[1-2]*.csv"))
datasets = [load_dataset("csv", data_files=str(t), split="train") for t in table_paths]
for i, dataset in enumerate(datasets):
    datasets[i] = dataset.map(
        preprocess,
        batched=True,
        batch_size=32,
        features=Features({
            "inputs_embeds": Array2D(
                shape=(256, model.embeddings.word_embeddings.embedding_dim), dtype="float32"
            )
        }),
        remove_columns=dataset.column_names
    )
    
def encoder(batch):
    "empty encode function"
    return {}

for i, dataset in enumerate(datasets):
    datasets[i] = dataset.map(
        encoder,
        batched=True,
        batch_size=32,
    )

In [None]:
@torch.no_grad()
def transform(batch: dict[list], index_col: str = "id", tokenizer=tokenizer, model=model):
    columns = [c for c in batch.keys() if index_col not in c]
    batch_size = len(next(iter(batch.values())))

    records = []
    for i in range(batch_size):
        records.append([(c, batch[c][i]) for c in columns])
    
    texts = [
        " ".join(str(t[1]).lower() for t in record if t[1] is not None)
        for record in records
    ]
    features = tokenizer(
        texts,
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    inputs_embeds = model.embeddings.word_embeddings(features["input_ids"])
    
    return {"inputs_embeds": inputs_embeds.numpy()}

datasets = [load_dataset("csv", data_files=str(t), split="train") for t in table_paths]
for i, dataset in enumerate(datasets):
    datasets[i] = dataset.with_transform(
        transform,
#         features=Features({
#             "inputs_embeds": Array2D(
#                 shape=(256, model.embeddings.word_embeddings.embedding_dim), dtype="float32"
#             )
#         }),
#         remove_columns=dataset.column_names
    )
    
def encoder(batch):
    "empty encode function"
    return {}

for i, dataset in enumerate(datasets):
    datasets[i] = dataset.map(
        encoder,
        batched=True,
        batch_size=32,
    )