<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/research/modeling_PolyEncoder/test_sample_PolyEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install -q transformers
! pip install -q datasets

In [None]:
! rm -rf PyTorch-Architectures/
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git
%cd PyTorch-Architectures/

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from research import PolyEncoder

In [None]:
dataset = load_dataset("glue", "mrpc")

In [5]:
def print_details(samples):
  labels = []
  max_len_sent1, max_len_sent2 = -1, -1
  min_len_sent1, min_len_sent2 = 100, 100
  for sample in samples:
    max_len_sent1 = max(len(sample["sentence1"].split()), max_len_sent1)
    max_len_sent2 = max(len(sample["sentence2"].split()), max_len_sent2)

    min_len_sent1 = min(len(sample["sentence1"].split()), min_len_sent1)
    min_len_sent2 = min(len(sample["sentence2"].split()), min_len_sent2)
    labels.append(sample["label"])
  
  print(f"Total pairs --> {len(samples)}")
  print(f"No of positive pairs --> {sum(labels)}")
  print(f"No of negative pairs --> {len(labels) - sum(labels)}")

  print(f"Maximum length of sentence1 --> {max_len_sent1}")
  print(f"Maximum length of sentence2 --> {max_len_sent2}")

  print(f"Minimum length of sentence1 --> {min_len_sent1}")
  print(f"Minimum length of sentence2 --> {min_len_sent2}")

print(f"For Training Set:")
print_details(dataset["train"])

print(f"\nFor Validation Set:")
print_details(dataset["validation"])

print(f"\nFor Test Set:")
print_details(dataset["test"])

For Training Set:
Total pairs --> 3668
No of positive pairs --> 2474
No of negative pairs --> 1194
Maximum length of sentence1 --> 39
Maximum length of sentence2 --> 42
Minimum length of sentence1 --> 7
Minimum length of sentence2 --> 8

For Validation Set:
Total pairs --> 408
No of positive pairs --> 279
No of negative pairs --> 129
Maximum length of sentence1 --> 35
Maximum length of sentence2 --> 34
Minimum length of sentence1 --> 9
Minimum length of sentence2 --> 9

For Test Set:
Total pairs --> 1725
No of positive pairs --> 1147
No of negative pairs --> 578
Maximum length of sentence1 --> 36
Maximum length of sentence2 --> 35
Minimum length of sentence1 --> 8
Minimum length of sentence2 --> 7


In [6]:
class CustomDataset(Dataset):
  def __init__(self, tokenizer, list_samples, max_length=64):
    self.tokenizer = tokenizer
    self.list_sent1 = [sample[0] for sample in list_samples]
    self.list_sent2 = [sample[1] for sample in list_samples]
    self.labels = [sample[2] for sample in list_samples]
    self.max_length = max_length

    assert len(self.list_sent1) == len(self.list_sent2), \
    f"Length mismatch {len(self.list_sent1)} -- {len(self.list_sent2)}"
  
  def __len__(self,):
    return len(self.list_sent1)
  
  def __getitem__(self, idx):
    sent1 = self.list_sent1[idx]
    sent2 = self.list_sent2[idx]
    labels = self.labels[idx]
    return {
        "sent1": sent1,
        "sent2": sent2,
        "labels": labels,
    }
  
  def collate_fn(self, batch):
    sents_1, sents_2, labels = [], [], []
    for sample in batch:
      sents_1.append(sample["sent1"])
      sents_2.append(sample["sent2"])
      labels.append(sample["labels"])
    
    tokens_1 = self.tokenizer(sents_1,
                              max_length=self.max_length,
                              padding=True,
                              truncation=True,
                              return_tensors="pt",
                              )
    tokens_2 = self.tokenizer(sents_2,
                              max_length=self.max_length,
                              padding=True,
                              truncation=True,
                              return_tensors="pt",
                              )
    
    return {
        "context_input_ids": tokens_1["input_ids"],
        "context_attention_mask": tokens_1["attention_mask"],
        "candidate_input_ids": tokens_2["input_ids"],
        "candidate_attention_mask": tokens_2["attention_mask"],
        "labels": torch.tensor(labels, dtype=torch.long),
    }

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = PolyEncoder(poly_m=2,
                    hidden_size=784,
                    num_labels=2,
                    encoder_name=model_name)
model.to(device)

params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable Parameters --> {params}")

In [8]:
train_samples = [(s["sentence1"], s["sentence2"], s["label"]) \
                 for s in dataset["train"]]
valid_samples = [(s["sentence1"], s["sentence2"], s["label"]) \
                 for s in dataset["validation"]]

print(f"Total Train samples --> {len(train_samples)}")
print(f"Total Validation samples --> {len(valid_samples)}")

Total Train samples --> 3668
Total Validation samples --> 408


In [9]:
train_dataset = CustomDataset(tokenizer=tokenizer,
                              list_samples=train_samples)
valid_dataset = CustomDataset(tokenizer=tokenizer,
                              list_samples=valid_samples)