In [13]:
import joblib
import json
import pandas as pd
import numpy as np
import utils
from transformers import AutoTokenizer
import torch
import sys
from collections import Counter

sys.path.append("../datasets")
import configs

In [14]:
dataset = "dbpedia"
architecture = "BART"
data_dir = f"../datasets/{dataset}_dataset"
batch_size = 256
model_name = f"{dataset}_model_0.9_0.9_20.0"

In [15]:
if architecture == "BART":
    tokenizer = AutoTokenizer.from_pretrained("ModelTC/bart-base-mnli")
    # tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
elif architecture == "ELECTRA":
    tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")
else:
    print(f"Invalid backbone architecture: {architecture}")

In [16]:
all_datasets = utils.load_dataset(
    data_dir=data_dir,
    tokenizer=tokenizer,
    max_length=configs.dataset_to_max_length[dataset],
)

all_dataloaders = {
    dataset_name: torch.utils.data.DataLoader(
        all_datasets[dataset_name],
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda batch: {
            "input_ids": torch.LongTensor([i["input_ids"] for i in batch]),
            "attention_mask": torch.Tensor([i["attention_mask"] for i in batch]),
            "label": torch.LongTensor([i["label"] for i in batch]),
        },
    )
    for dataset_name in all_datasets.keys()
}

Train data shape:  (24094, 2)


Map:   0%|          | 0/24094 [00:00<?, ? examples/s]

Map:   0%|          | 0/60794 [00:00<?, ? examples/s]

Map:   0%|          | 0/1836 [00:00<?, ? examples/s]

Map:   0%|          | 0/1281 [00:00<?, ? examples/s]

Map:   0%|          | 0/1836 [00:00<?, ? examples/s]

Map:   0%|          | 0/1281 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [17]:
bestk_train_data_per_proto = joblib.load(
    f"artifacts/{dataset}/{model_name}/bestk_train_data_per_proto.joblib"
)

best_protos_per_traineg = joblib.load(
    f"artifacts/{dataset}/{model_name}/best_protos_per_traineg.joblib"
)

In [18]:
bestk_train_data_per_proto[0].shape

(16, 5)

In [19]:
bestk_train_data_per_proto[0][0]

array([ 2281,  3695, 13359, 10491,  8389])

In [20]:
sorted(
    dict(Counter(all_datasets["train"]["label"])).items(),
    key=lambda x: x[1],
    reverse=True,
)

[(0, 12480),
 (2, 4588),
 (3, 2147),
 (1, 2101),
 (5, 1911),
 (6, 588),
 (4, 176),
 (8, 78),
 (7, 25)]

In [21]:
for label in range(16):
    print(f"Label: {label}")
    print(
        np.array(all_datasets["train"]["label"])[bestk_train_data_per_proto[0][label]]
    )
    print(bestk_train_data_per_proto[1][label])
    print("------------------")

Label: 0
[1 2 0 0 0]
[-3.8624659 -3.861605  -3.8615801 -3.8613226 -3.8612463]
------------------
Label: 1
[0 0 0 0 0]
[-3.850351  -3.850011  -3.8479881 -3.84537   -3.8445837]
------------------
Label: 2
[0 5 0 0 0]
[-3.773628  -3.753562  -3.7531555 -3.752609  -3.7500129]
------------------
Label: 3
[5 3 1 6 0]
[-3.7334611 -3.732118  -3.731756  -3.731465  -3.7313373]
------------------
Label: 4
[8 0 2 1 0]
[-3.8137    -3.8127384 -3.8117366 -3.811621  -3.8113782]
------------------
Label: 5
[2 0 0 5 0]
[-3.8440573 -3.8431418 -3.8419259 -3.8415952 -3.841591 ]
------------------
Label: 6
[0 0 7 0 1]
[-3.8162267 -3.815861  -3.8152468 -3.8147988 -3.814158 ]
------------------
Label: 7
[2 1 2 0 0]
[-3.8186543 -3.8178284 -3.8177805 -3.8168666 -3.8151393]
------------------
Label: 8
[0 2 3 0 1]
[-3.8159566 -3.8146105 -3.813936  -3.8121421 -3.8113277]
------------------
Label: 9
[0 0 2 5 5]
[-3.8361278 -3.8330958 -3.8313797 -3.8244326 -3.823787 ]
------------------
Label: 10
[6 0 0 8 6]
[-3.8473