In [9]:
import torch
import pandas as pd

In [10]:
class Collator:
    def __init__(self, converter):
        self.converter = converter

    def __call__(self, batch):
        max_sample_len = max([len(sample) for sample, _ in batch])
        max_name_len = max([len(label["name"]) for _, label in batch])

        samples, labels = map(list, zip(*batch))
        names = [label["name"] for label in labels]

        padded_samples = torch.nn.utils.rnn.pad_sequence(samples, batch_first=True, padding_value=self.converter["<PAD>"])
        padded_names = torch.nn.utils.rnn.pad_sequence(names, batch_first=True, padding_value=self.converter["<PAD>"])

        samples_attention_mask = (padded_samples == self.converter["<PAD>"])
        names_attention_mask = (padded_names == self.converter["<PAD>"])

        padded_labels = [{
            **label,
            "name": padded_name,
        } for label, padded_name in zip(labels, padded_names)]

        return padded_samples, padded_labels, samples_attention_mask, names_attention_mask

In [11]:
# Helpers
def extract_all_symbols(path_to_csv):
    df = pd.read_csv(path_to_csv, dtype=str)

    df_as_text = df.astype(str).agg("".join, axis=1).str.cat()

    return sorted(set(df_as_text))

def extract_all_unique_column_values(path_to_csv, column):
    df = pd.read_csv(path_to_csv)

    filtered = df[column].dropna()
    filtered = filtered[filtered != ""]

    return filtered.unique().tolist()

In [12]:
%run 2_converter.ipynb
%run 3_dataset.ipynb

name_symbols = extract_all_symbols("../data/train.csv")
name_converter = Converter(symbols=name_symbols, special_symbols=["<PAD>", "<BOS>", "<EOS>", "<NONE>"])

unit_symbols = extract_all_unique_column_values("../data/train.csv", "unit")
unit_converter = Converter(symbols=unit_symbols, special_symbols=["<NONE>"])

tax_symbols = extract_all_unique_column_values("../data/train.csv", "tax_category")
tax_converter = Converter(symbols=tax_symbols, special_symbols=["<NONE>"])

dataset = Dataset("../data/data.csv", name_converter=name_converter, unit_converter=unit_converter, tax_converter=tax_converter)

collator = Collator(converter=name_converter)

print("#######################################################")
collator([dataset[0], dataset[1]])

[' ', '!', '"', '#', '%', '&', "'", '(', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '=', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '\\', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '|', 'Ó', 'ó', 'Ą', 'ą', 'Ć', 'ć', 'Ę', 'ę', 'Ł', 'ł', 'ń', 'Ś', 'ś', 'ź', 'Ż', 'ż']
['KG', 'G', 'SZT', 'L', 'ML']
['C', 'A', 'B', 'G', '0', 't', 'r', 'c', '4', 'a', '8', 'l', 'k']
#######################################################


(tensor([[43, 31, 41, 31,  5, 32, 31, 49, 39, 31,  5, 20, 41, 37,  5, 33,  5, 28,
           5, 82, 22, 15, 24, 28,  5, 22, 21, 15, 22, 20, 33,  1,  1,  1],
         [56, 35, 42, 41, 39,  5, 41, 53,  5, 37, 53, 39, 31, 56, 34, 41,  5, 28,
           5, 31,  5, 20,  5, 82, 22, 15, 28, 28,  5, 22, 15, 28, 28, 31]]),
 [{'name': tensor([43, 31, 41, 31,  5, 32, 31, 49, 39, 31,  5, 20, 41, 37,  1,  1]),
   'unit': tensor(2),
   'tax_category': tensor(2),
   'quantity': tensor(9.),
   'amount': tensor(1.),
   'price': tensor(3.5900),
   'total_price': tensor(32.3100),
   'quantity_present': tensor(1, dtype=torch.int8),
   'amount_present': tensor(1, dtype=torch.int8),
   'price_present': tensor(1, dtype=torch.int8),
   'total_price_present': tensor(1, dtype=torch.int8)},
  {'name': tensor([56, 35, 42, 41, 39,  5, 41, 53,  5, 37, 53, 39, 31, 56, 34, 41]),
   'unit': tensor(1),
   'tax_category': tensor(3),
   'quantity': tensor(1.),
   'amount': tensor(-1.),
   'price': tensor(3.9900),
   'tot