In [1]:
from unsloth import FastLanguageModel
from poi import settings
from poi.dataset.llm import load_prompt_completion_llm_dataset, load_tokenized_llm_dataset
from poi.llm import LLMConfig, inference, load_fast_inference_model


config = LLMConfig(
        run_name="llama3-nyc-test-full-fintune", num_epochs=8, batch_size=4, gradient_accumulation_steps=16, do_eval=True, resume_from_checkpoint=True
    )

DATASET_DIR = settings.DATASETS_DIR / "NYC" / "LLM Dataset" / "paper"
train_ds = load_prompt_completion_llm_dataset(DATASET_DIR / "train_codebook.json")
test_ds = load_prompt_completion_llm_dataset(DATASET_DIR / "test_codebook.json")


model, _ = FastLanguageModel.from_pretrained(
    model_name=config.output_dir.as_posix(),
    max_seq_length=config.max_length,
    dtype=None,
    load_in_4bit=False,
    load_in_8bit=False,
)
model = FastLanguageModel.for_inference(model)


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm
Skipping import of cpp extensions due to incompatible torch version 2.8.0+cu128 for torchao version 0.14.0         Please see GitHub issue #2919 for more info


🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.9: Fast Llama patching. Transformers: 4.56.2.
   \\   /|    NVIDIA RTX 6000 Ada Generation. Num GPUs = 1. Max memory: 47.372 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.03it/s]
Unsloth 2025.10.9 patched 32 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


In [2]:
from tqdm import tqdm


def eval(model, ds):
    total = len(ds)
    correct = 0
    for i in tqdm(range(total)):
        res = ""
        retry_count = 0
        while retry_count < 5 and res == "":
            res = inference(config, model, ds[i]["prompt"] + "<a_").strip()  # provide <a_ as prefix hint
            retry_count += 1
        res = "<a_" + res
        if ds[i]["completion"] in res:
            correct += 1
    return correct / total


In [3]:
# NYC base trained on paper-provided NYC train_codebook.json, best model, no quantization
test_res = eval(model, test_ds)
print(f"Test Accuracy: {test_res}")

train_res = eval(model, train_ds)
print(f"Train Accuracy: {train_res}")

  0%|          | 0/876 [00:00<?, ?it/s]

100%|██████████| 876/876 [08:52<00:00,  1.65it/s]


Test Accuracy: 0.3276255707762557


100%|██████████| 2848/2848 [28:52<00:00,  1.64it/s]

Train Accuracy: 0.37429775280898875



