# Assignment 8

Develop a model for 20 news groups dataset from scikit-learn. Select 20% of data for test set.  

Develop metric learning model with siamese network [3 points] and triplet loss [3 points] (from seminar). 
Use KNN and LSH (any library for approximate nearest neighbor search) for final prediction after the network was trained. [2 points]

! Remember, that LSH gives you a set of neighbor candidates, for which you have to calculate distances to choose top-k nearest neighbors. 

Your quality metric = accuracy score [2 points if acc > 0.8 ]

In [0]:
!pip install -U sentence-transformers
!pip install annoy

In [2]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split

import re
import numpy as np
import pandas as pd
from tqdm import tqdm, tqdm_notebook

import nltk
from nltk import tokenize
nltk.download('punkt')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset

from annoy import AnnoyIndex
from scipy.spatial import cKDTree
from sentence_transformers import SentenceTransformer

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


# Data Stuff

In [148]:
df = fetch_20newsgroups(subset='all')
X = df['data']
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42)

len(X_train), len(X_test), len(y_train), len(y_test)

(15076, 3770, 15076, 3770)

In [0]:
def clean_data(text):
    text = re.sub(r'-{2,10}\s.{2,20}\s-{2,25}', '', text)
    text = re.sub(r'\n#', ' ', text)
    text = re.sub(r'\s{2,10}', ' ', text)
    return text

def clean_tqdm():
    for instance in list(tqdm._instances): 
        tqdm._decr_instances(instance)

# Getting Embeddings

In [0]:
# https://paperswithcode.com/paper/sentence-bert-sentence-embeddings-using
# https://github.com/UKPLab/sentence-transformers

In [0]:
model = SentenceTransformer('bert-base-wikipedia-sections-mean-tokens')

In [151]:
clean_tqdm()
sentences_train = [clean_data(text) for text in X_train]
embeddings_train = model.encode(sentences_train, show_progress_bar=True)

sentences_test = [clean_data(text) for text in X_test]
embeddings_test = model.encode(sentences_test, show_progress_bar=True)


Batches: 100%|██████████| 1885/1885 [02:45<00:00,  1.46s/it]
Batches: 100%|██████████| 472/472 [00:42<00:00,  1.07it/s]


In [152]:
x_train = torch.FloatTensor(embeddings_train)
x_test = torch.FloatTensor(embeddings_test)

y_train = torch.FloatTensor(y_train)
y_test = torch.FloatTensor(y_test)

x_train.size(), x_test.size(), y_train.size(), y_test.size()

(torch.Size([15076, 768]),
 torch.Size([3770, 768]),
 torch.Size([15076]),
 torch.Size([3770]))

In [0]:
def get_triplets(embeddings, y):
    """
    Рандомно выбирается эмбеддинг совпадающего класса (добавляется в pos) и эмбеддинг несовпадающего (добавляется в neg).
    """
    pos = []
    neg = []
    for ind, anchor in enumerate(embeddings):
        pos.append(embeddings[np.random.choice(np.where(y == y[ind])[0])])
        neg.append(embeddings[np.random.choice(np.where(y != y[ind])[0])])
    return torch.FloatTensor(pos), torch.FloatTensor(neg)

In [0]:
pos_train, neg_train = get_triplets(embeddings_train, y_train)
pos_test, neg_test = get_triplets(embeddings_test, y_test)

assert pos_train.size() == neg_train.size() == x_train.size()
assert pos_test.size() == neg_test.size() == x_test.size()

In [0]:
batch_size = 1024
train_loader = DataLoader(TensorDataset(x_train, pos_train, neg_train,  y_train), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(x_test, pos_test, neg_test, y_test), batch_size=batch_size)

In [156]:
for el in test_loader:
    print(type(el), len(el))
    print(el[0].size(), el[1].size(), el[2].size(), el[3].size())
    break

<class 'list'> 4
torch.Size([1024, 768]) torch.Size([1024, 768]) torch.Size([1024, 768]) torch.Size([1024])


# Model

In [0]:
def triplet_loss(anchor_embed, pos_embed, neg_embed, margin=.5):
    # https://pytorch.org/docs/stable/nn.html#torch.nn.TripletMarginLoss
    # https://github.com/UKPLab/sentence-transformers/blob/9b94d3fae98970ccdf380139542d93011bb984ea/sentence_transformers/losses/TripletLoss.py#L8

    # positive = F.pairwise_distance(anchor_embed, pos_embed, p=2)  # Варианты для euclidian distance (loss хуже)
    # negative = F.pairwise_distance(anchor_embed, neg_embed, p=2)

    positive = 1 - F.cosine_similarity(anchor_embed, pos_embed)
    negative = 1 - F.cosine_similarity(anchor_embed, neg_embed)
    differ = positive - negative + margin
    return F.relu(differ).mean()
    
class Tripletnet(nn.Module):
    def __init__(self, input_size,  hidden_size, output_size):
        super(Tripletnet, self).__init__()
        self.fc = nn.Linear(input_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.lstm = nn.LSTM(input_size, hidden_size, 2)
        
    def branch(self, x):
        x = self.fc(x)
        x = F.relu(x)
        x = self.out(x)
        return x

    def forward(self, batch):
        anchor, pos, neg = batch[0], batch[1], batch[2]
        
        anchor = self.branch(anchor)
        pos = self.branch(pos)
        neg = self.branch(neg)
        out = triplet_loss(anchor, pos, neg)
        return out

model = Tripletnet(input_size=768, hidden_size=256, output_size=126)
optimizer = optim.Adam(model.parameters())

In [158]:
def _train_epoch(model, iterator, optimizer, curr_epoch):

    model.train()

    running_loss = 0

    n_batches = len(iterator)
    iterator = tqdm_notebook(iterator, total=n_batches, desc='epoch %d' % (curr_epoch), leave=True)

    for i, batch in enumerate(iterator):
        optimizer.zero_grad()

        loss = model(batch)
        loss.backward()
        optimizer.step()

        curr_loss = loss.item()
        
        loss_smoothing = i / (i+1)
        running_loss = loss_smoothing * running_loss + (1 - loss_smoothing) * curr_loss

        iterator.set_postfix(loss='%.5f' % running_loss)

    return running_loss

def _test_epoch(model, iterator):
    model.eval()
    epoch_loss = 0

    n_batches = len(iterator)
    with torch.no_grad():
        for batch in iterator:
            loss = model(batch)
            epoch_loss += loss.data.item()

    return epoch_loss / n_batches

def nn_train(model, train_iterator, valid_iterator, optimizer, n_epochs=100,
          scheduler=None, early_stopping=0):

    prev_loss = 100500
    es_epochs = 0
    best_epoch = None
    history = pd.DataFrame()

    for epoch in range(n_epochs):
        train_loss = _train_epoch(model, train_iterator, optimizer, epoch)
        valid_loss = _test_epoch(model, valid_iterator)

        valid_loss = valid_loss
        print('validation loss %.5f' % valid_loss)

        record = {'epoch': epoch, 'train_loss': train_loss, 'valid_loss': valid_loss}
        history = history.append(record, ignore_index=True)

        if early_stopping > 0:
            if valid_loss > prev_loss:
                es_epochs += 1
            else:
                es_epochs = 0

            if es_epochs >= early_stopping:
                best_epoch = history[history.valid_loss == history.valid_loss.min()].iloc[0]
                print('Early stopping! best epoch: %d val %.5f' % (best_epoch['epoch'], best_epoch['valid_loss']))
                break

            prev_loss = min(prev_loss, valid_loss)

nn_train(model, train_loader, test_loader, optimizer, n_epochs=100)

HBox(children=(IntProgress(value=0, description='epoch 0', max=15, style=ProgressStyle(description_width='init…

validation loss 0.29334


HBox(children=(IntProgress(value=0, description='epoch 1', max=15, style=ProgressStyle(description_width='init…

validation loss 0.27902


HBox(children=(IntProgress(value=0, description='epoch 2', max=15, style=ProgressStyle(description_width='init…

validation loss 0.22370


HBox(children=(IntProgress(value=0, description='epoch 3', max=15, style=ProgressStyle(description_width='init…

validation loss 0.20484


HBox(children=(IntProgress(value=0, description='epoch 4', max=15, style=ProgressStyle(description_width='init…

validation loss 0.20267


HBox(children=(IntProgress(value=0, description='epoch 5', max=15, style=ProgressStyle(description_width='init…

validation loss 0.19061


HBox(children=(IntProgress(value=0, description='epoch 6', max=15, style=ProgressStyle(description_width='init…

validation loss 0.19569


HBox(children=(IntProgress(value=0, description='epoch 7', max=15, style=ProgressStyle(description_width='init…

validation loss 0.18566


HBox(children=(IntProgress(value=0, description='epoch 8', max=15, style=ProgressStyle(description_width='init…

validation loss 0.18230


HBox(children=(IntProgress(value=0, description='epoch 9', max=15, style=ProgressStyle(description_width='init…

validation loss 0.17723


HBox(children=(IntProgress(value=0, description='epoch 10', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.18221


HBox(children=(IntProgress(value=0, description='epoch 11', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.17944


HBox(children=(IntProgress(value=0, description='epoch 12', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.17372


HBox(children=(IntProgress(value=0, description='epoch 13', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.17230


HBox(children=(IntProgress(value=0, description='epoch 14', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.18857


HBox(children=(IntProgress(value=0, description='epoch 15', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.17011


HBox(children=(IntProgress(value=0, description='epoch 16', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16811


HBox(children=(IntProgress(value=0, description='epoch 17', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.17072


HBox(children=(IntProgress(value=0, description='epoch 18', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16814


HBox(children=(IntProgress(value=0, description='epoch 19', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16402


HBox(children=(IntProgress(value=0, description='epoch 20', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16390


HBox(children=(IntProgress(value=0, description='epoch 21', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15730


HBox(children=(IntProgress(value=0, description='epoch 22', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16725


HBox(children=(IntProgress(value=0, description='epoch 23', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15829


HBox(children=(IntProgress(value=0, description='epoch 24', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15834


HBox(children=(IntProgress(value=0, description='epoch 25', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15694


HBox(children=(IntProgress(value=0, description='epoch 26', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15904


HBox(children=(IntProgress(value=0, description='epoch 27', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15904


HBox(children=(IntProgress(value=0, description='epoch 28', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16444


HBox(children=(IntProgress(value=0, description='epoch 29', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16151


HBox(children=(IntProgress(value=0, description='epoch 30', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15685


HBox(children=(IntProgress(value=0, description='epoch 31', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15768


HBox(children=(IntProgress(value=0, description='epoch 32', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15945


HBox(children=(IntProgress(value=0, description='epoch 33', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15109


HBox(children=(IntProgress(value=0, description='epoch 34', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16390


HBox(children=(IntProgress(value=0, description='epoch 35', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15529


HBox(children=(IntProgress(value=0, description='epoch 36', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15887


HBox(children=(IntProgress(value=0, description='epoch 37', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15465


HBox(children=(IntProgress(value=0, description='epoch 38', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.16068


HBox(children=(IntProgress(value=0, description='epoch 39', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15283


HBox(children=(IntProgress(value=0, description='epoch 40', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.17273


HBox(children=(IntProgress(value=0, description='epoch 41', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15662


HBox(children=(IntProgress(value=0, description='epoch 42', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15725


HBox(children=(IntProgress(value=0, description='epoch 43', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14723


HBox(children=(IntProgress(value=0, description='epoch 44', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14535


HBox(children=(IntProgress(value=0, description='epoch 45', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15148


HBox(children=(IntProgress(value=0, description='epoch 46', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15663


HBox(children=(IntProgress(value=0, description='epoch 47', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14953


HBox(children=(IntProgress(value=0, description='epoch 48', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14510


HBox(children=(IntProgress(value=0, description='epoch 49', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14952


HBox(children=(IntProgress(value=0, description='epoch 50', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15059


HBox(children=(IntProgress(value=0, description='epoch 51', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15158


HBox(children=(IntProgress(value=0, description='epoch 52', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15055


HBox(children=(IntProgress(value=0, description='epoch 53', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14830


HBox(children=(IntProgress(value=0, description='epoch 54', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15754


HBox(children=(IntProgress(value=0, description='epoch 55', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14005


HBox(children=(IntProgress(value=0, description='epoch 56', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14391


HBox(children=(IntProgress(value=0, description='epoch 57', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15708


HBox(children=(IntProgress(value=0, description='epoch 58', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14878


HBox(children=(IntProgress(value=0, description='epoch 59', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15179


HBox(children=(IntProgress(value=0, description='epoch 60', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15501


HBox(children=(IntProgress(value=0, description='epoch 61', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14784


HBox(children=(IntProgress(value=0, description='epoch 62', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14399


HBox(children=(IntProgress(value=0, description='epoch 63', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14021


HBox(children=(IntProgress(value=0, description='epoch 64', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13982


HBox(children=(IntProgress(value=0, description='epoch 65', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13887


HBox(children=(IntProgress(value=0, description='epoch 66', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14217


HBox(children=(IntProgress(value=0, description='epoch 67', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14112


HBox(children=(IntProgress(value=0, description='epoch 68', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14813


HBox(children=(IntProgress(value=0, description='epoch 69', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14629


HBox(children=(IntProgress(value=0, description='epoch 70', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.15242


HBox(children=(IntProgress(value=0, description='epoch 71', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14426


HBox(children=(IntProgress(value=0, description='epoch 72', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13898


HBox(children=(IntProgress(value=0, description='epoch 73', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14363


HBox(children=(IntProgress(value=0, description='epoch 74', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13867


HBox(children=(IntProgress(value=0, description='epoch 75', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14406


HBox(children=(IntProgress(value=0, description='epoch 76', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14783


HBox(children=(IntProgress(value=0, description='epoch 77', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14827


HBox(children=(IntProgress(value=0, description='epoch 78', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14809


HBox(children=(IntProgress(value=0, description='epoch 79', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14466


HBox(children=(IntProgress(value=0, description='epoch 80', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14225


HBox(children=(IntProgress(value=0, description='epoch 81', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14938


HBox(children=(IntProgress(value=0, description='epoch 82', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14782


HBox(children=(IntProgress(value=0, description='epoch 83', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14067


HBox(children=(IntProgress(value=0, description='epoch 84', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14408


HBox(children=(IntProgress(value=0, description='epoch 85', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14321


HBox(children=(IntProgress(value=0, description='epoch 86', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14210


HBox(children=(IntProgress(value=0, description='epoch 87', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13877


HBox(children=(IntProgress(value=0, description='epoch 88', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14898


HBox(children=(IntProgress(value=0, description='epoch 89', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14098


HBox(children=(IntProgress(value=0, description='epoch 90', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13262


HBox(children=(IntProgress(value=0, description='epoch 91', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14108


HBox(children=(IntProgress(value=0, description='epoch 92', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14006


HBox(children=(IntProgress(value=0, description='epoch 93', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13791


HBox(children=(IntProgress(value=0, description='epoch 94', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14172


HBox(children=(IntProgress(value=0, description='epoch 95', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.14016


HBox(children=(IntProgress(value=0, description='epoch 96', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13895


HBox(children=(IntProgress(value=0, description='epoch 97', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13854


HBox(children=(IntProgress(value=0, description='epoch 98', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13542


HBox(children=(IntProgress(value=0, description='epoch 99', max=15, style=ProgressStyle(description_width='ini…

validation loss 0.13846


# Final Predicion

In [159]:
DIM = 126
N_TREES = 1000

f = DIM
t = AnnoyIndex(f, "manhattan")
for i in range(x_train.size()[0]):
    v = model.branch(x_train[i])
    t.add_item(i, v)
t.build(N_TREES)
t.save('justincase.ann')

True

In [169]:
clean_tqdm()
count_true = 0

_true = []
_pred = []
for i, vector in tqdm(enumerate(x_test)):

    nn_ids = np.array(t.get_nns_by_vector(model.branch(vector), 1250))
    neighbor_vectors = np.take(x_train, nn_ids, axis=0).tolist()
    
    neighbor_vectors_y = np.take(y_train, nn_ids).tolist()
    pred = neighbor_vectors[cKDTree(neighbor_vectors).query(vector,k=1)[1]]
    ind_pred = neighbor_vectors.index(pred)

    _true.append(y_test[i].item())
    _pred.append(neighbor_vectors_y[ind_pred])
    if y_test[i] == neighbor_vectors_y[ind_pred]: # Сравниваем с тру классом
        count_true += 1

    # if i == 100:
    #     print(count_true)
    #     break

print('Accuracy score: {}'.format(count_true/len(x_test)))


3770it [11:14,  5.45it/s]

Accuracy score: 0.7596816976127321





In [170]:
from sklearn.metrics import accuracy_score
accuracy_score(_true, _pred)

0.7596816976127321