In [2]:
import torch
import torch.nn.functional as F
from datasets import Dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from transformers import BertConfig, BertTokenizerFast, BertModel

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = torch.device('mps')

In [7]:
### load uncased bert model
bert_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(bert_name)
config = BertConfig.from_pretrained(bert_name, output_hidden_states=True, output_attentions=True)
model = BertModel.from_pretrained(bert_name, config).to(device)



In [34]:
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


In [18]:
all_tensors = tokenizer(['i love kira'], truncation=True, return_tensors='pt', padding=True, max_length=100)
all_tensors = all_tensors.to(device)

In [46]:
emb = model(**all_tensors)
avg_emb = mean_pooling(emb, all_tensors['attention_mask'])
avg_emb = F.normalize(avg_emb, p=2, dim=1)

In [44]:
avg_emb.matmul(avg_emb.T)

tensor([[74.7504]], device='mps:0', grad_fn=<MmBackward0>)

In [58]:
### load transformer model as benchmark
model_st = SentenceTransformer('all-mpnet-base-v2')



In [11]:
emb_st = model_st.encode(['i love kira'], convert_to_tensor=True)

In [49]:
emb_st.matmul(avg_emb.T)

tensor([[0.0230]], device='mps:0', grad_fn=<MmBackward0>)

In [60]:
train_dataset = Dataset.from_dict({
    "sentence1": ["It's nice weather outside today.", "He drove to work."],
    "sentence2": ["It's so sunny.", "She walked to the store."],
    "label": [1, 0],
})
loss = losses.ContrastiveLoss(model_st)

trainer = SentenceTransformerTrainer(
    model=model_st,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

100%|██████████| 3/3 [00:03<00:00,  1.31s/it]

{'train_runtime': 3.9204, 'train_samples_per_second': 1.53, 'train_steps_per_second': 0.765, 'train_loss': 0.011350090305010477, 'epoch': 3.0}





TrainOutput(global_step=3, training_loss=0.011350090305010477, metrics={'train_runtime': 3.9204, 'train_samples_per_second': 1.53, 'train_steps_per_second': 0.765, 'total_flos': 0.0, 'train_loss': 0.011350090305010477, 'epoch': 3.0})

In [67]:
train_dataset

Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 2
})

In [55]:
loss

ContrastiveLoss(
  (model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
    (2): Normalize()
  )
)

In [66]:
from datasets import load_dataset

# Indicate the dataset id from the Hub
dataset_id = "sentence-transformers/msmarco-msmarco-distilbert-base-tas-b"
dataset = load_dataset(dataset_id, split="train", config='')
"""
Dataset({
   features: ['query', 'answer'],
   num_rows: 100231
})
"""
print(dataset[0])

TypeError: ParquetConfig.__init__() got an unexpected keyword argument 'config'

In [64]:
dataset[0]

{'query': 'when did richmond last play in a preliminary final',
 'answer': "Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final sin

In [68]:
all_nli_triplet_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:10000]")

Downloading readme: 100%|██████████| 5.15k/5.15k [00:00<00:00, 12.1MB/s]
Downloading data: 100%|██████████| 38.4M/38.4M [00:06<00:00, 6.11MB/s]
Downloading data: 100%|██████████| 782k/782k [00:00<00:00, 3.71MB/s]
Downloading data: 100%|██████████| 810k/810k [00:00<00:00, 4.48MB/s]
Generating train split: 100%|██████████| 557850/557850 [00:00<00:00, 4287100.89 examples/s]
Generating dev split: 100%|██████████| 6584/6584 [00:00<00:00, 2740972.46 examples/s]
Generating test split: 100%|██████████| 6609/6609 [00:00<00:00, 2750561.14 examples/s]


In [6]:
from datasets import load_dataset

In [7]:
# (premise, hypothesis) + label
all_nli_pair_class_train = load_dataset("sentence-transformers/all-nli", "pair-class", split="train[:10000]")

Downloading data: 100%|██████████| 69.5M/69.5M [00:10<00:00, 6.69MB/s]
Downloading data: 100%|██████████| 1.57M/1.57M [00:00<00:00, 2.39MB/s]
Downloading data: 100%|██████████| 1.61M/1.61M [00:00<00:00, 4.15MB/s]
Generating train split: 100%|██████████| 942069/942069 [00:00<00:00, 4441820.71 examples/s]
Generating dev split: 100%|██████████| 19657/19657 [00:00<00:00, 3971073.78 examples/s]
Generating test split: 100%|██████████| 19656/19656 [00:00<00:00, 3727090.39 examples/s]


In [13]:
all_nli_pair_class_train[4]

{'premise': 'Children smiling and waving at camera',
 'hypothesis': 'There are children present',
 'label': 0}