In [1]:
from get_datasets import get_data
train_dataset, val_dataset, test_dataset, vocab = get_data('SentimentClassifier/')

In [2]:
from torch.utils.data import DataLoader

def sorted_collate(batch):
    x, y = zip(*batch)
    order = np.argsort([len(xi) for xi in x])[::-1]
    x = [torch.LongTensor(x[i]) for i in order]
    y = torch.stack([y[i] for i in order])[:, None]
    return x,y

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, collate_fn=sorted_collate)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False, collate_fn=sorted_collate)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, collate_fn=sorted_collate)

In [3]:
import torch
import numpy as np
from torch import nn
from skiprnn import SkipGRU
import torch.functional as F

class IMDbRNN(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_classes):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = SkipGRU(embedding_size, hidden_size)
        self.cls = nn.Linear(hidden_size, num_classes)
        self.emb = nn.Embedding(vocab_size, embedding_size)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
        x, l = self.prepare_sequence(x)
        #на данном этапе учим с маской из одних единиц
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)
        return self.cls(o)

In [4]:
device = torch.device('cuda')
model = IMDbRNN(2000, 32, 16, 1)
model.to(device)

IMDbRNN(
  (rnn): SkipGRU(
    (layer): SkipGRULayer(
      (cell): SkipGRUCell(
        (inner_gru): WeakScriptModuleProxy()
      )
    )
  )
  (cls): Linear(in_features=16, out_features=1, bias=True)
  (emb): Embedding(2000, 32)
)

In [5]:
import torch
import torch.nn as nn
from catalyst.dl import SupervisedRunner
from catalyst.dl.callbacks import AUCCallback, F1ScoreCallback, CheckpointCallback, CriterionCallback, CriterionAggregatorCallback

In [None]:
num_epochs = 100
logdir = "./logs/exp3"
loaders = {
    "train": train_loader,
    "valid": val_loader
}

criterion = {
   "bce": nn.BCEWithLogitsLoss()
}

optimizer = torch.optim.Adam([
    {'params': model.parameters(), 'lr': 3e-4}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner()

In [None]:
runner.train(
    model=model,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_bce",
                     criterion_key='bce', multiplier=1.),
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_bce']),
               AUCCallback(num_classes=1), 
                F1ScoreCallback(),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='auc/_mean',
    minimize_metric=False,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

In [6]:
state = torch.load("logs/exp3/checkpoints/best.pth",map_location='cpu')

model.load_state_dict(state['model_state_dict'])

<All keys matched successfully>

In [7]:
model.eval()

IMDbRNN(
  (rnn): SkipGRU(
    (layer): SkipGRULayer(
      (cell): SkipGRUCell(
        (inner_gru): WeakScriptModuleProxy()
      )
    )
  )
  (cls): Linear(in_features=16, out_features=1, bias=True)
  (emb): Embedding(2000, 32)
)

In [8]:
import numpy as np
from tqdm import tqdm

predictions, labels = [], []
with torch.no_grad():
    for x,y in tqdm(test_loader):
        predictions.append(model(x).detach().cpu().numpy())
        labels.append(y.detach().cpu().numpy())
predictions = np.concatenate(predictions)
labels = np.concatenate(labels)

100%|██████████| 25/25 [00:05<00:00,  4.54it/s]


In [9]:
from sklearn.metrics import roc_auc_score

roc_auc_score(labels, predictions)

0.9095062592

In [19]:
from binarizers import HintonBinarizer

In [20]:
hb = HintonBinarizer()

In [21]:
for x,y in train_loader:
    break

Запускаем для предложения x[0]

In [22]:
x[0].shape

torch.Size([1177])

In [23]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)
        
        #считаем с бинаризованной маской u
        b_new = torch.repeat_interleave(hb(self.u), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [24]:
model_bin = IMDbRNN_bin(x[0].shape[0])

Делаем датасет из предложения x[0] размножив его несколько раз

In [25]:
from torch.utils.data import Dataset as BaseDataset

class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)


In [26]:
num_epochs = 100  # change me
logdir = "./logs/exp5"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [27]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-11-30 08:06:19,608] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=4.0977 | _timers/batch_time=0.4887 | _timers/data_time=0.0005 | _timers/model_time=0.4881 | f1_score=0.9710 | loss=0.6808 | loss_dist=0.6346 | loss_mean=0.0462
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=5.4131 | _timers/batch_time=0.3700 | _timers/data_time=0.0005 | _timers/model_time=0.3695 | f1_score=0.9707 | loss=0.6976 | loss_dist=0.6551 | loss_mean=0.0425
[2019-11-30 08:06:26,372] 
2/100 * Epoch 2 (train): _base/lr=0.0300 | _base/momentum=0.9000 | _timers/_fps=4.1446 | _timers/batch_time=0.4830 | _timers/data_time=0.0007 | _timers/model_time=0.4822 | f1_score=0.9534 | loss=1.4232 | loss_dist=1.3844 | loss_mean=0.0389
2/100 * Epoch 2 (valid): _base/lr=0.0300 | _base/momentum=0.9000 | _timers/_fps=5.4132 | _timers/batch_time=0.3701 | _timers/data_time=0.0005 | _timers/model_time=0.3696 | f1_score=0.9434 | loss=0.4975 | loss_dist=0.4639 | los

[2019-11-30 08:08:14,535] 
18/100 * Epoch 18 (train): _base/lr=0.0243 | _base/momentum=0.9000 | _timers/_fps=4.2309 | _timers/batch_time=0.4731 | _timers/data_time=0.0006 | _timers/model_time=0.4725 | f1_score=0.9686 | loss=0.1090 | loss_dist=0.0892 | loss_mean=0.0198
18/100 * Epoch 18 (valid): _base/lr=0.0243 | _base/momentum=0.9000 | _timers/_fps=5.4747 | _timers/batch_time=0.3661 | _timers/data_time=0.0005 | _timers/model_time=0.3655 | f1_score=0.9686 | loss=0.1093 | loss_dist=0.0892 | loss_mean=0.0201
[2019-11-30 08:08:21,317] 
19/100 * Epoch 19 (train): _base/lr=0.0243 | _base/momentum=0.9000 | _timers/_fps=4.2296 | _timers/batch_time=0.4733 | _timers/data_time=0.0007 | _timers/model_time=0.4726 | f1_score=0.9686 | loss=0.1095 | loss_dist=0.0894 | loss_mean=0.0201
19/100 * Epoch 19 (valid): _base/lr=0.0243 | _base/momentum=0.9000 | _timers/_fps=5.5353 | _timers/batch_time=0.3618 | _timers/data_time=0.0006 | _timers/model_time=0.3612 | f1_score=0.9687 | loss=0.1104 | loss_dist=0.09

[2019-11-30 08:10:09,729] 
35/100 * Epoch 35 (train): _base/lr=0.0197 | _base/momentum=0.9000 | _timers/_fps=4.2128 | _timers/batch_time=0.4752 | _timers/data_time=0.0006 | _timers/model_time=0.4746 | f1_score=0.9439 | loss=0.0268 | loss_dist=0.0061 | loss_mean=0.0207
35/100 * Epoch 35 (valid): _base/lr=0.0197 | _base/momentum=0.9000 | _timers/_fps=5.5101 | _timers/batch_time=0.3635 | _timers/data_time=0.0005 | _timers/model_time=0.3630 | f1_score=0.9439 | loss=0.0263 | loss_dist=0.0061 | loss_mean=0.0203
[2019-11-30 08:10:16,506] 
36/100 * Epoch 36 (train): _base/lr=0.0197 | _base/momentum=0.9000 | _timers/_fps=4.2232 | _timers/batch_time=0.4739 | _timers/data_time=0.0006 | _timers/model_time=0.4733 | f1_score=0.9439 | loss=0.0263 | loss_dist=0.0061 | loss_mean=0.0202
36/100 * Epoch 36 (valid): _base/lr=0.0197 | _base/momentum=0.9000 | _timers/_fps=5.5063 | _timers/batch_time=0.3639 | _timers/data_time=0.0005 | _timers/model_time=0.3634 | f1_score=0.9439 | loss=0.0262 | loss_dist=0.00

[2019-11-30 08:12:04,947] 
52/100 * Epoch 52 (train): _base/lr=0.0177 | _base/momentum=0.9000 | _timers/_fps=4.2128 | _timers/batch_time=0.4752 | _timers/data_time=0.0007 | _timers/model_time=0.4744 | f1_score=0.9335 | loss=0.0201 | loss_dist=0.0030 | loss_mean=0.0171
52/100 * Epoch 52 (valid): _base/lr=0.0177 | _base/momentum=0.9000 | _timers/_fps=5.5106 | _timers/batch_time=0.3635 | _timers/data_time=0.0005 | _timers/model_time=0.3630 | f1_score=0.9335 | loss=0.0200 | loss_dist=0.0030 | loss_mean=0.0170
[2019-11-30 08:12:11,737] 
53/100 * Epoch 53 (train): _base/lr=0.0177 | _base/momentum=0.9000 | _timers/_fps=4.2166 | _timers/batch_time=0.4748 | _timers/data_time=0.0007 | _timers/model_time=0.4741 | f1_score=0.9335 | loss=0.0199 | loss_dist=0.0030 | loss_mean=0.0170
53/100 * Epoch 53 (valid): _base/lr=0.0177 | _base/momentum=0.9000 | _timers/_fps=5.5193 | _timers/batch_time=0.3628 | _timers/data_time=0.0005 | _timers/model_time=0.3623 | f1_score=0.9335 | loss=0.0198 | loss_dist=0.00

[2019-11-30 08:14:00,135] 
69/100 * Epoch 69 (train): _base/lr=0.0177 | _base/momentum=0.9000 | _timers/_fps=4.2316 | _timers/batch_time=0.4730 | _timers/data_time=0.0006 | _timers/model_time=0.4723 | f1_score=0.9335 | loss=0.0186 | loss_dist=0.0030 | loss_mean=0.0156
69/100 * Epoch 69 (valid): _base/lr=0.0177 | _base/momentum=0.9000 | _timers/_fps=5.5080 | _timers/batch_time=0.3636 | _timers/data_time=0.0005 | _timers/model_time=0.3631 | f1_score=0.9335 | loss=0.0186 | loss_dist=0.0030 | loss_mean=0.0156
[2019-11-30 08:14:06,908] 
70/100 * Epoch 70 (train): _base/lr=0.0177 | _base/momentum=0.9000 | _timers/_fps=4.2149 | _timers/batch_time=0.4749 | _timers/data_time=0.0007 | _timers/model_time=0.4742 | f1_score=0.9335 | loss=0.0186 | loss_dist=0.0030 | loss_mean=0.0156
70/100 * Epoch 70 (valid): _base/lr=0.0177 | _base/momentum=0.9000 | _timers/_fps=5.5351 | _timers/batch_time=0.3619 | _timers/data_time=0.0005 | _timers/model_time=0.3613 | f1_score=0.9335 | loss=0.0186 | loss_dist=0.00

[2019-11-30 08:15:55,184] 
86/100 * Epoch 86 (train): _base/lr=0.0159 | _base/momentum=0.9000 | _timers/_fps=4.2216 | _timers/batch_time=0.4742 | _timers/data_time=0.0007 | _timers/model_time=0.4734 | f1_score=0.9335 | loss=0.0174 | loss_dist=0.0030 | loss_mean=0.0144
86/100 * Epoch 86 (valid): _base/lr=0.0159 | _base/momentum=0.9000 | _timers/_fps=5.5182 | _timers/batch_time=0.3629 | _timers/data_time=0.0005 | _timers/model_time=0.3624 | f1_score=0.9335 | loss=0.0174 | loss_dist=0.0030 | loss_mean=0.0144
[2019-11-30 08:16:01,986] 
87/100 * Epoch 87 (train): _base/lr=0.0159 | _base/momentum=0.9000 | _timers/_fps=4.2215 | _timers/batch_time=0.4742 | _timers/data_time=0.0007 | _timers/model_time=0.4735 | f1_score=0.9335 | loss=0.0174 | loss_dist=0.0030 | loss_mean=0.0144
87/100 * Epoch 87 (valid): _base/lr=0.0159 | _base/momentum=0.9000 | _timers/_fps=5.4921 | _timers/batch_time=0.3648 | _timers/data_time=0.0005 | _timers/model_time=0.3642 | f1_score=0.9335 | loss=0.0174 | loss_dist=0.00