In [1]:
import torch
import pickle
import os
from datetime import datetime
from tqdm import tqdm
from model import get_model
from data_utils import read_abc, collate_function
from dataset import ABCDataset
import youtokentome as yttm
from transformers import Trainer, TrainingArguments, TrainerCallback, EncoderDecoderModel, EncoderDecoderConfig
from pathlib import Path

In [2]:
def get_training_files(dir):
    dir = Path(dir)
    return list(dir.glob("*.abc"))

def save_train_data(data, base_filename="train_data"):
    # Save pickle
    with open(f"{base_filename}.pkl", 'wb') as f:
        pickle.dump(data, f)
    
    # Save human readable
    with open(f"{base_filename}.txt", 'w', encoding='utf-8') as f:
        for keys_tokens, notes_tokens in data:
            # Decode for readability
            keys_text = tokenizer.decode(keys_tokens)
            notes_text = tokenizer.decode(notes_tokens)
            
            # print(notes_tokens)
            # Write tokens and decoded text
            f.write("KEYS: " + keys_text[0] + "\n")
            f.write("NOTES: " + '\n'.join(notes_text) + "\n")
            f.write("-" * 80 + "\n")

def load_train_data(filename="train_data.pkl"):
    try:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    except:
        return None

def test_model(model, tokenizer, keys_tokens, notes_tokens):
    notes_tokens_in = notes_tokens[:8]
    notes_tokens_out = notes_tokens[8:]
    notes_tokens_in = [item for sublist in notes_tokens_in for item in sublist]
    notes_tokens_out = [item for sublist in notes_tokens_out for item in sublist]
    context_tokens = [2] + keys_tokens + notes_tokens_in + [3]
    label = [2] + notes_tokens_out + [3]
    
    # print(f"Context Tokens Fed to Generate: {context_tokens}")

    context_tokens = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0)

    # print(f"context (reshaped): {context_tokens}")
    # print(f"label: {label}")

    if torch.cuda.is_available():
        context_tokens = context_tokens.cuda()

    gen_tokens = model.generate(input_ids=context_tokens, 
                                max_length=500, 
                                min_length=64,
                                early_stopping=False,
                                # no_repeat_ngram_size=4,
                                # length_penalty=1.2,
                                repetition_penalty=1.1,
                                # Beam search
                                # do_sample=False,
                                # num_beams=15,
                                # Sampling
                                do_sample = True,
                                temperature = 0.7,
                                top_k = 50,
                                )
                                
    gen_tokens = gen_tokens[0].tolist()
    
    # print(f"Context Token IDs: {context_tokens}")
    # print(f"Label Token IDs: {label}")
    # print(f"Generated Token IDs: {gen_tokens}")

    pred = tokenizer.decode(gen_tokens, ignore_ids=[0,1,2,3])[0]
    label = tokenizer.decode(label, ignore_ids=[0,1,2,3])[0]
    pred = pred.replace(" ", "").replace("|", "|\n")
    label = label.replace(" ", "").replace("|", "|\n")

    return pred, label

class TestingCallback(TrainerCallback):
    def __init__(self, model, tokenizer, test_data, every_n_steps=100):
        self.model = model
        self.tokenizer = tokenizer
        self.test_data = test_data
        self.every_n_steps = every_n_steps
        self.output_dir = "test_outputs"
        os.makedirs(self.output_dir, exist_ok=True)
        
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.every_n_steps == 0: 
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_file = os.path.join(self.output_dir, f"test_output_step_{state.global_step}_{timestamp}.txt")
            
            with open(output_file, "w", encoding="utf-8") as f:
                f.write(f"Test Results at Step {state.global_step}\n")
                f.write("=" * 80 + "\n\n")
                
                for i, (keys_tokens, notes_tokens) in enumerate(self.test_data):
                    # Get model predictions
                    pred, target = test_model(self.model, self.tokenizer, keys_tokens, notes_tokens)

                    # Write results
                    f.write(f"Example {i+1}\n")
                    f.write("-" * 40 + "\n")
                    f.write("Input:\n")
                    f.write('\n'.join(self.tokenizer.decode(keys_tokens)) + "\n")
                    f.write('\n'.join(self.tokenizer.decode(notes_tokens)[:8]) + "\n\n")
                    f.write("Target:\n")
                    f.write(target + "\n")
                    f.write("Predicted:\n")
                    f.write(pred + "\n\n")
                    f.write("=" * 80 + "\n\n")

In [3]:
train_dir = "./cleaned_data"
min_sequence_length = 64
max_sequence_length = 500
checkpoint = "./ABCModel/checkpoint-21255"

training_args = TrainingArguments(
    output_dir="./ABCModel",
    overwrite_output_dir=True,
    num_train_epochs=2,
    per_device_train_batch_size=2,
    save_strategy="steps",
    save_steps=5000,
    # save_strategy="best",
    # metric_for_best_model="loss",
    # evaluation_strategy="epoch",
    save_total_limit=10,
    gradient_accumulation_steps=1,
    dataloader_num_workers=0,
    # remove_unused_columns=False,
    learning_rate=1e-6,
    bf16=False,
    save_safetensors=False,
    # optim="adamw_torch_fused",
    dataloader_pin_memory=True,
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=100,
)

tokenizer = yttm.BPE("abc.yttm")
model = EncoderDecoderModel.from_pretrained(checkpoint) if checkpoint else get_model(vocab_size=tokenizer.vocab_size())
cached_data = load_train_data()

if cached_data is not None:
    print("Using cached training data")
    train_data = cached_data
else:
    print("Creating new training data...")
    train_paths = get_training_files(train_dir)
    train_data = []
    for p in tqdm(train_paths):
        (keys, notes) = read_abc(p)
        if keys is None:
            continue

        keys_tokens = tokenizer.encode(keys)
        bars = notes.split(" | ")
        notes_tokens = [tokenizer.encode(i + " | ") for i in bars]
        sequence_len = sum(len(i) for i in notes_tokens)
        
        if min_sequence_length < sequence_len < max_sequence_length:
            train_data.append((keys_tokens, notes_tokens))

    save_train_data(train_data) 

print(f"Total training examples: {len(train_data)}")

train_dataset = ABCDataset(train_data)
testing_callback = TestingCallback(model, tokenizer, train_data[:5], every_n_steps=10)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_function,
    train_dataset=train_dataset,
    # eval_dataset=test_dataset,
    # compute_metrics=compute_eval_metrics,
    callbacks=[testing_callback]
)

EncoderDecoderModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Using cached training data
Total training examples: 42510


In [None]:
trainer.train(resume_from_checkpoint=checkpoint)

  0%|          | 0/42510 [00:00<?, ?it/s]

