# Finetuned Instruction model from Finetuned Base model with 딸깍
- llama3 base model과 inst model 준비
- kowiki가 학습된 base model 준비
- llama3 base model weights에서 inst model weights를 빼고 저장
- 저장된 differences of weights를 kowiki가 학습된 base model weights에서 뺌
- kowiki가 학습된 inst model 완성

In [1]:
import torch, torchtune, yaml
from torchtune import config, utils
from generate import InferenceRecipe
from full_finetune_single_device import FullFinetuneRecipeSingleDevice
from omegaconf import DictConfig
from matplotlib import pyplot as plt

In [2]:
config_base = './configs/llams_base_generation_kowiki_20240520__.yaml'

In [3]:
config_lists = {
    'base_origin': {
        'checkpoint_dir': 'checkpoints/Meta-Llama-3-8B/original',
        'checkpoint_files': ['consolidated.00.pth'],
        'output_dir': 'checkpoints/Meta-Llama-3-8B/original'
    },
    'inst_origin': {
        'checkpoint_dir': 'checkpoints/Meta-Llama-3-8B-Instruct/original',
        'checkpoint_files': ['consolidated.00.pth'],
        'output_dir': 'checkpoints/Meta-Llama-3-8B-Instruct/original'
    },
    'base_lastepoch': {
        'checkpoint_dir': 'checkpoints/Llama-3-LLaMS-kowiki',
        'checkpoint_files': ['meta_model_4.pt'],
        'output_dir': 'checkpoints/Llama-3-LLaMS-kowiki'
    },
}

In [5]:
config_item = 'base_origin'
with open(config_base) as f:
    config_dict = DictConfig(yaml.safe_load(f))
for k, v in config_lists[config_item].items():
    config_dict.checkpointer[k] = v
config_dict.device = 'cuda:0'
    
base_origin = InferenceRecipe(config_dict)
base_origin.setup(config_dict)

DEBUG:torchtune.utils.logging:Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.


In [6]:
config_item = 'inst_origin'
with open(config_base) as f:
    config_dict = DictConfig(yaml.safe_load(f))
for k, v in config_lists[config_item].items():
    config_dict.checkpointer[k] = v
config_dict.device = 'cuda:1'
    
inst_origin = InferenceRecipe(config_dict)
inst_origin.setup(config_dict)

DEBUG:torchtune.utils.logging:Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.


In [7]:
config_item = 'base_lastepoch'
with open('configs/base_lastepoch.yaml') as f:
    config_dict = DictConfig(yaml.safe_load(f))
# for k, v in config_lists[config_item].items():
#     config_dict.checkpointer[k] = v
config_dict.device = 'cuda:2'
    
base_lastepoch = FullFinetuneRecipeSingleDevice(config_dict)
base_lastepoch.setup(config_dict)

DEBUG:torchtune.utils.logging:Setting manual seed to local seed 31344. Local seed is seed + rank = 31344 + 0


Writing logs to /tmp/lora_finetune_output/log_1726186433.txt


INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Compiling model with torch.compile...
INFO:torchtune.utils.logging:Memory Stats after model init:
{'peak_memory_active': 16.565969408, 'peak_memory_alloc': 16.565969408, 'peak_memory_reserved': 16.638803968}
INFO:torchtune.utils.logging:Tokenizer is initialized from file.
INFO:torchtune.utils.logging:Optimizer is initialized.
INFO:torchtune.utils.logging:Loss is initialized.
INFO:torchtune.utils.logging:Dataset and Sampler are initialized.


In [8]:
diff_param_base_inst = []
for x, y in zip(base_origin._model.parameters(), inst_origin._model.parameters()):
    diff_param_base_inst.append(x.detach().to('cpu') - y.detach().to('cpu'))

In [9]:
with torch.no_grad():
    for z, xh in zip(diff_param_base_inst, base_lastepoch._model.parameters()):
        xh -= z.detach().to('cuda:2')

In [10]:
base_lastepoch.save_checkpoint(0)

INFO:torchtune.utils.logging:Model checkpoint of size 16.06 GB saved to checkpoints/Llama-3-LLaMS-inst-from-base-start/meta_model_0.pt
INFO:torchtune.utils.logging:Recipe checkpoint of size 0.00 GB saved to checkpoints/Llama-3-LLaMS-inst-from-base-start/recipe_state.pt
