In [1]:
import pandas as pd
import numpy as np
import torch

from torchtext import datasets

from torchtext.data import Field, LabelField
from torchtext.data import BucketIterator

from torchtext.vocab import Vectors, GloVe

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
from tqdm.autonotebook import tqdm

from sklearn.metrics import f1_score
import pickle

In [2]:
TEXT = Field(sequential=True, lower=True, include_lengths=True)  # Поле текста
LABEL = LabelField(dtype=torch.float)  # Поле метки

In [3]:
SEED = 1234

torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
train, test = datasets.IMDB.splits(TEXT, LABEL)  # загрузим датасет
train, valid = train.split(random_state=random.seed(SEED))  # разобьем на части

In [5]:
TEXT.build_vocab(train)
LABEL.build_vocab(train)

In [6]:
BATCH_SIZE = 64

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_iter, valid_iter, test_iter = BucketIterator.splits(
    (train, valid, test), 
    batch_size = BATCH_SIZE,
    sort_within_batch = True,
    device = device)

In [8]:
train_iter.dataset[0]

<torchtext.data.example.Example at 0x1c98bc9e880>

In [9]:
TEXT = Field(sequential=True, lower=True, batch_first=True)  # batch_first тк мы используем conv  
LABEL = LabelField(batch_first=True, dtype=torch.float)

train, tst = datasets.IMDB.splits(TEXT, LABEL)
trn, vld = train.split(random_state=random.seed(SEED))

TEXT.build_vocab(trn)
LABEL.build_vocab(trn)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [10]:
with open('text.pkl', 'wb') as f:
        pickle.dump(TEXT, f)

In [11]:
train_iter, val_iter, test_iter = BucketIterator.splits(
        (trn, vld, tst),
        batch_sizes=(128, 256, 256),
        sort=False,
        sort_key= lambda x: len(x.src),
        sort_within_batch=False,
        device=device,
        repeat=False,
)

In [12]:
class CNN(nn.Module):
    def __init__(
        self,
        vocab_size,
        emb_dim,
        out_channels,
        kernel_sizes,
        dropout=0.5,
    ):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.out_channels = out_channels
        
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        
        self.conv_0 = nn.Conv2d(in_channels=1, out_channels=self.out_channels, 
                                kernel_size=(kernel_sizes[0], emb_dim), padding=1, stride=2)  # YOUR CODE GOES HERE
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=self.out_channels, 
                                kernel_size=(kernel_sizes[1], emb_dim), padding=1, stride=2)  # YOUR CODE GOES HERE
        self.conv_2 = nn.Conv2d(in_channels=1, out_channels=self.out_channels, 
                                kernel_size=(kernel_sizes[2], emb_dim), padding=1, stride=2)  # YOUR CODE GOES HERE
        
        self.fc = nn.Linear(len(kernel_sizes) * out_channels, 1)
        
        self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, text):
        
        embedded = self.embedding(text)
        
        batch_size = embedded.shape[0]
        embedded = embedded.unsqueeze(1)  # may be reshape here
        
        conved_0 = F.relu(self.conv_0(embedded)).view(batch_size, self.out_channels, -1)  # may be reshape here
        conved_1 = F.relu(self.conv_1(embedded)).view(batch_size, self.out_channels, -1)  # may be reshape here
        conved_2 = F.relu(self.conv_2(embedded)).view(batch_size, self.out_channels, -1)  # may be reshape here
        
        pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)
        pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
        pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)
        
        cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1))
            
        return self.fc(cat)

In [13]:
kernel_sizes = [3, 5, 6]
vocab_size = len(TEXT.vocab)
out_channels=16
dropout = 0.25
dim = 50
patience=3

model = CNN(vocab_size=vocab_size, emb_dim=dim, out_channels=out_channels,
            kernel_sizes=kernel_sizes, dropout=dropout)

In [14]:
model.to(device)

CNN(
  (embedding): Embedding(202065, 50)
  (conv_0): Conv2d(1, 16, kernel_size=(3, 50), stride=(2, 2), padding=(1, 1))
  (conv_1): Conv2d(1, 16, kernel_size=(5, 50), stride=(2, 2), padding=(1, 1))
  (conv_2): Conv2d(1, 16, kernel_size=(6, 50), stride=(2, 2), padding=(1, 1))
  (fc): Linear(in_features=48, out_features=1, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
)

In [15]:
opt = torch.optim.Adam(model.parameters())
loss_func = nn.BCEWithLogitsLoss()

In [16]:
max_epochs = 45

In [17]:
import numpy as np

min_loss = np.inf

cur_patience = 0

for epoch in range(1, max_epochs + 1):
    train_loss = 0.0
    model.train()
    pbar = tqdm(enumerate(train_iter), total=len(train_iter), leave=False)
    pbar.set_description(f"Epoch {epoch}")
    for it, batch in pbar: 
        X_train, y_train = batch.text.to(device), batch.label.to(device)
        opt.zero_grad()
        
        answers_train = model(X_train)
        
        loss = loss_func(answers_train.squeeze(), y_train)
        loss.backward()
        
        opt.step()
        train_loss += loss

    train_loss /= len(train_iter)
    val_loss = 0.0
    model.eval()
    pbar = tqdm(enumerate(val_iter), total=len(val_iter), leave=False)
    pbar.set_description(f"Epoch {epoch}")
    with torch.no_grad():
        for it, batch in pbar:
            X_val, y_val = batch.text.to(device), batch.label.to(device)

            answers_val = model(X_val)
            val_loss += loss_func(answers_val.squeeze(), y_val)
            
    val_loss /= len(val_iter)
    if val_loss < min_loss:
        min_loss = val_loss
        best_model = model.state_dict()
    else:
        cur_patience += 1
        if cur_patience == patience:
            cur_patience = 0
            break
    
    print('Epoch: {}, Training Loss: {}, Validation Loss: {}'.format(epoch, train_loss, val_loss))
model.load_state_dict(best_model)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 1, Training Loss: 0.7156025171279907, Validation Loss: 0.668056070804596


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 2, Training Loss: 0.6646580696105957, Validation Loss: 0.6322876811027527


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 3, Training Loss: 0.6204766035079956, Validation Loss: 0.5831075310707092


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 4, Training Loss: 0.5779945850372314, Validation Loss: 0.5488821268081665


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 5, Training Loss: 0.5389863848686218, Validation Loss: 0.5246716141700745


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 6, Training Loss: 0.5040284395217896, Validation Loss: 0.4990682601928711


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 7, Training Loss: 0.47071659564971924, Validation Loss: 0.47555944323539734


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 8, Training Loss: 0.43360671401023865, Validation Loss: 0.4558789134025574


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 9, Training Loss: 0.39875757694244385, Validation Loss: 0.44394415616989136


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 10, Training Loss: 0.3628860116004944, Validation Loss: 0.42425134778022766


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 11, Training Loss: 0.3226604759693146, Validation Loss: 0.41187047958374023


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 12, Training Loss: 0.27999022603034973, Validation Loss: 0.4027194380760193


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 13, Training Loss: 0.245819091796875, Validation Loss: 0.39162373542785645


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 14, Training Loss: 0.21130119264125824, Validation Loss: 0.39207708835601807


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 15, Training Loss: 0.17842216789722443, Validation Loss: 0.3897049129009247


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 16, Training Loss: 0.14864319562911987, Validation Loss: 0.3999641239643097


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

<All keys matched successfully>

In [18]:
pbar = tqdm(enumerate(test_iter), total=len(test_iter), leave=False)

test_loss = 0.0

to_prob = nn.Sigmoid()

y_test_fact_full = []
y_test_pred_full = []
with torch.no_grad():
    for it, batch in pbar:
        X_test, y_test = batch.text.to(device), batch.label.to(device)
        
        answers_test = model(X_test)
        test_loss += loss_func(answers_test.squeeze(), y_test)
        
        y_test_fact_full += y_test.int().tolist()
        y_test_pred_full += answers_test.float().tolist()
        
f1_score(y_test_fact_full, np.round(to_prob(torch.tensor(y_test_pred_full)).numpy(), 0).astype(int))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))

0.7952698344442054

## Интерпретируемость

Посмотрим, куда смотрит наша модель. Достаточно запустить код ниже.

In [19]:
!pip install -q captum

In [20]:
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

PAD_IND = TEXT.vocab.stoi['pad']

token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)
lig = LayerIntegratedGradients(model, model.embedding)

In [21]:
def forward_with_softmax(inp):
    logits = model(inp)
    return torch.softmax(logits, 0)[0][1]

def forward_with_sigmoid(input):
    return torch.sigmoid(model(input))


# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model, sentence, min_len = 7, label = 0):
    model.eval()
    text = [tok for tok in TEXT.tokenize(sentence)]
    if len(text) < min_len:
        text += ['pad'] * (min_len - len(text))
    indexed = [TEXT.vocab.stoi[t] for t in text]

    model.zero_grad()

    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    # input_indices dim: [sequence_length]
    seq_length = min_len

    # predict
    pred = forward_with_sigmoid(input_indices).item()
    pred_ind = round(pred)

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=5000, return_convergence_delta=True)

    print('pred: ', LABEL.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

In [22]:
interpret_sentence(model, 'It was a fantastic performance !', label=1)
interpret_sentence(model, 'Best film ever', label=1)
interpret_sentence(model, 'Such a great show!', label=1)
interpret_sentence(model, 'It was a horrible movie', label=0)
interpret_sentence(model, 'I\'ve never watched something as bad', label=0)
interpret_sentence(model, 'It is a disgusting movie!', label=0)

pred:  pos ( 1.00 ) , delta:  tensor([9.3025e-05], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.16 ) , delta:  tensor([0.0002], device='cuda:0', dtype=torch.float64)
pred:  pos ( 0.70 ) , delta:  tensor([0.0001], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.00 ) , delta:  tensor([0.0001], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.10 ) , delta:  tensor([0.0001], device='cuda:0', dtype=torch.float64)
pred:  pos ( 0.92 ) , delta:  tensor([0.0002], device='cuda:0', dtype=torch.float64)


In [23]:
interpret_sentence(model, 'Hello world, this is great movie !', label=1)
interpret_sentence(model, 'How old this film, it look bad', label=0)
interpret_sentence(model, 'Is this Statham, i dont like him', label=0)

pred:  pos ( 0.89 ) , delta:  tensor([0.0002], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.11 ) , delta:  tensor([8.9936e-07], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.45 ) , delta:  tensor([6.9883e-05], device='cuda:0', dtype=torch.float64)


In [24]:
interpret_sentence(model, 'Is this Statham, i dont like him')

pred:  neg ( 0.45 ) , delta:  tensor([6.9883e-05], device='cuda:0', dtype=torch.float64)


In [25]:
torch.save(model.state_dict(), 'model.pt')

In [26]:
len(TEXT.vocab)

202065

In [None]:
# import pandas as pd
# import numpy as np
import torch

# from torchtext import datasets

# from torchtext.data import Field, LabelField
# from torchtext.data import BucketIterator

# from torchtext.vocab import Vectors, GloVe

import torch.nn as nn
import torch.nn.functional as F
import pickle
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
#import torch.optim as optim
#import random
#from tqdm.autonotebook import tqdm

#from sklearn.metrics import f1_score

In [None]:
with open('text.pkl', 'rb') as f:
        TEXT = pickle.load(f)

In [None]:
class CNN(nn.Module):
    def __init__(
        self,
        vocab_size,
        emb_dim,
        out_channels,
        kernel_sizes,
        dropout=0.5,
    ):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.out_channels = out_channels
        
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        
        self.conv_0 = nn.Conv2d(in_channels=1, out_channels=self.out_channels, 
                                kernel_size=(kernel_sizes[0], emb_dim), padding=1, stride=2)  # YOUR CODE GOES HERE
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=self.out_channels, 
                                kernel_size=(kernel_sizes[1], emb_dim), padding=1, stride=2)  # YOUR CODE GOES HERE
        self.conv_2 = nn.Conv2d(in_channels=1, out_channels=self.out_channels, 
                                kernel_size=(kernel_sizes[2], emb_dim), padding=1, stride=2)  # YOUR CODE GOES HERE
        
        self.fc = nn.Linear(len(kernel_sizes) * out_channels, 1)
        
        self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, text):
        
        embedded = self.embedding(text)
        
        batch_size = embedded.shape[0]
        embedded = embedded.unsqueeze(1)  # may be reshape here
        
        conved_0 = F.relu(self.conv_0(embedded)).view(batch_size, self.out_channels, -1)  # may be reshape here
        conved_1 = F.relu(self.conv_1(embedded)).view(batch_size, self.out_channels, -1)  # may be reshape here
        conved_2 = F.relu(self.conv_2(embedded)).view(batch_size, self.out_channels, -1)  # may be reshape here
        
        pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)
        pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
        pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)
        
        cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1))
            
        return self.fc(cat)

In [None]:
kernel_sizes = [3, 5, 6]
vocab_size = len(TEXT.vocab)
out_channels=16
dropout = 0.25
dim = 50
patience=3

model = CNN(vocab_size=vocab_size, emb_dim=dim, out_channels=out_channels,
            kernel_sizes=kernel_sizes, dropout=dropout)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model.load_state_dict(torch.load('model.pt'))
model.to(device)

In [None]:
PAD_IND = TEXT.vocab.stoi['pad']
token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)
lig = LayerIntegratedGradients(model, model.embedding)

In [None]:
def forward_with_softmax(inp):
    logits = model(inp)
    return torch.softmax(logits, 0)[0][1]

def forward_with_sigmoid(input):
    return torch.sigmoid(model(input))


# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model, sentence, min_len = 7, label = 0):
    model.eval()
    text = [tok for tok in TEXT.tokenize(sentence)]
    if len(text) < min_len:
        text += ['pad'] * (min_len - len(text))
    indexed = [TEXT.vocab.stoi[t] for t in text]

    model.zero_grad()

    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    # input_indices dim: [sequence_length]
    seq_length = min_len

    # predict
    pred = forward_with_sigmoid(input_indices).item()
    pred_ind = round(pred)

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=5000, return_convergence_delta=True)

    print('pred: ', '(', '%.2f'%pred, ')')

In [None]:
interpret_sentence(model, 'Is this Statham, i dislike him')

In [None]:
interpret_sentence(model, 'Hello world, this is good movie !')

In [None]:
interpret_sentence(model, 'Scarry movie is a bad film')

In [None]:
best_model = LSTMTagger(INPUT_DIM, EMB_DIM, HID_DIM, OUTPUT_DIM, DROPOUT, BIDIRECTIONAL).to(device)
best_model.load_state_dict(torch.load('best-val-model.pt'))