In [1]:
import json

def parse_concatenated_json(path):
    objects = []
    buffer = []
    brace_level = 0

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            # Count opening/closing braces not inside strings
            brace_level += line.count("{")
            brace_level -= line.count("}")

            buffer.append(line)

            # When brace level drops to zero we have a complete JSON object
            if brace_level == 0 and buffer:
                block = "".join(buffer).strip()
                if block:
                    objects.append(json.loads(block))
                buffer = []

    return objects


In [2]:
x = parse_concatenated_json('/Users/artem.semidetnov/Documents/Predictor/data.json')

In [3]:
def simplify_jsonl(entry : dict) -> dict:
    two_field_dict = dict()
    two_field_dict['prompt'] = str(entry['Context']) + "<<<break>>>" + str(entry['Premises']) + '<<<break>>>' + entry['Expected type']
    two_field_dict['completion'] = entry['Expression']
    return two_field_dict

In [6]:
import random

path_train = '/Users/artem.semidetnov/Documents/Predictor/train1/train.jsonl'
path_valid = '/Users/artem.semidetnov/Documents/Predictor/train1/valid.jsonl'
path_test = '/Users/artem.semidetnov/Documents/Predictor/train1/test.jsonl'

train_sample = random.sample(x, 1000)
with open(path_train, 'w', encoding='utf-8') as f:
    for entry in train_sample:
        f.write(json.dumps(simplify_jsonl(entry)) + "\n")

valid_sample = random.sample(x, 100)
with open(path_valid, 'w', encoding='utf-8') as f:
    for entry in valid_sample:
        f.write(json.dumps(simplify_jsonl(entry)) + "\n")

test_sample = random.sample(x, 100)
with open(path_test, 'w', encoding='utf-8') as f:
    for entry in test_sample:
        f.write(json.dumps(simplify_jsonl(entry)) + "\n")

In [None]:
a_chunk = random.sample(x, 1500)
path_chunk = '/Users/artem.semidetnov/Documents/Predictor/train1/datachunk.jsonl'
with open(path_chunk, 'w', encoding='utf-8') as f:
    for entry in a_chunk:
        f.write(json.dumps(simplify_jsonl(entry)) + "\n")

In [None]:
!pip3 install mlx-lm

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = 'mlx-community/codegemma-1.1-2b-4bit'
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name)

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,                  # rank of LoRA matrices
    lora_alpha=16,        # scaling factor
    target_modules=["q_proj", "v_proj"],  # layers to apply LoRA
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

dataset = load_dataset("json", data_files={"train": "/Users/artem.semidetnov/Documents/Predictor/train1/train.jsonl", "validation": "/Users/artem.semidetnov/Documents/Predictor/train1/valid.jsonl"})

tokenizer = AutoTokenizer.from_pretrained("mlx-community/codegemma-1.1-2b-4bit")
tokenizer.pad_token = tokenizer.eos_token
def preprocess(batch):
    # Combine prompt and completion
    combined_text = [p + tokenizer.eos_token + c for p, c in zip(batch["prompt"], batch["completion"])]
    # Tokenize
    return tokenizer(combined_text, truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(preprocess, batched=True)

In [None]:
print(tokenized_dataset["train"][0].keys())

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./lora_mlxlm",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    learning_rate=3e-4,
    logging_steps=50,
    save_steps=500,
    save_total_limit=2,
    bf16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"]
)

trainer.train()

In [None]:
# training?

# model = 'mlx-community/codegemma-1.1-2b-4bit'
model = 'mistralai/Mistral-7B-v0.1'
dataset = '/Users/artem.semidetnov/Documents/Predictor/train1/'
# output = './lora-output'

!python3 -m mlx_lm lora --model {model} --data {dataset} --train --fine-tune-type lora --iters 1000 --steps-per-eval 100 --batch-size 4

## Train

In [15]:
model = 'Qwen/Qwen2.5-Coder-0.5B-Instruct'
dataset = '/Users/artem.semidetnov/Documents/Predictor/train1/'

!python3 -m mlx_lm.lora --model {model} --data {dataset} --train --learning-rate 1e-5 --iters 100 --fine-tune-type lora --batch-size 1

Calling `python -m mlx_lm.lora...` directly is deprecated. Use `mlx_lm.lora...` or `python -m mlx_lm lora ...` instead.
Loading pretrained model
Fetching 7 files: 100%|████████████████████████| 7/7 [00:00<00:00, 36291.88it/s]
Loading datasets
Training
Trainable parameters: 0.594% (2.933M/494.033M)
Starting training..., iters: 100
Calculating loss...: 100%|██████████████████████| 25/25 [00:05<00:00,  4.47it/s]
Iter 1: Val loss 2.403, Val took 5.599s
Iter 10: Train loss 2.478, Learning Rate 1.000e-05, It/sec 1.069, Tokens/sec 1459.681, Trained Tokens 13658, Peak mem 8.447 GB
Iter 20: Train loss 2.334, Learning Rate 1.000e-05, It/sec 1.836, Tokens/sec 1798.928, Trained Tokens 23454, Peak mem 8.448 GB
Iter 30: Train loss 1.887, Learning Rate 1.000e-05, It/sec 2.118, Tokens/sec 1809.140, Trained Tokens 31994, Peak mem 8.448 GB
Iter 40: Train loss 1.723, Learning Rate 1.000e-05, It/sec 2.037, Tokens/sec 1798.688, Trained Tokens 40826, Peak mem 8.448 GB
Iter 50: Train loss 1.487,

In [64]:
model = 'Qwen/Qwen2.5-Coder-0.5B-Instruct'
adap_path = '/Users/artem.semidetnov/Documents/Predictor/adapters'
test_prompt = r'func add(x y : Nat) : Nat'

# !python3 -m mlx_lm.generate --model {model} --max-tokens 500 -adapters-path {adap_path} --prompt {test_prompt}
# prompt = '\\\\func add(x y : Nat) : Nat'
# prompt= r'''['d : D.E', 'D : BottomJoinSemilattice', 'F : Functor D MonoidCat']<<<break>>>['| func-* {x y : E {Dom}} : func (x * y) = func x * func y', '| \\\\infixl 7 * E E : E', \"\\\\func \\\\infix 1 = {A : \\\\Type} (a a' : A) : \\\\Type => a = a'\", '| E : \\\\Set', '| Cod : BaseSet', '| func (E {Dom}) : E {Cod}', \"\\\\func inMap {D : BottomJoinSemilattice} {F1 : Functor D MonoidCat} (d : D.E) : MonoidHom (F.F d) (MonoidLatticeColimit {D} F1) {\\n  | func => \\\\lam (a : E {Dom {\\\\this}}) =>\\n  in~ {\\\\Sigma (j : D.Ob) (F j)} {\\\\lam (s : \\\\Sigma (j : D.Ob) (F j)) (s' : \\\\Sigma (j : D.Ob) (F j)) =>\\n    \\\\Sigma (p : D.Hom s.1 s'.1) (Func {s.1} {s'.1} p s.2 = s'.2)} (d, a)\\n} => \\\\new MonoidHom {\\n  | func-ide => {?hidden}\\n  | func-* => {?hidden}\\n}\", '| Dom : BaseSet']<<<break>>>\\Pi {x y : E {Dom {inMap {D} {F} d}}} ->\n  func {inMap {D} {F} d} (x * y) = func {inMap {D} {F} d} x * func {inMap {D} {F} d} y'''

prompt = r'''['d : D.E', 'D : BottomJoinSemilattice', 'F : Functor D MonoidCat']<<<break>>>['| func-* {x y : E {Dom}} : func (x * y) = func x * func y', '| \\\\infixl 7 * E E : E', \"\\\\func \\\\infix 1 = {A : \\\\Type} (a a' : A) : \\\\Type => a = a'\", '| E : \\\\Set', '| Cod : BaseSet', '| func (E {Dom}) : E {Cod}', \"\\\\func inMap {D : BottomJoinSemilattice} {F1 : Functor D MonoidCat} (d : D.E) : MonoidHom (F.F d) (MonoidLatticeColimit {D} F1) {\\n  | func => \\\\lam (a : E {Dom {\\\\this}}) =>\\n  in~ {\\\\Sigma (j : D.Ob) (F j)} {\\\\lam (s : \\\\Sigma (j : D.Ob) (F j)) (s' : \\\\Sigma (j : D.Ob) (F j)) =>\\n    \\\\Sigma (p : D.Hom s.1 s'.1) (Func {s.1} {s'.1} p s.2 = s'.2)} (d, a)\\n} => \\\\new MonoidHom {\\n  | func-ide => {?hidden}\\n  | func-* => {?hidden}\\n}\", '| Dom : BaseSet']<<<break>>>\\Pi {x y : E {Dom {inMap {D} {F} d}}} ->\n  func {inMap {D} {F} d} (x * y) = func {inMap {D} {F} d} x * func {inMap {D} {F} d} y'''

# prompt = prompt.replace('{', '(')
# prompt = prompt.replace('}', ')')
# prompt = prompt.replace('\\\\', '\\')


!python3 -m mlx_lm.generate --model {model} --adapter-path {adap_path} --prompt "{prompt}"

Calling `python -m mlx_lm.generate...` directly is deprecated. Use `mlx_lm.generate...` or `python -m mlx_lm generate ...` instead.
Fetching 7 files: 100%|████████████████████████| 7/7 [00:00<00:00, 19341.32it/s]
\lam (d : D.E) => func-* {inMap {D} {F} d}
Prompt: 416 tokens, 2502.451 tokens-per-sec
Generation: 23 tokens, 110.068 tokens-per-sec
Peak memory: 1.805 GB
