We will use a RNN this time and see if that helps.

https://www.youtube.com/watch?v=AsNTP8Kwu80

In [1]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from transformers import BertTokenizerFast, BertModel, get_scheduler
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.optim import AdamW

# Further Preprocessing

In [3]:
data = pd.read_csv('./data/cleaned.csv').drop(columns=['Unnamed: 0'])
data

Unnamed: 0,cleaned text,label
0,"the pope is infallible, this is a catholic dog...",intj
1,"being you makes you look cute on, because then...",intj
2,"i'm like entp but idiotichey boy, do you want ...",intj
3,give it to ... he has pica since childhood say...,intj
4,frances farmer will have her revenge on seattl...,intj
...,...,...
7232,"god,,pls take care hiro emergency room???? are...",intp
7233,wow last time i got intp i think u upset the f...,intp
7234,a 100% that someone will get his ass kicked so...,entp
7235,if you’re #intj this one is for you | what is ...,infj


In [4]:
le = LabelEncoder()
data['label'] = le.fit_transform(data['label'])
data

Unnamed: 0,cleaned text,label
0,"the pope is infallible, this is a catholic dog...",10
1,"being you makes you look cute on, because then...",10
2,"i'm like entp but idiotichey boy, do you want ...",10
3,give it to ... he has pica since childhood say...,10
4,frances farmer will have her revenge on seattl...,10
...,...,...
7232,"god,,pls take care hiro emergency room???? are...",11
7233,wow last time i got intp i think u upset the f...,11
7234,a 100% that someone will get his ass kicked so...,3
7235,if you’re #intj this one is for you | what is ...,8


# Dataset Class

In [7]:
class MBTIDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    # returns a sample from dataset based on given idx
    def __getitem__(self, idx):
        item_dct = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item_dct['labels'] = torch.tensor(self.labels[idx])
        return item_dct

In [6]:
train_text, test_text, train_labels, test_labels = train_test_split(data['cleaned text'].tolist(), data['label'].tolist(), test_size=0.2, random_state=45)

In [8]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

train_encodings = tokenizer(text=train_text, truncation=True, padding=True)
test_encodings = tokenizer(text=test_text, truncation=True, padding=True)

train_data = MBTIDataset(train_encodings, train_labels)
test_data = MBTIDataset(test_encodings, test_labels)

In [9]:
train_data[0]

{'input_ids': tensor([  101,  7098,  1997,  2017,  2035,  1045,  3984, 12399,  2063,  2038,
          2025,  2042, 17060,  3689,  2005,  1037,  3232,  1997,  2420,  1010,
          2024,  2017,  4364,  3110,  7929,  1029,  7929, 22708,   999,  2054,
          2006,  3011,  2052,  2191,  2017,  2228,  2008,  1012,  1012,  1012,
          1052,  4246,  2102,  1012,  1012,  1012,  2065,  1045,  2123,  1521,
          1056,  5256,  2039,  2000,  2070,  3492, 15281,  1999,  2026,  2606,
          2059,  1045,  8415,  2000,  2643,  1012,  1012,  1012,  2065,  2017,
          2123,  1521,  1056,  3191,  2026, 12385,  2094,  2077,  2206,  2033,
          2008,  1521,  1055,  1037,  2017,  3291,  1012,  2298,  2012,  2017,
         11065,  2026,  4485,  2525,  1010,  2057,  1521,  2128,  5306,  2085,
         22747,  2860, 18411,  2015,  4429,  4658,  1010, 24978,  2546,  2860,
          2025,  4658,  3597,  4747,  1051,  5358,  2546,  2904,  2000,  1037,
         24978,  2546,  2860,  1006,  1

# Building the Neural Network

https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

In [17]:
class MBTIPredictor(nn.Module):
    def __init__(self, hidden_layer_size, sequence_length):
        super(MBTIPredictor, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.rnn = nn.RNN(768, hidden_layer_size, batch_first=True)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(hidden_layer_size * sequence_length, 16)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        rnn_out, _ = self.rnn(outputs.last_hidden_state)
        rnn_out = rnn_out.reshape(rnn_out.size(0), -1)  # Flatten the output
        relu_out = self.relu(rnn_out)
        logits = self.linear(relu_out)
        
        return logits

In [18]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [19]:
hidden_layer_size = 64
sequence_length = 500

model = MBTIPredictor(hidden_layer_size, sequence_length).to(device)
print(model)

MBTIPredictor(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise

# Model Training

In [20]:
learning_rate = 5e-5
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=learning_rate)
batch_size = 16
epochs = 3

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

num_training_steps = epochs * len(train_loader)
lr_scheduler = get_scheduler(
    name="cosine", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [21]:
def train_loop(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    model.train()
    for batch, batch_data in enumerate(dataloader):

        # Compute prediction and loss
        input_ids = batch_data['input_ids'].to(device)
        attention_mask = batch_data['attention_mask'].to(device)
        labels = batch_data['labels'].to(device)

        pred = model(input_ids, attention_mask)
        loss = loss_fn(pred, labels)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        lr_scheduler.step()

        if batch % 50 == 0:
            loss_value = loss.item()
            current = batch * len(input_ids)
            print(f"loss: {loss_value:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
model = MBTIPredictor().to(device)

# Training loop
epochs = 3
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loop(train_loader, model, loss_fn, optimizer, device)

print("Training complete!")

In [None]:
def find_accuracy(model, dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids'].to(device)
            attention_mask = batch_data['attention_mask'].to(device)
            labels = batch_data['labels'].to(device)

            model_output = model(input_ids, attention_mask)

            _, predicted = torch.max(model_output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
accuracy = find_accuracy(model, test_loader)
print(f'Model accuracy: {accuracy}%')