Skip to content

Commit

Permalink
Use tempfile
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu committed May 23, 2024
1 parent 63b01f6 commit b5df80f
Showing 1 changed file with 19 additions and 26 deletions.
45 changes: 19 additions & 26 deletions TrainingExtensions/torch/test/python/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
import os
import shutil
import pytest

import tempfile
import torch
from safetensors.torch import save_file
from safetensors import safe_open
Expand Down Expand Up @@ -182,16 +182,13 @@ def test_enable_and_load_weights_adapter(self):
qc_lora = sim.model.base_model.model.linear
assert torch.all(qc_lora.lora_B[0].weight == torch.zeros((10, 4)))

meta_path = './tmp'
if not os.path.exists(meta_path):
os.mkdir(meta_path)
tensors = {'base_model.model.linear.lora_A.0.weight': torch.randn((4, 10)),
'base_model.model.linear.lora_B.0.weight': torch.randn((10, 4))}
path = './tmp/weight.safetensor'
save_file(tensors, path)
peft_utils.enable_adapter_and_load_weights(sim, './tmp/weight.safetensor')
assert torch.all(qc_lora.lora_B[0].weight == tensors['base_model.model.linear.lora_B.0.weight'])
shutil.rmtree('./tmp')
with tempfile.TemporaryDirectory() as tmpdir:
tensors = {'base_model.model.linear.lora_A.0.weight': torch.randn((4, 10)),
'base_model.model.linear.lora_B.0.weight': torch.randn((10, 4))}
path = os.path.join(tmpdir, 'weight.safetensor')
save_file(tensors, path)
peft_utils.enable_adapter_and_load_weights(sim, path)
assert torch.all(qc_lora.lora_B[0].weight == tensors['base_model.model.linear.lora_B.0.weight'])

def test_export_adapter_weights(self):
model = one_adapter_model()
Expand All @@ -206,21 +203,17 @@ def test_export_adapter_weights(self):
qc_lora = sim.model.base_model.model.linear
assert torch.all(qc_lora.lora_B[0].weight == torch.zeros((10, 4)))

meta_path = './tmp'
if not os.path.exists(meta_path):
os.mkdir(meta_path)

peft_utils.export_adapter_weights(sim, meta_path, 'weight')
tensor_name = []
with safe_open('./tmp/weight.safetensor', framework="pt", device=0) as f:
for key in f.keys():
tensor_name.append(key)

assert len(tensor_name) == 2
tensors = ['base_model.model.linear.lora_A.0.weight',
'base_model.model.linear.lora_B.0.weight']
assert sorted(tensor_name) == sorted(tensors)
shutil.rmtree('./tmp')
with tempfile.TemporaryDirectory() as tmpdir:
peft_utils.export_adapter_weights(sim, tmpdir, 'weight')
tensor_name = []
with safe_open(os.path.join(tmpdir, 'weight.safetensor'), framework="pt", device=0) as f:
for key in f.keys():
tensor_name.append(key)

assert len(tensor_name) == 2
tensors = ['base_model.model.linear.lora_A.0.weight',
'base_model.model.linear.lora_B.0.weight']
assert sorted(tensor_name) == sorted(tensors)

def _is_frozen(quantizer):
return quantizer._allow_overwrite == False and\
Expand Down

0 comments on commit b5df80f

Please sign in to comment.