<a href="https://colab.research.google.com/github/tuhinmallick/AI-for-Fashion/blob/main/Train_LLM_Embedding_Models_with_SimCSE_Examples_with_Llama_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

*More details in this article: [Train Better Llama 3 Embeddings with Simple Contrastive Learning](https://newsletter.kaitchup.com/p/train-better-llama-3-embeddings-with)*

This notebook shows how to train better embedding models, extracted from an LLM, with simple contrastive learning.

It uses the LLM2Vec framework and Llama 3 8B for example. For a 7B/8B LLM, it can run on a 24 GB GPU.

We need to install the following packages:

In [None]:
!pip install --upgrade llm2vec transformers
!pip install flash-attn --no-build-isolation

Collecting llm2vec
  Downloading llm2vec-0.1.8-py2.py3-none-any.whl (26 kB)
Collecting transformers
  Downloading transformers-4.42.3-py3-none-any.whl (9.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.3/9.3 MB[0m [31m72.7 MB/s[0m eta [36m0:00:00[0m
Collecting peft (from llm2vec)
  Downloading peft-0.11.1-py3-none-any.whl (251 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.6/251.6 kB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets (from llm2vec)
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m52.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate (from llm2vec)
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets->llm2vec)
  Downloading pyarrow-16.1.

We will use a script from the LLM2VEC repository:

In [None]:
!git clone https://github.com/McGill-NLP/llm2vec.git

Cloning into 'llm2vec'...
remote: Enumerating objects: 751, done.[K
remote: Counting objects: 100% (169/169), done.[K
remote: Compressing objects: 100% (90/90), done.[K
remote: Total 751 (delta 100), reused 86 (delta 79), pack-reused 582[K
Receiving objects: 100% (751/751), 1.37 MiB | 32.54 MiB/s, done.
Resolving deltas: 100% (421/421), done.


To use Llama 3 8B, you will need to enter your HF token below:

In [None]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


#Llama 3 8B

### MNTP Training

More information about this step in this article:
[Turn Llama 3 into an Embedding Model with LLM2Vec](https://kaitchup.substack.com/p/turn-llama-3-into-an-embedding-model)

In [None]:
JSON_CONFIG='''
{
    "model_name_or_path": "meta-llama/Meta-Llama-3-8B",
    "dataset_name": "wikitext",
    "dataset_config_name": "wikitext-103-raw-v1",
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "gradient_accumulation_steps": 16,
    "do_train": true,
    "do_eval": true,
    "max_seq_length": 512,
    "mask_token_type": "blank",
    "data_collator_type": "all_mask",
    "mlm_probability": 0.8,
    "overwrite_output_dir": true,
    "output_dir": "./drive/MyDrive/llm2vec/Meta-Llama-3-8B-llm2vec-MNTP-Emb",
    "evaluation_strategy": "steps",
    "eval_steps": 100,
    "save_steps": 200,
    "stop_after_n_steps": 1000,
    "lora_r": 16,
    "gradient_checkpointing": true,
    "torch_dtype": "bfloat16",
    "attn_implementation": "flash_attention_2"
}
'''

with open("mtnp_config_Meta-Llama-3-8B.json", 'w') as f:
  f.write(JSON_CONFIG)

In [None]:
!python llm2vec/experiments/run_mntp.py mtnp_config_Meta-Llama-3-8B.json

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 41% 204/499 [00:44<01:05,  4.53it/s][A
 41% 205/499 [00:44<01:04,  4.54it/s][A
 41% 206/499 [00:45<01:04,  4.53it/s][A
 41% 207/499 [00:45<01:04,  4.56it/s][A
 42% 208/499 [00:45<01:03,  4.55it/s][A
 42% 209/499 [00:45<01:03,  4.55it/s][A
 42% 210/499 [00:45<01:03,  4.55it/s][A
 42% 211/499 [00:46<01:03,  4.56it/s][A
 42% 212/499 [00:46<01:02,  4.56it/s][A
 43% 213/499 [00:46<01:02,  4.55it/s][A
 43% 214/499 [00:46<01:02,  4.55it/s][A
 43% 215/499 [00:47<01:02,  4.55it/s][A
 43% 216/499 [00:47<01:02,  4.54it/s][A
 43% 217/499 [00:47<01:02,  4.51it/s][A
 44% 218/499 [00:47<01:01,  4.53it/s][A
 44% 219/499 [00:47<01:01,  4.56it/s][A
 44% 220/499 [00:48<01:01,  4.56it/s][A
 44% 221/499 [00:48<01:01,  4.55it/s][A
 44% 222/499 [00:48<01:00,  4.54it/s][A
 45% 223/499 [00:48<01:00,  4.55it/s][A
 45% 224/499 [00:49<01:00,  4.55it/s][A
 45% 225/499 [00:49<01:00,  4.56it/s][A
 45% 226/499 [00:49<00:59,  4.56i

###Unsupervised contrastive training (SimCSE)

First, we need some training data. It can be any text but preferably text in your target domain. Here, I use Wikipedia.

In [None]:
!wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt

--2024-06-30 14:36:22--  https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt
Resolving huggingface.co (huggingface.co)... 13.33.30.49, 13.33.30.114, 13.33.30.76, ...
Connecting to huggingface.co (huggingface.co)|13.33.30.49|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/datasets/princeton-nlp/datasets-for-simcse/7b1825863a99aa76479b0456f7c210539dfaeeb69598b41fb4de4f524dd5a706?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27wiki1m_for_simcse.txt%3B+filename%3D%22wiki1m_for_simcse.txt%22%3B&response-content-type=text%2Fplain&Expires=1720017382&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDAxNzM4Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9kYXRhc2V0cy9wcmluY2V0b24tbmxwL2RhdGFzZXRzLWZvci1zaW1jc2UvN2IxODI1ODYzYTk5YWE3NjQ3OWIwNDU2ZjdjMjEwNTM5ZGZhZWViNjk1OThiNDFmYjRkZTRmNTI0ZGQ1YTcwNj9yZXNwb25zZS1jb250ZW5

Then, the configuration file for SimCSE training:

In [None]:
JSON_CONFIG='''
{
    "model_name_or_path": "meta-llama/Meta-Llama-3-8B",
    "peft_model_name_or_path": "./Meta-Llama-3-8B-llm2vec-MNTP-Emb",
    "simcse_dropout": 0.3,
    "bidirectional": true,
    "pooling_mode": "mean",
    "dataset_name": "Wiki1M",
    "dataset_file_path": "wiki1m_for_simcse.txt",
    "remove_unused_columns": false,
    "learning_rate": 3e-5,
    "do_train": true,
    "loss_scale": 20,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 32,
    "disable_tqdm": false,
    "save_steps": 200,
    "stop_after_n_steps": 1000,
    "max_seq_length": 128,
    "lora_r": 16,
    "gradient_checkpointing": true,
    "torch_dtype": "bfloat16",
    "attn_implementation": "flash_attention_2",
    "output_dir": "./Meta-Llama-3-8B-llm2vec-SimCSE-Emb"
}
'''

with open("SimCSE_config_Meta-Llama-3-8B.json", 'w') as f:
  f.write(JSON_CONFIG)

Run LLM2VEC SimCSE script:

*Note: If you have more memory than 24 GB, I highly recommend increasing "per_device_train_batch_size", in the config file above, as much as you can for better performance and faster training.*

In [None]:
!python llm2vec/experiments/run_simcse.py SimCSE_config_Meta-Llama-3-8B.json