### IMDB Dataset - Classification of Movie Reviews

Let's look at the dataset from the Interneet Movie Database (IMDB) and use it to train a model to classify whether a review is positive or negative.  This is an example of what is called 'sentiment analysis'.  It is used a lot by firms to find out whether their customers are leaving them good or bad reviews, for instance.

In [49]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [50]:
import torchtext
from torchtext.datasets import IMDB, SST2
from torchtext.data.utils import get_tokenizer
from collections import Counter, OrderedDict
from torchtext.vocab import Vocab
import torch
from torch.utils.data import DataLoader

NUMBER_OF_WORDS = 10000

# q_texts = []
# q_labels = []

# q_test_texts = []
# q_test_labels = []

# # Test
# q_texts = ["A D", "B", "A B D", "A B C", "C", "A C"]	
# q_labels = [1, 0, 1, 1, 0, 1]  # 1 = contains A, 0 = does not contain A

# q_test_texts = ["A", "B C", "A C D", "A B C D", "C D", "B C D"]
# q_test_labels = [1, 0, 1, 1, 0, 0]  # 1 = contains A, 0 = does not contain A

tokenizer = get_tokenizer('basic_english')
train_data, test_data = IMDB(split=('train', 'test'))

  

# count = 0
# for item in train_data:
#     if item[0] == 1:
#         count += 1
#         q_texts.append(item[1])
#         q_labels.append(item[0] - 1)
#     if count>1000:
#         break

# count = 0
# for item in train_data:
#     if item[0] == 2:
#         count += 1
#         q_texts.append(item[1])
#         q_labels.append(item[0] - 1)
#     if count>1000:
#         break

# count = 0
# for item in test_data:
#     if item[0] == 2:
#         count += 1
#         q_test_texts.append(item[1])
#         q_test_labels.append(item[0] - 1)
#     if count>200:
#         break

# count = 0
# for item in test_data:
#     if item[0] == 1:
#         count += 1
#         q_test_texts.append(item[1])
#         q_test_labels.append(item[0] - 1)
#     if count>200:
#         break

# counter = Counter()
# for text in q_texts:
#     counter.update(tokenizer(text))

counter = Counter()
for text in train_data:
    counter.update(tokenizer(text[1]))

dict = OrderedDict()
for i, (token, _) in enumerate(counter.most_common(NUMBER_OF_WORDS)):
    dict[token] = i
vocab = Vocab(dict)

def n_hot_encoding(text, vocab):
    indices = [vocab[token] for token in tokenizer(text) if token in vocab]
    one_hot = torch.zeros(len(vocab))
    one_hot[indices] = 1
    return one_hot


BATCH_SIZE = 64
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)


In [51]:
import torch.nn as nn

class ImdbClassifier1(nn.Module):
    def __init__(self):
        super(ImdbClassifier1, self).__init__()
        self.fc1 = nn.Linear(NUMBER_OF_WORDS, 16)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(16, 16)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(16, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        #x = torch.sigmoid(x)
        return x

model = ImdbClassifier1().to(device)

loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [52]:
# def q_test_model():
#     model.eval()
#     with torch.no_grad():
#         correct = 0
#         total = 0
#         for text, label in zip(q_test_texts, q_test_labels):
#             label_tensor = torch.tensor(label).float().to(device)
#             label_tensor = label_tensor.unsqueeze(0).unsqueeze(0)
#             x = torch.stack([n_hot_encoding(text, vocab)]).to(device)
#             y_pred = model(x)
#             y_pred =  (y_pred > 0.5).float()    # convert to 1.0 if greater than 0.5, 0.0 otherwise
#             correct += (y_pred == label_tensor).sum().item()
#             total += 1
#         print(f"Accuracy: {(100.0 * correct)/total if total > 0 else 0.0}")

# def q_train_model():
#     model.train()
#     total_loss = 0
#     count = 0
#     for text, label in zip(q_texts, q_labels):
#         optimizer.zero_grad()
#         label_tensor = torch.tensor(label).float().to(device)
#         label_tensor = label_tensor.unsqueeze(0).unsqueeze(0)
#         x = torch.stack([n_hot_encoding(text, vocab)]).to(device)
#         y_pred = model(x)
#         loss = loss_fn(y_pred, label_tensor)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#         count += 1
#     print(f"Loss: {total_loss/count if count > 0 else -1}")

In [53]:
#q_train_model()

# q_test_model()
# for epoch in range(100):
#     q_train_model()
#     q_test_model()


In [54]:

def train_model():
    model.train()
    count = 0
    total_loss = 0
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        labels = batch[0].float().to(device)
        labels = (labels - 1.0).unsqueeze(1)  # IMDB dataset has labels 1 and 2, convert them to 0 and 1
        texts = batch[1]
        x = torch.stack([n_hot_encoding(text, vocab) for text in texts]).to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        count += 1        
    print(i, f"Loss: {total_loss/count if count > 0 else -1}")
        

def test_model():
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for i, batch in enumerate(test_loader):
            labels = batch[0].float().to(device)
            labels = (labels - 1.0).unsqueeze(1)  # IMDB dataset has labels 1 and 2, convert them to 0 and 1

            # if labels.sum() < 3:
            #     continue
            # if labels.sum() > 61:
            #     continue

            texts = batch[1]
            x = torch.stack([n_hot_encoding(text, vocab) for text in texts]).to(device)
            y_pred = model(x)
            y_pred =  (y_pred > 0.5).float()    # convert to 1.0 if greater than 0.5, 0.0 otherwise
            correct += (y_pred == labels).sum().item()

            total += labels.size(0)
            
        print(f"Accuracy: {(100.0 * correct)/total if total > 0 else -1}")





test_model()
for epoch in range(100):
    print("Epoch: ", epoch)
    train_model()
    test_model()


Accuracy: 50.0
Epoch:  0
390 Loss: 0.3977440000914247
Accuracy: 88.776
Epoch:  1
390 Loss: 0.19219690342814855
Accuracy: 88.484
Epoch:  2
390 Loss: 0.14438770959794503
Accuracy: 87.868
Epoch:  3
390 Loss: 0.11361071500448448
Accuracy: 87.284
Epoch:  4
390 Loss: 0.0974408501330132
Accuracy: 87.132
Epoch:  5
390 Loss: 0.07966549601629162
Accuracy: 86.704
Epoch:  6
390 Loss: 0.06388930152283501
Accuracy: 86.452
Epoch:  7
390 Loss: 0.05257605593291271
Accuracy: 86.32
Epoch:  8
390 Loss: 0.04274060663646873
Accuracy: 85.936
Epoch:  9
390 Loss: 0.035444547758859586
Accuracy: 85.84
Epoch:  10
390 Loss: 0.0311519572032107
Accuracy: 85.732
Epoch:  11
390 Loss: 0.026095112936314887
Accuracy: 85.692
Epoch:  12
390 Loss: 0.0188684296885455
Accuracy: 85.412
Epoch:  13
390 Loss: 0.014275856491196436
Accuracy: 85.484
Epoch:  14
390 Loss: 0.009198761192147008


KeyboardInterrupt: 