In [1]:
import torch
import spacy
import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchtext import data, datasets

In [2]:
torch.cuda.get_device_name()

'Tesla T4'

In [3]:
spacy_en = spacy.load("en")

In [4]:
def get_bigrams(x):
    n_grams = set(zip(*[x[i:] for i in range(2)]))
    for n_gram in n_grams:
        x.append(' '.join(n_gram))

    return x

In [5]:
TEXT = data.Field(tokenize="spacy", preprocessing=get_bigrams)
LABEL = data.LabelField(dtype=torch.float)

In [6]:
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

aclImdb_v1.tar.gz:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

downloading aclImdb_v1.tar.gz


aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:03<00:00, 24.4MB/s]


In [7]:
TEXT.build_vocab(train_data, max_size=25000)
LABEL.build_vocab(train_data)

In [8]:
len(TEXT.vocab.stoi), LABEL.vocab.stoi

(25002,
 defaultdict(<function torchtext.vocab._default_unk_index>,
             {'neg': 0, 'pos': 1}))

In [9]:
for d in train_data:
    print(vars(d)["text"])
    print(vars(d)["label"])
    break

['Not', 'sure', 'one', 'can', 'call', 'this', 'an', 'anti', '-', 'war', 'film', ',', 'it', 'shows', 'war', 'at', 'an', 'elite', 'level', '.', 'These', 'are', 'elite', 'troops', 'that', 'know', 'what', 'they', 'are', 'doing', 'and', 'take', 'great', 'pride', 'in', 'it', '.', 'Even', 'when', 'they', 'are', 'pacifist', ',', 'they', 'still', 'enjoy', 'the', 'skill', 'level', 'and', 'defeating', 'their', 'foes', ',', 'even', 'if', 'it', 'does', 'go', 'against', 'being', 'a', 'pacifist', '.', 'The', 'movies', 'is', 'slow', 'and', 'rather', 'uneventful', 'and', 'in', 'many', 'ways', 'is', 'rather', 'tame', 'as', 'war', 'movies', 'go', '-', 'more', 'so', 'by', 'todays', 'standards', ',', 'no', 'body', 'parts', 'flying', 'off', 'as', 'in', 'modern', 'movies', '.', 'It', 'is', 'brutal', 'in', 'other', 'ways', 'though', 'as', 'you', 'see', 'killing', 'at', 'a', 'personal', 'level', '.', 'This', 'is', 'more', 'of', 'a', 'thinking', 'man', "'s", 'movie', '.', 'Once', 'you', 'start', 'to', 'watch', 

In [10]:
class FastTextNet(nn.Module):
    def __init__(self, vocab_size, embedding_size, output_size, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_size, output_size)

    def forward(self, text):
        embedded = self.embedding(text).permute(1, 0, 2)
        pooled = F.avg_pool2d(embedded, kernel_size=(embedded.shape[1], 1)).squeeze(1)
        output = self.fc(pooled)

        return output

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10
batch_size = 64
vocab_size = len(TEXT.vocab)
embedding_size = 100
output_size = 1
pad_idx = TEXT.vocab.stoi["<pad>"]

In [12]:
device

device(type='cuda')

In [13]:
train_batches, test_batches = data.BucketIterator.splits((train_data, test_data), batch_size=batch_size, device = device)

In [14]:
net = FastTextNet(vocab_size, embedding_size, output_size, pad_idx).to(device)
net

FastTextNet(
  (embedding): Embedding(25002, 100, padding_idx=1)
  (fc): Linear(in_features=100, out_features=1, bias=True)
)

In [15]:
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

In [16]:
def get_accuracy(preds, y):
    preds = torch.round(torch.sigmoid(preds))
    correct = (preds == y).float()
    acc = correct.sum() / len(correct)

    return acc

In [17]:
def loop(net, batches, train):
    batch_losses = []
    batch_accs = []

    if train:
        print("Train Loop:")
        net.train()
        for batch in tqdm.tqdm(batches, total=len(batches)):
            texts = batch.text.to(device)
            labels = batch.label.to(device)

            preds = net(texts)
            preds = preds.squeeze(1)

            loss = loss_fn(preds, labels)
            acc = get_accuracy(preds, labels)

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_losses.append(loss.item())
            batch_accs.append(acc)

    else:
        print("Inference Loop:")
        net.eval()
        with torch.no_grad():
            for batch in tqdm.tqdm(batches, total=len(batches)):
                texts = batch.text.to(device)
                labels = batch.label.to(device)

                preds = net(texts)
                preds = preds.squeeze(1)

                loss = loss_fn(preds, labels)
                acc = get_accuracy(preds, labels)

                batch_losses.append(loss.item())
                batch_accs.append(acc) 

    print("")
    print("")
    
    return sum(batch_losses) / len(batch_losses), sum(batch_accs) / len(batch_accs)

In [18]:
def predict_sentiment(net, text):
    net.eval()
    tokens = get_bigrams([t.text for t in spacy_en.tokenizer(text)])
    indices = [TEXT.vocab.stoi[t] for t in tokens]
    indices = torch.LongTensor(indices).unsqueeze(1).to(device)
    
    preds = net(indices)
    preds = torch.sigmoid(preds)
    
    print(f"sentiment: {preds.item()}")

In [19]:
text = "this is a very good idea"

In [20]:
for epoch in range(epochs):
    train_loss, train_acc = loop(net, train_batches, True)
    val_loss, val_acc = loop(net, test_batches, False)
    
    print(f"epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")
    predict_sentiment(net, text)
    print("")

  0%|          | 0/391 [00:00<?, ?it/s]

Train Loop:


100%|██████████| 391/391 [00:12<00:00, 32.50it/s]
  6%|▌         | 23/391 [00:00<00:01, 224.71it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 80.74it/s]
  1%|          | 3/391 [00:00<00:13, 28.13it/s]



epoch: 0 | train_loss: 0.6871 | train_acc: 0.6010 | val_loss: 0.6290 | val_acc: 0.6974
sentiment: 0.6498771905899048

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 34.24it/s]
  6%|▌         | 22/391 [00:00<00:01, 215.10it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 84.90it/s]
  1%|          | 3/391 [00:00<00:15, 24.39it/s]



epoch: 1 | train_loss: 0.6391 | train_acc: 0.7556 | val_loss: 0.4822 | val_acc: 0.7726
sentiment: 0.9837432503700256

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 34.02it/s]
  6%|▌         | 24/391 [00:00<00:01, 233.10it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 84.59it/s]
  1%|          | 3/391 [00:00<00:14, 26.08it/s]



epoch: 2 | train_loss: 0.5501 | train_acc: 0.8133 | val_loss: 0.4072 | val_acc: 0.8179
sentiment: 0.9999942779541016

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 33.26it/s]
  6%|▌         | 23/391 [00:00<00:01, 229.86it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 84.30it/s]
  1%|          | 3/391 [00:00<00:14, 25.89it/s]



epoch: 3 | train_loss: 0.4657 | train_acc: 0.8538 | val_loss: 0.3866 | val_acc: 0.8478
sentiment: 1.0

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 33.75it/s]
  6%|▌         | 23/391 [00:00<00:01, 222.72it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 84.65it/s]
  1%|          | 3/391 [00:00<00:15, 24.47it/s]



epoch: 4 | train_loss: 0.3999 | train_acc: 0.8758 | val_loss: 0.3885 | val_acc: 0.8644
sentiment: 1.0

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 33.88it/s]
  6%|▌         | 23/391 [00:00<00:01, 226.39it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 84.52it/s]
  1%|          | 3/391 [00:00<00:14, 27.01it/s]



epoch: 5 | train_loss: 0.3512 | train_acc: 0.8902 | val_loss: 0.4109 | val_acc: 0.8709
sentiment: 1.0

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 33.78it/s]
  6%|▌         | 23/391 [00:00<00:01, 228.65it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 85.51it/s]
  1%|          | 3/391 [00:00<00:15, 25.49it/s]



epoch: 6 | train_loss: 0.3181 | train_acc: 0.9004 | val_loss: 0.4174 | val_acc: 0.8803
sentiment: 1.0

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 34.12it/s]
  6%|▌         | 24/391 [00:00<00:01, 235.52it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 84.30it/s]
  1%|          | 3/391 [00:00<00:14, 26.43it/s]



epoch: 7 | train_loss: 0.2890 | train_acc: 0.9077 | val_loss: 0.4345 | val_acc: 0.8849
sentiment: 1.0

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 33.24it/s]
  6%|▌         | 22/391 [00:00<00:01, 218.29it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 84.25it/s]
  1%|          | 3/391 [00:00<00:14, 26.89it/s]



epoch: 8 | train_loss: 0.2679 | train_acc: 0.9139 | val_loss: 0.4525 | val_acc: 0.8879
sentiment: 1.0

Train Loop:


100%|██████████| 391/391 [00:11<00:00, 33.70it/s]
  6%|▌         | 23/391 [00:00<00:01, 221.99it/s]



Inference Loop:


100%|██████████| 391/391 [00:04<00:00, 85.10it/s]



epoch: 9 | train_loss: 0.2492 | train_acc: 0.9188 | val_loss: 0.4710 | val_acc: 0.8914
sentiment: 1.0






In [21]:
def save_checkpoint(net, opt, filename):
    check_point = {"net_dict": net.state_dict(), "opt_dict": opt.state_dict()}
    torch.save(check_point, filename)
    print("Checkpoint Saved!")

def load_checkpoint(net, opt, filename):
    check_point = torch.load(filename)
    net.load_state_dict(check_point["net_dict"])
    opt.load_state_dict(check_point["opt_dict"])
    losses = check_point["losses"]
    print("Checkpoint Loaded!")

In [22]:
save_checkpoint(net, opt, "checkpoint.pth.tar")

Checkpoint Saved!


In [23]:
predict_sentiment(net, "this is a very bad idea")

sentiment: 2.2288126899638883e-09


In [24]:
predict_sentiment(net, "this film is terrible")

sentiment: 0.0


In [25]:
predict_sentiment(net, "you are terrific")

sentiment: 1.0


In [26]:
predict_sentiment(net, "that is horrible")

sentiment: 0.0


In [27]:
predict_sentiment(net, "yeet!!")

sentiment: 1.0


In [28]:
predict_sentiment(net, "what are you doing?")

sentiment: 0.235542431473732
