https://aws.amazon.com/blogs/machine-learning/create-and-fine-tune-sentence-transformers-for-enhanced-classification-accuracy/

In [24]:
import numpy as np
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [81]:
from sentence_transformers import SentenceTransformer

# 1. Load a pretrained Sentence Transformer model
model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
embedding_size = 384


In [82]:
print(f"Max seq len = {model.max_seq_length}")
sentences = ["This is sentence 1.",
             "This is sentence 2."]

embeddings = model.encode(sentences)
print(type(embeddings))
print(embeddings.shape)

similarities = model.similarity(embeddings, embeddings)
print(similarities)

Max seq len = 256
<class 'numpy.ndarray'>
(2, 384)
tensor([[1.0000, 0.9287],
        [0.9287, 1.0000]])


In [83]:
model.prompts

{}

In [84]:
from torch import nn

num_classes = 9
embedding_size = model.get_sentence_embedding_dimension()

class ClassificationHead(nn.Module):
    def __init__(self, embedding_size: int, num_classes: int):
        super(ClassificationHead, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(embedding_size, num_classes),
        )
    
    def forward(self, features: torch.Tensor):
        x = features["sentence_embedding"]
        x = self.layers(x)
        return x

class SentenceTransformerWithHead(nn.Module):
    def __init__(self, transformer: SentenceTransformer, head: ClassificationHead):
        super(SentenceTransformerWithHead, self).__init__()
        self.transformer = transformer
        self.head = head
    
    def forward(self, input: dict[torch.Tensor, torch.Tensor]):
        # input['input_ids']
        # input['attention_mask']
        features = self.transformer(input)
        logits = self.head(features)
        return logits

head = ClassificationHead(embedding_size, num_classes)
model_with_head = SentenceTransformerWithHead(model, head)
head.to(device)
model_with_head.to(device)
        

SentenceTransformerWithHead(
  (transformer): SentenceTransformer(
    (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
    (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
    (2): Normalize()
  )
  (head): ClassificationHead(
    (layers): Sequential(
      (0): Linear(in_features=384, out_features=9, bias=True)
    )
  )
)

In [85]:
print(next(model.parameters()).device)
print(next(head.parameters()).device)
print(next(model_with_head.parameters()).device)

cuda:0
cuda:0
cuda:0


In [46]:
inputs = model.tokenize(sentences)
inputs = {key: value.to(device) for key, value in inputs.items()}
inputs

{'input_ids': tensor([[ 101, 2023, 2003, 6251, 1015, 1012,  102],
         [ 101, 2023, 2003, 6251, 1016, 1012,  102]], device='cuda:0'),
 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [47]:
logits = model_with_head(inputs)

In [86]:
import re

# Freeze all weights except classifier head
# for name, param in model_with_head.named_parameters():
#     # print(name)
#     if not re.search(r"^head", name):
#         # print(name)
#         param.requires_grad = False

for name, param in model_with_head.named_parameters():
    if param.requires_grad:
        print(name)

transformer.0.auto_model.embeddings.word_embeddings.weight
transformer.0.auto_model.embeddings.position_embeddings.weight
transformer.0.auto_model.embeddings.token_type_embeddings.weight
transformer.0.auto_model.embeddings.LayerNorm.weight
transformer.0.auto_model.embeddings.LayerNorm.bias
transformer.0.auto_model.encoder.layer.0.attention.self.query.weight
transformer.0.auto_model.encoder.layer.0.attention.self.query.bias
transformer.0.auto_model.encoder.layer.0.attention.self.key.weight
transformer.0.auto_model.encoder.layer.0.attention.self.key.bias
transformer.0.auto_model.encoder.layer.0.attention.self.value.weight
transformer.0.auto_model.encoder.layer.0.attention.self.value.bias
transformer.0.auto_model.encoder.layer.0.attention.output.dense.weight
transformer.0.auto_model.encoder.layer.0.attention.output.dense.bias
transformer.0.auto_model.encoder.layer.0.attention.output.LayerNorm.weight
transformer.0.auto_model.encoder.layer.0.attention.output.LayerNorm.bias
transformer.0.aut

In [87]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd

class QueryDataset(Dataset):
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        item = {
            "Query": row["Query"],
            "Label": row["Label"],

        }
        return item

ds = QueryDataset("intent_classifier_samples.csv")

In [88]:
train_dataloader = DataLoader(ds, batch_size=16, shuffle=True)
sample = next(iter(train_dataloader))

In [76]:
sample["Query"]
sample["Label"]
model.tokenize(sample["Query"])

{'input_ids': tensor([[  101,  2025,  2559,  2005,  2151,  3752, 11433,  2747,  1010,  2021,
           4283,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0],
         [  101,  2071,  2017,  3073, 20062,  2046,  1996,  8906,  1997,  6048,
           1999,  2115,  2951, 13462,  1029,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0],
         [  101,  2054,  2051,  2515,  1996,  3075,  2485,  2651,  1029,   102,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0],
         [  101,  1045,  1005,  1049,  8025,  2055,  1996,  4353,  1997,  4772,
           2086,  1999,  2115,  2951, 13462,  1012,   102,     0,     0,     0,
           

In [None]:
from torch import nn, optim

loss_function = nn.CrossEntropyLoss(reduction="sum")
optimizer = optim.AdamW(model_with_head.parameters(), lr=0.001)
# lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9)

n_epochs = 30
for edx in range(n_epochs):
    cum_loss = 0
    for bdx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()

        inputs = model.tokenize(batch["Query"])
        inputs = {key: value.to(device) for key, value in inputs.items()}
        labels = batch["Label"].to(device)

        logits = model_with_head(inputs)

        loss = loss_function(logits, labels)
        cum_loss += loss.item()
        loss.backward()
        optimizer.step()
        # lr_scheduler.step()
    print(cum_loss/len(train_dataloader.dataset.df))
