In [None]:
import torch
from torch.optim import AdamW
from datasets import DatasetDict

from transformers import AutoTokenizer
from tklearn.metrics import Accuracy
from tklearn.nn.utils.data import RecordBatch, AugmentedCollator
from tklearn.nn.callbacks import ProgbarLogger
from tklearn.nn.transformers.classification import TransformerForSequenceClassification

In [None]:
class PromptTransformerForSequenceClassification(TransformerForSequenceClassification):
    def prepare_metric_inputs(self, batch, output):
        print(batch, output)
        return super().prepare_metric_inputs(batch, output)

In [None]:
model = PromptTransformerForSequenceClassification("bert-base-uncased", num_labels=1, target_type="binary")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
import pandas as pd
from datasets import load_dataset

datasets = load_dataset("clinc/clinc_oos", "imbalanced")

intent_id2str = datasets["train"].features["intent"].int2str

intents = pd.Series(datasets["train"]["intent"]).value_counts().sort_index()

top_intents = intents[intents == intents.max()][:5]

datasets = datasets.filter(lambda x: x["intent"] in top_intents)
datasets = datasets.map(lambda x: {**x, "intent_str": intent_id2str(x["intent"])})

datasets = datasets.remove_columns(["intent"])
datasets = datasets.rename_column("intent_str", "intent")

In [None]:
# datasets["train"]["intent"]

In [None]:
id2label = top_intents.index.to_series(index=range(len(top_intents))).apply(intent_id2str).to_dict()
label2id = {v: k for k, v in id2label.items()}

In [None]:
def augment_data(examples, indices):
    prompts, labels, global_indices = [], [], []
    for idx, text in enumerate(examples["text"]):
        global_index = indices[idx]
        act_intent = examples["intent"][idx]
        for rand_label in label2id:
            prompt = text + f" {tokenizer.sep_token} " + rand_label
            prompts.append(prompt)
            labels.append(act_intent == rand_label)
            global_indices.append(global_index)
    return {"index": torch.tensor(global_indices, dtype=torch.long), "prompt": prompts, "labels": torch.tensor(labels, dtype=torch.float32)}

augmented_ds = DatasetDict({
    k: v.map(augment_data, batched=True, remove_columns=v.column_names, batch_size=8, with_indices=True) for k, v in datasets.items()
})

In [None]:
def format_labels(examples):
    tokens = tokenizer(examples["prompt"], return_tensors="pt", padding="max_length", truncation=True)
    tokens = {k: v for k, v in tokens.items()} 
    return tokens


augmented_ds = augmented_ds.map(format_labels, batched=True, batch_size=8)

augmented_ds.set_format("torch")

In [None]:
augmented_ds

In [None]:
ds = augmented_ds

In [None]:
model.to("mps")

train_data = ds["train"][:100]
validation_data = ds["validation"][:10]
optimizer = AdamW(model.parameters(), lr=5e-5)
metrics = {"accuracy": Accuracy()}
callbacks = [ProgbarLogger()]

model = model.fit(train_data, batch_size=2, optimizer=optimizer, validation_data=validation_data, metrics=metrics, callbacks=callbacks)

In [None]:
model.evaluate(validation_data)

In [None]:
# list(model.predict_iter(validation_data, collate_fn=collate_fn))

In [None]:
# model.eval()

# with torch.no_grad():
    # output = model.predict_on_batch(batch)
    # print(model.extract_eval_input(batch, output))

In [None]:
# model.base_model.save_pretrained("model")
# !rm -rf model

In [None]:
# list(model.named_parameters())

In [None]:
dict({"a": 100}, c=1)