In [1]:
import torch
import tqdm
import torch.nn as nn
from torchtext import data, datasets

In [2]:
torch.cuda.get_device_name()

'Tesla P100-PCIE-16GB'

In [3]:
TEXT = data.Field(tokenize="spacy", lower=True)
LABEL = data.LabelField()

In [4]:
train_data, valid_data, test_data = datasets.SNLI.splits(TEXT, LABEL)

In [5]:
vars(train_data[0])

{'hypothesis': ['a',
  'person',
  'is',
  'training',
  'his',
  'horse',
  'for',
  'a',
  'competition',
  '.'],
 'label': 'neutral',
 'premise': ['a',
  'person',
  'on',
  'a',
  'horse',
  'jumps',
  'over',
  'a',
  'broken',
  'down',
  'airplane',
  '.']}

In [6]:
len(train_data), len(valid_data), len(test_data)

(549367, 9842, 9824)

In [7]:
TEXT.build_vocab(train_data, min_freq=2)
LABEL.build_vocab(train_data)

In [8]:
len(TEXT.vocab), LABEL.vocab.stoi

(23566,
 defaultdict(<function torchtext.vocab._default_unk_index>,
             {'contradiction': 1, 'entailment': 0, 'neutral': 2}))

In [9]:
class Net(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=pad_idx)
        self.fc_in = nn.Linear(embedding_size, hidden_size)
        self.lstm_layers = nn.LSTM(hidden_size, hidden_size, num_layers=2, bidirectional=True, dropout=0.25)

        self.fc_layers = nn.ModuleList([nn.Linear(hidden_size * 4, hidden_size * 4) for _ in range(3)])
        self.fc_out = nn.Linear(hidden_size * 4, output_size)

        self.dropout = nn.Dropout(0.25)
        self.relu = nn.ReLU()

    def forward(self, p, h):
        p_seq_len, batch_size = p.shape
        h_seq_len = h.shape[0]

        p_embedded = self.embedding(p)
        h_embedded = self.embedding(h)

        p_activ = self.relu(self.fc_in(p_embedded))
        h_activ = self.relu(self.fc_in(h_embedded))

        po, (ph, pc) = self.lstm_layers(p_activ)
        ho, (hh, hc) = self.lstm_layers(h_activ)

        p_hidden = torch.cat((ph[-1, :, :], ph[-2, :, :]), dim=1)
        h_hidden = torch.cat((hh[-1, :, :], hh[-2, :, :]), dim=1)
        hidden = torch.cat((p_hidden, h_hidden), dim=1)

        for fc in self.fc_layers:
            hidden = self.relu(fc(hidden))
            hidden = self.dropout(hidden)

        output = self.fc_out(hidden)

        return output

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10
batch_size = 512
input_size = len(TEXT.vocab)
embedding_size = 300
hidden_size = 300
output_size = len(LABEL.vocab)
pad_idx = TEXT.vocab.stoi["<pad>"] 

In [11]:
device

device(type='cuda')

In [12]:
train_batches, test_batches = data.BucketIterator.splits((train_data, test_data), batch_size=batch_size, device=device)

In [13]:
for batch in train_batches:
    print(batch.premise.shape)
    print(batch.hypothesis.shape)
    print(batch.label.shape)
    break

torch.Size([40, 512])
torch.Size([28, 512])
torch.Size([512])


In [14]:
net = Net(input_size, embedding_size, hidden_size, output_size, pad_idx).to(device)
net

Net(
  (embedding): Embedding(23566, 300, padding_idx=1)
  (fc_in): Linear(in_features=300, out_features=300, bias=True)
  (lstm_layers): LSTM(300, 300, num_layers=2, dropout=0.25, bidirectional=True)
  (fc_layers): ModuleList(
    (0): Linear(in_features=1200, out_features=1200, bias=True)
    (1): Linear(in_features=1200, out_features=1200, bias=True)
    (2): Linear(in_features=1200, out_features=1200, bias=True)
  )
  (fc_out): Linear(in_features=1200, out_features=3, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
  (relu): ReLU()
)

In [15]:
def count_parameters(net):
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

In [16]:
count_parameters(net)

15096903

In [17]:
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [18]:
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)

    return acc.item()

In [19]:
def loop(net, batches, train):
    batch_losses = []
    batch_accs = []

    if train:
        print("Train Loop:")
        net.train()
        for batch in tqdm.tqdm(batches, total=len(batches)):
            prem = batch.premise.to(device)
            hypo = batch.hypothesis.to(device)
            labels = batch.label.to(device)

            preds = net(prem, hypo)
            loss = loss_fn(preds, labels)
            acc = get_accuracy(preds, labels)

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_losses.append(loss.item())
            batch_accs.append(acc)

    else:
        print("Inference Loop:")
        net.eval()
        with torch.no_grad():
            for batch in tqdm.tqdm(batches, total=len(batches)):
                prem = batch.premise.to(device)
                hypo = batch.hypothesis.to(device)
                labels = batch.label.to(device)

                preds = net(prem, hypo)
                loss = loss_fn(preds, labels)
                acc = get_accuracy(preds, labels)

                batch_losses.append(loss.item())
                batch_accs.append(acc) 

    print("")
    print("")
    
    return sum(batch_losses) / len(batch_losses), sum(batch_accs) / len(batch_accs)

In [20]:
def predict(net, p, h):
    net.eval()
    premise = [t.lower() for t in TEXT.tokenize(p)]
    hypothesis = [t.lower() for t in TEXT.tokenize(h)]

    premise = [TEXT.vocab.stoi[t] for t in premise]
    hypothesis = [TEXT.vocab.stoi[t] for t in hypothesis]

    premise = torch.LongTensor(premise).unsqueeze(1).to(device)
    hypothesis = torch.LongTensor(hypothesis).unsqueeze(1).to(device)
    
    preds = net(premise, hypothesis)
    preds = preds.argmax(dim=1)
    
    print(LABEL.vocab.itos[preds.item()])

In [21]:
premise = "the dog is eating food"
hypothesis = "the dog is playing in the park"

In [22]:
for epoch in range(epochs):
    train_loss, train_acc = loop(net, train_batches, True)
    val_loss, val_acc = loop(net, test_batches, False)
    
    print(f"epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")
    predict(net, premise, hypothesis)
    print("")

  0%|          | 0/1073 [00:00<?, ?it/s]

Train Loop:


100%|██████████| 1073/1073 [03:06<00:00,  5.76it/s]
 15%|█▌        | 3/20 [00:00<00:00, 26.79it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.36it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 0 | train_loss: 0.7901 | train_acc: 0.6468 | val_loss: 0.6792 | val_acc: 0.7123
contradiction

Train Loop:


100%|██████████| 1073/1073 [03:07<00:00,  5.71it/s]
 15%|█▌        | 3/20 [00:00<00:00, 26.56it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.66it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 1 | train_loss: 0.6519 | train_acc: 0.7275 | val_loss: 0.6125 | val_acc: 0.7464
contradiction

Train Loop:


100%|██████████| 1073/1073 [03:06<00:00,  5.76it/s]
 15%|█▌        | 3/20 [00:00<00:00, 26.30it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.62it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 2 | train_loss: 0.5889 | train_acc: 0.7592 | val_loss: 0.5915 | val_acc: 0.7510
contradiction

Train Loop:


100%|██████████| 1073/1073 [03:06<00:00,  5.76it/s]
 10%|█         | 2/20 [00:00<00:00, 19.95it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.71it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 3 | train_loss: 0.5425 | train_acc: 0.7816 | val_loss: 0.5687 | val_acc: 0.7689
contradiction

Train Loop:


100%|██████████| 1073/1073 [03:06<00:00,  5.74it/s]
 15%|█▌        | 3/20 [00:00<00:00, 26.56it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.53it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 4 | train_loss: 0.5133 | train_acc: 0.7955 | val_loss: 0.5587 | val_acc: 0.7739
neutral

Train Loop:


100%|██████████| 1073/1073 [03:08<00:00,  5.70it/s]
 10%|█         | 2/20 [00:00<00:00, 19.96it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.24it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 5 | train_loss: 0.4705 | train_acc: 0.8144 | val_loss: 0.5529 | val_acc: 0.7768
neutral

Train Loop:


100%|██████████| 1073/1073 [03:06<00:00,  5.74it/s]
 10%|█         | 2/20 [00:00<00:00, 19.78it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.59it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 6 | train_loss: 0.4392 | train_acc: 0.8275 | val_loss: 0.5613 | val_acc: 0.7811
neutral

Train Loop:


100%|██████████| 1073/1073 [03:06<00:00,  5.74it/s]
 15%|█▌        | 3/20 [00:00<00:00, 27.40it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 35.11it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 7 | train_loss: 0.4099 | train_acc: 0.8408 | val_loss: 0.5826 | val_acc: 0.7750
neutral

Train Loop:


100%|██████████| 1073/1073 [03:07<00:00,  5.74it/s]
 15%|█▌        | 3/20 [00:00<00:00, 27.14it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.62it/s]
  0%|          | 0/1073 [00:00<?, ?it/s]



epoch: 8 | train_loss: 0.3812 | train_acc: 0.8524 | val_loss: 0.6299 | val_acc: 0.7760
contradiction

Train Loop:


100%|██████████| 1073/1073 [03:08<00:00,  5.70it/s]
 15%|█▌        | 3/20 [00:00<00:00, 26.10it/s]



Inference Loop:


100%|██████████| 20/20 [00:00<00:00, 34.53it/s]



epoch: 9 | train_loss: 0.3538 | train_acc: 0.8638 | val_loss: 0.6193 | val_acc: 0.7779
neutral






In [23]:
def save_checkpoint(net, opt, filename):
    check_point = {"net_dict": net.state_dict(), "opt_dict": opt.state_dict()}
    torch.save(check_point, filename)
    print("Checkpoint Saved!")

def load_checkpoint(net, opt, filename):
    check_point = torch.load(filename)
    net.load_state_dict(check_point["net_dict"])
    opt.load_state_dict(check_point["opt_dict"])
    losses = check_point["losses"]
    print("Checkpoint Loaded!")

In [24]:
save_checkpoint(net, opt, "checkpoint.pth.tar")

Checkpoint Saved!


In [25]:
premise = 'a man sitting on a green bench.'
hypothesis = 'a woman sitting on a green bench.'

predict(net, premise, hypothesis)

contradiction


In [26]:
premise = 'a man sitting on a green bench.'
hypothesis = 'a man sitting on a blue bench.'

predict(net, premise, hypothesis)

contradiction


In [29]:
premise = 'a dog has finished eating'
hypothesis = 'a dog is waiting for her next meal'

predict(net, premise, hypothesis)

neutral


In [28]:
premise = 'a horse is running.'
hypothesis = 'a horse is training for a race'

predict(net, premise, hypothesis)

neutral


In [45]:
premise = 'a girl is driving the car'
hypothesis = 'a girl is drving to her home in the car'

predict(net, premise, hypothesis)

neutral


In [34]:
premise = 'a lady sits on a bench that is aganist a shopping mall'
hypothesis = 'a person sits on a bench'

predict(net, premise, hypothesis)

entailment
