In [None]:
# %pip install -q sentencepiece
# %pip install -q protobuf
# %pip install -q -U bitsandbytes
# %pip install -q -U git+https://github.com/huggingface/transformers.git
# %pip install -q -U git+https://github.com/huggingface/peft.git
# %pip install -q -U git+https://github.com/huggingface/accelerate.git
# %pip install -q scipy
# %pip install -q -U trl

In [None]:
device_map = {"": 0}

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, PeftModel, PeftConfig, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import Dataset, concatenate_datasets
import pandas as pd
import re
import ast
import os
import random
import json
import numpy as np
import gc

In [None]:
class NumpyArrayEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


def format_instruction_and_output(puzzle_string):
    puzzle_dict = json.loads(puzzle_string)
    train_pairs = puzzle_dict["train"]
    test_pairs = puzzle_dict["test"]

    instruction = ""
    output = ""

    for i, train_pair in enumerate(train_pairs):
        if i > 0:
            instruction += " "
        train_input_string = str(train_pair["input"])
        train_input_string = train_input_string.replace(" ", "")
        instruction += "Train_" + str(i + 1) + "_Input=" + train_input_string
        
        train_output_string = str(train_pair["output"])
        train_output_string = train_output_string.replace(" ", "")
        instruction += " Train_" + str(i + 1) + "_Output=" + train_output_string

    for i, test_pair in enumerate(test_pairs):
        test_input_string = str(test_pair["input"])
        test_input_string = test_input_string.replace(" ", "")
        instruction += " Test_" + str(i + 1) + "_Input=" + test_input_string
        
        test_output_string = str(test_pair["output"])
        test_output_string = test_output_string.replace(" ", "")
        output += "Test_" + str(i + 1) + "_Output=" + test_output_string

    return instruction, output

In [None]:
def generate_prompt(data_point):
    user_message = data_point["instruction"]
    assistant_message = data_point["output"]
    text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n{assistant_message}<|im_end|>"
    return text

In [None]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

In [None]:
# base_model_name = "microsoft/Orca-2-13b"
base_model_name = "merged_models/Orca2_merged_finetune_1"

new_adapter_name = "TEMP"

In [None]:
root_dir = "data/step_2/evaluation_max_3_trains_1_test/"
count = 0
correct = 0
too_long_count = 0
for root, dirs, files in os.walk(root_dir):
    for name in files:
        torch.cuda.empty_cache()
        
        if ".json" not in name:
            continue
            
        file_path = os.path.join(root, name)
        print("***********************************************************************")
        print(file_path)

        with open(file_path, "r") as f:
            data = json.load(f)

        train_tasks = data["train"]
        test_tasks = data["test"]

        if len(train_tasks) > 3:
            train_tasks = train_tasks[:3]

        for i, test_task in enumerate(test_tasks):
            test_task = [test_task]

            # token length
            prompt = instruction + " " + output
            tokenized_request = tokenizer.tokenize(prompt)
            token_length = len(tokenized_request)
            if token_length <= 4096:
                ############################# TRAIN #############################
                rot_range = 4
                flip_range = 2
                color_range = 9
                
                puzzles = []
                
                num_train_tasks = len(train_tasks)
                num_test_tasks = len(test_tasks)

                aug_range = num_train_tasks
                num_augmented = rot_range * flip_range * color_range * aug_range

                colors = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
                shifted_colors = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])

                augmentation_count = 0
                for rot_num in range(rot_range):
                    for flip_num in range(flip_range):
                        for color_num in range(color_range):
                            for aug_num in range(aug_range):
                                augmentation_count += 1

                                task = {
                                    "train": [
                                        {"input": [], "output": []},
                                    ],
                                    "test": [{"input": [], "output": []}],
                                }

                                while len(task["train"]) < num_train_tasks - 1:
                                    task["train"].append({"input": [], "output": []})

                                while len(task["test"]) < num_test_tasks:
                                    task["test"].append({"input": [], "output": []})

                                augmented_test_tasks = [train_tasks[aug_num]]

                                next_aug = aug_num + 1
                                if next_aug > num_train_tasks - 1:
                                    next_aug = next_aug - num_train_tasks

                                augmented_train_tasks = [
                                    train_tasks[next_aug],
                                ]

                                while len(augmented_train_tasks) < num_train_tasks - 1:
                                    next_aug += 1
                                    if next_aug > num_train_tasks - 1:
                                        next_aug = next_aug - num_train_tasks
                                    augmented_train_tasks.append(train_tasks[next_aug])

                                for i, augmented_train_task in enumerate(augmented_train_tasks):
                                    input = np.array(augmented_train_task["input"])
                                    output = np.array(augmented_train_task["output"])

                                    input_grid_dim_x = input.shape[0]
                                    input_grid_dim_y = input.shape[1]

                                    output_grid_dim_x = output.shape[0]
                                    output_grid_dim_y = output.shape[1]

                                    if color_num > 0:
                                        for orig_index in range(len(colors)):
                                            swap_index = orig_index + color_num
                                            if swap_index >= len(colors):
                                                swap_index = (orig_index + color_num) - len(
                                                    colors
                                                )
                                            shifted_colors[orig_index] = colors[swap_index]
                                        mapping = dict(zip(colors, shifted_colors))

                                        for x in range(input_grid_dim_x):
                                            for y in range(input_grid_dim_y):
                                                if input[x, y] != 0:
                                                    input[x, y] = mapping[input[x, y]]

                                        for x in range(output_grid_dim_x):
                                            for y in range(output_grid_dim_y):
                                                if output[x, y] != 0:
                                                    output[x, y] = mapping[output[x, y]]

                                    if rot_num > 0:
                                        input = np.rot90(input, k=rot_num).copy()
                                        output = np.rot90(output, k=rot_num).copy()

                                    if flip_num > 0:
                                        if flip_num == 1:
                                            input = np.flipud(input).copy()
                                            output = np.flipud(output).copy()
                                        else:
                                            input = np.fliplr(input).copy()
                                            output = np.fliplr(output).copy()

                                    task["train"][i]["input"] = input
                                    task["train"][i]["output"] = output

                                for i, augmented_test_task in enumerate(augmented_test_tasks):
                                    input = np.array(augmented_test_task["input"])
                                    output = np.array(augmented_test_task["output"])

                                    input_grid_dim_x = input.shape[0]
                                    input_grid_dim_y = input.shape[1]

                                    output_grid_dim_x = output.shape[0]
                                    output_grid_dim_y = output.shape[1]

                                    if color_num > 0:
                                        for orig_index in range(len(colors)):
                                            swap_index = orig_index + color_num
                                            if swap_index >= len(colors):
                                                swap_index = (orig_index + color_num) - len(
                                                    colors
                                                )
                                            shifted_colors[orig_index] = colors[swap_index]
                                        mapping = dict(zip(colors, shifted_colors))

                                        for x in range(input_grid_dim_x):
                                            for y in range(input_grid_dim_y):
                                                if input[x, y] != 0:
                                                    input[x, y] = mapping[input[x, y]]

                                        for x in range(output_grid_dim_x):
                                            for y in range(output_grid_dim_y):
                                                if output[x, y] != 0:
                                                    output[x, y] = mapping[output[x, y]]

                                    if rot_num > 0:
                                        input = np.rot90(input, k=rot_num).copy()
                                        output = np.rot90(output, k=rot_num).copy()

                                    if flip_num > 0:
                                        if flip_num == 1:
                                            input = np.flipud(input).copy()
                                            output = np.flipud(output).copy()
                                        else:
                                            input = np.fliplr(input).copy()
                                            output = np.fliplr(output).copy()

                                    task["test"][i]["input"] = input
                                    task["test"][i]["output"] = output
                                    
                                json_string = json.dumps(data, cls=NumpyArrayEncoder)
                                instruction, output = format_instruction_and_output(json_string)
                                
                                puzzles.append({"instruction": instruction, "output": output})
                                
                json_string = json.dumps(puzzles, cls=NumpyArrayEncoder)
                
                df = pd.read_json(json_string)

                # Convert the pandas dataframe to a dataset
                dataset = Dataset.from_pandas(df)

                # add the "prompt" column in the dataset
                text_column = [generate_prompt(data_point) for data_point in dataset]
                dataset = dataset.add_column("prompt", text_column)

                # Load base model
                model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    quantization_config=bnb_config,
                    device_map=device_map,
                    # local_files_only=True  # Add this line if the model is stored locally
                )
                model.config.use_cache = False
                # model.config.pretraining_tp = 1
                model.gradient_checkpointing_enable()
                model = prepare_model_for_kbit_training(model)

                # Load tokenizer
                #     tokenizer = AutoTokenizer.from_pretrained(base_model_name, add_eos_token=True, use_fast=False)
                tokenizer = AutoTokenizer.from_pretrained(base_model_name, add_eos_token=True)

                tokenizer.pad_token = tokenizer.eos_token
                tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 training

                per_device_train_batch_size = 1
                gradient_accumulation_steps = 4
                max_seq_length = 4096

                output_dir = "results/" + new_adapter_name

                steps_per_epoch = len(dataset) // (
                    per_device_train_batch_size * gradient_accumulation_steps
                )
                print("Steps:", steps_per_epoch)

                # Set training parameters
                training_arguments = TrainingArguments(
                    num_train_epochs=4,
                    output_dir=output_dir,
                    # max_steps=steps_per_epoch,
                    per_device_train_batch_size=per_device_train_batch_size,
                    gradient_accumulation_steps=gradient_accumulation_steps,
                    optim="paged_adamw_8bit",
                    lr_scheduler_type="cosine",
                    save_strategy="steps",
                    evaluation_strategy="no",
                    save_steps=1000,
                    logging_steps=1,
                    learning_rate=1e-4,
                    fp16=True,
                    warmup_steps=0.03,
                    group_by_length=True,
                    gradient_checkpointing=True,
                )

                # Set supervised fine-tuning parameters
                trainer = SFTTrainer(
                    model=model,
                    train_dataset=dataset,
                    peft_config=peft_config,
                    dataset_text_field="prompt",
                    max_seq_length=max_seq_length,
                    tokenizer=tokenizer,
                    args=training_arguments,
                )

                # Train model
                trainer.train()

                # Save trained model
                # adapter = "adapters/" + new_adapter_name
                # trainer.model.save_pretrained(adapter)
                
                del model
                gc.collect()
                
                fine_tuned_model = trainer.model
                
                ############################# PREDICT #############################
                data = {"train": train_tasks, "test": test_tasks}
                json_string = json.dumps(data, cls=NumpyArrayEncoder)
                instruction, output = format_instruction_and_output(json_string)
                user_message = instruction
                query = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"

                encodeds = tokenizer(query, return_tensors="pt", add_special_tokens=True)
                model_inputs = encodeds.to("cuda")

                print("instruction")
                print(instruction)

                print("output")
                print(output)
                print("")

                fine_tuned_model.eval()
                with torch.no_grad():
                  generated_ids = fine_tuned_model.generate(**model_inputs)

                decoded = tokenizer.batch_decode(generated_ids)
                result = decoded[0]
                result = result[len(query) :]

                match = re.search(r"Test_1_Output =\[\[.*?\]\]", result)
                if match:
                    predicted_array_string = match.group(0).replace("Test_1_Output =", "")
                    predicted_array = ast.literal_eval(predicted_array_string)
                    print("\nPrediction:")
                    print(predicted_array)

                    print("\nGround Truth:")
                    target_match = re.search(r"Test_1_Output=\[\[.*?\]\]", output)
                    target_array_string = target_match.group(0).replace("Test_1_Output=", "")
                    target_array = ast.literal_eval(target_array_string)
                    print(target_array)

                    are_equal = np.array_equal(predicted_array, target_array)
                    if are_equal:
                        print("ARRAYS ARE EQUAL!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                        correct += 1
                    else:
                        print("ARRAYS NOT EQUAL.")

                else:
                    print("\nNo regx match. Full Result:")
                    print(result)
                    print("\nGround Truth:")
                    print(output)
                
                del fine_tuned_model
                gc.collect()
                count += 1
            else:
                print("token_length too long, skipping:", token_length)
                too_long_count += 1
                break