# Brand Extraction

In [None]:
import pandas as pd
from re import search
from random import seed
from numpy.random import seed as np_seed
from time import perf_counter
from eda import stop_words, get_only_chars, eda

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

print("PyTorch version:", torch.__version__)

## Miscellaneous

In [None]:
verbose = True
device = "cuda" if torch.cuda.is_available() else "cpu"


class SeqToSeqDataset(Dataset):
    def __init__(self, input_ids, attention_mask, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        labels = self.labels[idx]
        return input_ids, attention_mask, labels


def train(dataloader, model, optimizer, device="cuda", step_size=1, verbose=False):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()
    optimizer.zero_grad()
    steps = 0
    train_loss = 0
    for batch, (input_ids, attention_mask, labels) in enumerate(dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        # Compute prediction error
        pred = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = pred.loss
        train_loss += loss.item()

        # Backpropagation
        loss.backward()
        steps += 1
        if steps % step_size == 0:
            optimizer.step()
            optimizer.zero_grad()

        if batch % 100 == 0 and verbose:
            loss, current = loss.item(), batch * len(input_ids)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    optimizer.step()
    train_loss /= num_batches
    return train_loss


def test(dataloader, model, device="cuda", verbose=False):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for input_ids, attention_mask, labels in dataloader:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            pred = model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
            test_loss += pred.loss.item()
            output_sequences = model.generate(
                input_ids=input_ids, attention_mask=attention_mask, do_sample=False
            )
            for output_sequence, label in zip(output_sequences, labels):
                output_sequence = output_sequence[output_sequence != 0]
                label = label[label != -100]
                if output_sequence.shape == label.shape:
                    if all(output_sequence == label):
                        correct += 1
    test_loss /= num_batches
    correct /= size
    if verbose:
        print(
            f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}"
        )
    return test_loss, correct


def learn(
    training_data,
    test_data,
    model,
    optimizer,
    batch_size=64,
    device="cuda",
    epochs=5,
    step_size=1,
    file=None,
    verbose=False,
):
    # Create data loader.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    for input_ids, attention_mask, labels in test_dataloader:
        if verbose:
            print("Shape of input ids: ", input_ids.shape, input_ids.dtype)
            print(
                "Shape of attention mask: ", attention_mask.shape, attention_mask.dtype
            )
            print("Shape of labels: ", labels.shape, labels.dtype)
            print(f"Using {device} device")
            print(model)
        break
    
    tic = perf_counter()

    train_losses = []
    test_losses = []
    corrects = []

    for t in range(epochs):
        if verbose:
            print(f"Epoch {t+1}\n-------------------------------")
        train_loss = train(
            train_dataloader, model, optimizer, device, step_size, verbose
        )
        test_loss, correct = test(test_dataloader, model, device, verbose)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        corrects.append(correct)
        if file:
            torch.save(model, f"{file}-{t+1}.pth")
            if verbose:
                print(f"Saved PyTorch Model State to {file}-{t+1}.pth")
        if verbose:
            toc = perf_counter()
            print(f"Done Epoch {t+1} in {toc - tic} seconds \n")
    if verbose:
        print("Done!")

    return train_losses, test_losses, corrects


def get_augmented_sentences(
    df,
    col1,
    col2,
    alpha_sr=0.1,
    alpha_ri=0.1,
    alpha_rs=0.1,
    p_rd=0.1,
    num_aug=9,
):
    def augment(val1, val2):
        return pd.Series(
            eda(
                val2,
                alpha_sr,
                alpha_ri,
                alpha_rs,
                p_rd,
                num_aug,
                stop_words + val1.split(),
            )
        )

    augmented_sentences = df.apply(lambda x: augment(x[col1], x[col2]), axis=1)
    augmented_sentences = augmented_sentences.transpose()
    augmented_sentences = augmented_sentences.rename(columns=df[col1].to_dict())
    augmented_sentences = augmented_sentences.melt(var_name=col1, value_name=col2)
    return augmented_sentences


def get_augmented_labels(df, col1, col2, num_aug=9):
    def augment(val1, val2):
        if isin(val1, val2):
            val3 = f" {df[col1].sample().iloc[0]} "
            val2 = f" {val2} ".replace(f" {val1} ", val3).strip()
            val1 = val3.strip()
        return pd.Series([val1, val2], index=[col1, col2])

    augmented_labels = pd.concat(
        [df.apply(lambda x: augment(x[col1], x[col2]), axis=1) for _ in range(num_aug)],
        ignore_index=True,
    )
    return augmented_labels


def isin(value, sentence):
    value = f" {value} "
    sentence = f" {sentence} "
    return value in sentence


def startswith(sentence, value):
    sentence = f"{sentence} "
    value = f"{value} "
    return sentence.startswith(value)


def get_dataset(
    tokenizer,
    df,
    col1,
    col2,
    task_prefix="",
    max_source_length=512,
    max_target_length=128,
):
    def get_concatenation(val):
        return task_prefix + val

    encoding = tokenizer(
        df[col1].apply(get_concatenation).tolist(),
        padding="longest",
        max_length=max_source_length,
        truncation=True,
        return_tensors="pt",
    )
    input_ids, attention_mask = encoding.input_ids, encoding.attention_mask

    target_encoding = tokenizer(
        df[col2].tolist(),
        padding="longest",
        max_length=max_target_length,
        truncation=True,
    )
    labels = target_encoding.input_ids
    
    labels = torch.tensor(labels)
    labels[labels == tokenizer.pad_token_id] = -100

    dataset = SeqToSeqDataset(input_ids, attention_mask, labels)
    return dataset

## Obtaining Data

In [None]:
df1 = pd.read_csv("../data/Hackathon_Ideal_Data.csv")
df1

## Scrubbing Data

In [None]:
df2 = df1[["MBRD", "BRD"]]
df2 = df2.rename(columns={"MBRD": "brand", "BRD": "product"})
df2["brand"] = df2["brand"].apply(get_only_chars)
df2["brand"] = df2["brand"].str.strip()
df2 = df2[df2["brand"].str.len() > 0]
df2["product"] = df2["product"].apply(get_only_chars)
df2["product"] = df2["product"].str.strip()
df2 = df2[df2["product"].str.len() > 0]
df2

In [None]:
training_set1 = df2.sample(frac=0.7, random_state=1)
training_set1

In [None]:
validation_set = df2.drop(training_set1.index)
validation_set = validation_set.sample(frac=0.7, random_state=1)
validation_set

In [None]:
test_set = df2.drop(training_set1.index)
test_set = test_set.drop(validation_set.index)
test_set

In [None]:
seed(1)
augmented_sentences = get_augmented_sentences(
    training_set1,
    "brand",
    "product",
    alpha_sr=0.1,
    alpha_ri=0.1,
    alpha_rs=0.1,
    p_rd=0.1,
    num_aug=4,
)
training_set2 = pd.concat([training_set1, augmented_sentences])
training_set2

In [None]:
np_seed(1)
augmented_labels = get_augmented_labels(training_set1, "brand", "product", num_aug=4)
training_set3 = pd.concat([training_set1, augmented_labels])
training_set3

In [None]:
training_set4 = pd.concat([training_set2, augmented_labels])
training_set4

## Exploring Data

In [None]:
training_set1["brand"].str.split(" ").str.len().value_counts().sort_index()

In [None]:
training_set1.apply(
    lambda x: isin(x["product"], x["brand"]), axis=1
).value_counts().sort_index()

In [None]:
training_set1.apply(
    lambda x: startswith(x["product"], x["brand"]), axis=1
).value_counts().sort_index()

In [None]:
training_set1.apply(
    lambda x: isin(x["product"], x["brand"])
    and not startswith(x["product"], x["brand"]),
    axis=1,
).value_counts().sort_index()

## Modelling Data

In [None]:
tokenizer = torch.load(r"..\models\t5-small-tokenizer.pth")
model = torch.load(r"..\models\t5-small-model.pth").to(device)
training_data1 = get_dataset(tokenizer, training_set1["document"])
training_data2 = get_dataset(tokenizer, training_set2["document"])
training_data3 = get_dataset(tokenizer, training_set3["document"])
training_data4 = get_dataset(tokenizer, training_set4["document"])
validation_data = get_dataset(tokenizer, validation_set, "product", "brand")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
train_losses, test_losses, corrects = learn(
    training_data,
    validation_data,
    model,
    optimizer,
    batch_size=64,
    device=device,
    epochs=387,
    step_size=1,
    file=r"..\models\t5-small",
    verbose=verbose,
)

## Interpreting Data

In [None]:
tokenizer = torch.load(r"D:\models\seq\t5-small-tokenizer.pth")
model = torch.load(r"D:\brands\t5-small-edabr-39.pth").to(device)
validation_data = get_dataset(tokenizer, validation_set, "product", "brand")
validation_data = DataLoader(validation_data, batch_size=256)
test_loss, correct = test(validation_data, model, device=device, verbose=verbose)

In [None]:
size = test_set.shape[0]
outputs = test_set["product"].str.extract("^([^ ]*) ?.*$", expand=False)
correct = (outputs == test_set["brand"]).sum()
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: N/A")