In [1]:
%load_ext autoreload
%autoreload 2
from llamawrapper import LlamaHelper

In [2]:
llama = LlamaHelper(dir='/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf', load_in_8bit=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
class Llama2Embedding():
    def __init__(self, model, 
                 template: str = "(*):<word> (*):<word> (*):",
                 keyword: str = "<word>", 
                 layer_idx: int = 20):
        """
        model: LlamaWrapper
        """
        self.model = model
        self.template = template 
        self.keyword = keyword
        self.layer_idx = layer_idx
    
    def encode(self, sentences, batch_size=4, **kwargs):
        """
        Returns a list of embeddings for the given sentences.
        Args:
            sentences (`List[str]`): List of sentences to encode
            batch_size (`int`): Batch size for the encoding

        Returns:
            `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
        """
        #batch_size = 4
        prompts = [self.template.replace(self.keyword, sentence) for sentence in sentences]
        batch = []
        embs = []
        for prompt in prompts:
            batch += [prompt]
            if len(batch) == batch_size:
                latents = self.model.latents_all_layers(batch)[self.layer_idx, :, -1].float()
                latents /= latents.norm(dim=-1, keepdim=True)
                latents = latents.cpu().detach()
                embs += [latent for latent in latents]
                batch = []
                self.model.reset_all()
        if len(batch) > 0:
            latents = self.model.latents_all_layers(batch)[self.layer_idx, :, -1].float()
            latents /= latents.norm(dim=-1, keepdim=True)
            latents = latents.cpu().detach()
            embs += [latent for latent in latents]
            self.model.reset_all()
        return embs


In [4]:
llama.reset_all()
model = Llama2Embedding(llama, template="<word>")
embs = model.encode(["dog dog dog", "3 times dog", "dog", "cat", "house", "day", "night"], 2)

In [5]:
import torch
emb_t = torch.stack(embs)
print(emb_t.shape)

torch.Size([7, 4096])


In [6]:
emb_t @ emb_t.T

tensor([[1.0000, 0.5711, 0.6402, 0.4809, 0.4011, 0.3289, 0.3760],
        [0.5711, 1.0000, 0.7942, 0.5005, 0.3223, 0.2776, 0.3102],
        [0.6402, 0.7942, 1.0000, 0.6648, 0.4768, 0.4172, 0.4398],
        [0.4809, 0.5005, 0.6648, 1.0000, 0.4514, 0.4026, 0.4545],
        [0.4011, 0.3223, 0.4768, 0.4514, 1.0000, 0.3834, 0.4470],
        [0.3289, 0.2776, 0.4172, 0.4026, 0.3834, 1.0000, 0.6181],
        [0.3760, 0.3102, 0.4398, 0.4545, 0.4470, 0.6181, 1.0000]])

In [7]:
print(len(embs), embs[0].shape)

7 torch.Size([4096])


In [15]:
from mteb import MTEB
evaluation = MTEB(tasks=["EmotionClassification"])
evaluation.run(model, batch_size=4)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

{'EmotionClassification': {'mteb_version': '1.2.0',
  'dataset_revision': '4f58c6b202a23cf9a4da393831edf4f9183cad37',
  'mteb_dataset_name': 'EmotionClassification',
  'validation': {'accuracy': 0.27205,
   'f1': 0.24361670554964268,
   'accuracy_stderr': 0.022654414580827283,
   'f1_stderr': 0.016505119951211462,
   'main_score': 0.27205,
   'evaluation_time': 4430.34},
  'test': {'accuracy': 0.2692,
   'f1': 0.2358981752404002,
   'accuracy_stderr': 0.019927619024860944,
   'f1_stderr': 0.011273625869745571,
   'main_score': 0.2692,
   'evaluation_time': 4415.34}}}