In [2]:
from get_datasets import get_data
train_dataset, val_dataset, test_dataset, vocab = get_data('SentimentClassifier/')

In [3]:
from torch.utils.data import DataLoader

def sorted_collate(batch):
    x, y = zip(*batch)
    order = np.argsort([len(xi) for xi in x])[::-1]
    x = [torch.LongTensor(x[i]) for i in order]
    y = torch.stack([y[i] for i in order])[:, None]
    return x,y

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, collate_fn=sorted_collate)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False, collate_fn=sorted_collate)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, collate_fn=sorted_collate)

In [4]:
from torch import jit
import torch
from torch import nn

class SkipGRUCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        """Simple SkipGRU cell.
        Args:
            input_size (int): size of input space of GRU
            hidden_size (int): size of hidden space of GRU
        """
        super().__init__()
        self.inner_gru = nn.GRUCell(input_size, hidden_size)

    @jit.script_method
    def forward(self, input, hidden, mix):
        """
        Args:
            input (torch.FloatTensor): input vector for the current step
            hidden (torch.FloatTensor): hidden vector for the previous step
            mix (torch.FloatTensor): mixture vector which defines proportion of mix between previous step and current.
        Returns:
            hidden (torch.FloatTensor): hidden vetor for the current step
        """
        # type: (Tensor, Tensor, Tensor) -> Tensor

        output = self.inner_gru(input, hidden)
        output = output * mix + hidden * (1 - mix)

        return output

from typing import List, Tuple

class SkipGRULayer(jit.ScriptModule):
    __constants__ = ['hidden_size']
    def __init__(self, input_size, hidden_size):
        """
        Layer consisted of SkipGRU cells.
        Args:
            input_size (int): size of input space of GRU
            hidden_size (int): size of hidden space of GRU
        """
        super().__init__()
        self.cell = SkipGRUCell(input_size, hidden_size)
        self.hidden_size = hidden_size

    @jit.script_method
    def forward(self, input, mix):
        """
        Forward pass on SkipGRU layer.
        Args:
            input (torch.FloatTensor): inputs for the SkipGRU layer of form [BxTxN]
            mix (torch.FloatTensor): mixture coefficients for thehidden steps of form [BxTx1]
        Returns:
            h (torch.FloatTensor): history of outputs by steps. Size [BxTxH]
            o (torch.FloatTensor): last hidden step of size [BxH]
        """
        # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
        hidden = torch.zeros(input.shape[0], self.hidden_size,
                             device=self.cell.inner_gru.weight_ih.device)
        inputs = input.unbind(1)
        mixtures = mix.unbind(1)
        outputs = jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            hidden = self.cell(inputs[i], hidden, mixtures[i])
            outputs += [hidden]
        return torch.stack(outputs, 1), hidden


class SkipGRU(nn.Module):
    def __init__(self, input_size=None, hidden_size=None, layer=None, do_copy_weights=False):
        """
        SkipRNN layer with trivial binarizer (All betas are directly parametrized.)
        Args:
            input_size (int): size of input space for SkipGRU
            hidden_size (int): size of hidden space for SkipGRU
            layer (nn.GRU): original GRU layers to copy sizes and make a link to weights
        """
        super().__init__()
        if layer is not None:
            self.layer = SkipGRULayer(layer.input_size, layer.hidden_size)
            self.impute_weights(layer, do_copy_weights)
        elif (input_size is not None) and (hidden_size is not None):
            self.layer = SkipGRULayer(input_size, hidden_size)
        else:
            raise Exception('Either layer or input_size & hidden_size are required')

    def impute_weights(self, donor_layer, copy=False):
        """
        Get weights from instance of nn.GRU.
        Args:
            donor_layer (nn.GRU): pretrained layer to get weights from.
            copy (bool): if True, will make copy of weights instead of linkage.
        """
        assert isinstance(donor_layer, nn.GRU), 'Wrong type of donor layer. GRU required!'

        if copy:
            self.layer.cell.inner_gru.weight_ih.data = donor_layer.weight_ih_l0.clone()
            self.layer.cell.inner_gru.weight_hh.data = donor_layer.weight_hh_l0.clone()
            self.layer.cell.inner_gru.bias_ih.data = donor_layer.bias_ih_l0.clone()
            self.layer.cell.inner_gru.bias_hh.data = donor_layer.bias_hh_l0.clone()
        else:
            self.layer.cell.inner_gru.weight_ih = donor_layer.weight_ih_l0
            self.layer.cell.inner_gru.weight_hh = donor_layer.weight_hh_l0
            self.layer.cell.inner_gru.bias_ih = donor_layer.bias_ih_l0
            self.layer.cell.inner_gru.bias_hh = donor_layer.bias_hh_l0

    def forward(self, x, b, l=None):
        """
        Make forward SkipGRU pass.
        Args:
            x (torch.FloatTensor): input of size [BxTxN]. NB! sequences should be padded from the end, not from the start position.
            u (torch.FloatTensor): mixture coefficients of size [BxTx1]
            l (torch.LongTensor): lengths of the padded sequences. If not provided, output will contain last elements of sequences.
        Returns:
            h (torch.FloatTensor, nn.utils.rnn.PackedSequence): history of hidden states of size [BxTxH]
            o (torch.FloatTensor): last hidden state of size [BxH]
        """
        h, o = self.layer(x, b)
        if l is not None:
            o = h[torch.arange(h.shape[0]), l-1]
        return h, o

In [5]:
import torch
import numpy as np
from torch import nn
#from skiprnn import SkipGRU
import torch.functional as F

class IMDbRNN(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_classes):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = SkipGRU(embedding_size, hidden_size)
        self.cls = nn.Linear(hidden_size, num_classes)
        self.emb = nn.Embedding(vocab_size, embedding_size)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
        x, l = self.prepare_sequence(x)
        #на данном этапе учим с маской из одних единиц
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)
        return self.cls(o)

In [6]:
device = torch.device('cpu')
model = IMDbRNN(2000, 32, 16, 1)
model.to(device)

IMDbRNN(
  (rnn): SkipGRU(
    (layer): SkipGRULayer(
      original_name=SkipGRULayer
      (cell): SkipGRUCell(
        original_name=SkipGRUCell
        (inner_gru): ScriptModule(original_name=GRUCell)
      )
    )
  )
  (cls): Linear(in_features=16, out_features=1, bias=True)
  (emb): Embedding(2000, 32)
)

In [7]:
import torch
import torch.nn as nn
from catalyst.dl import SupervisedRunner
from catalyst.dl.callbacks import AUCCallback, F1ScoreCallback, CheckpointCallback, CriterionCallback, CriterionAggregatorCallback

  data = yaml.load(f.read()) or {}


In [8]:
num_epochs = 100
logdir = "./logs/exp3"
loaders = {
    "train": train_loader,
    "valid": val_loader
}

criterion = {
   "bce": nn.BCEWithLogitsLoss()
}

optimizer = torch.optim.Adam([
    {'params': model.parameters(), 'lr': 3e-4}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner()

In [8]:
runner.train(
    model=model,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_bce",
                     criterion_key='bce', multiplier=1.),
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_bce']),
               AUCCallback(num_classes=1), 
                F1ScoreCallback(),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='auc/_mean',
    minimize_metric=False,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 13:52:43,204] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=738.5268 | _timers/batch_time=0.7703 | _timers/data_time=0.0107 | _timers/model_time=0.7595 | auc/_mean=0.5220 | auc/class_0=0.5220 | f1_score=0.5011 | loss=0.6922 | loss_bce=0.6922
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=1178.0794 | _timers/batch_time=0.9482 | _timers/data_time=0.0192 | _timers/model_time=0.9290 | auc/_mean=0.5301 | auc/class_0=0.5301 | f1_score=0.3597 | loss=0.6911 | loss_bce=0.6911


KeyboardInterrupt: 

In [9]:
state = torch.load("best_skiprnn.pth",map_location='cpu')

model.load_state_dict(state['model_state_dict'])

<All keys matched successfully>

In [10]:
model.eval()

IMDbRNN(
  (rnn): SkipGRU(
    (layer): SkipGRULayer(
      original_name=SkipGRULayer
      (cell): SkipGRUCell(
        original_name=SkipGRUCell
        (inner_gru): ScriptModule(original_name=GRUCell)
      )
    )
  )
  (cls): Linear(in_features=16, out_features=1, bias=True)
  (emb): Embedding(2000, 32)
)

In [13]:
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score



In [14]:
predictions, labels = [], []
with torch.no_grad():
    for x,y in tqdm(test_loader):
        predictions.append(model(x).detach().cpu().numpy())
        labels.append(y.detach().cpu().numpy())
predictions = np.concatenate(predictions)
labels = np.concatenate(labels)

roc_auc_score(labels, predictions)

100%|██████████| 25/25 [00:21<00:00,  1.14it/s]


0.7049218304

## HintonBinarizer

In [15]:
from binarizers import HintonBinarizer

In [16]:
hb = HintonBinarizer()

In [17]:
for x,y in train_loader:
    break

In [18]:
x[0].shape

torch.Size([1306])

In [19]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)
        
        #считаем с бинаризованной маской u
        b_new = torch.repeat_interleave(hb(self.u), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [20]:
model_bin = IMDbRNN_bin(x[0].shape[0])

Делаем датасет из предложения x[0] размножив его несколько раз

In [21]:
from torch.utils.data import Dataset as BaseDataset

class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)


In [22]:
num_epochs = 100  # change me
logdir = "./logs/Hinton"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [23]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 17:42:39,393] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=4.5196 | _timers/batch_time=0.4474 | _timers/data_time=0.0003 | _timers/model_time=0.4470 | f1_score=7.940e-07 | loss=0.2718 | loss_dist=0.2263 | loss_mean=0.0455
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=9.4056 | _timers/batch_time=0.2171 | _timers/data_time=0.0004 | _timers/model_time=0.2167 | f1_score=8.099e-07 | loss=0.2710 | loss_dist=0.2257 | loss_mean=0.0453
[2019-12-06 17:42:44,066] 
2/100 * Epoch 2 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.3819 | _timers/batch_time=0.3745 | _timers/data_time=0.0004 | _timers/model_time=0.3741 | f1_score=8.095e-07 | loss=0.2708 | loss_dist=0.2257 | loss_mean=0.0450
2/100 * Epoch 2 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=10.2587 | _timers/batch_time=0.1992 | _timers/data_time=0.0003 | _timers/model_time=0.1989 | f1_score=8.131e-07 | loss=0.2705 | loss_dist

[2019-12-06 17:43:54,046] 
17/100 * Epoch 17 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.6214 | _timers/batch_time=0.3564 | _timers/data_time=0.0003 | _timers/model_time=0.3560 | f1_score=7.030e-07 | loss=0.0643 | loss_dist=0.0283 | loss_mean=0.0361
17/100 * Epoch 17 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=10.4071 | _timers/batch_time=0.1962 | _timers/data_time=0.0002 | _timers/model_time=0.1960 | f1_score=7.785e-07 | loss=0.0597 | loss_dist=0.0238 | loss_mean=0.0359
[2019-12-06 17:43:58,409] 
18/100 * Epoch 18 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.7331 | _timers/batch_time=0.3503 | _timers/data_time=0.0004 | _timers/model_time=0.3498 | f1_score=7.897e-07 | loss=0.0592 | loss_dist=0.0235 | loss_mean=0.0357
18/100 * Epoch 18 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=10.3656 | _timers/batch_time=0.1969 | _timers/data_time=0.0004 | _timers/model_time=0.1965 | f1_score=8.861e-07 | loss=0.0579 | 

[2019-12-06 17:45:11,015] 
33/100 * Epoch 33 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.5506 | _timers/batch_time=0.3622 | _timers/data_time=0.0004 | _timers/model_time=0.3618 | f1_score=1.015e-06 | loss=0.0307 | loss_dist=0.0030 | loss_mean=0.0277
33/100 * Epoch 33 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=9.8237 | _timers/batch_time=0.2080 | _timers/data_time=0.0004 | _timers/model_time=0.2076 | f1_score=1.016e-06 | loss=0.0299 | loss_dist=0.0030 | loss_mean=0.0269
[2019-12-06 17:45:15,801] 
34/100 * Epoch 34 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.3750 | _timers/batch_time=0.3741 | _timers/data_time=0.0004 | _timers/model_time=0.3737 | f1_score=1.016e-06 | loss=0.0298 | loss_dist=0.0030 | loss_mean=0.0269
34/100 * Epoch 34 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=9.1978 | _timers/batch_time=0.2218 | _timers/data_time=0.0005 | _timers/model_time=0.2212 | f1_score=1.016e-06 | loss=0.0297 | lo

[2019-12-06 17:46:25,088] 
49/100 * Epoch 49 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.5588 | _timers/batch_time=0.3607 | _timers/data_time=0.0004 | _timers/model_time=0.3602 | f1_score=8.604e-07 | loss=0.0183 | loss_dist=0.0003 | loss_mean=0.0180
49/100 * Epoch 49 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=9.6411 | _timers/batch_time=0.2112 | _timers/data_time=0.0005 | _timers/model_time=0.2107 | f1_score=8.702e-07 | loss=0.0182 | loss_dist=0.0003 | loss_mean=0.0179
[2019-12-06 17:46:29,627] 
50/100 * Epoch 50 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.4898 | _timers/batch_time=0.3654 | _timers/data_time=0.0003 | _timers/model_time=0.3650 | f1_score=8.745e-07 | loss=0.0181 | loss_dist=0.0003 | loss_mean=0.0178
50/100 * Epoch 50 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=10.1425 | _timers/batch_time=0.2008 | _timers/data_time=0.0003 | _timers/model_time=0.2005 | f1_score=8.760e-07 | loss=0.0179 | l

[2019-12-06 17:47:40,243] 
65/100 * Epoch 65 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.1639 | _timers/batch_time=0.3884 | _timers/data_time=0.0005 | _timers/model_time=0.3879 | f1_score=1.036e-06 | loss=0.0121 | loss_dist=0.0005 | loss_mean=0.0116
65/100 * Epoch 65 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=9.8428 | _timers/batch_time=0.2077 | _timers/data_time=0.0004 | _timers/model_time=0.2073 | f1_score=1.031e-06 | loss=0.0118 | loss_dist=0.0005 | loss_mean=0.0113
[2019-12-06 17:47:44,964] 
66/100 * Epoch 66 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.4891 | _timers/batch_time=0.3665 | _timers/data_time=0.0003 | _timers/model_time=0.3661 | f1_score=1.030e-06 | loss=0.0117 | loss_dist=0.0005 | loss_mean=0.0112
66/100 * Epoch 66 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=9.8360 | _timers/batch_time=0.2084 | _timers/data_time=0.0004 | _timers/model_time=0.2080 | f1_score=1.030e-06 | loss=0.0117 | lo

[2019-12-06 17:48:57,603] 
81/100 * Epoch 81 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.5849 | _timers/batch_time=0.3593 | _timers/data_time=0.0004 | _timers/model_time=0.3588 | f1_score=1.089e-06 | loss=0.0091 | loss_dist=0.0011 | loss_mean=0.0080
81/100 * Epoch 81 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=10.2098 | _timers/batch_time=0.2010 | _timers/data_time=0.0003 | _timers/model_time=0.2007 | f1_score=1.020e-06 | loss=0.0087 | loss_dist=0.0006 | loss_mean=0.0081
[2019-12-06 17:49:02,288] 
82/100 * Epoch 82 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.4742 | _timers/batch_time=0.3679 | _timers/data_time=0.0004 | _timers/model_time=0.3675 | f1_score=9.896e-07 | loss=0.0086 | loss_dist=0.0005 | loss_mean=0.0080
82/100 * Epoch 82 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=9.7907 | _timers/batch_time=0.2093 | _timers/data_time=0.0003 | _timers/model_time=0.2089 | f1_score=9.588e-07 | loss=0.0083 | l

[2019-12-06 17:50:11,637] 
97/100 * Epoch 97 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=5.6063 | _timers/batch_time=0.3583 | _timers/data_time=0.0004 | _timers/model_time=0.3578 | f1_score=1.168e-06 | loss=0.0092 | loss_dist=0.0027 | loss_mean=0.0065
97/100 * Epoch 97 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=10.2885 | _timers/batch_time=0.1979 | _timers/data_time=0.0003 | _timers/model_time=0.1976 | f1_score=1.168e-06 | loss=0.0091 | loss_dist=0.0027 | loss_mean=0.0064
[2019-12-06 17:50:16,135] 
98/100 * Epoch 98 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=5.6234 | _timers/batch_time=0.3572 | _timers/data_time=0.0003 | _timers/model_time=0.3568 | f1_score=1.168e-06 | loss=0.0090 | loss_dist=0.0027 | loss_mean=0.0063
98/100 * Epoch 98 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=10.1119 | _timers/batch_time=0.2020 | _timers/data_time=0.0003 | _timers/model_time=0.2016 | f1_score=1.169e-06 | loss=0.0090 | 

In [48]:
model_bin.eval()
predictions, labels = [], []
with torch.no_grad():
    for x,y in tqdm(test_loader):
        try:
        
            predictions.append(model_bin(x)['cls'].detach().cpu().numpy())
            labels.append(y.detach().cpu().numpy())
        except: pass
predictions = np.concatenate(predictions)
labels = np.concatenate(labels)

roc_auc_score(labels, predictions)

100%|██████████| 25/25 [00:43<00:00,  1.75s/it]


0.5299492057635427

## ConcreteBinarizer

In [51]:
from binarizers import ConcreteBinarizer

In [52]:
cb = ConcreteBinarizer()

In [53]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
        
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)
        
        #считаем с бинаризованной маской u
        b_new = torch.repeat_interleave(cb(self.u), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [54]:
model_bin = IMDbRNN_bin(x[0].shape[0])

In [55]:
from torch.utils.data import Dataset as BaseDataset

class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)



In [56]:
num_epochs = 100  # change me
logdir = "./logs/Concrete"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [57]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 18:14:22,067] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.3213 | _timers/batch_time=0.3169 | _timers/data_time=0.0007 | _timers/model_time=0.3160 | f1_score=8.577e-07 | loss=3.2292 | loss_dist=3.1015 | loss_mean=0.1277
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=11.5877 | _timers/batch_time=0.1756 | _timers/data_time=0.0003 | _timers/model_time=0.1752 | f1_score=8.689e-07 | loss=3.1823 | loss_dist=3.0550 | loss_mean=0.1273
[2019-12-06 18:14:25,970] 
2/100 * Epoch 2 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=6.3997 | _timers/batch_time=0.3137 | _timers/data_time=0.0004 | _timers/model_time=0.3132 | f1_score=1.236e-06 | loss=0.6186 | loss_dist=0.4880 | loss_mean=0.1306
2/100 * Epoch 2 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=11.7097 | _timers/batch_time=0.1746 | _timers/data_time=0.0003 | _timers/model_time=0.1742 | f1_score=1.251e-06 | loss=0.6099 | loss_dis

[2019-12-06 18:15:27,177] 
17/100 * Epoch 17 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=6.1392 | _timers/batch_time=0.3276 | _timers/data_time=0.0004 | _timers/model_time=0.3272 | f1_score=1.250e-06 | loss=0.3591 | loss_dist=0.2376 | loss_mean=0.1216
17/100 * Epoch 17 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=10.2200 | _timers/batch_time=0.2003 | _timers/data_time=0.0004 | _timers/model_time=0.1999 | f1_score=1.253e-06 | loss=0.3565 | loss_dist=0.2350 | loss_mean=0.1215
[2019-12-06 18:15:31,570] 
18/100 * Epoch 18 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=5.7391 | _timers/batch_time=0.3519 | _timers/data_time=0.0003 | _timers/model_time=0.3515 | f1_score=1.081e-06 | loss=0.3549 | loss_dist=0.2339 | loss_mean=0.1209
18/100 * Epoch 18 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=10.6587 | _timers/batch_time=0.1932 | _timers/data_time=0.0004 | _timers/model_time=0.1928 | f1_score=1.082e-06 | loss=0.3529 | 

[2019-12-06 18:16:36,804] 
33/100 * Epoch 33 (train): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=5.8860 | _timers/batch_time=0.3417 | _timers/data_time=0.0004 | _timers/model_time=0.3413 | f1_score=1.610e-06 | loss=0.2352 | loss_dist=0.1171 | loss_mean=0.1181
33/100 * Epoch 33 (valid): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=10.5727 | _timers/batch_time=0.1930 | _timers/data_time=0.0002 | _timers/model_time=0.1928 | f1_score=1.617e-06 | loss=0.2343 | loss_dist=0.1164 | loss_mean=0.1179
[2019-12-06 18:16:40,962] 
34/100 * Epoch 34 (train): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=6.0177 | _timers/batch_time=0.3330 | _timers/data_time=0.0003 | _timers/model_time=0.3327 | f1_score=1.056e-06 | loss=0.4652 | loss_dist=0.3510 | loss_mean=0.1141
34/100 * Epoch 34 (valid): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=10.4361 | _timers/batch_time=0.1943 | _timers/data_time=0.0002 | _timers/model_time=0.1941 | f1_score=1.061e-06 | loss=0.4639 | 

[2019-12-06 18:17:40,722] 
49/100 * Epoch 49 (train): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=6.5369 | _timers/batch_time=0.3073 | _timers/data_time=0.0004 | _timers/model_time=0.3069 | f1_score=1.292e-06 | loss=0.2648 | loss_dist=0.1536 | loss_mean=0.1111
49/100 * Epoch 49 (valid): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=11.8972 | _timers/batch_time=0.1707 | _timers/data_time=0.0005 | _timers/model_time=0.1702 | f1_score=1.295e-06 | loss=0.2628 | loss_dist=0.1518 | loss_mean=0.1110
[2019-12-06 18:17:44,494] 
50/100 * Epoch 50 (train): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=6.6405 | _timers/batch_time=0.3019 | _timers/data_time=0.0003 | _timers/model_time=0.3015 | f1_score=1.274e-06 | loss=0.2102 | loss_dist=0.0979 | loss_mean=0.1122
50/100 * Epoch 50 (valid): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=11.8715 | _timers/batch_time=0.1709 | _timers/data_time=0.0004 | _timers/model_time=0.1705 | f1_score=1.278e-06 | loss=0.2096 | 

[2019-12-06 18:18:43,996] 
65/100 * Epoch 65 (train): _base/lr=0.0012 | _base/momentum=0.9000 | _timers/_fps=6.5917 | _timers/batch_time=0.3049 | _timers/data_time=0.0004 | _timers/model_time=0.3045 | f1_score=1.356e-06 | loss=0.4229 | loss_dist=0.3153 | loss_mean=0.1076
65/100 * Epoch 65 (valid): _base/lr=0.0012 | _base/momentum=0.9000 | _timers/_fps=11.7263 | _timers/batch_time=0.1728 | _timers/data_time=0.0003 | _timers/model_time=0.1725 | f1_score=1.357e-06 | loss=0.4227 | loss_dist=0.3152 | loss_mean=0.1075
[2019-12-06 18:18:47,871] 
66/100 * Epoch 66 (train): _base/lr=0.0012 | _base/momentum=0.9000 | _timers/_fps=6.5135 | _timers/batch_time=0.3088 | _timers/data_time=0.0003 | _timers/model_time=0.3085 | f1_score=1.132e-06 | loss=0.2220 | loss_dist=0.1130 | loss_mean=0.1089
66/100 * Epoch 66 (valid): _base/lr=0.0012 | _base/momentum=0.9000 | _timers/_fps=11.7317 | _timers/batch_time=0.1736 | _timers/data_time=0.0003 | _timers/model_time=0.1733 | f1_score=1.137e-06 | loss=0.2208 | 

[2019-12-06 18:19:47,320] 
81/100 * Epoch 81 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.5210 | _timers/batch_time=0.3070 | _timers/data_time=0.0004 | _timers/model_time=0.3066 | f1_score=1.667e-06 | loss=0.3028 | loss_dist=0.1985 | loss_mean=0.1042
81/100 * Epoch 81 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=11.8073 | _timers/batch_time=0.1724 | _timers/data_time=0.0002 | _timers/model_time=0.1722 | f1_score=1.668e-06 | loss=0.3023 | loss_dist=0.1982 | loss_mean=0.1041
[2019-12-06 18:19:51,106] 
82/100 * Epoch 82 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.6520 | _timers/batch_time=0.3021 | _timers/data_time=0.0004 | _timers/model_time=0.3016 | f1_score=1.797e-06 | loss=0.1549 | loss_dist=0.0452 | loss_mean=0.1097
82/100 * Epoch 82 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=11.9542 | _timers/batch_time=0.1707 | _timers/data_time=0.0003 | _timers/model_time=0.1704 | f1_score=1.798e-06 | loss=0.1547 | 

[2019-12-06 18:20:50,280] 
97/100 * Epoch 97 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.2791 | _timers/batch_time=0.3233 | _timers/data_time=0.0004 | _timers/model_time=0.3229 | f1_score=1.407e-06 | loss=0.1394 | loss_dist=0.0387 | loss_mean=0.1007
97/100 * Epoch 97 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=11.1642 | _timers/batch_time=0.1810 | _timers/data_time=0.0003 | _timers/model_time=0.1806 | f1_score=1.408e-06 | loss=0.1391 | loss_dist=0.0384 | loss_mean=0.1007
[2019-12-06 18:20:54,417] 
98/100 * Epoch 98 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.2374 | _timers/batch_time=0.3245 | _timers/data_time=0.0003 | _timers/model_time=0.3241 | f1_score=1.528e-06 | loss=0.1387 | loss_dist=0.0387 | loss_mean=0.1000
98/100 * Epoch 98 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=10.8333 | _timers/batch_time=0.1879 | _timers/data_time=0.0003 | _timers/model_time=0.1876 | f1_score=1.529e-06 | loss=0.1391 | 

### HardKumaBinarizer

In [59]:
from binarizers import HardKumaBinarizer

In [60]:
hkb = HardKumaBinarizer()

In [61]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
        self.v = Variable(torch.randn(1,sent_size,1), requires_grad=True)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)
        
        #считаем с бинаризованной маской u
        b_new = torch.repeat_interleave(hkb(self.u, self.v), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [62]:
model_bin = IMDbRNN_bin(x[0].shape[0])
class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)

In [63]:
num_epochs = 100  # change me
logdir = "./logs/HardKuma"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [64]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 18:22:44,807] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=5.8000 | _timers/batch_time=0.3458 | _timers/data_time=0.0002 | _timers/model_time=0.3455 | f1_score=5.300e-07 | loss=2.6945 | loss_dist=2.5584 | loss_mean=0.1361
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=10.9563 | _timers/batch_time=0.1864 | _timers/data_time=0.0002 | _timers/model_time=0.1861 | f1_score=5.511e-07 | loss=2.5137 | loss_dist=2.3779 | loss_mean=0.1358
[2019-12-06 18:22:49,198] 
2/100 * Epoch 2 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.8326 | _timers/batch_time=0.3468 | _timers/data_time=0.0003 | _timers/model_time=0.3465 | f1_score=6.214e-07 | loss=13.2910 | loss_dist=13.1538 | loss_mean=0.1372
2/100 * Epoch 2 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=11.0507 | _timers/batch_time=0.1853 | _timers/data_time=0.0002 | _timers/model_time=0.1850 | f1_score=6.311e-07 | loss=13.0239 | loss_

[2019-12-06 18:23:52,440] 
17/100 * Epoch 17 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=6.3261 | _timers/batch_time=0.3163 | _timers/data_time=0.0003 | _timers/model_time=0.3160 | f1_score=1.320e-06 | loss=0.5241 | loss_dist=0.3872 | loss_mean=0.1368
17/100 * Epoch 17 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=11.5995 | _timers/batch_time=0.1752 | _timers/data_time=0.0004 | _timers/model_time=0.1747 | f1_score=1.328e-06 | loss=0.5137 | loss_dist=0.3769 | loss_mean=0.1368
[2019-12-06 18:23:56,342] 
18/100 * Epoch 18 (train): _base/lr=0.0024 | _base/momentum=0.9000 | _timers/_fps=6.5595 | _timers/batch_time=0.3056 | _timers/data_time=0.0002 | _timers/model_time=0.3053 | f1_score=1.206e-06 | loss=0.3433 | loss_dist=0.2080 | loss_mean=0.1353
18/100 * Epoch 18 (valid): _base/lr=0.0024 | _base/momentum=0.9000 | _timers/_fps=11.2003 | _timers/batch_time=0.1855 | _timers/data_time=0.0005 | _timers/model_time=0.1850 | f1_score=1.218e-06 | loss=0.3396 | 

[2019-12-06 18:24:58,511] 
33/100 * Epoch 33 (train): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=6.2059 | _timers/batch_time=0.3257 | _timers/data_time=0.0003 | _timers/model_time=0.3254 | f1_score=1.351e-06 | loss=0.2424 | loss_dist=0.1057 | loss_mean=0.1368
33/100 * Epoch 33 (valid): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=10.7376 | _timers/batch_time=0.1920 | _timers/data_time=0.0004 | _timers/model_time=0.1916 | f1_score=1.355e-06 | loss=0.2412 | loss_dist=0.1045 | loss_mean=0.1368
[2019-12-06 18:25:02,837] 
34/100 * Epoch 34 (train): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=5.8015 | _timers/batch_time=0.3492 | _timers/data_time=0.0004 | _timers/model_time=0.3486 | f1_score=1.125e-06 | loss=0.3071 | loss_dist=0.1724 | loss_mean=0.1347
34/100 * Epoch 34 (valid): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=11.6071 | _timers/batch_time=0.1759 | _timers/data_time=0.0003 | _timers/model_time=0.1756 | f1_score=1.130e-06 | loss=0.3044 | 

[2019-12-06 18:26:06,689] 
49/100 * Epoch 49 (train): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=5.6938 | _timers/batch_time=0.3531 | _timers/data_time=0.0004 | _timers/model_time=0.3526 | f1_score=1.558e-06 | loss=0.2376 | loss_dist=0.1026 | loss_mean=0.1350
49/100 * Epoch 49 (valid): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=10.0012 | _timers/batch_time=0.2033 | _timers/data_time=0.0003 | _timers/model_time=0.2030 | f1_score=1.560e-06 | loss=0.2367 | loss_dist=0.1017 | loss_mean=0.1349
[2019-12-06 18:26:11,242] 
50/100 * Epoch 50 (train): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=5.6107 | _timers/batch_time=0.3578 | _timers/data_time=0.0004 | _timers/model_time=0.3573 | f1_score=1.421e-06 | loss=0.2255 | loss_dist=0.0921 | loss_mean=0.1334
50/100 * Epoch 50 (valid): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=10.4859 | _timers/batch_time=0.1940 | _timers/data_time=0.0002 | _timers/model_time=0.1938 | f1_score=1.424e-06 | loss=0.2244 | 

[2019-12-06 18:27:11,801] 
65/100 * Epoch 65 (train): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=5.8216 | _timers/batch_time=0.3437 | _timers/data_time=0.0002 | _timers/model_time=0.3435 | f1_score=1.479e-06 | loss=0.3088 | loss_dist=0.1752 | loss_mean=0.1335
65/100 * Epoch 65 (valid): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=7.4523 | _timers/batch_time=0.2735 | _timers/data_time=0.0005 | _timers/model_time=0.2729 | f1_score=1.481e-06 | loss=0.3077 | loss_dist=0.1742 | loss_mean=0.1335
[2019-12-06 18:27:16,659] 
66/100 * Epoch 66 (train): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=5.1382 | _timers/batch_time=0.3932 | _timers/data_time=0.0004 | _timers/model_time=0.3928 | f1_score=1.421e-06 | loss=0.2137 | loss_dist=0.0807 | loss_mean=0.1331
66/100 * Epoch 66 (valid): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=11.2467 | _timers/batch_time=0.1813 | _timers/data_time=0.0003 | _timers/model_time=0.1810 | f1_score=1.423e-06 | loss=0.2127 | l

[2019-12-06 18:28:18,288] 
81/100 * Epoch 81 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.4692 | _timers/batch_time=0.3110 | _timers/data_time=0.0003 | _timers/model_time=0.3106 | f1_score=1.437e-06 | loss=0.3270 | loss_dist=0.1945 | loss_mean=0.1325
81/100 * Epoch 81 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=11.7204 | _timers/batch_time=0.1740 | _timers/data_time=0.0004 | _timers/model_time=0.1736 | f1_score=1.439e-06 | loss=0.3255 | loss_dist=0.1931 | loss_mean=0.1324
[2019-12-06 18:28:22,072] 
82/100 * Epoch 82 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.6110 | _timers/batch_time=0.3032 | _timers/data_time=0.0003 | _timers/model_time=0.3029 | f1_score=1.640e-06 | loss=0.1962 | loss_dist=0.0603 | loss_mean=0.1358
82/100 * Epoch 82 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=11.8844 | _timers/batch_time=0.1712 | _timers/data_time=0.0003 | _timers/model_time=0.1708 | f1_score=1.642e-06 | loss=0.1958 | 

[2019-12-06 18:29:26,015] 
97/100 * Epoch 97 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=5.8393 | _timers/batch_time=0.3446 | _timers/data_time=0.0003 | _timers/model_time=0.3443 | f1_score=1.518e-06 | loss=0.1814 | loss_dist=0.0481 | loss_mean=0.1333
97/100 * Epoch 97 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=10.6915 | _timers/batch_time=0.1938 | _timers/data_time=0.0002 | _timers/model_time=0.1935 | f1_score=1.520e-06 | loss=0.1810 | loss_dist=0.0477 | loss_mean=0.1333
[2019-12-06 18:29:30,058] 
98/100 * Epoch 98 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.2063 | _timers/batch_time=0.3240 | _timers/data_time=0.0003 | _timers/model_time=0.3236 | f1_score=1.525e-06 | loss=0.2151 | loss_dist=0.0839 | loss_mean=0.1312
98/100 * Epoch 98 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=11.0639 | _timers/batch_time=0.1841 | _timers/data_time=0.0004 | _timers/model_time=0.1836 | f1_score=1.526e-06 | loss=0.2146 | 

### Bernoulli p = 0.5

In [66]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)               
        
        #считаем с бинаризованной маской u
        #torch.bernoulli(torch.tensor([1.0]),)
        b_new = torch.bernoulli(torch.ones(x.shape[:2]),0.5).unsqueeze(2).to(self.emb.weight.device)
        #b_new = torch.repeat_interleave(bb(0.4), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [67]:
model_bin = IMDbRNN_bin(x[0].shape[0])
class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)

In [68]:
num_epochs = 100  # change me
logdir = "./logs/b50"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [69]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 18:34:08,611] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.0348 | _timers/batch_time=0.3343 | _timers/data_time=0.0003 | _timers/model_time=0.3340 | f1_score=4.040e-07 | loss=24.2843 | loss_dist=24.1560 | loss_mean=0.1283
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=10.6520 | _timers/batch_time=0.1943 | _timers/data_time=0.0003 | _timers/model_time=0.1940 | f1_score=4.040e-07 | loss=24.2843 | loss_dist=24.1560 | loss_mean=0.1283
[2019-12-06 18:34:12,822] 
2/100 * Epoch 2 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.9265 | _timers/batch_time=0.3413 | _timers/data_time=0.0003 | _timers/model_time=0.3409 | f1_score=7.147e-07 | loss=5.0791 | loss_dist=4.9562 | loss_mean=0.1230
2/100 * Epoch 2 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=11.1146 | _timers/batch_time=0.1859 | _timers/data_time=0.0004 | _timers/model_time=0.1855 | f1_score=7.147e-07 | loss=5.0791 | loss

[2019-12-06 18:35:14,407] 
17/100 * Epoch 17 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=6.5094 | _timers/batch_time=0.3081 | _timers/data_time=0.0003 | _timers/model_time=0.3077 | f1_score=5.519e-07 | loss=2.5328 | loss_dist=2.4046 | loss_mean=0.1282
17/100 * Epoch 17 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=11.4901 | _timers/batch_time=0.1775 | _timers/data_time=0.0004 | _timers/model_time=0.1770 | f1_score=5.519e-07 | loss=2.5328 | loss_dist=2.4046 | loss_mean=0.1282
[2019-12-06 18:35:18,234] 
18/100 * Epoch 18 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=6.5154 | _timers/batch_time=0.3077 | _timers/data_time=0.0006 | _timers/model_time=0.3071 | f1_score=9.825e-07 | loss=1.0114 | loss_dist=0.8853 | loss_mean=0.1261
18/100 * Epoch 18 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=11.6510 | _timers/batch_time=0.1750 | _timers/data_time=0.0004 | _timers/model_time=0.1745 | f1_score=9.825e-07 | loss=1.0114 | 

[2019-12-06 18:36:19,677] 
33/100 * Epoch 33 (train): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=6.3349 | _timers/batch_time=0.3170 | _timers/data_time=0.0003 | _timers/model_time=0.3167 | f1_score=9.034e-07 | loss=1.4922 | loss_dist=1.3672 | loss_mean=0.1250
33/100 * Epoch 33 (valid): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=10.1426 | _timers/batch_time=0.2009 | _timers/data_time=0.0004 | _timers/model_time=0.2005 | f1_score=9.034e-07 | loss=1.4922 | loss_dist=1.3672 | loss_mean=0.1250
[2019-12-06 18:36:23,668] 
34/100 * Epoch 34 (train): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=6.3335 | _timers/batch_time=0.3165 | _timers/data_time=0.0004 | _timers/model_time=0.3160 | f1_score=6.886e-07 | loss=14.2114 | loss_dist=14.0850 | loss_mean=0.1265
34/100 * Epoch 34 (valid): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=11.0508 | _timers/batch_time=0.1835 | _timers/data_time=0.0004 | _timers/model_time=0.1831 | f1_score=6.886e-07 | loss=14.2114

[2019-12-06 18:37:24,165] 
49/100 * Epoch 49 (train): _base/lr=0.0014 | _base/momentum=0.9000 | _timers/_fps=6.2584 | _timers/batch_time=0.3198 | _timers/data_time=0.0004 | _timers/model_time=0.3193 | f1_score=2.646e-07 | loss=22.5015 | loss_dist=22.3776 | loss_mean=0.1239
49/100 * Epoch 49 (valid): _base/lr=0.0014 | _base/momentum=0.9000 | _timers/_fps=11.3785 | _timers/batch_time=0.1789 | _timers/data_time=0.0002 | _timers/model_time=0.1787 | f1_score=2.646e-07 | loss=22.5015 | loss_dist=22.3776 | loss_mean=0.1239
[2019-12-06 18:37:28,132] 
50/100 * Epoch 50 (train): _base/lr=0.0014 | _base/momentum=0.9000 | _timers/_fps=6.3551 | _timers/batch_time=0.3163 | _timers/data_time=0.0003 | _timers/model_time=0.3160 | f1_score=1.592e-07 | loss=12.5136 | loss_dist=12.3879 | loss_mean=0.1258
50/100 * Epoch 50 (valid): _base/lr=0.0014 | _base/momentum=0.9000 | _timers/_fps=11.3304 | _timers/batch_time=0.1802 | _timers/data_time=0.0003 | _timers/model_time=0.1799 | f1_score=1.592e-07 | loss=12.

[2019-12-06 18:38:27,596] 
65/100 * Epoch 65 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.2262 | _timers/batch_time=0.3218 | _timers/data_time=0.0002 | _timers/model_time=0.3215 | f1_score=7.077e-07 | loss=2.9024 | loss_dist=2.7773 | loss_mean=0.1251
65/100 * Epoch 65 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=10.5392 | _timers/batch_time=0.1977 | _timers/data_time=0.0003 | _timers/model_time=0.1974 | f1_score=7.077e-07 | loss=2.9024 | loss_dist=2.7773 | loss_mean=0.1251
[2019-12-06 18:38:31,881] 
66/100 * Epoch 66 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=5.9060 | _timers/batch_time=0.3389 | _timers/data_time=0.0002 | _timers/model_time=0.3386 | f1_score=2.943e-07 | loss=12.2013 | loss_dist=12.0794 | loss_mean=0.1219
66/100 * Epoch 66 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=9.8835 | _timers/batch_time=0.2057 | _timers/data_time=0.0004 | _timers/model_time=0.2053 | f1_score=2.943e-07 | loss=12.2013 

[2019-12-06 18:39:29,633] 
81/100 * Epoch 81 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.5241 | _timers/batch_time=0.3072 | _timers/data_time=0.0003 | _timers/model_time=0.3068 | f1_score=9.280e-07 | loss=3.8589 | loss_dist=3.7336 | loss_mean=0.1253
81/100 * Epoch 81 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=11.7839 | _timers/batch_time=0.1738 | _timers/data_time=0.0003 | _timers/model_time=0.1735 | f1_score=9.280e-07 | loss=3.8589 | loss_dist=3.7336 | loss_mean=0.1253
[2019-12-06 18:39:33,936] 
82/100 * Epoch 82 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.3509 | _timers/batch_time=0.3187 | _timers/data_time=0.0003 | _timers/model_time=0.3183 | f1_score=7.649e-07 | loss=18.8165 | loss_dist=18.6904 | loss_mean=0.1262
82/100 * Epoch 82 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=10.3668 | _timers/batch_time=0.1947 | _timers/data_time=0.0002 | _timers/model_time=0.1945 | f1_score=7.649e-07 | loss=18.8165

[2019-12-06 18:40:39,939] 
97/100 * Epoch 97 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.2483 | _timers/batch_time=0.3210 | _timers/data_time=0.0004 | _timers/model_time=0.3205 | f1_score=9.544e-07 | loss=8.9184 | loss_dist=8.7955 | loss_mean=0.1230
97/100 * Epoch 97 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=10.7169 | _timers/batch_time=0.1910 | _timers/data_time=0.0002 | _timers/model_time=0.1907 | f1_score=9.544e-07 | loss=8.9184 | loss_dist=8.7955 | loss_mean=0.1230
[2019-12-06 18:40:44,240] 
98/100 * Epoch 98 (train): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=5.9723 | _timers/batch_time=0.3374 | _timers/data_time=0.0003 | _timers/model_time=0.3371 | f1_score=1.824e-07 | loss=17.6124 | loss_dist=17.4898 | loss_mean=0.1226
98/100 * Epoch 98 (valid): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=10.0376 | _timers/batch_time=0.2036 | _timers/data_time=0.0002 | _timers/model_time=0.2034 | f1_score=1.824e-07 | loss=17.6124

### Bernoulli p = 0.9

In [71]:
p = 0.9

In [72]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)               
        
        #считаем с бинаризованной маской u
        #torch.bernoulli(torch.tensor([1.0]),)
        b_new = torch.bernoulli(torch.ones(x.shape[:2]),p).unsqueeze(2).to(self.emb.weight.device)
        #b_new = torch.repeat_interleave(bb(0.4), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [73]:
model_bin = IMDbRNN_bin(x[0].shape[0])
class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)

In [74]:
num_epochs = 100  # change me
logdir = "./logs/b90"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [75]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 19:03:35,055] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=5.4257 | _timers/batch_time=0.3689 | _timers/data_time=0.0004 | _timers/model_time=0.3685 | f1_score=1.460e-06 | loss=0.4123 | loss_dist=0.0059 | loss_mean=0.4064
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=10.0839 | _timers/batch_time=0.2021 | _timers/data_time=0.0003 | _timers/model_time=0.2018 | f1_score=1.460e-06 | loss=0.4123 | loss_dist=0.0059 | loss_mean=0.4064
[2019-12-06 19:03:39,127] 
2/100 * Epoch 2 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=6.4065 | _timers/batch_time=0.3126 | _timers/data_time=0.0004 | _timers/model_time=0.3122 | f1_score=1.628e-06 | loss=0.4067 | loss_dist=0.0042 | loss_mean=0.4025
2/100 * Epoch 2 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=10.5202 | _timers/batch_time=0.1953 | _timers/data_time=0.0003 | _timers/model_time=0.1950 | f1_score=1.628e-06 | loss=0.4067 | loss_dis

[2019-12-06 19:04:43,426] 
17/100 * Epoch 17 (train): _base/lr=0.0024 | _base/momentum=0.9000 | _timers/_fps=6.1057 | _timers/batch_time=0.3285 | _timers/data_time=0.0003 | _timers/model_time=0.3282 | f1_score=1.627e-06 | loss=0.4139 | loss_dist=0.0061 | loss_mean=0.4078
17/100 * Epoch 17 (valid): _base/lr=0.0024 | _base/momentum=0.9000 | _timers/_fps=11.4799 | _timers/batch_time=0.1778 | _timers/data_time=0.0003 | _timers/model_time=0.1775 | f1_score=1.627e-06 | loss=0.4139 | loss_dist=0.0061 | loss_mean=0.4078
[2019-12-06 19:04:47,350] 
18/100 * Epoch 18 (train): _base/lr=0.0024 | _base/momentum=0.9000 | _timers/_fps=6.3384 | _timers/batch_time=0.3163 | _timers/data_time=0.0003 | _timers/model_time=0.3159 | f1_score=1.762e-06 | loss=0.4156 | loss_dist=0.0105 | loss_mean=0.4051
18/100 * Epoch 18 (valid): _base/lr=0.0024 | _base/momentum=0.9000 | _timers/_fps=11.6235 | _timers/batch_time=0.1754 | _timers/data_time=0.0003 | _timers/model_time=0.1751 | f1_score=1.762e-06 | loss=0.4156 | 

[2019-12-06 19:05:48,529] 
33/100 * Epoch 33 (train): _base/lr=0.0018 | _base/momentum=0.9000 | _timers/_fps=6.1598 | _timers/batch_time=0.3248 | _timers/data_time=0.0004 | _timers/model_time=0.3244 | f1_score=1.581e-06 | loss=0.4157 | loss_dist=0.0105 | loss_mean=0.4052
33/100 * Epoch 33 (valid): _base/lr=0.0018 | _base/momentum=0.9000 | _timers/_fps=11.5970 | _timers/batch_time=0.1762 | _timers/data_time=0.0003 | _timers/model_time=0.1760 | f1_score=1.581e-06 | loss=0.4157 | loss_dist=0.0105 | loss_mean=0.4052
[2019-12-06 19:05:52,708] 
34/100 * Epoch 34 (train): _base/lr=0.0018 | _base/momentum=0.9000 | _timers/_fps=5.9783 | _timers/batch_time=0.3369 | _timers/data_time=0.0004 | _timers/model_time=0.3365 | f1_score=1.640e-06 | loss=0.4096 | loss_dist=0.0078 | loss_mean=0.4018
34/100 * Epoch 34 (valid): _base/lr=0.0018 | _base/momentum=0.9000 | _timers/_fps=10.3590 | _timers/batch_time=0.1977 | _timers/data_time=0.0004 | _timers/model_time=0.1973 | f1_score=1.640e-06 | loss=0.4096 | 

[2019-12-06 19:06:55,347] 
49/100 * Epoch 49 (train): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=5.7135 | _timers/batch_time=0.3510 | _timers/data_time=0.0003 | _timers/model_time=0.3507 | f1_score=1.390e-06 | loss=0.4668 | loss_dist=0.0653 | loss_mean=0.4015
49/100 * Epoch 49 (valid): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=11.1566 | _timers/batch_time=0.1836 | _timers/data_time=0.0003 | _timers/model_time=0.1832 | f1_score=1.390e-06 | loss=0.4668 | loss_dist=0.0653 | loss_mean=0.4015
[2019-12-06 19:06:59,552] 
50/100 * Epoch 50 (train): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=6.2091 | _timers/batch_time=0.3252 | _timers/data_time=0.0003 | _timers/model_time=0.3248 | f1_score=1.427e-06 | loss=0.4353 | loss_dist=0.0297 | loss_mean=0.4056
50/100 * Epoch 50 (valid): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=10.1371 | _timers/batch_time=0.2006 | _timers/data_time=0.0003 | _timers/model_time=0.2002 | f1_score=1.427e-06 | loss=0.4353 | 

[2019-12-06 19:07:59,009] 
65/100 * Epoch 65 (train): _base/lr=0.0009 | _base/momentum=0.9000 | _timers/_fps=6.3808 | _timers/batch_time=0.3159 | _timers/data_time=0.0003 | _timers/model_time=0.3156 | f1_score=1.685e-06 | loss=0.4139 | loss_dist=0.0064 | loss_mean=0.4075
65/100 * Epoch 65 (valid): _base/lr=0.0009 | _base/momentum=0.9000 | _timers/_fps=10.8258 | _timers/batch_time=0.1887 | _timers/data_time=0.0003 | _timers/model_time=0.1884 | f1_score=1.685e-06 | loss=0.4139 | loss_dist=0.0064 | loss_mean=0.4075
[2019-12-06 19:08:03,304] 
66/100 * Epoch 66 (train): _base/lr=0.0009 | _base/momentum=0.9000 | _timers/_fps=5.9379 | _timers/batch_time=0.3380 | _timers/data_time=0.0003 | _timers/model_time=0.3377 | f1_score=1.552e-06 | loss=0.4232 | loss_dist=0.0171 | loss_mean=0.4061
66/100 * Epoch 66 (valid): _base/lr=0.0009 | _base/momentum=0.9000 | _timers/_fps=10.4001 | _timers/batch_time=0.1965 | _timers/data_time=0.0003 | _timers/model_time=0.1963 | f1_score=1.552e-06 | loss=0.4232 | 

[2019-12-06 19:09:02,466] 
81/100 * Epoch 81 (train): _base/lr=0.0007 | _base/momentum=0.9000 | _timers/_fps=6.5761 | _timers/batch_time=0.3052 | _timers/data_time=0.0002 | _timers/model_time=0.3050 | f1_score=1.336e-06 | loss=0.4049 | loss_dist=0.0045 | loss_mean=0.4004
81/100 * Epoch 81 (valid): _base/lr=0.0007 | _base/momentum=0.9000 | _timers/_fps=9.7261 | _timers/batch_time=0.2093 | _timers/data_time=0.0006 | _timers/model_time=0.2087 | f1_score=1.336e-06 | loss=0.4049 | loss_dist=0.0045 | loss_mean=0.4004
[2019-12-06 19:09:06,457] 
82/100 * Epoch 82 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.5747 | _timers/batch_time=0.3056 | _timers/data_time=0.0003 | _timers/model_time=0.3052 | f1_score=1.270e-06 | loss=0.4192 | loss_dist=0.0151 | loss_mean=0.4041
82/100 * Epoch 82 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=11.6705 | _timers/batch_time=0.1752 | _timers/data_time=0.0003 | _timers/model_time=0.1749 | f1_score=1.270e-06 | loss=0.4192 | l

[2019-12-06 19:10:04,055] 
97/100 * Epoch 97 (train): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=6.4702 | _timers/batch_time=0.3106 | _timers/data_time=0.0003 | _timers/model_time=0.3102 | f1_score=1.421e-06 | loss=0.4277 | loss_dist=0.0202 | loss_mean=0.4074
97/100 * Epoch 97 (valid): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=11.6600 | _timers/batch_time=0.1756 | _timers/data_time=0.0003 | _timers/model_time=0.1752 | f1_score=1.421e-06 | loss=0.4277 | loss_dist=0.0202 | loss_mean=0.4074
[2019-12-06 19:10:07,895] 
98/100 * Epoch 98 (train): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=6.5459 | _timers/batch_time=0.3063 | _timers/data_time=0.0003 | _timers/model_time=0.3059 | f1_score=1.317e-06 | loss=2.9827 | loss_dist=2.5786 | loss_mean=0.4041
98/100 * Epoch 98 (valid): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=11.7777 | _timers/batch_time=0.1730 | _timers/data_time=0.0003 | _timers/model_time=0.1727 | f1_score=1.317e-06 | loss=2.9827 | 

### Bernoulli p = 0.2

In [76]:
p = 0.2

In [77]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)               
        
        #считаем с бинаризованной маской u
        #torch.bernoulli(torch.tensor([1.0]),)
        b_new = torch.bernoulli(torch.ones(x.shape[:2]),p).unsqueeze(2).to(self.emb.weight.device)
        #b_new = torch.repeat_interleave(bb(0.4), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [78]:
model_bin = IMDbRNN_bin(x[0].shape[0])
class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)

In [79]:
num_epochs = 100  # change me
logdir = "./logs/b20"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [80]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 19:10:21,998] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.4760 | _timers/batch_time=0.3105 | _timers/data_time=0.0004 | _timers/model_time=0.3101 | f1_score=8.282e-08 | loss=44.1146 | loss_dist=44.0937 | loss_mean=0.0209
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=11.7129 | _timers/batch_time=0.1738 | _timers/data_time=0.0004 | _timers/model_time=0.1734 | f1_score=8.282e-08 | loss=44.1146 | loss_dist=44.0937 | loss_mean=0.0209
[2019-12-06 19:10:25,875] 
2/100 * Epoch 2 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=6.4616 | _timers/batch_time=0.3110 | _timers/data_time=0.0003 | _timers/model_time=0.3107 | f1_score=1.666e-07 | loss=40.4243 | loss_dist=40.4041 | loss_mean=0.0202
2/100 * Epoch 2 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=11.2197 | _timers/batch_time=0.1814 | _timers/data_time=0.0004 | _timers/model_time=0.1810 | f1_score=1.666e-07 | loss=40.4243 | l

[2019-12-06 19:11:23,137] 
17/100 * Epoch 17 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=6.5978 | _timers/batch_time=0.3040 | _timers/data_time=0.0003 | _timers/model_time=0.3036 | f1_score=3.815e-07 | loss=29.6728 | loss_dist=29.6522 | loss_mean=0.0206
17/100 * Epoch 17 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=11.6172 | _timers/batch_time=0.1750 | _timers/data_time=0.0005 | _timers/model_time=0.1745 | f1_score=3.815e-07 | loss=29.6728 | loss_dist=29.6522 | loss_mean=0.0206
[2019-12-06 19:11:26,991] 
18/100 * Epoch 18 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=6.5125 | _timers/batch_time=0.3082 | _timers/data_time=0.0004 | _timers/model_time=0.3078 | f1_score=2.853e-07 | loss=9.7974 | loss_dist=9.7764 | loss_mean=0.0210
18/100 * Epoch 18 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=11.5090 | _timers/batch_time=0.1767 | _timers/data_time=0.0005 | _timers/model_time=0.1762 | f1_score=2.853e-07 | loss=9.797

[2019-12-06 19:12:24,311] 
33/100 * Epoch 33 (train): _base/lr=0.0022 | _base/momentum=0.9000 | _timers/_fps=6.5900 | _timers/batch_time=0.3047 | _timers/data_time=0.0003 | _timers/model_time=0.3044 | f1_score=2.745e-07 | loss=23.9852 | loss_dist=23.9643 | loss_mean=0.0209
33/100 * Epoch 33 (valid): _base/lr=0.0022 | _base/momentum=0.9000 | _timers/_fps=11.4609 | _timers/batch_time=0.1779 | _timers/data_time=0.0005 | _timers/model_time=0.1774 | f1_score=2.745e-07 | loss=23.9852 | loss_dist=23.9643 | loss_mean=0.0209
[2019-12-06 19:12:28,113] 
34/100 * Epoch 34 (train): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=6.5793 | _timers/batch_time=0.3054 | _timers/data_time=0.0004 | _timers/model_time=0.3050 | f1_score=2.523e-07 | loss=23.4727 | loss_dist=23.4525 | loss_mean=0.0203
34/100 * Epoch 34 (valid): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=11.6502 | _timers/batch_time=0.1753 | _timers/data_time=0.0005 | _timers/model_time=0.1748 | f1_score=2.523e-07 | loss=23.

[2019-12-06 19:13:25,464] 
49/100 * Epoch 49 (train): _base/lr=0.0014 | _base/momentum=0.9000 | _timers/_fps=6.6157 | _timers/batch_time=0.3039 | _timers/data_time=0.0003 | _timers/model_time=0.3035 | f1_score=1.278e-07 | loss=29.1340 | loss_dist=29.1146 | loss_mean=0.0194
49/100 * Epoch 49 (valid): _base/lr=0.0014 | _base/momentum=0.9000 | _timers/_fps=11.5164 | _timers/batch_time=0.1774 | _timers/data_time=0.0005 | _timers/model_time=0.1768 | f1_score=1.278e-07 | loss=29.1340 | loss_dist=29.1146 | loss_mean=0.0194
[2019-12-06 19:13:29,280] 
50/100 * Epoch 50 (train): _base/lr=0.0014 | _base/momentum=0.9000 | _timers/_fps=6.6524 | _timers/batch_time=0.3017 | _timers/data_time=0.0004 | _timers/model_time=0.3013 | f1_score=1.487e-07 | loss=43.7522 | loss_dist=43.7314 | loss_mean=0.0208
50/100 * Epoch 50 (valid): _base/lr=0.0014 | _base/momentum=0.9000 | _timers/_fps=11.5598 | _timers/batch_time=0.1763 | _timers/data_time=0.0004 | _timers/model_time=0.1758 | f1_score=1.487e-07 | loss=43.

[2019-12-06 19:14:26,519] 
65/100 * Epoch 65 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.6293 | _timers/batch_time=0.3025 | _timers/data_time=0.0003 | _timers/model_time=0.3021 | f1_score=3.429e-07 | loss=8.0419 | loss_dist=8.0215 | loss_mean=0.0204
65/100 * Epoch 65 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=11.7883 | _timers/batch_time=0.1731 | _timers/data_time=0.0003 | _timers/model_time=0.1727 | f1_score=3.429e-07 | loss=8.0419 | loss_dist=8.0215 | loss_mean=0.0204
[2019-12-06 19:14:30,332] 
66/100 * Epoch 66 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.6194 | _timers/batch_time=0.3035 | _timers/data_time=0.0004 | _timers/model_time=0.3030 | f1_score=1.234e-07 | loss=35.8699 | loss_dist=35.8511 | loss_mean=0.0187
66/100 * Epoch 66 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=11.8697 | _timers/batch_time=0.1712 | _timers/data_time=0.0002 | _timers/model_time=0.1710 | f1_score=1.234e-07 | loss=35.8699

[2019-12-06 19:15:29,555] 
81/100 * Epoch 81 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.2849 | _timers/batch_time=0.3198 | _timers/data_time=0.0003 | _timers/model_time=0.3195 | f1_score=1.118e-07 | loss=25.7942 | loss_dist=25.7746 | loss_mean=0.0196
81/100 * Epoch 81 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=10.9493 | _timers/batch_time=0.1868 | _timers/data_time=0.0003 | _timers/model_time=0.1864 | f1_score=1.118e-07 | loss=25.7942 | loss_dist=25.7746 | loss_mean=0.0196
[2019-12-06 19:15:33,521] 
82/100 * Epoch 82 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.4341 | _timers/batch_time=0.3121 | _timers/data_time=0.0003 | _timers/model_time=0.3118 | f1_score=2.945e-07 | loss=29.1933 | loss_dist=29.1723 | loss_mean=0.0211
82/100 * Epoch 82 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=11.7569 | _timers/batch_time=0.1743 | _timers/data_time=0.0003 | _timers/model_time=0.1740 | f1_score=2.945e-07 | loss=29.

[2019-12-06 19:16:34,545] 
97/100 * Epoch 97 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.5857 | _timers/batch_time=0.3045 | _timers/data_time=0.0003 | _timers/model_time=0.3041 | f1_score=2.238e-07 | loss=19.6927 | loss_dist=19.6745 | loss_mean=0.0183
97/100 * Epoch 97 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=11.2647 | _timers/batch_time=0.1810 | _timers/data_time=0.0004 | _timers/model_time=0.1805 | f1_score=2.238e-07 | loss=19.6927 | loss_dist=19.6745 | loss_mean=0.0183
[2019-12-06 19:16:38,396] 
98/100 * Epoch 98 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.4254 | _timers/batch_time=0.3118 | _timers/data_time=0.0005 | _timers/model_time=0.3113 | f1_score=1.175e-07 | loss=51.5079 | loss_dist=51.4893 | loss_mean=0.0186
98/100 * Epoch 98 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=11.5065 | _timers/batch_time=0.1779 | _timers/data_time=0.0003 | _timers/model_time=0.1776 | f1_score=1.175e-07 | loss=51.

In [81]:
model_bin.eval()
predictions, labels = [], []
with torch.no_grad():
    for x,y in tqdm(test_loader):
        try:
        
            predictions.append(model_bin(x)['cls'].detach().cpu().numpy())
            labels.append(y.detach().cpu().numpy())
        except: pass
predictions = np.concatenate(predictions)
labels = np.concatenate(labels)

roc_auc_score(labels, predictions)

100%|██████████| 25/25 [00:40<00:00,  1.64s/it]


0.6008593376

### Bernoulli p = 0.8

In [82]:
p = 0.8

In [83]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)               
        
        #считаем с бинаризованной маской u
        #torch.bernoulli(torch.tensor([1.0]),)
        b_new = torch.bernoulli(torch.ones(x.shape[:2]),p).unsqueeze(2).to(self.emb.weight.device)
        #b_new = torch.repeat_interleave(bb(0.4), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [84]:
model_bin = IMDbRNN_bin(x[0].shape[0])
class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)

In [85]:
num_epochs = 100  # change me
logdir = "./logs/b80"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [86]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 19:17:33,521] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=6.2646 | _timers/batch_time=0.3201 | _timers/data_time=0.0003 | _timers/model_time=0.3198 | f1_score=1.598e-06 | loss=0.3440 | loss_dist=0.0239 | loss_mean=0.3201
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=11.2860 | _timers/batch_time=0.1796 | _timers/data_time=0.0003 | _timers/model_time=0.1793 | f1_score=1.598e-06 | loss=0.3440 | loss_dist=0.0239 | loss_mean=0.3201
[2019-12-06 19:17:37,423] 
2/100 * Epoch 2 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=6.4092 | _timers/batch_time=0.3135 | _timers/data_time=0.0003 | _timers/model_time=0.3131 | f1_score=1.488e-06 | loss=0.3445 | loss_dist=0.0282 | loss_mean=0.3163
2/100 * Epoch 2 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=11.4898 | _timers/batch_time=0.1763 | _timers/data_time=0.0003 | _timers/model_time=0.1760 | f1_score=1.488e-06 | loss=0.3445 | loss_dis

[2019-12-06 19:18:38,774] 
17/100 * Epoch 17 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=6.2604 | _timers/batch_time=0.3209 | _timers/data_time=0.0003 | _timers/model_time=0.3206 | f1_score=1.430e-06 | loss=0.3476 | loss_dist=0.0264 | loss_mean=0.3212
17/100 * Epoch 17 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=11.5361 | _timers/batch_time=0.1766 | _timers/data_time=0.0002 | _timers/model_time=0.1763 | f1_score=1.430e-06 | loss=0.3476 | loss_dist=0.0264 | loss_mean=0.3212
[2019-12-06 19:18:42,634] 
18/100 * Epoch 18 (train): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=6.5476 | _timers/batch_time=0.3067 | _timers/data_time=0.0002 | _timers/model_time=0.3065 | f1_score=1.295e-06 | loss=1.7903 | loss_dist=1.4686 | loss_mean=0.3217
18/100 * Epoch 18 (valid): _base/lr=0.0027 | _base/momentum=0.9000 | _timers/_fps=11.4291 | _timers/batch_time=0.1785 | _timers/data_time=0.0003 | _timers/model_time=0.1782 | f1_score=1.295e-06 | loss=1.7903 | 

[2019-12-06 19:19:43,738] 
33/100 * Epoch 33 (train): _base/lr=0.0022 | _base/momentum=0.9000 | _timers/_fps=5.9268 | _timers/batch_time=0.3384 | _timers/data_time=0.0004 | _timers/model_time=0.3380 | f1_score=1.364e-06 | loss=0.3457 | loss_dist=0.0267 | loss_mean=0.3190
33/100 * Epoch 33 (valid): _base/lr=0.0022 | _base/momentum=0.9000 | _timers/_fps=10.1762 | _timers/batch_time=0.1991 | _timers/data_time=0.0002 | _timers/model_time=0.1988 | f1_score=1.364e-06 | loss=0.3457 | loss_dist=0.0267 | loss_mean=0.3190
[2019-12-06 19:19:48,171] 
34/100 * Epoch 34 (train): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=5.6590 | _timers/batch_time=0.3542 | _timers/data_time=0.0004 | _timers/model_time=0.3537 | f1_score=1.247e-06 | loss=3.3798 | loss_dist=3.0601 | loss_mean=0.3197
34/100 * Epoch 34 (valid): _base/lr=0.0020 | _base/momentum=0.9000 | _timers/_fps=9.9862 | _timers/batch_time=0.2025 | _timers/data_time=0.0003 | _timers/model_time=0.2021 | f1_score=1.247e-06 | loss=3.3798 | l

[2019-12-06 19:20:49,223] 
49/100 * Epoch 49 (train): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=6.4429 | _timers/batch_time=0.3114 | _timers/data_time=0.0002 | _timers/model_time=0.3112 | f1_score=1.429e-06 | loss=0.4191 | loss_dist=0.1049 | loss_mean=0.3141
49/100 * Epoch 49 (valid): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=11.3275 | _timers/batch_time=0.1793 | _timers/data_time=0.0002 | _timers/model_time=0.1791 | f1_score=1.429e-06 | loss=0.4191 | loss_dist=0.1049 | loss_mean=0.3141
[2019-12-06 19:20:53,075] 
50/100 * Epoch 50 (train): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=6.5503 | _timers/batch_time=0.3060 | _timers/data_time=0.0002 | _timers/model_time=0.3057 | f1_score=1.501e-06 | loss=0.3790 | loss_dist=0.0557 | loss_mean=0.3233
50/100 * Epoch 50 (valid): _base/lr=0.0016 | _base/momentum=0.9000 | _timers/_fps=11.3406 | _timers/batch_time=0.1790 | _timers/data_time=0.0004 | _timers/model_time=0.1786 | f1_score=1.501e-06 | loss=0.3790 | 

[2019-12-06 19:21:51,390] 
65/100 * Epoch 65 (train): _base/lr=0.0012 | _base/momentum=0.9000 | _timers/_fps=6.5043 | _timers/batch_time=0.3088 | _timers/data_time=0.0003 | _timers/model_time=0.3085 | f1_score=1.513e-06 | loss=0.3897 | loss_dist=0.0628 | loss_mean=0.3269
65/100 * Epoch 65 (valid): _base/lr=0.0012 | _base/momentum=0.9000 | _timers/_fps=11.7615 | _timers/batch_time=0.1724 | _timers/data_time=0.0003 | _timers/model_time=0.1720 | f1_score=1.513e-06 | loss=0.3897 | loss_dist=0.0628 | loss_mean=0.3269
[2019-12-06 19:21:55,249] 
66/100 * Epoch 66 (train): _base/lr=0.0012 | _base/momentum=0.9000 | _timers/_fps=6.5303 | _timers/batch_time=0.3074 | _timers/data_time=0.0004 | _timers/model_time=0.3070 | f1_score=1.413e-06 | loss=0.3512 | loss_dist=0.0313 | loss_mean=0.3199
66/100 * Epoch 66 (valid): _base/lr=0.0012 | _base/momentum=0.9000 | _timers/_fps=11.6167 | _timers/batch_time=0.1755 | _timers/data_time=0.0003 | _timers/model_time=0.1751 | f1_score=1.413e-06 | loss=0.3512 | 

[2019-12-06 19:22:53,756] 
81/100 * Epoch 81 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.5173 | _timers/batch_time=0.3081 | _timers/data_time=0.0004 | _timers/model_time=0.3076 | f1_score=1.144e-06 | loss=7.1780 | loss_dist=6.8574 | loss_mean=0.3206
81/100 * Epoch 81 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=11.5806 | _timers/batch_time=0.1749 | _timers/data_time=0.0003 | _timers/model_time=0.1746 | f1_score=1.144e-06 | loss=7.1780 | loss_dist=6.8574 | loss_mean=0.3206
[2019-12-06 19:22:57,601] 
82/100 * Epoch 82 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.5928 | _timers/batch_time=0.3044 | _timers/data_time=0.0003 | _timers/model_time=0.3040 | f1_score=1.280e-06 | loss=0.3351 | loss_dist=0.0214 | loss_mean=0.3137
82/100 * Epoch 82 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=11.5672 | _timers/batch_time=0.1763 | _timers/data_time=0.0003 | _timers/model_time=0.1760 | f1_score=1.280e-06 | loss=0.3351 | 

[2019-12-06 19:23:56,435] 
97/100 * Epoch 97 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.4687 | _timers/batch_time=0.3102 | _timers/data_time=0.0003 | _timers/model_time=0.3099 | f1_score=1.567e-06 | loss=0.3516 | loss_dist=0.0318 | loss_mean=0.3198
97/100 * Epoch 97 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=11.5607 | _timers/batch_time=0.1762 | _timers/data_time=0.0002 | _timers/model_time=0.1759 | f1_score=1.567e-06 | loss=0.3516 | loss_dist=0.0318 | loss_mean=0.3198
[2019-12-06 19:24:00,300] 
98/100 * Epoch 98 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.4630 | _timers/batch_time=0.3105 | _timers/data_time=0.0003 | _timers/model_time=0.3102 | f1_score=1.024e-06 | loss=5.7255 | loss_dist=5.4093 | loss_mean=0.3162
98/100 * Epoch 98 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=11.5353 | _timers/batch_time=0.1768 | _timers/data_time=0.0003 | _timers/model_time=0.1764 | f1_score=1.024e-06 | loss=5.7255 | 

### Bernoulli p = 0.99

In [87]:
p = 0.99

In [88]:
from torch.autograd import Variable


class IMDbRNN_bin(nn.Module):
    def __init__(self, sent_size):
        super().__init__()
#         self.rnn = nn.GRU(embedding_size, hidden_size)
        self.rnn = model.rnn
        self.cls = model.cls
        self.emb = model.emb
        
        self.u = Variable(torch.randn(1,sent_size,1), requires_grad=True)
    
    def prepare_sequence(self, x):
        l = torch.LongTensor([len(xi) for xi in x])
        x = torch.split_with_sizes(self.emb(torch.cat(x).to(self.emb.weight.device)), l.unbind(0))
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=5)
        return x, l
    
    def forward(self, x):
#         print(x.shape,x[0].shape)
        
        x = [i for i in x]
        x, l = self.prepare_sequence(x)
        
        #считаем с единичной маской
        b = torch.ones(x.shape[:2]).unsqueeze(2).to(self.emb.weight.device)
        h,o = self.rnn(x,b,l)               
        
        #считаем с бинаризованной маской u
        #torch.bernoulli(torch.tensor([1.0]),)
        b_new = torch.bernoulli(torch.ones(x.shape[:2]),p).unsqueeze(2).to(self.emb.weight.device)
        #b_new = torch.repeat_interleave(bb(0.4), repeats=len(x), dim=0).to(self.emb.weight.device)
        h_new,o_new = self.rnn(x,b_new,l)
        
        #расстояние между выходами сети на полном и прореженом предложении
        dist = ((o-o_new)**2).sum(1)
        
        #среднее значение бинарной маски
        mean = b_new.squeeze(2).mean(1)
        
        #финальное предсказание класса на прореженом сообщении
        target = self.cls(o_new)
        
        return {'cls': target, 'dist': dist, 'mean': mean}

In [89]:
model_bin = IMDbRNN_bin(x[0].shape[0])
class Dataset(BaseDataset):

    
    def __init__(
            self, x, y, idx
    ):
        
        self.x = x
        self.y = y
        self.idx = idx
    
    def __getitem__(self, idx):

        return {'features': self.x[self.idx], 'target_cls':self.y[self.idx],
                'target_dist': torch.tensor(0).float(), 'target_mean': torch.tensor(0).float()}
        
    def __len__(self):
        return 8


new_train_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)
new_val_loader = DataLoader(Dataset(x,y,0), shuffle=True, batch_size=2)

In [90]:
num_epochs = 100  # change me
logdir = "./logs/b99"
loaders = {
    "train": new_train_loader,
    "valid": new_val_loader
}

criterion = {
   "mse": nn.MSELoss()
}

optimizer = torch.optim.Adam([
    {'params': model_bin.u, 'lr': 3e-3}])


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=4)

runner = SupervisedRunner(input_target_key=None,output_key=None)

In [91]:
runner.train(
    model=model_bin,
    criterion=criterion,
    scheduler=scheduler,
    callbacks=[   
               CriterionCallback(prefix="loss_dist", input_key='target_dist', output_key='dist',
                     criterion_key='mse', multiplier=0.5),
               CriterionCallback(prefix="loss_mean", input_key='target_mean', output_key='mean',
                     criterion_key='mse', multiplier=0.5),
        
               CriterionAggregatorCallback(prefix="loss",loss_keys=['loss_dist','loss_mean']),
#                AUCCallback(num_classes=1, input_key='target_cls', output_key='cls'), 
                F1ScoreCallback(input_key='target_cls', output_key='cls'),
               CheckpointCallback(save_n_best=3)
                ],

    optimizer=optimizer,
    main_metric='loss',
    minimize_metric=True,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
#     fp16={"opt_level": "O1"},
    verbose=False
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[2019-12-06 20:00:32,519] 
1/100 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=5.8344 | _timers/batch_time=0.3453 | _timers/data_time=0.0003 | _timers/model_time=0.3449 | f1_score=1.611e-06 | loss=0.4885 | loss_dist=6.300e-06 | loss_mean=0.4885
1/100 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=10.2343 | _timers/batch_time=0.1994 | _timers/data_time=0.0003 | _timers/model_time=0.1991 | f1_score=1.611e-06 | loss=0.4885 | loss_dist=6.300e-06 | loss_mean=0.4885
[2019-12-06 20:00:37,207] 
2/100 * Epoch 2 (train): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=5.7884 | _timers/batch_time=0.3467 | _timers/data_time=0.0004 | _timers/model_time=0.3463 | f1_score=1.635e-06 | loss=0.4897 | loss_dist=4.855e-06 | loss_mean=0.4897
2/100 * Epoch 2 (valid): _base/lr=0.0030 | _base/momentum=0.9000 | _timers/_fps=8.6677 | _timers/batch_time=0.2389 | _timers/data_time=0.0005 | _timers/model_time=0.2384 | f1_score=1.635e-06 | loss=0.4897 | 

[2019-12-06 20:01:41,301] 
17/100 * Epoch 17 (train): _base/lr=0.0022 | _base/momentum=0.9000 | _timers/_fps=6.2698 | _timers/batch_time=0.3212 | _timers/data_time=0.0004 | _timers/model_time=0.3208 | f1_score=1.572e-06 | loss=0.4902 | loss_dist=0.0002 | loss_mean=0.4900
17/100 * Epoch 17 (valid): _base/lr=0.0022 | _base/momentum=0.9000 | _timers/_fps=11.2537 | _timers/batch_time=0.1822 | _timers/data_time=0.0002 | _timers/model_time=0.1819 | f1_score=1.572e-06 | loss=0.4902 | loss_dist=0.0002 | loss_mean=0.4900
[2019-12-06 20:01:45,276] 
18/100 * Epoch 18 (train): _base/lr=0.0022 | _base/momentum=0.9000 | _timers/_fps=6.5044 | _timers/batch_time=0.3086 | _timers/data_time=0.0004 | _timers/model_time=0.3081 | f1_score=1.616e-06 | loss=0.4896 | loss_dist=6.700e-06 | loss_mean=0.4896
18/100 * Epoch 18 (valid): _base/lr=0.0022 | _base/momentum=0.9000 | _timers/_fps=10.1222 | _timers/batch_time=0.1993 | _timers/data_time=0.0002 | _timers/model_time=0.1991 | f1_score=1.616e-06 | loss=0.4896

[2019-12-06 20:02:47,586] 
33/100 * Epoch 33 (train): _base/lr=0.0018 | _base/momentum=0.9000 | _timers/_fps=6.1492 | _timers/batch_time=0.3272 | _timers/data_time=0.0003 | _timers/model_time=0.3269 | f1_score=1.604e-06 | loss=0.4900 | loss_dist=2.742e-06 | loss_mean=0.4900
33/100 * Epoch 33 (valid): _base/lr=0.0018 | _base/momentum=0.9000 | _timers/_fps=10.6862 | _timers/batch_time=0.1911 | _timers/data_time=0.0003 | _timers/model_time=0.1908 | f1_score=1.604e-06 | loss=0.4900 | loss_dist=2.742e-06 | loss_mean=0.4900
[2019-12-06 20:02:51,705] 
34/100 * Epoch 34 (train): _base/lr=0.0018 | _base/momentum=0.9000 | _timers/_fps=6.0575 | _timers/batch_time=0.3312 | _timers/data_time=0.0003 | _timers/model_time=0.3308 | f1_score=1.613e-06 | loss=0.4892 | loss_dist=3.705e-09 | loss_mean=0.4892
34/100 * Epoch 34 (valid): _base/lr=0.0018 | _base/momentum=0.9000 | _timers/_fps=11.1244 | _timers/batch_time=0.1849 | _timers/data_time=0.0004 | _timers/model_time=0.1845 | f1_score=1.613e-06 | loss=

[2019-12-06 20:03:55,187] 
49/100 * Epoch 49 (train): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=5.8413 | _timers/batch_time=0.3467 | _timers/data_time=0.0004 | _timers/model_time=0.3462 | f1_score=1.520e-06 | loss=0.4911 | loss_dist=0.0004 | loss_mean=0.4907
49/100 * Epoch 49 (valid): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=10.6046 | _timers/batch_time=0.1954 | _timers/data_time=0.0005 | _timers/model_time=0.1948 | f1_score=1.520e-06 | loss=0.4911 | loss_dist=0.0004 | loss_mean=0.4907
[2019-12-06 20:03:59,294] 
50/100 * Epoch 50 (train): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=6.1822 | _timers/batch_time=0.3266 | _timers/data_time=0.0003 | _timers/model_time=0.3262 | f1_score=1.609e-06 | loss=0.4905 | loss_dist=1.240e-08 | loss_mean=0.4905
50/100 * Epoch 50 (valid): _base/lr=0.0013 | _base/momentum=0.9000 | _timers/_fps=10.7855 | _timers/batch_time=0.1890 | _timers/data_time=0.0003 | _timers/model_time=0.1887 | f1_score=1.609e-06 | loss=0.4905

[2019-12-06 20:05:01,876] 
65/100 * Epoch 65 (train): _base/lr=0.0009 | _base/momentum=0.9000 | _timers/_fps=5.9937 | _timers/batch_time=0.3342 | _timers/data_time=0.0004 | _timers/model_time=0.3337 | f1_score=1.554e-06 | loss=0.4895 | loss_dist=0.0002 | loss_mean=0.4894
65/100 * Epoch 65 (valid): _base/lr=0.0009 | _base/momentum=0.9000 | _timers/_fps=10.9030 | _timers/batch_time=0.1870 | _timers/data_time=0.0004 | _timers/model_time=0.1866 | f1_score=1.554e-06 | loss=0.4895 | loss_dist=0.0002 | loss_mean=0.4894
[2019-12-06 20:05:06,182] 
66/100 * Epoch 66 (train): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=6.1061 | _timers/batch_time=0.3283 | _timers/data_time=0.0004 | _timers/model_time=0.3279 | f1_score=1.623e-06 | loss=0.4889 | loss_dist=0.0003 | loss_mean=0.4886
66/100 * Epoch 66 (valid): _base/lr=0.0008 | _base/momentum=0.9000 | _timers/_fps=10.0226 | _timers/batch_time=0.2035 | _timers/data_time=0.0002 | _timers/model_time=0.2032 | f1_score=1.623e-06 | loss=0.4889 | 

[2019-12-06 20:06:07,636] 
81/100 * Epoch 81 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=6.2121 | _timers/batch_time=0.3220 | _timers/data_time=0.0004 | _timers/model_time=0.3216 | f1_score=1.607e-06 | loss=0.4888 | loss_dist=5.937e-08 | loss_mean=0.4888
81/100 * Epoch 81 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=10.8772 | _timers/batch_time=0.1873 | _timers/data_time=0.0004 | _timers/model_time=0.1869 | f1_score=1.607e-06 | loss=0.4888 | loss_dist=5.937e-08 | loss_mean=0.4888
[2019-12-06 20:06:11,826] 
82/100 * Epoch 82 (train): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=5.7567 | _timers/batch_time=0.3480 | _timers/data_time=0.0004 | _timers/model_time=0.3475 | f1_score=1.641e-06 | loss=0.4908 | loss_dist=0.0002 | loss_mean=0.4906
82/100 * Epoch 82 (valid): _base/lr=0.0006 | _base/momentum=0.9000 | _timers/_fps=11.4310 | _timers/batch_time=0.1797 | _timers/data_time=0.0003 | _timers/model_time=0.1793 | f1_score=1.641e-06 | loss=0.4

[2019-12-06 20:07:13,699] 
97/100 * Epoch 97 (train): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=6.0985 | _timers/batch_time=0.3297 | _timers/data_time=0.0003 | _timers/model_time=0.3294 | f1_score=1.496e-06 | loss=0.4941 | loss_dist=0.0039 | loss_mean=0.4902
97/100 * Epoch 97 (valid): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=11.3443 | _timers/batch_time=0.1801 | _timers/data_time=0.0003 | _timers/model_time=0.1798 | f1_score=1.496e-06 | loss=0.4941 | loss_dist=0.0039 | loss_mean=0.4902
[2019-12-06 20:07:17,876] 
98/100 * Epoch 98 (train): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=6.2656 | _timers/batch_time=0.3222 | _timers/data_time=0.0003 | _timers/model_time=0.3218 | f1_score=1.630e-06 | loss=0.4928 | loss_dist=0.0027 | loss_mean=0.4901
98/100 * Epoch 98 (valid): _base/lr=0.0005 | _base/momentum=0.9000 | _timers/_fps=10.2134 | _timers/batch_time=0.1975 | _timers/data_time=0.0002 | _timers/model_time=0.1973 | f1_score=1.630e-06 | loss=0.4928 | 