In [1]:
from data import *
from utils import *
import pandas as pd

df = pd.read_csv("data/data_sample.csv", sep="|")
df.shape

(29, 4)

Test class methods:
- build_vocab_from_data
- build_vocab_from_pretrain_emb
- build_with_transformer

In [2]:
data = df["headline"].str.strip() + " " + df["text"].str.strip()

dataset = MultiLabelDataset.build_vocab_from_data(
    data=data.values, 
    labels=df.label.values, 
    tokenizer=Tokenizer())

dataset[0]

(tensor([1, 1, 1,  ..., 1, 1, 1]),
 tensor([0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1,
         1, 0, 0, 1], dtype=torch.int16))

In [3]:
dataset = MultiLabelDataset.build_vocab_from_pretrain_emb(
    data=data.values, 
    labels=df.label.values, 
    tokenizer=Tokenizer(),
    pretrained_name="glove.6B.50d")

dataset[1]

(tensor([    0,     0,     0,    12,     0,     0,  4217,     0,     0,     0,
             0,     0,     4,     0,     0,     0,     0,     0,  3069,  5749,
             4,  1087,     0,   919, 24025,     0,  1246,     5,     0,     0,
             0,    14,   970,     7,  1903,  2309,   588,     0,   134, 10393,
             0,     4,     0,     0,     0,     0,     0,   177,  7124,     4,
           408,     0,   997, 24025,     0,     0,     0,  3096,  1852,  2575,
             6,  9068,     5,     0,     0,     0,     0,     4,     0,     0,
             0,     0,     0,   233,  7124,     4,   207,     0,   997, 24025,
             0,     0,     0,   195,     5,     0,     0,     0,     0,     0,
             0]),
 tensor([0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1,
         1, 0, 0, 1], dtype=torch.int16))

Test get_dataloaders()

In [4]:
train_loader, test_loader, num_classes = get_dataloaders(
    file="data/data_sample.csv",
    tokenizer=Tokenizer(),
    vocab_from="glove.6B.50d"
)

num_classes

45

In [5]:
for i in train_loader:
    print(i)
    break

{'x': tensor([[   0,    0,    0,  ...,    0,    0,    0],
        [   0,  545, 3065,  ...,    0,    0,    0],
        [   0,  896,  108,  ...,    0,    0,    0],
        ...,
        [   0,    0,   12,  ...,    0,    0,    0],
        [   0,    0,    0,  ...,    0,    0,    0],
        [   0,    0,   12,  ...,    0,    0,    0]]), 'y': tensor([[0, 0, 0,  ..., 0, 1, 0],
        [1, 0, 0,  ..., 0, 1, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 1, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int16), 'lengths': [6471, 754, 590, 534, 457, 304, 174, 173, 135, 133, 120, 117, 111, 109, 105, 102, 100, 91, 85, 74, 73, 69, 65]}


In [6]:
for i in test_loader:
    print(i)
    break

{'x': tensor([[     0,    896,    157,  ...,  35115,  69312, 167123],
        [     0,   1948,    384,  ...,      0,      0,      0],
        [     0,   5428,      0,  ...,      0,      0,      0],
        [     0,    211,      0,  ...,      0,      0,      0],
        [     0,      0,      0,  ...,      0,      0,      0],
        [     0,      0,     12,  ...,      0,      0,      0]]), 'y': tensor([[0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [9]:
train_loader, test_loader, num_classes = get_dataloaders(
    file="data/data_sample.csv",
    vocab_from="bert"
)

num_classes

45

In [10]:
for i in test_loader:
    print(i)
    break

{'texts': tensor([[  101, 24529,  2063,  2758,  2097,  2025,  9190,  7987,  2063,  1011,
          1060,  2006,  5227,  1012,  2044,  1037,  4121,  3872,  1997,  3119,
          1999,  7987,  2063,  1060, 13246,  5183,  3303,  1996,  4361,  4518,
          3863,   102],
        [  101,  2470,  9499,  1011, 22326,  5821,  9725,  1012, 11605,  5253,
         24665,  2368, 23510,  2056, 12941,  2848,  6287,  9725, 22326,  5821,
         13058,  2000,  4965,  2013, 27598,  3038,  1996,  4518,  1055,  6689,
          2018,   102],
        [  101, 17235,  2850,  4160,  5494,  2087,  3161,  2015,  1011,  2258,
          1015,  1012,  1996,  2206,  2020,  1996,  2087,  3161,  3314,  1999,
         17235,  2850,  4160,  6202,  2006,  9857, 12367,  7646, 13058,  2484,
          4261,   102],
        [  101,  8915,  2072,  4297,  1053,  2549,  3463,  1012,  4951, 21020,
          4483,  4297,  3479,  4082,  2592, 14477, 21041,  3064,  1999,  5190,
          3272, 16565,  2566,  3745,  2093,  2706