In [None]:
import torch
from torch.optim import AdamW

from transformers import AutoTokenizer
from tklearn.metrics import Accuracy
from tklearn.nn.utils.data import RecordBatch, AugmentedCollator
from tklearn.nn.callbacks import ProgbarLogger
from tklearn.nn.transformers.classification import TransformerForSequenceClassification

In [None]:
model = TransformerForSequenceClassification("bert-base-uncased", num_labels=1, target_type="binary")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
# tokens = tokenizer(["this is an example text"], return_tensors="pt")

# model.output_layer(model.base_model(**tokens).pooler_output)

# model.eval()

# with torch.no_grad():
#     print(model(**tokens).logits)

# tokenizer([
#     "this is an example text", 
#     "this is an example text"
# ], return_tensors="pt", padding="max_length", truncation=True)

In [None]:
import pandas as pd
from datasets import load_dataset

ds = load_dataset("clinc/clinc_oos", "imbalanced")

In [None]:
intent_id2str = ds["train"].features["intent"].int2str

intents = pd.Series(ds["train"]["intent"]).value_counts().sort_index()

top_intents = intents[intents == intents.max()][:5]

ds = ds.filter(lambda x: x["intent"] in top_intents)
ds = ds.map(lambda x: {**x, "intent_str": intent_id2str(x["intent"])})

ds

In [None]:
# ds["train"]["intent_str"]

In [None]:
id2label = top_intents.index.to_series(index=range(len(top_intents))).apply(intent_id2str).to_dict()
label2id = {v: k for k, v in id2label.items()}

In [None]:
# import random
# random.choice(id2label)

In [None]:
def format_labels(tab):
    # rand_label = random.choice(id2label)
    # prompt = tab["text"] + f" {tokenizer.sep_token} " + rand_label
    # tokens = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True)
    # tokens = {k: v[0] for k, v in tokens.items()}
    return {"text": tab["text"], "labels": label2id[tab["intent_str"]]}

ds = ds.map(format_labels)

ds.set_format("torch")

In [None]:
import random

help(random.choices)

In [None]:
import random
from joblib import Memory

class PromptAugmentedTrainingCollator(AugmentedCollator):
    def __init__(self, tokenizer, id2label, sep_token=None, k=None):
        super().__init__()
        if sep_token is None:
            if hasattr(tokenizer, "sep_token"):
                sep_token = tokenizer.sep_token
            else:
                sep_token = "[SEP]"
        self.tokenizer = tokenizer
        self.sep_token = sep_token
        self.id2label = id2label
        self.ids = {cls: list(set(id2label) - {cls}) for cls in id2label}
        self.k = k

    def generate(self, batch: RecordBatch) -> RecordBatch:
        prompts = []
        labels = []
        indexes = []
        args = batch.x.pop("text"), batch.x.pop("labels", batch.y), batch.index
        for txt, true_lbl, idx in zip(*args):
            if isinstance(true_lbl, torch.Tensor):
                true_lbl = true_lbl.item()
            other_labels = self.ids[true_lbl]
            if self.k is not None:
                other_labels = random.choices(self.ids[true_lbl], k=self.k)
            aug_lbls = [true_lbl] + other_labels
            for lbl in aug_lbls:
                cls = self.id2label[true_lbl]
                prompt = f"{txt} {self.sep_token} {cls}"
                prompts.append(prompt)
                labels.append(true_lbl == lbl)
                indexes.append(idx)
        tokens = self.tokenizer(
            prompts, return_tensors="pt", padding="max_length", truncation=True
        )
        labels = torch.tensor(labels, dtype=torch.long)
        if batch.y:
            return RecordBatch(dict(tokens), labels, index=indexes)
        return RecordBatch({**tokens, "labels": labels}, index=indexes)


collate_fn = PromptAugmentedTrainingCollator(tokenizer, id2label=id2label)

In [None]:
model.to("mps")

train_data = ds["train"][:100]
validation_data = ds["validation"][:10]
optimizer = AdamW(model.parameters(), lr=5e-5)
metrics = {"accuracy": Accuracy()}
callbacks = [ProgbarLogger()]

model = model.fit(train_data, batch_size=2, optimizer=optimizer, validation_data=validation_data, metrics=metrics, callbacks=callbacks, collate_fn=collate_fn)

In [None]:
callbacks[0]

In [None]:
model.evaluate(validation_data)

In [None]:
list(model.predict_iter(validation_data))

In [None]:
# model.eval()

# with torch.no_grad():
    # output = model.predict_on_batch(batch)
    # print(model.extract_eval_input(batch, output))

In [None]:
# model.base_model.save_pretrained("model")
# !rm -rf model

In [None]:
# list(model.named_parameters())

In [None]:
dict({"a": 100}, c=1)