In [9]:
import os
import json
import torch
from torch import nn
from torch.nn import functional as F
from IPython.display import display
import ipywidgets as widgets

In [10]:
# ----------------------
# Load Vocabulary
# ----------------------
def load_vocab(model_name):
    vocab_path = os.path.join("models", model_name, "vocab-dict.json")
    with open(vocab_path, 'r') as json_file:
        vocab = json.load(json_file)
    return vocab

# ----------------------
# Load Model
# ----------------------
def load_model(model_name, vocab):
    def mlp(num_inputs, num_hiddens, flatten):
        model = []
        model.append(nn.Dropout(0.2))
        model.append(nn.Linear(num_inputs, num_hiddens))
        model.append(nn.ReLU())
        if flatten:
            model.append(nn.Flatten(start_dim=1))
        model.append(nn.Dropout(0.2))
        model.append(nn.Linear(num_hiddens, num_hiddens))
        model.append(nn.ReLU())
        if flatten:
            model.append(nn.Flatten(start_dim=1))
        return nn.Sequential(*model)
    
    class Attend(nn.Module):
        def __init__(self, num_inputs, num_hiddens, **kwargs):
            super(Attend, self).__init__(**kwargs)
            self.f = mlp(num_inputs, num_hiddens, flatten=False)
        def forward(self, A, B):
            f_A = self.f(A)
            f_B = self.f(B)
            e = torch.bmm(f_A, f_B.permute(0, 2, 1))
            beta = torch.bmm(F.softmax(e, dim=-1), B)
            alpha = torch.bmm(F.softmax(e.permute(0, 2, 1), dim=-1), A)
            return beta, alpha
        
    class Compare(nn.Module):
        def __init__(self, num_inputs, num_hiddens, **kwargs):
            super(Compare, self).__init__(**kwargs)
            self.g = mlp(num_inputs, num_hiddens, flatten=False)
        def forward(self, A, B, beta, alpha):
            V_A = self.g(torch.cat([A, beta], dim=2))
            V_B = self.g(torch.cat([B, alpha], dim=2))
            return V_A, V_B

    class Aggregate(nn.Module):
        def __init__(self, num_inputs, num_hiddens, num_outputs, **kwargs):
            super(Aggregate, self).__init__(**kwargs)
            self.h = mlp(num_inputs, num_hiddens, flatten=True)
            self.linear = nn.Linear(num_hiddens, num_outputs)
        def forward(self, V_A, V_B):
            V_A = V_A.sum(dim=1)
            V_B = V_B.sum(dim=1)
            Y_hat = self.linear(self.h(torch.cat([V_A, V_B], dim=1)))
            return Y_hat
        
    class DecomposableAttention(nn.Module):
        def __init__(self, vocab, embed_size, num_hiddens, num_inputs_attend=100,
                    num_inputs_compare=200, num_inputs_agg=400, **kwargs):
            super(DecomposableAttention, self).__init__(**kwargs)
            self.embedding = nn.Embedding(len(vocab), embed_size)
            self.attend = Attend(num_inputs_attend, num_hiddens)
            self.compare = Compare(num_inputs_compare, num_hiddens)
            self.aggregate = Aggregate(num_inputs_agg, num_hiddens, num_outputs=3)
        def forward(self, X):
            premises, hypotheses = X
            A = self.embedding(premises)
            B = self.embedding(hypotheses)
            beta, alpha = self.attend(A, B)
            V_A, V_B = self.compare(A, B, beta, alpha)
            Y_hat = self.aggregate(V_A, V_B)
            return Y_hat

    model = DecomposableAttention(vocab, 100, 200)
    model_path = os.path.join("models", model_name, "model-state.pt")
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
    return model

# ----------------------
# Predict Function
# ----------------------
def predict_snli(model, vocab, premises, hypotheses):
    premise = premises.strip().split()
    hypothesis = hypotheses.strip().split()
    premise_indices = [vocab.get(word, vocab["<pad>"]) for word in premise]
    hypothesis_indices = [vocab.get(word, vocab["<pad>"]) for word in hypothesis]
    model.eval()
    premise_tensor = torch.tensor(premise_indices).unsqueeze(0)
    hypothesis_tensor = torch.tensor(hypothesis_indices).unsqueeze(0)
    with torch.no_grad():
        outputs = model((premise_tensor, hypothesis_tensor))
        label_idx = torch.argmax(outputs, dim=1).item()
    return "entailment" if label_idx == 0 else "contradiction" if label_idx == 1 else "neutral"

In [11]:
# ----------------------
# User Interface
# ----------------------
def create_nli_interface(model_name="decomposable-attention"):
    vocab = load_vocab(model_name)
    model = load_model(model_name, vocab)

    title = widgets.Label(value="Natural Language Inference")
    premise_input = widgets.Textarea(description="Premise:", placeholder="e.g. A soccer game with multiple males playing.")
    hypothesis_input = widgets.Textarea(description="Hypothesis:", placeholder="e.g. A soccer game with multiple males playing.")
    output_area = widgets.Textarea(value="Result:", layout=widgets.Layout(height='50px'), disabled=True)
    infer_button = widgets.Button(description="Infer")
    
    def on_infer_clicked(b):
        premise = premise_input.value
        hypothesis = hypothesis_input.value
        if premise and hypothesis:
            result = predict_snli(model, vocab, premise, hypothesis)
            output_area.value = f"Result: {result.capitalize()}"
        else:
            output_area.value = "Please enter some text for inference."
    
    infer_button.on_click(on_infer_clicked)
    
    display(widgets.VBox([title, premise_input, hypothesis_input, infer_button, output_area]))

In [12]:
create_nli_interface("decomposable-attention")

VBox(children=(Label(value='Natural Language Inference'), Textarea(value='', description='Premise:', placehold…