In [None]:
!pip install transformers

In [2]:
import torch
import tqdm
import torch.nn as nn
from torchtext import data, datasets
from transformers import BertTokenizer, BertModel

In [3]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [4]:
tokens = tokenizer.tokenize("What IS your name?")
tokens

['what', 'is', 'your', 'name', '?']

In [5]:
indices = tokenizer.convert_tokens_to_ids(tokens)
indices

[2054, 2003, 2115, 2171, 1029]

In [6]:
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

init_token, eos_token, pad_token, unk_token

('[CLS]', '[SEP]', '[PAD]', '[UNK]')

In [7]:
init_token_idx = tokenizer.cls_token_id
eos_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id

init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx

(101, 102, 0, 100)

In [8]:
max_input_len = tokenizer.max_model_input_sizes["bert-base-uncased"]
max_input_len

512

In [9]:
def tokenize_and_cut(text):
    tokens = tokenizer.tokenize(text)[:max_input_len - 2]
    return tokens

In [10]:
TEXT = data.Field(batch_first=True, use_vocab=False, tokenize=tokenize_and_cut, preprocessing=tokenizer.convert_tokens_to_ids,
                  init_token=init_token_idx, eos_token=eos_token_idx, pad_token=pad_token_idx, unk_token=unk_token_idx)

LABEL = data.LabelField(dtype=torch.float)

In [11]:
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

In [12]:
for d in train_data:
    print(vars(d)["text"])
    print(vars(d)["label"])
    break

[19401, 4038, 4626, 14042, 9082, 1006, 2007, 1996, 21358, 7011, 3468, 2520, 15030, 5267, 1007, 1012, 2402, 14911, 8480, 1999, 5365, 6224, 2732, 9527, 1012, 12081, 4616, 13398, 1000, 2035, 1996, 3340, 1999, 15418, 1005, 1055, 6014, 1000, 1999, 1996, 3297, 4012, 15630, 10286, 2100, 3496, 1010, 4606, 7167, 1997, 13528, 2143, 2437, 6987, 2005, 1996, 6288, 1012, 27263, 2036, 19566, 2005, 13082, 3012, 9082, 1005, 3297, 1010, 10433, 2135, 22473, 17727, 18617, 10708, 1997, 1996, 2327, 3340, 1997, 1996, 2154, 1006, 2014, 26699, 2239, 2003, 1037, 17935, 2102, 999, 1007, 1012, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 1000, 14911, 1010, 1000, 2130, 11269, 2841, 2004, 2016, 11340, 1996, 3297, 2732, 14042, 9082, 2012, 5093, 1010, 4332, 2039, 2014, 4451, 1998, 7928, 1010, 1000, 2821, 2232, 1010, 1045, 2123, 1005, 1056, 2066, 2014, 999, 1000, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 2026, 6140, 2001, 3819, 1012, 2466, 1010, 3257, 1010, 3772, 2019, 14469, 11084, 1998, 1037, 2442, 2005, 20

In [13]:
LABEL.build_vocab(train_data)

In [14]:
LABEL.vocab.stoi

defaultdict(<function torchtext.vocab._default_unk_index>,
            {'neg': 0, 'pos': 1})

In [15]:
bert = BertModel.from_pretrained("bert-base-uncased")

In [16]:
class Net(nn.Module):
    def __init__(self, bert, hidden_size, output_size):
        super().__init__()
        self.bert = bert
        embedding_size = bert.config.to_dict()["hidden_size"]

        self.gru = nn.GRU(embedding_size, hidden_size, num_layers=2, bidirectional=True, batch_first=True, dropout=0.25)
        self.fc = nn.Linear(2 * hidden_size, output_size)

        self.dropout = nn.Dropout(0.25)

    def forward(self, text):
        with torch.no_grad():
            embedded = self.bert(text)[0]

        _, hidden = self.gru(embedded)
        hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        output = self.fc(hidden)

        return output

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10
batch_size = 128
hidden_size = 256
output_size = 1

In [18]:
device

device(type='cuda')

In [19]:
train_batches, test_batches = data.BucketIterator.splits((train_data, test_data), batch_size=batch_size, 
                                                         device=device)

In [20]:
net = Net(bert, hidden_size, output_size).to(device)
net

Net(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        

In [21]:
for name, param in net.named_parameters():
    if "bert" in name:
        param.requires_grad = False

In [22]:
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

In [23]:
def get_accuracy(preds, y):
    preds = torch.round(torch.sigmoid(preds))
    correct = (preds == y).float()
    acc = correct.sum() / len(correct)

    return acc

In [24]:
def loop(net, batches, train):
    batch_losses = []
    batch_accs = []

    if train:
        print("Train Loop:")
        net.train()
        for batch in tqdm.tqdm(batches, total=len(batches)):
            texts = batch.text.to(device)
            labels = batch.label.to(device)

            preds = net(texts)
            preds = preds.squeeze(1)

            loss = loss_fn(preds, labels)
            acc = get_accuracy(preds, labels)

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_losses.append(loss.item())
            batch_accs.append(acc)

    else:
        print("Inference Loop:")
        net.eval()
        with torch.no_grad():
            for batch in tqdm.tqdm(batches, total=len(batches)):
                texts = batch.text.to(device)
                labels = batch.label.to(device)

                preds = net(texts)
                preds = preds.squeeze(1)

                loss = loss_fn(preds, labels)
                acc = get_accuracy(preds, labels)

                batch_losses.append(loss.item())
                batch_accs.append(acc) 

    print("")
    print("")
    
    return sum(batch_losses) / len(batch_losses), sum(batch_accs) / len(batch_accs)

In [25]:
def predict_sentiment(net, text):
    net.eval()
    tokens = tokenizer.tokenize(text)[:max_input_len - 2]
    indices = [init_token_idx] + tokenizer.convert_tokens_to_ids(tokens) + [eos_token_idx]
    indices = torch.LongTensor(indices).unsqueeze(0).to(device)
    
    preds = net(indices)
    preds = torch.sigmoid(preds)
    
    print(f"sentiment: {preds.item()}")

In [26]:
text = "this is a very good idea"

In [27]:
for epoch in range(epochs):
    train_loss, train_acc = loop(net, train_batches, True)
    val_loss, val_acc = loop(net, test_batches, False)
    
    print(f"epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")
    predict_sentiment(net, text)
    print("")

  0%|          | 0/196 [00:00<?, ?it/s]

Train Loop:


100%|██████████| 196/196 [20:44<00:00,  6.35s/it]
  0%|          | 0/196 [00:00<?, ?it/s]



Inference Loop:


100%|██████████| 196/196 [09:29<00:00,  2.90s/it]
  0%|          | 0/196 [00:00<?, ?it/s]



epoch: 0 | train_loss: 0.4066 | train_acc: 0.8074 | val_loss: 0.2399 | val_acc: 0.9026
sentiment: 0.8344929814338684

Train Loop:


100%|██████████| 196/196 [21:01<00:00,  6.43s/it]
  0%|          | 0/196 [00:00<?, ?it/s]



Inference Loop:


100%|██████████| 196/196 [09:28<00:00,  2.90s/it]
  0%|          | 0/196 [00:00<?, ?it/s]



epoch: 1 | train_loss: 0.2522 | train_acc: 0.8975 | val_loss: 0.2350 | val_acc: 0.9033
sentiment: 0.8451172709465027

Train Loop:


 30%|███       | 59/196 [06:21<14:44,  6.46s/it]

KeyboardInterrupt: ignored

In [28]:
def save_checkpoint(net, opt, filename):
    check_point = {"net_dict": net.state_dict(), "opt_dict": opt.state_dict()}
    torch.save(check_point, filename)
    print("Checkpoint Saved!")

def load_checkpoint(net, opt, filename):
    check_point = torch.load(filename)
    net.load_state_dict(check_point["net_dict"])
    opt.load_state_dict(check_point["opt_dict"])
    losses = check_point["losses"]
    print("Checkpoint Loaded!")

In [29]:
save_checkpoint(net, opt, "checkpoint.pth.tar")

Checkpoint Saved!


In [30]:
predict_sentiment(net, "this is a very bad idea")

sentiment: 0.05635718256235123


In [31]:
predict_sentiment(net, "this film is terrible")

sentiment: 0.04847601428627968


In [32]:
predict_sentiment(net, "you are terrific")

sentiment: 0.9460165500640869


In [33]:
predict_sentiment(net, "that is horrible")

sentiment: 0.12634426355361938


In [34]:
predict_sentiment(net, "yeet!!")

sentiment: 0.378313273191452


In [35]:
predict_sentiment(net, "what are you doing?")

sentiment: 0.4481266438961029
