In [1]:
import torch
import torch.nn as nn

from sklearn import metrics
from sklearn.model_selection import train_test_split

from dataset import get_dataloader

import pandas as pd
import numpy as np
from tensorflow.keras.utils import to_categorical
from tqdm import tqdm

unable to import 'smart_open.gcs', disabling that module


In [2]:
df = pd.read_parquet('data/QA_bin.parquet')
train, test = train_test_split(df, test_size=0.2, stratify=df['isAnswer'])
train, val = train_test_split(train, test_size=0.3, stratify=train['isAnswer'])

### По совету из https://arxiv.org/abs/1901.10444 пробуем BOREP

In [3]:
class BOREP(nn.Module):
    def __init__(self, 
                 init="uniform", 
                 projection=True, 
                 input_dim=300, 
                 output_dim=500, 
                 activation=None, 
                 pooling="max", 
                 gpu=False):
        super(BOREP, self).__init__()

        self.init = init
        self.projection = projection
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.activation = activation
        self.pooling = pooling
        self.gpu = gpu
        self.proj = self.get_projection()

        if gpu:
            self.cuda()

    def get_projection(self):
        proj = nn.Linear(self.input_dim, self.output_dim)
        if self.init == "orthogonal":
            nn.init.orthogonal_(proj.weight)
        elif self.init == "sparse":
            nn.init.sparse_(proj.weight, sparsity=0.1)
        elif self.init == "normal":
            nn.init.normal_(proj.weight, std=0.1)
        elif self.init == "uniform":
            nn.init.uniform_(proj.weight, a=-0.1, b=0.1)
        elif self.init == "kaiming":
            nn.init.kaiming_uniform_(proj.weight)
        elif self.init == "xavier":
            nn.init.xavier_uniform_(proj.weight)

        nn.init.constant_(proj.bias, 0)

        if self.gpu:
            proj = proj.cuda()
        return proj

    def borep(self, x):
        batch_sz, seq_len = x.size(1), x.size(0)
        out = torch.FloatTensor(seq_len, batch_sz, self.output_dim).zero_()
        for i in range(seq_len):
            if self.projection:
                emb = self.proj(x[i])
            else:
                emb = x[i]
            out[i] = emb
        return out

    def forward(self, batch):
        out = self.borep(batch)
        lengths = [batch.size(0)]*batch.size(1)
        out = self.pool(out, lengths)

        if self.activation is not None:
            out = self.activation(out)
            
        return out

    def encode(self, batch, params):
        return self.forward(batch, params).cpu().detach().numpy()
    
    def sum_pool(self, x, lengths):
        out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF
        for i in range(x.size(1)):
            out[i] = torch.sum(x[:lengths[i],i,:], 0)
        return out

    def mean_pool(self, x, lengths):
        out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF
        for i in range(x.size(1)):
            out[i] = torch.mean(x[:lengths[i],i,:], 0)
        return out

    def max_pool(self, x, lengths):
        out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF
        for i in range(x.size(1)):
            out[i,:] = torch.max(x[:lengths[i],i,:], 0)[0]
        return out

    def min_pool(self, x, lengths):
        out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF
        for i in range(x.size(1)):
            out[i] = torch.min(x[:lengths[i],i,:], 0)[0]
        return out

    def hier_pool(self, x, lengths, n=5):
        out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF
        if x.size(0) <= n: return self.mean_pool(x, lengths) # BxF
        for i in range(x.size(1)):
            sliders = []
            if lengths[i] <= n:
                out[i] = torch.mean(x[:lengths[i],i,:], 0)
                continue
            for j in range(lengths[i]-n):
                win = torch.mean(x[j:j+n,i,:], 0, keepdim=True) # 1xN
                sliders.append(win)
            sliders = torch.cat(sliders, 0)
            out[i] = torch.max(sliders, 0)[0]
        return out

    def pool(self, out, lengths):
        if self.pooling == "mean":
            out = self.mean_pool(out, lengths)
        elif self.pooling == "max":
            out = self.max_pool(out, lengths)
        elif self.pooling == "min":
            out = self.min_pool(out, lengths)
        elif self.pooling == "hier":
            out = self.hier_pool(out, lengths)
        elif self.pooling == "sum":
            out = self.sum_pool(out, lengths)
        else:
            raise ValueError("No valid pooling operation specified!")
        return out

In [4]:
class BOREP_clf(nn.Module):
    def __init__(self, 
                 init="uniform", 
                 projection=True, 
                 input_dim=300, 
                 output_dim=500, 
                 activation=None, 
                 pooling="max", 
                 gpu=False):
        super(BOREP_clf, self).__init__()

        self.borep = BOREP(init, projection, input_dim, output_dim, activation, pooling, gpu)
        self.cosine = nn.CosineSimilarity(dim=1)
        self.lin = nn.Linear(1, 2)
        self.pred = nn.LogSoftmax(dim=1)

        if gpu:
            self.cuda()
            
    def forward(self, q, a):
        q_proj = self.borep(q.permute(1, 0, 2))
        a_proj = self.borep(a.permute(1, 0, 2))
        dist = self.cosine(q_proj, a_proj)
        return self.pred(self.lin(dist.unsqueeze(1)))

In [5]:
test_loader = get_dataloader(test.loc[:, ['Question', 'Answer', 'isAnswer']], batch_size=512)
with torch.no_grad():
    for pool in ["mean", "max", "min", "hier", "sum"]:
        model = BOREP_clf(pooling=pool)
        pred = []
        y_test = []
        for batch in tqdm(test_loader):
            pred.append(model(batch[0].float(), batch[1].float()).cpu().numpy())
            y_test.append(batch[2].cpu().numpy())
        
        y_preds = np.argmax(np.vstack(pred), axis=1)
        y_test = np.concatenate(y_test)
        print(metrics.classification_report(y_test, y_preds))

100%|██████████| 516/516 [48:20<00:00,  5.62s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  0%|          | 0/516 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.00      0.00      0.00    134517
           1       0.49      1.00      0.66    129597

    accuracy                           0.49    264114
   macro avg       0.25      0.50      0.33    264114
weighted avg       0.24      0.49      0.32    264114



100%|██████████| 516/516 [52:13<00:00,  6.07s/it] 
  0%|          | 0/516 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.00      0.00      0.00    134517
           1       0.49      1.00      0.66    129597

    accuracy                           0.49    264114
   macro avg       0.25      0.50      0.33    264114
weighted avg       0.24      0.49      0.32    264114



100%|██████████| 516/516 [51:20<00:00,  5.97s/it] 
  0%|          | 0/516 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.59      0.24      0.34    134517
           1       0.51      0.83      0.63    129597

    accuracy                           0.53    264114
   macro avg       0.55      0.53      0.48    264114
weighted avg       0.55      0.53      0.48    264114



100%|██████████| 516/516 [2:48:51<00:00, 19.63s/it]  
  0%|          | 0/516 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.00      0.00      0.00    134517
           1       0.49      1.00      0.66    129597

    accuracy                           0.49    264114
   macro avg       0.25      0.50      0.33    264114
weighted avg       0.24      0.49      0.32    264114



100%|██████████| 516/516 [1:10:47<00:00,  8.23s/it]


              precision    recall  f1-score   support

           0       0.00      0.00      0.00    134517
           1       0.49      1.00      0.66    129597

    accuracy                           0.49    264114
   macro avg       0.25      0.50      0.33    264114
weighted avg       0.24      0.49      0.32    264114



In [5]:
model = BOREP_clf(pooling="min")

In [6]:
for name, par in model.named_parameters():
    if name == 'borep.proj.weight' or name == 'borep.proj.bias':
        print('here')
        par.requires_grad = False
    
    print(name)
    print(par)

here
borep.proj.weight
Parameter containing:
tensor([[ 0.0145, -0.0860, -0.0364,  ..., -0.0554, -0.0715,  0.0906],
        [ 0.0777, -0.0003, -0.0775,  ...,  0.0226,  0.0027, -0.0217],
        [-0.0733, -0.0102, -0.0371,  ...,  0.0938, -0.0168,  0.0984],
        ...,
        [-0.0106, -0.0914,  0.0477,  ..., -0.0964, -0.0996, -0.0289],
        [-0.0421, -0.0544,  0.0918,  ...,  0.0646,  0.0651, -0.0648],
        [ 0.0766,  0.0916,  0.0704,  ..., -0.0355, -0.0783, -0.0604]])
here
borep.proj.bias
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [7]:
opt = torch.optim.Adam(model.lin.parameters(), lr=1e-3)
criterion = nn.NLLLoss()

In [8]:
train_iter = get_dataloader(train.loc[:, ['Question', 'Answer', 'isAnswer']], batch_size=512)
val_iter = get_dataloader(val.loc[:, ['Question', 'Answer', 'isAnswer']], batch_size=512)
for epoch in range(1, 5):
    running_loss = 0.0
    y_pred = []
    y_true = []
    model.train() 
    for batch in tqdm(train_iter): 
        y = batch[2]
        opt.zero_grad()
        preds = model(batch[0].float(), batch[1].float())
        loss = criterion(preds, y)
        loss.backward()
        opt.step()
        running_loss += loss.item()
        
        y_pred.append(preds.detach().numpy())
        y_true.append(y.detach().numpy())

    epoch_loss = running_loss / len(train_iter)
    
    y_pred = np.argmax(np.vstack(y_pred), axis=1)
    y_true = np.concatenate(y_true)
    print(f'Report after epoch {epoch} training:')
    print(metrics.classification_report(y_true, y_pred))
    
    val_loss = 0.0
    y_pred = []
    y_true = []
    model.eval()
    for batch in val_iter:
        y = batch[2]
        
        preds = model(batch[0].float(), batch[1].float())
        loss = criterion(preds, y)
        val_loss += loss.item()
        
        y_pred.append(preds.detach().numpy())
        y_true.append(y.detach().numpy())
        
    val_loss /= len(val_iter)
    print(f'Report after epoch {epoch} validation:')
    
    y_pred = np.argmax(np.vstack(y_pred), axis=1)
    y_true = np.concatenate(y_true)
    print(metrics.classification_report(y_true, y_pred))
    print('Epoch: {}, Training Loss: {:.3f}, Validation Loss: {:.3f}'.format(epoch, epoch_loss, val_loss))

100%|██████████| 1445/1445 [2:28:18<00:00,  6.16s/it] 


Report after epoch 1 training:


  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.51      1.00      0.67    376649
           1       0.00      0.00      0.00    362870

    accuracy                           0.51    739519
   macro avg       0.25      0.50      0.34    739519
weighted avg       0.26      0.51      0.34    739519

Report after epoch 1 validation:
              precision    recall  f1-score   support

           0       0.51      1.00      0.67    161421
           1       0.00      0.00      0.00    155516

    accuracy                           0.51    316937
   macro avg       0.25      0.50      0.34    316937
weighted avg       0.26      0.51      0.34    316937

Epoch: 1, Training Loss: 0.869, Validation Loss: 0.691


100%|██████████| 1445/1445 [2:25:21<00:00,  6.04s/it]  


Report after epoch 2 training:
              precision    recall  f1-score   support

           0       0.54      0.76      0.63    376649
           1       0.57      0.33      0.42    362870

    accuracy                           0.55    739519
   macro avg       0.56      0.55      0.53    739519
weighted avg       0.56      0.55      0.53    739519

Report after epoch 2 validation:


  0%|          | 0/1445 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.56      0.62      0.59    161421
           1       0.56      0.50      0.53    155516

    accuracy                           0.56    316937
   macro avg       0.56      0.56      0.56    316937
weighted avg       0.56      0.56      0.56    316937

Epoch: 2, Training Loss: 0.691, Validation Loss: 0.690


100%|██████████| 1445/1445 [2:33:36<00:00,  6.38s/it]  


Report after epoch 3 training:
              precision    recall  f1-score   support

           0       0.56      0.63      0.59    376649
           1       0.56      0.49      0.52    362870

    accuracy                           0.56    739519
   macro avg       0.56      0.56      0.56    739519
weighted avg       0.56      0.56      0.56    739519

Report after epoch 3 validation:


  0%|          | 0/1445 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.57      0.48      0.52    161421
           1       0.54      0.63      0.58    155516

    accuracy                           0.55    316937
   macro avg       0.55      0.55      0.55    316937
weighted avg       0.55      0.55      0.55    316937

Epoch: 3, Training Loss: 0.690, Validation Loss: 0.690


100%|██████████| 1445/1445 [2:21:45<00:00,  5.89s/it] 


Report after epoch 4 training:
              precision    recall  f1-score   support

           0       0.57      0.56      0.56    376649
           1       0.55      0.55      0.55    362870

    accuracy                           0.56    739519
   macro avg       0.56      0.56      0.56    739519
weighted avg       0.56      0.56      0.56    739519

Report after epoch 4 validation:
              precision    recall  f1-score   support

           0       0.57      0.57      0.57    161421
           1       0.55      0.55      0.55    155516

    accuracy                           0.56    316937
   macro avg       0.56      0.56      0.56    316937
weighted avg       0.56      0.56      0.56    316937

Epoch: 4, Training Loss: 0.689, Validation Loss: 0.689
