In [1]:
import torch
import pandas as pd

In [2]:
class Collator:
    def __init__(self, converter):
        self.converter = converter

    def __call__(self, batch):
        max_sample_len = max([len(sample) for sample, _ in batch])
        max_name_len = max([len(label["name"]) for _, label in batch])

        samples, labels = map(list, zip(*batch))
        names = [label["name"] for label in labels]

        padded_samples = torch.nn.utils.rnn.pad_sequence(samples, batch_first=True, padding_value=self.converter["<PAD>"])
        padded_names = torch.nn.utils.rnn.pad_sequence(names, batch_first=True, padding_value=self.converter["<PAD>"])

        samples_attention_mask = (padded_samples != self.converter["<PAD>"]).byte()
        names_attention_mask = (padded_names != self.converter["<PAD>"]).byte()

        padded_labels = [{
            **label,
            "name": padded_name,
        } for label, padded_name in zip(labels, padded_names)]

        return padded_samples, padded_labels, samples_attention_mask, names_attention_mask

In [3]:
# Helpers
def extract_all_symbols(path_to_csv):
    df = pd.read_csv(path_to_csv, dtype=str)

    df_as_text = df.astype(str).agg("".join, axis=1).str.cat()

    return sorted(set(df_as_text))

def extract_all_unique_column_values(path_to_csv, column):
    df = pd.read_csv(path_to_csv)

    filtered = df[column].dropna()
    filtered = filtered[filtered != ""]

    return filtered.unique().tolist()

In [5]:
%run 2_converter.ipynb
%run 3_dataset.ipynb

name_symbols = extract_all_symbols("data.csv")
name_converter = Converter(symbols=name_symbols, special_symbols=["<PAD>", "<BOS>", "<EOS>", "<NONE>"])

unit_symbols = extract_all_unique_column_values("data.csv", "unit")
unit_converter = Converter(symbols=unit_symbols, special_symbols=["<NONE>"])

tax_symbols = extract_all_unique_column_values("data.csv", "tax_category")
tax_converter = Converter(symbols=tax_symbols, special_symbols=["<NONE>"])

dataset = Dataset("data.csv", name_converter=name_converter, unit_converter=unit_converter, tax_converter=tax_converter)

collator = Collator(converter=name_converter)

print("#######################################################")
collator([dataset[0], dataset[1]])

[' ', '*', ',', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', 'A', 'B', 'C', 'D', 'E', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'R', 'S', 'T', 'U', 'W', 'X', 'Y', 'Z', 'a', 'c', 'd', 'e', 'g', 'k', 'l', 'n', 'o', 's', 't', 'u', 'x', 'y', 'z', 'Ł', 'ł']
['kg', 'g', 'szt']
['A', 'B', 'C', 't']
#######################################################


(tensor([[36, 33, 30, 29, 21, 31, 28, 25, 32, 32, 21,  5, 11,  7, 11, 29, 26,  5,
          21,  5, 11,  5, 55, 12,  7, 12, 19,  5, 12,  7, 12, 19, 21],
         [42, 25, 30, 29, 28, 29, 39, 21, 36, 32, 25,  5, 11, 10, 10, 26,  5, 21,
           5, 11,  5, 55, 13,  7, 19, 19,  5, 13,  7, 19, 19, 21,  1]]),
 [{'name': tensor([36, 33, 30, 29, 21, 31, 28, 25, 32, 32, 21,  1]),
   'unit': tensor(2),
   'tax_category': tensor(0),
   'quantity': tensor(-1.),
   'amount': tensor(-1.),
   'price': tensor(2.2900),
   'total_price': tensor(2.2900),
   'quantity_present': tensor(0, dtype=torch.int8),
   'amount_present': tensor(0, dtype=torch.int8),
   'price_present': tensor(1, dtype=torch.int8),
   'total_price_present': tensor(1, dtype=torch.int8)},
  {'name': tensor([42, 25, 30, 29, 28,  5, 29, 39, 21, 36, 32, 25]),
   'unit': tensor(3),
   'tax_category': tensor(0),
   'quantity': tensor(1.),
   'amount': tensor(-1.),
   'price': tensor(3.9900),
   'total_price': tensor(3.9900),
   'quantity