In [1]:
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule
from sentence_transformers import SentenceTransformer , models

comet_ml is installed but `COMET_API_KEY` is not set.


In [2]:
import random
import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping
import json

import argparse
from os.path import exists
from os import makedirs

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
# load

from datasets import load_dataset

dataset = load_dataset("PiC/phrase_similarity")

Downloading data: 100%|██████████| 1.42M/1.42M [00:02<00:00, 568kB/s]
Downloading data: 100%|██████████| 202k/202k [00:00<00:00, 428kB/s]
Downloading data: 100%|██████████| 403k/403k [00:00<00:00, 472kB/s]


Generating train split:   0%|          | 0/7004 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [5]:
train_dataset = dataset["train"]
test_dataset = dataset["test"]
validation_dataset = dataset["validation"]

print(len(train_dataset))
print(len(test_dataset))
print(len(validation_dataset))


7004
2000
1000


In [6]:
def extract_contextual_phrase_embeddings(model, list_phrase, sentences, sentence_embeddings, max_length=256):

    def find_sub_list(sl, l):
        sll = len(sl)
        for ind in (i for i, e in enumerate(l) if e == sl[0]):
            if l[ind:ind + sll] == sl:
                return ind, ind + sll - 1

    all_phrase_embs = []

    max_seq_length = model.get_max_seq_length()

    for idx, (phrase, sent) in enumerate(zip(list_phrase, sentences)):
        encoded_phrase = model.tokenizer.encode_plus(text=phrase, max_length=max_length, padding='max_length', truncation=True, add_special_tokens=True)
        encoded_phrase = np.array(encoded_phrase["input_ids"])[np.array(encoded_phrase["attention_mask"]) == 1]

        encoded_sent = model.tokenizer.encode_plus(text=sent, max_length=max_length, padding='max_length', truncation=True, add_special_tokens=True)
        encoded_sent = np.array(encoded_sent["input_ids"])[np.array(encoded_sent["attention_mask"]) == 1]

        try:
            start_idx, end_idx = find_sub_list(list(encoded_phrase[1:-1]), list(encoded_sent))
            if end_idx >= max_seq_length:
                print("Context is too long: Idx {} - Phrase: {} - Sentence: {}".format(idx, phrase, sent))
                all_phrase_embs.append(torch.FloatTensor(0))
                continue
        except:
            print("Phrase not found: Idx {} - Phrase: {} - Sentence: {}".format(idx, phrase, sent))
            all_phrase_embs.append(torch.FloatTensor(0))
            continue

        phrase_indices = list(range(start_idx, end_idx + 1, 1))

        phrase_embs = sentence_embeddings[idx][phrase_indices]
        phrase_embs = phrase_embs.mean(dim=0)

        all_phrase_embs.append(phrase_embs)

    return all_phrase_embs


In [7]:
def encode_in_batch(model, batch_size, text_list, device):
    all_emb_tensor_list = []


    model.eval()

    with torch.no_grad():
        for i in range(0, len(text_list), batch_size):
            batch_text_list = text_list[i:i+batch_size]
            batch_emb_list = model.encode(batch_text_list, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False)
            all_emb_tensor_list.extend(batch_emb_list)

    return all_emb_tensor_list


def encode_with_context_in_batch(model, batch_size, text_list, context_list, device):
    all_contextual_emb_tensor_list = []

    # print(len(context_list),batch_size)

    model.eval()
    with torch.no_grad():
        for i in range(0, len(context_list), batch_size):
            batch_text_list = text_list[i:i + batch_size]
            batch_context_list = context_list[i:i + batch_size]

            batch_emb_list = model.encode(batch_context_list, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False, output_value="token_embeddings")

            contextual_phrase_embs = extract_contextual_phrase_embeddings(model, batch_text_list, batch_context_list, batch_emb_list)
            all_contextual_emb_tensor_list.extend(contextual_phrase_embs)

            # print("Encode_with_context_in batch" , len(all_contextual_emb_tensor_list))

    return all_contextual_emb_tensor_list


In [8]:
def load_model(model_path, spanRep=False):
    model = SentenceTransformer(model_path)
    return model

In [9]:

def get_data_emb( split, model_path, device, shuffle=True, contextual=True):

    if split == "train":
        data_list = train_dataset
    elif split == "test":
        data_list = test_dataset
    elif split == "validation":
        data_list = validation_dataset

    phrase1_list = [item['phrase1'] for item in data_list]
    phrase2_list = [item['phrase2'] for item in data_list]
    labels = [item['label'] for item in data_list]

    context1_list = [item['sentence1'] for item in data_list] if contextual else []
    context2_list = [item['sentence2'] for item in data_list] if contextual else []

    model = load_model(model_path, device)
    model.to(device)

    print(device)
    emb_batch_size = 32


    phrase1_emb_tensor_list = encode_with_context_in_batch(model, emb_batch_size, phrase1_list, context1_list, device)
    phrase2_emb_tensor_list = encode_with_context_in_batch(model, emb_batch_size, phrase2_list, context2_list, device)

    combined_phrase_list = []
    for phrase1_emb_tensor, phrase2_emb_tensor, label in zip(phrase1_emb_tensor_list, phrase2_emb_tensor_list, labels):
        if phrase1_emb_tensor.shape[0] > 0 and phrase2_emb_tensor.shape[0] > 0:
            combined_phrase_list.append((phrase1_emb_tensor, phrase2_emb_tensor, label))


    # print(len(combined_phrase_list))

    phrase1_emb_tensor_list, phrase2_emb_tensor_list, labels = zip(*combined_phrase_list)
    assert len(phrase1_emb_tensor_list) == len(phrase2_emb_tensor_list)

    if shuffle:
        import random
        random.seed(42)
        combined = list(zip(phrase1_emb_tensor_list, phrase2_emb_tensor_list, labels))
        random.shuffle(combined)
        phrase1_emb_tensor_list, phrase2_emb_tensor_list, labels = zip(*combined)

    label_tensor = torch.FloatTensor(labels)

    return torch.stack(phrase1_emb_tensor_list), torch.stack(phrase2_emb_tensor_list), label_tensor


In [10]:
class ParaphraseDataset(Dataset):
    def __init__(self, phrase1_tensor, phrase2_tensor, label_tensor):
        self.concat_input = torch.cat((phrase1_tensor, phrase2_tensor), 1)
        self.label = label_tensor

    def __getitem__(self, index):
        return (self.concat_input[index], self.label[index])

    def __len__(self):
        return self.concat_input.size()[0]


In [11]:
class ProbingModel(LightningModule):
    def __init__(self, input_dim=1536, train_dataset=None, valid_dataset=None, test_dataset=None):
        super(ProbingModel, self).__init__()
        self.input_dim = input_dim
        self.linear = nn.Linear(self.input_dim, 256)
        self.linear2 = nn.Linear(256, 1)
        self.output = nn.Sigmoid()

        self.lr = 0.0001
        self.batch_size = 200

        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.test_dataset = test_dataset

        self.test_y = []
        self.test_y_hat = []

    def forward(self, x):
        x = F.relu(self.linear(x))
        x = self.output(self.linear2(x))
        return x.squeeze()

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

    def compute_accuracy(self, y_hat, y):
        return ((y_hat >= 0.5).float() == y).float().mean()

    def training_step(self, batch, batch_nb):
          mode = 'train'
          x, y = batch
          y_hat = self(x)
          loss = F.binary_cross_entropy(y_hat, y)
          accuracy = self.compute_accuracy(y_hat, y)
          self.log(f'{mode}_loss', loss, on_epoch=True, on_step=True)
          self.log(f'{mode}_accuracy', accuracy, on_epoch=True, on_step=True)
          return {f'loss': loss, f'{mode}_accuracy':accuracy, 'log': {f'{mode}_loss': loss}}



    def validation_step(self, batch, batch_nb):
        mode = 'val'
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy(y_hat, y)
        accuracy = self.compute_accuracy(y_hat, y)
        self.log(f'{mode}_loss', loss, on_epoch=True, on_step=True)
        self.log(f'{mode}_accuracy', accuracy, on_epoch=True, on_step=True)
        return {f'{mode}_loss': loss, f'{mode}_accuracy':accuracy, 'log': {f'{mode}_loss': loss}}



    def test_step(self, batch, batch_nb):
        mode = 'test'
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy(y_hat, y)
        accuracy = self.compute_accuracy(y_hat, y)

        # ThangPM: Save predictions
        self.test_y.extend(y)
        self.test_y_hat.extend(y_hat)

        self.log(f'{mode}_loss', loss, on_epoch=True, on_step=True)
        self.log(f'{mode}_accuracy', accuracy, on_epoch=True, on_step=True)
        return {f'{mode}_loss': loss, f'{mode}_accuracy':accuracy, 'log': {f'{mode}_loss': loss}}



In [12]:
import torch

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Assuming MODEL_PATH and other necessary functions and classes are defined elsewhere
    model_path = "whaleloops/phrase-bert"
    phrase1_tensor, phrase2_tensor, label_tensor = get_data_emb("train", model_path, device)
    phrase1_tensor.to(device)
    phrase2_tensor.to(device)
    label_tensor.to(device)
    train_dataset = ParaphraseDataset(phrase1_tensor, phrase2_tensor, label_tensor)

    phrase1_tensor, phrase2_tensor, label_tensor = get_data_emb("validation", model_path, device)
    phrase1_tensor.to(device)
    phrase2_tensor.to(device)
    label_tensor.to(device)
    valid_dataset = ParaphraseDataset(phrase1_tensor, phrase2_tensor, label_tensor)

    phrase1_tensor, phrase2_tensor, label_tensor = get_data_emb("test", model_path, device, shuffle=False)
    phrase1_tensor.to(device)
    phrase2_tensor.to(device)
    label_tensor.to(device)
    test_dataset = ParaphraseDataset(phrase1_tensor, phrase2_tensor, label_tensor)

    # early_stop_callback = EarlyStopping(monitor='epoch_val_accuracy', min_delta=0.00, patience=10, verbose=True, mode='max')

    model = ProbingModel(input_dim=phrase1_tensor.shape[1] * 2, train_dataset=train_dataset,
                            valid_dataset=valid_dataset, test_dataset=test_dataset).to(device)

    trainer = Trainer(max_epochs=100, min_epochs=3)

    trainer.fit(model)
    result = trainer.test(dataloaders=model.test_dataloader())
    print(result)

    # # Writing the result to a text file
    # with open('/content/results.txt', 'w') as file:
    #     file.write(str(result))
    # torch.save(model.state_dict(), '/content/context_phrasebert.pth')

if __name__ == '__main__':
    main()



modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.41k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


tokenizer_config.json:   0%|          | 0.00/632 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

cuda
Phrase not found: Idx 27 - Phrase: de facto power - Sentence: Around 1906, Menelik became incapacitated and "Itege" Taytu Betul became the "de facto" power behind the throne.
Phrase not found: Idx 24 - Phrase: de facto power - Sentence: The early history of juries supports the recognition of the "de facto" power of nullification.
Phrase not found: Idx 31 - Phrase: a side - Sentence: the wharf was designed to allow for double deck boarding on the "a" side of the wharf.
Phrase not found: Idx 13 - Phrase: a side - Sentence: the "a" side of the 45 was "don't know where i'm going," and the "b" side was "marble eyes."
cuda
cuda


/home2/punnavajhala.prakash/miniconda3/envs/salad/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home2/punnavajhala.prakash/miniconda3/envs/salad/li ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | linear  | Linear  | 393 K 
1 | linear2 | Linear  | 257   
2 | output  | Sigmoid | 0     
------------------------------------
393 K     Trainable params
0         Non-trainable params
393 K     Total params
1.575     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home2/punnavajhala.prakash/miniconda3/envs/salad/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/home2/punnavajhala.prakash/miniconda3/envs/salad/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/home2/punnavajhala.prakash/miniconda3/envs/salad/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (35) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
Restoring states from the checkpoint path at /home2/punnavajhala.prakash/lightning_logs/version_1178036/checkpoints/epoch=99-step=3500.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home2/punnavajhala.prakash/lightning_logs/version_1178036/checkpoints/epoch=99-step=3500.ckpt
/home2/punnavajhala.prakash/miniconda3/envs/salad/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss_epoch': 0.8931949734687805, 'test_accuracy_epoch': 0.6039999723434448}]
