In [66]:
import pandas as pd
import torch
import transformers
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer

In [67]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

'cuda'

In [68]:
class Triage(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, index):
        memo = str(self.data.Memo[index])
        memo = " ".join(memo.split())
        inputs = self.tokenizer.encode_plus(
            memo,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True,
            truncation=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'targets': torch.tensor(self.data.Tags_Encoded[index], dtype=torch.long)
        }

    def __len__(self):
        return self.len

In [69]:
class DistillBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistillBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, 7)

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

In [None]:
model = DistillBERTClass()

state_dict = torch.load('best_synthetic.pkl')

adjusted_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

model.load_state_dict(adjusted_state_dict)
model.to(device)


In [71]:
def predict():

  while True:
    df = pd.DataFrame(columns=['Memo', 'Tags'])

    print("Please enter an example memo for a transaction:")
    memo = input("Memo: ")

    print("Please enter the tag you expect for this memo: \n")
    print("Tags include:  ['Funding', 'Operations', 'Misc', 'Food', 'Equipment', 'Programming', 'Travel']")
    tag = input("Tag: ")

    new_row = pd.DataFrame({'Memo': [memo], 'Tags': [tag]})
    df = pd.concat([df, new_row], ignore_index=True)

    tags = ['Funding', 'Operations', 'Misc', 'Food', 'Equipment', 'Programming', 'Travel']
    map = {tag: i for i, tag in enumerate(tags)}

    def encode_tags(x):
        return map.get(x, -1)

    df['Tags_Encoded'] = df['Tags'].apply(lambda x: encode_tags(x))

    # print(df.head())
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

    data = Triage(df, tokenizer, max_len=512)
    # print(data)

    loader_params = {
                    'batch_size': 1,
                    'shuffle': False,
                    'num_workers': 0
                  }

    loader = DataLoader(data, **loader_params)
    # print(loader)
    model.eval()

    with torch.no_grad():
        for _, data in enumerate(loader, 0):
            ids = data['ids'].to(device, dtype=torch.long)
            mask = data['mask'].to(device, dtype=torch.long)
            targets = data['targets'].to(device, dtype=torch.long)

            outputs = model(ids, mask)
            _, big_idx = torch.max(outputs, dim=1)

            decode_map =  {i: tag for i, tag in enumerate(tags)}
            decoded_inputs = tokenizer.decode(ids[0], skip_special_tokens=True)
            predicted_class = decode_map[big_idx[0].item()]
            true_class = decode_map[targets[0].item()]

            print(f"\nInput Text: {decoded_inputs}")
            print(f"Predicted Class: {predicted_class}, True Class: {true_class}")



In [None]:
predict()