# Siamese Network with BERT Pooling: Online Contrastive Loss Function

- We train our siamese network with the training data from SemEval 2014.
- We use the **online contrastive loss function**.
- We then run k-NN search with test queries (previously generated for BM25) to produce test query results.

## Google Colab setups

This part only gets executed if this notebook is being run under Google Colab. **Please change the working path  directory below in advance!**

In [1]:
# Use Google Colab
use_colab = True

# Is this notebook running on Colab?
# If so, then google.colab package (github.com/googlecolab/colabtools)
# should be available in this environment

# Previous version used importlib, but we could do the same thing with
# just attempting to import google.colab
try:
    from google.colab import drive
    colab_available = True
except:
    colab_available = False

if use_colab and colab_available:
    drive.mount('/content/drive')
    
    # If there's a package I need to install separately, do it here
    #!pip install sentence-transformers==0.3.9 transformers==3.4.0

    # cd to the appropriate working directory under my Google Drive
    %cd '/content/drive/My Drive/CS646_Final_Project/Siamese'
    
    # List the directory contents
    !ls

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1zyXK0VOQZwuIfSaMkgJTmwtvSpUOQ1zm/CS646_Final_Project/Siamese
requirements.txt
sbert_bert_ada_joint_contrastive
sbert_bert_ada_joint_online_contrastive
sbert_bert_ada_joint_partially_correct_cosine
sbert_bert_ada_joint_partially_correct_softmax
sentence_bert_contrastive.ipynb
sentence_bert_cosine.ipynb
sentence_bert_online_contrastive.ipynb
sentence_bert_softmax.ipynb
Siamese_results


## PyTorch GPU setup

In [3]:
# torch.device / CUDA Setup
import torch

use_cuda = True
use_colab_tpu = False
colab_tpu_available = False

if use_colab_tpu:
    try:
        assert os.environ['COLAB_TPU_ADDR']
        colab_tpu_available = True
    except:
        colab_tpu_available = True

if use_cuda and torch.cuda.is_available():
    torch_device = torch.device('cuda')

    # Set this to True to make your output immediately reproducible
    # Note: https://pytorch.org/docs/stable/notes/randomness.html
    torch.backends.cudnn.deterministic = False
    
    # Disable 'benchmark' mode: Set this False if you want to measure running times more fairly
    # Note: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
    torch.backends.cudnn.benchmark = True
    
    # Faster Host to GPU copies with page-locked memory
    use_pin_memory = True 

    # CUDA libraries version information
    print("CUDA Version: " + str(torch.version.cuda))
    print("cuDNN Version: " + str(torch.backends.cudnn.version()))
    print("CUDA Device Name: " + str(torch.cuda.get_device_name()))
    print("CUDA Capabilities: "+ str(torch.cuda.get_device_capability()))

elif use_colab_tpu and colab_tpu_available:
    # This needs to be installed separately
    # https://github.com/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb
    import torch_xla 
    import torch_xla.core.xla_model as xm

    torch_device = xm.xla_device()

else:
    torch_device = torch.device('cpu')
    use_pin_memory = False

CUDA Version: 10.1
cuDNN Version: 7603
CUDA Device Name: Tesla T4
CUDA Capabilities: (7, 5)


## Import packages

In [4]:
import os
import random
import json
import pathlib

import sentence_transformers
from sentence_transformers import losses
import numpy as np
import jsonlines

In [5]:
# Random seed settings
random_seed = 646
random.seed(random_seed) # Python
np.random.seed(random_seed) # NumPy
torch.manual_seed(random_seed) # PyTorch

<torch._C.Generator at 0x7fea73694c00>

## Load the dataset

In [None]:
# 0-1 label (for contrastive loss)
with open(os.path.join('.', 'data', 'our_datasets', 'laptop_train.json')) as laptop_train_file:
    laptop_train = json.load(laptop_train_file)

with open(os.path.join('.', 'data', 'our_datasets', 'restaurant_train.json')) as restaurants_train_file:
    restaurants_train = json.load(restaurants_train_file)

### Training set: Joint = Laptop + Restaurants

In [None]:
train_combined_examples = []

for row in laptop_train:
    example = sentence_transformers.InputExample(
        texts=[row['query'][0] + ', ' + row['query'][1], row['doc']], label=row['label'])
    
    train_combined_examples.append(example)

for row in restaurants_train:
    example = sentence_transformers.InputExample(
        texts=[row['query'][0] + ', ' + row['query'][1], row['doc']], label=row['label'])
    
    train_combined_examples.append(example)

In [None]:
print(train_combined_examples[0])

<InputExample> label: 1, texts: functions, positive; The Macbook arrived in a nice twin packing and sealed in the box, all the functions works great.


## Siamese Network with BERT Pooling (SBERT) Model

- We use the pretrained weights released by the BERT-ADA authors.
- Please download and extract them to the same directory as this notebook: https://github.com/deepopinion/domain-adapted-atsc#release-of-bert-language-models-finetuned-on-a-specific-domain
    - **NOTE**: Because BERT-ADA was trained with an older version of `transformers`, you need to add `"model_type": "bert"` to `config.json`.

In [None]:
# Load the pretrained BERT-ADA model
# Extract the tar.xz file
#!tar -xf laptops_and_restaurants_2mio_ep15.tar.xz

pretrained_model_name = 'laptops_and_restaurants_2mio_ep15'

In [6]:
sbert_new_model_name = 'sbert_bert_ada_joint_online_contrastive'

In [None]:
word_embedding_model = sentence_transformers.models.Transformer(
    pretrained_model_name, max_seq_length=256)

pooling_model = sentence_transformers.models.Pooling(
    word_embedding_model.get_word_embedding_dimension())

model = sentence_transformers.SentenceTransformer(
    modules=[word_embedding_model, pooling_model])

### Training

In [None]:
# PyTorch DataLoader
train_dataset = sentence_transformers.SentencesDataset(train_combined_examples, model)
train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=16)

# Loss function
# Tuples of (DataLoader, LossFunction)
train_online_contrastive_loss = (train_dataloader, losses.OnlineContrastiveLoss(model))

# Tune the model
model.fit(
    train_objectives=[train_online_contrastive_loss], 
    epochs=20,
    warmup_steps=300,
    weight_decay=0.01,
    use_amp=True)

HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))




HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=725.0), HTML(value='')))





In [None]:
model.save(sbert_new_model_name)

### Play with my own sentences

In [8]:
# Uncomment the following line to load the existing trained model.
# model = sentence_transformers.SentenceTransformer(sbert_new_model_name)

In [None]:
query_embedding = model.encode('Windows 8, Positive')
passage_embedding = model.encode("This laptop's design is amazing")

print("Similarity:", sentence_transformers.util.pytorch_cos_sim(query_embedding, passage_embedding))

Similarity: tensor([[0.5820]])


## k-NN Search

In [None]:
# Get the top k matches from k-NN search
top_k = 400

### Generate query results file for `trec_eval` evaluation: Laptop

In [None]:
test_laptop_documents_path = os.path.join('..', 'BM25', 'collection', 'laptop_test.jsonl')
test_laptop_documents_file = jsonlines.open(test_laptop_documents_path)

In [None]:
test_laptop_documents_id = []
test_laptop_documents = []

for d in test_laptop_documents_file:
    test_laptop_documents_id.append(d['id'])
    test_laptop_documents.append(d['contents'])

In [None]:
# Obtain embedding vector of test documents
test_laptop_embeddings = model.encode(test_laptop_documents, convert_to_tensor=True)

In [None]:
test_laptop_queries_path = os.path.join('..', 'BM25', 'test_queries_laptop.txt')
test_laptop_queries = open(test_laptop_queries_path, 'r').readlines()

In [None]:
test_laptop_result_path = os.path.join('.', 'query_results', sbert_new_model_name, 'top_' + str(top_k))
pathlib.Path(test_laptop_result_path).mkdir(parents=True, exist_ok=True)
test_laptop_result_file = 'query_results_laptop_' + sbert_new_model_name + '.txt'

In [None]:
!rm {os.path.join(test_laptop_result_path, test_laptop_result_file)}

for q_num, q in enumerate(test_laptop_queries):
    print("Processing query", q_num, ":", q)
    
    query_embedding = model.encode(q, convert_to_tensor=True)

    cos_scores = sentence_transformers.util.pytorch_cos_sim(query_embedding, test_laptop_embeddings)[0]

    if len(cos_scores) < top_k:
        top_k_retrieved = len(cos_scores)
    else:
        top_k_retrieved = top_k

    # We use torch.topk to find the highest 5 scores
    top_results = torch.topk(cos_scores, k=top_k_retrieved)

    # print("\n\n======================\n\n")
    # print("Query:", q)
    # print("\nTop 5 most similar sentences in corpus:")

    # for score, idx in zip(top_results[0], top_results[1]):
    #     print(test_laptop_documents[idx], "(Score: %.4f)" % (score))

    # trec_eval query results file
    i = 0

    for score, idx in zip(top_results[0], top_results[1]):
        line = str(q_num+1) + ' Q0 ' + test_laptop_documents_id[idx] + ' ' + str(i+1) + ' ' + '%.8f' % score + ' galago'

        i = i + 1
      
        with open(os.path.join(test_laptop_result_path, test_laptop_result_file), 'a') as f:
            f.write("%s\n" % line)

rm: cannot remove './Siamese_results/sbert_bert_ada_joint_online_contrastive/top_1000/query_results_laptop_sbert_bert_ada_joint_online_contrastive.txt': No such file or directory
Processing query 0 : touch pad, positive

Processing query 1 : Windows 8, negative

Processing query 2 : aluminum casing, positive

Processing query 3 : aesthetics, positive

Processing query 4 : software, positive

Processing query 5 : Firewire 800, positive

Processing query 6 : Legacy programs, neutral

Processing query 7 : Price, negative

Processing query 8 : OpenOffice, positive

Processing query 9 : touchscreen functions, negative

Processing query 10 : "tools" menu, neutral

Processing query 11 : itune, negative

Processing query 12 : create your own bookmarks, positive

Processing query 13 : sound quality, negative

Processing query 14 : downloading apps, positive

Processing query 15 : retina display, neutral

Processing query 16 : Samsung 830 SSD, positive

Processing query 17 : windows 8, positive


Processing query 183 : screen, neutral

Processing query 184 : CUSTOMER SERVICE, positive

Processing query 185 : hardware (keyboard), negative

Processing query 186 : mountain lion, negative

Processing query 187 : track pad, positive

Processing query 188 : performance, negative

Processing query 189 : customer service, positive

Processing query 190 : performs, positive

Processing query 191 : click pads, negative

Processing query 192 : games, neutral

Processing query 193 : remove the card, negative

Processing query 194 : Office, neutral

Processing query 195 : integrate bluetooth devices, positive

Processing query 196 : Microsoft Word, neutral

Processing query 197 : os.x, neutral

Processing query 198 : battery cycle count, positive

Processing query 199 : FireWire 800, neutral

Processing query 200 : battery life, positive

Processing query 201 : headphones, neutral

Processing query 202 : built-in camera, neutral

Processing query 203 : Starts up, positive

Processing query 

Processing query 363 : built, positive

Processing query 364 : OS, negative

Processing query 365 : TRACKPAD, negative

Processing query 366 : intel 4000 graphics chipset, neutral

Processing query 367 : glass, positive

Processing query 368 : construction, positive

Processing query 369 : mouse, neutral

Processing query 370 : USB3 Peripherals, positive

Processing query 371 : slim plastic case, neutral

Processing query 372 : aluminum, positive

Processing query 373 : apple OS, positive

Processing query 374 : card reader, negative

Processing query 375 : software, neutral

Processing query 376 : windows, positive

Processing query 377 : Micron SSD, neutral

Processing query 378 : built-in applications, positive

Processing query 379 : number pad on the keyboard, positive

Processing query 380 : windows, neutral

Processing query 381 : CD/DVD player, neutral

Processing query 382 : Shipping, positive

Processing query 383 : delete key, negative

Processing query 384 : OSX Mountain Li

### Generate query results file for `trec_eval` evaluation: Restaurant

In [None]:
test_restaurants_documents_path = os.path.join('..', 'BM25', 'collection', 'restaurant_test.jsonl')
test_restaurants_documents_file = jsonlines.open(test_restaurants_documents_path)

In [None]:
test_restaurants_documents_id = []
test_restaurants_documents = []

for d in test_restaurants_documents_file:
    test_restaurants_documents_id.append(d['id'])
    test_restaurants_documents.append(d['contents'])

test_restaurants_embeddings = model.encode(test_restaurants_documents, convert_to_tensor=True)

In [None]:
test_restaurants_queries_path = os.path.join('..', 'BM25', 'test_queries_restaurant.txt')
test_restaurants_queries = open(test_restaurants_queries_path, 'r').readlines()

In [None]:
test_restaurants_result_path = os.path.join('.', 'query_results', sbert_new_model_name, 'top_' + str(top_k))
pathlib.Path(test_restaurants_result_path).mkdir(parents=True, exist_ok=True)
test_restaurants_result_file = 'query_results_restaurants_' + sbert_new_model_name + '.txt'

In [None]:
!rm {os.path.join(test_restaurants_result_path, test_restaurants_result_file)}

for q_num, q in enumerate(test_restaurants_queries):
    print("Processing query", q_num, ":", q)

    query_embedding = model.encode(q, convert_to_tensor=True)

    cos_scores = sentence_transformers.util.pytorch_cos_sim(query_embedding, test_restaurants_embeddings)[0]

    if len(cos_scores) < top_k:
        top_k_retrieved = len(cos_scores)
    else:
        top_k_retrieved = top_k

    # We use torch.topk to find the highest 5 scores
    top_results = torch.topk(cos_scores, k=top_k_retrieved)

    # print("\n\n======================\n\n")
    # print("Query:", q)
    # print("\nTop 5 most similar sentences in corpus:")

    # for score, idx in zip(top_results[0], top_results[1]):
    #     print(test_laptop_documents[idx], "(Score: %.4f)" % (score))

    # trec_eval query results file
    i = 0

    for score, idx in zip(top_results[0], top_results[1]):
        line = str(q_num+1) + ' Q0 ' + test_restaurants_documents_id[idx] + ' ' + str(i+1) + ' ' + '%.8f' % score + ' galago'

        i = i + 1
      
        with open(os.path.join(test_restaurants_result_path, test_restaurants_result_file), 'a') as f:
            f.write("%s\n" % line)

rm: cannot remove './Siamese_results/sbert_bert_ada_joint_online_contrastive/top_1000/query_results_restaurants_sbert_bert_ada_joint_online_contrastive.txt': No such file or directory
Processing query 0 : chicken and falafel platters, neutral

Processing query 1 : Appetizer, positive

Processing query 2 : waiters, negative

Processing query 3 : taglierini with truffles, positive

Processing query 4 : chicken, negative

Processing query 5 : caviar-topped sturgeon, positive

Processing query 6 : beers, neutral

Processing query 7 : tabouleh, positive

Processing query 8 : cajun shrimp, neutral

Processing query 9 : brisket, positive

Processing query 10 : dumplings, positive

Processing query 11 : sweet corn-foie gras brulee, neutral

Processing query 12 : garlic knots, positive

Processing query 13 : reservations, neutral

Processing query 14 : quality of the meat, negative

Processing query 15 : taste, negative

Processing query 16 : Service, negative

Processing query 17 : brunch, neu

Processing query 180 : hostess, negative

Processing query 181 : [female] servers, positive

Processing query 182 : courses, neutral

Processing query 183 : lobster ravioli, positive

Processing query 184 : hot cakes, neutral

Processing query 185 : sushi, negative

Processing query 186 : outdoors, positive

Processing query 187 : Chef, positive

Processing query 188 : waiters, positive

Processing query 189 : Jap style hamburger steak, neutral

Processing query 190 : bi-level space, positive

Processing query 191 : lamb kebabs, positive

Processing query 192 : music, neutral

Processing query 193 : waiter, negative

Processing query 194 : creme brulee, positive

Processing query 195 : portobello/gorgonzola/sausage appetizer, positive

Processing query 196 : latin food, positive

Processing query 197 : steak, positive

Processing query 198 : rice, positive

Processing query 199 : value, positive

Processing query 200 : hot white mocha, positive

Processing query 201 : Meal, negative

P

Processing query 360 : outdoor eating area, positive

Processing query 361 : lasagna, neutral

Processing query 362 : Japanese food, negative

Processing query 363 : tapas, positive

Processing query 364 : scallops, neutral

Processing query 365 : selection of wines (primarily Spanish), positive

Processing query 366 : seats, positive

Processing query 367 : fried rice, positive

Processing query 368 : vegan cranberry pancakes, neutral

Processing query 369 : apple tart, positive

Processing query 370 : wine, positive

Processing query 371 : quaility, positive

Processing query 372 : salads, positive

Processing query 373 : staff, positive

Processing query 374 : cost, negative

Processing query 375 : Tables, negative

Processing query 376 : chinese food, negative

Processing query 377 : black roasted codfish, positive

Processing query 378 : soups, neutral

Processing query 379 : Amazin' Greens salads, neutral

Processing query 380 : mango chicken, positive

Processing query 381 : Cre

Processing query 537 : chef, neutral

Processing query 538 : repast, positive

Processing query 539 : Chicken Teriyaki, neutral

Processing query 540 : Entrees, neutral

Processing query 541 : Indian food, positive

Processing query 542 : taste, positive

Processing query 543 : sopaipillas, neutral

Processing query 544 : Saag gosht, positive

Processing query 545 : fork, neutral

Processing query 546 : coffee, positive

Processing query 547 : served, neutral

Processing query 548 : entrees, positive

Processing query 549 : king crab salad with passion fruit vinaigrette, positive

Processing query 550 : indian food, positive

Processing query 551 : Pizzas, positive

Processing query 552 : roast chicken, neutral

Processing query 553 : decor, neutral

Processing query 554 : bill, negative

Processing query 555 : MEAL, positive

Processing query 556 : owner, negative

Processing query 557 : rice medley, positive

Processing query 558 : place, negative

Processing query 559 : taramasalata