# Merge the LoRA adapter into the base model

In [None]:
import onnxruntime as ort
print(ort.get_available_providers())

['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']


In [None]:
!pip install "transformers==4.37.2" "peft==0.7.1"

Collecting transformers==4.37.2
  Downloading transformers-4.37.2-py3-none-any.whl.metadata (129 kB)
Collecting peft==0.7.1
  Downloading peft-0.7.1-py3-none-any.whl.metadata (25 kB)
Collecting huggingface-hub<1.0,>=0.19.3 (from transformers==4.37.2)
  Downloading huggingface_hub-0.31.1-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers==4.37.2)
  Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.19,>=0.14 (from transformers==4.37.2)
  Downloading tokenizers-0.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting safetensors>=0.4.1 (from transformers==4.37.2)
  Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting accelerate>=0.21.0 (from peft==0.7.1)
  Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Collecting hf-xet<2.0.0,>=1.1.0 (from huggingface-hub<1.0,>=0.19.

In [None]:
from huggingface_hub import login
login("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")

In [None]:
import json
from pathlib import Path

def clean_adapter_config(config_path):
    UNNEEDED_KEYS = [
        "corda_config",
        "eva_config",
        "megatron_config",
        "megatron_core",
        "loftq_config",
        "layers_pattern",
        "layer_replication",
        "auto_mapping",
        "revision",
        "modules_to_save",
        "trainable_token_indices",
        "use_dora",
        "use_rslora",
        "rank_pattern",
        "fan_in_fan_out",
        "init_lora_weights",
        "exclude_modules",
        "lora_bias",
        "layers_to_transform"
    ]

    path = Path(config_path)
    if not path.exists():
        raise FileNotFoundError(f"Config file not found: {path}")

    with open(path, "r") as f:
        config = json.load(f)

    for key in UNNEEDED_KEYS:
        if key in config:
            print(f"🧹 Removing: {key}")
            config.pop(key)

    with open(path, "w") as f:
        json.dump(config, f, indent=2)

    print(f"Cleaned config saved to: {path}")

# Clean this config before merging LoRA
clean_adapter_config("../fine_tuned_lora_model/adapter_config.json")


🧹 Removing: lora_bias
🧹 Removing: layers_to_transform
Cleaned config saved to: ../fine_tuned_lora_model/adapter_config.json


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import PeftModel
import torch
import pathlib

BASE = "meta-llama/Llama-2-7b-hf"
ADAPTER = "../fine_tuned_lora_model"
MERGED = pathlib.Path("../llama2-legal-merged")

# Load base model
model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(BASE)

# Load LoRA adapter
model = PeftModel.from_pretrained(model, ADAPTER)

# ⚠️ MANUAL LoRA MERGE
model.base_model.merge_and_unload()

config = AutoConfig.from_pretrained(BASE)
config.save_pretrained(MERGED)

# Save the merged model
model.save_pretrained(MERGED, safe_serialization=False)
tokenizer.save_pretrained(MERGED)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

('../llama2-legal-merged/tokenizer_config.json',
 '../llama2-legal-merged/special_tokens_map.json',
 '../llama2-legal-merged/tokenizer.json')

In [None]:
!mv ../llama2-legal-merged/adapter_model.bin ../llama2-legal-merged/pytorch_model.bin

# Export to ONNX

In [None]:
#tensorrt_llm==0.9.0
pip install "optimum-nvidia[trtllm]>=1.18.0"              # adds `export trtllm`

Collecting optimum[exporters]
  Downloading optimum-1.24.0-py3-none-any.whl.metadata (21 kB)
Collecting onnxruntime (from optimum[exporters])
  Downloading onnxruntime-1.21.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.7 kB)
Collecting timm (from optimum[exporters])
  Downloading timm-1.0.15-py3-none-any.whl.metadata (52 kB)
Downloading onnxruntime-1.21.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m69.6 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading optimum-1.24.0-py3-none-any.whl (433 kB)
Downloading timm-1.0.15-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m84.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnxruntime, timm, optimum
Successfully installed onnxruntime-1.21.1 optimum-1.24.0 timm-1.0.15


In [None]:
!optimum-cli export trtllm \
  --model ../llama2-legal-merged \        # the LoRA‑merged FP16 HF folder
  --task causal-lm-with-past \            # keeps KV‑cache graph
  --dtype fp16 \                          # FP16 weights (best perf, no extra VRAM)
  --sequence-length 4096 \                # max prompt length you need
  --batch-size 1 \                        # build vars baked into the engine
  --device cuda \
  --library transformers \
  --output-dir ../llama2-legal-trtllm

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loading checkpoint shards:   0%|                          | 0/2 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/opt/conda/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/opt/conda/lib/python3.12/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/opt/conda/lib/python3.12/site-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File "/opt/conda/lib/python3.12/site-packages/optimum/exporters/onnx/__main__.py", line 305, in main_export
    model = TasksManager.get_model_from_task(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.12/site-packages/optimum/exporters/tasks.py", line 2283, in get_model_from_task
    model = model_class.from_pretrained(model_name_or_path, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 566

In [None]:
!trtllm-build \
  --checkpoint_dir ../llama2-legal-trtllm \
  --output_dir     ../llama2-legal-engine \
  --dtype          fp16 \
  --max_batch_size 1 \
  --max_input_len 4096 \
  --max_seq_len    4096 \
  --tp_size        1 \
  --enable_kv_cache

In [None]:
import onnx

model = onnx.load("../llama2-legal-onnx/model.onnx")
print({tensor.data_type for tensor in model.graph.initializer})

{10}


# Quick test in ONNX Runtime

In [None]:
import onnxruntime as ort
print(ort.get_available_providers())

['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
['CPUExecutionProvider']


In [None]:
from tensorrt_llm.runtime import GenerationSession        # :contentReference[oaicite:2]{index=2}
from transformers import AutoTokenizer
import time

tokenizer = AutoTokenizer.from_pretrained("../llama2-legal-merged")

session = GenerationSession(
    engine_dir     ="../llama2-legal-engine",
    max_new_tokens =64,                   # runtime override
    dtype          ="float16",
)

prompt = "One‑sentence summary of clause 7.2:"
start = time.time()
outputs = session.generate([prompt])      # list‑in, list‑out
print(outputs[0])
print("Latency:", time.time()-start, "s")

#Wrap with FastAPI or Triton(Dont Run)

In [None]:
docker run -d --gpus all -p 8000:8000 \
  -v /home/cc/triton_repo:/models \
  nvcr.io/nvidia/tritonserver:24.05-py3 \
  tritonserver --model-repository=/models


#Build a FastAPI ONNX micro-service (pattern from the hand-out)
````
docker compose -f docker-compose-fastapi.yaml up -d --build
````

````
curl -X POST http://<IP>:8000/generate \
     -H "Content-Type: application/json" \
     -d '{"prompt":"Summarise clause 7.2 in two lines"}'
````
