In [93]:
import nltk
import pandas as pd
from nltk.corpus import stopwords
from textblob import Word
from sklearn.preprocessing import LabelEncoder
from collections import Counter
import wordcloud
from sklearn.metrics import classification_report,confusion_matrix,accuracy_score
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from collections import Counter
import numpy as np

Samples = 10

In [94]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [95]:
datas = pd.read_csv("IMDB Dataset.csv")
reviews = datas["review"]
Y = np.array([{"positive": 1, "negative":0}[i] for i in datas["sentiment"]])

In [99]:
#Taking only 2 reviews so the computer doesn't die
X = [i for i in reviews][:Samples]
Y = Y[:Samples]
print(X)

["One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked. They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO. Trust me, this is not a show for the faint hearted or timid. This show pulls no punches with regards to drugs, sex or violence. Its is hardcore, in the classic use of the word.<br /><br />It is called OZ as that is the nickname given to the Oswald Maximum Security State Penitentary. It focuses mainly on Emerald City, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. Em City is home to many..Aryans, Muslims, gangstas, Latinos, Christians, Italians, Irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.<br /><br />I would say the main appeal of the show is due to the f

In [100]:
X = [nltk.word_tokenize(i) for i in X]
print(X)

[['One', 'of', 'the', 'other', 'reviewers', 'has', 'mentioned', 'that', 'after', 'watching', 'just', '1', 'Oz', 'episode', 'you', "'ll", 'be', 'hooked', '.', 'They', 'are', 'right', ',', 'as', 'this', 'is', 'exactly', 'what', 'happened', 'with', 'me.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'The', 'first', 'thing', 'that', 'struck', 'me', 'about', 'Oz', 'was', 'its', 'brutality', 'and', 'unflinching', 'scenes', 'of', 'violence', ',', 'which', 'set', 'in', 'right', 'from', 'the', 'word', 'GO', '.', 'Trust', 'me', ',', 'this', 'is', 'not', 'a', 'show', 'for', 'the', 'faint', 'hearted', 'or', 'timid', '.', 'This', 'show', 'pulls', 'no', 'punches', 'with', 'regards', 'to', 'drugs', ',', 'sex', 'or', 'violence', '.', 'Its', 'is', 'hardcore', ',', 'in', 'the', 'classic', 'use', 'of', 'the', 'word.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'It', 'is', 'called', 'OZ', 'as', 'that', 'is', 'the', 'nickname', 'given', 'to', 'the', 'Oswald', 'Maximum', 'Security', 'State', 'Penitentary', '.',

In [101]:
all_words = [i for j in X for i in j]
word_counts = Counter(all_words)



In [102]:
vocab = {word: i+2 for i, (word, _) in enumerate(word_counts.most_common())}
vocab['<PAD>'] = 0
vocab['<UNK>'] = 1

In [103]:
def encode_review(review, vocab, max_len=200):
    tokens = nltk.word_tokenize(review)
    encoded = [vocab.get(word, vocab['<UNK>']) for word in tokens]
    if len(encoded) < max_len:
        encoded += [vocab['<PAD>']] * (max_len - len(encoded))
    else:
        encoded = encoded[:max_len]
    return encoded

X = np.array([encode_review(r, vocab) for r in [i for i in reviews][:Samples]])


In [104]:
class ReviewDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.long)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)

train_data = ReviewDataset(X_train, y_train)
test_data = ReviewDataset(X_test, y_test)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)


In [105]:
class SentimentLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, n_layers=2, dropout=0.5):
        super(SentimentLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=n_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, (h_n, c_n) = self.lstm(embedded)
        out = self.fc(self.dropout(h_n[-1]))  # last hidden state
        return torch.sigmoid(out)

vocab_size = len(vocab)
model = SentimentLSTM(vocab_size=vocab_size, embed_dim=128, hidden_dim=256, output_dim=1)


In [106]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 5
for epoch in range(epochs):
    model.train()
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y.view(-1,1))
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            preds = model(batch_X).squeeze()
            predicted = (preds >= 0.5).float()
            correct += (predicted == batch_y).sum().item()
            total += batch_y.size(0)
    acc = correct / total
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Val Acc: {acc:.4f}")


Epoch 1/5, Loss: 0.6902, Val Acc: 1.0000
Epoch 2/5, Loss: 0.6798, Val Acc: 1.0000
Epoch 3/5, Loss: 0.6521, Val Acc: 1.0000
Epoch 4/5, Loss: 0.6373, Val Acc: 1.0000
Epoch 5/5, Loss: 0.6107, Val Acc: 1.0000


# I KNOW the number of samples is low, the computer was taking way to much time to train on higher numbers