In [1]:
# Additional modules to docker environment
# ------------------------------------------------

!pip install datasets torch_optimizer lion_pytorch clang_repl_kernel --break-system-packages
!pip install --upgrade clang-repl-kernel  --break-system-packages

[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/torchtext-0.18.0a0+9bed85d-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/torchaudio-2.6.0a0+d883142-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/looseversion-1.3.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr

In [2]:
# Prompt to generate sample data for reaon training
# ------------------------------------------------

In [3]:
# import and setup
# ------------------------------------------------

import os
import copy
import json
import torch
import gc
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
)
from transformers.optimization import Adafactor
from datasets import Dataset

# For hyperparameter optimization
import optuna
import pickle
from Config import SimpleConfig
from ClangReplInterface import ClangReplInterface

ref_checkpoint_path = "./saved_models/sample/checkpoint.pt"
last_checkpoint_path = "./saved_models/reasoning/checkpoint.pt"
checkpoint_dir_pre = "./saved_models/reasoning/epoch_"

test_target_object_file = "./convert/test_target_and_object/converted_data.json"

config = SimpleConfig()




In [4]:
# Utility functions
# ------------------------------------------------

import re

def remove_comments(code: str):
    pattern = re.compile(r'//.*?$|/\*.*?\*/', re.DOTALL | re.MULTILINE)
    return re.sub(pattern, '', code)


def find_all_tag_indexes(text, tag):
    """Return a list of starting indexes where the tag occurs in the text."""
    indexes = []
    start = 0
    while True:
        idx = text.find(tag, start)
        if idx == -1:
            break
        indexes.append(idx)
        start = idx + len(tag)
    return indexes


def get_tag_start_end(idx, starts, ends, tag, full_text):
    start = starts[idx]+len(tag)
    end = ends[idx]
    return full_text[start: end].strip()
    

def reward_correct(full_text):
    # handle only first answer
    reward = 0.0
    test_target_open = find_all_tag_indexes(full_text, "<Test Target>")
    test_target_close = find_all_tag_indexes(full_text, "</Test Target>")
    clang_repl_open = find_all_tag_indexes(full_text, "<Clang-repl Test>")
    clang_repl_close = find_all_tag_indexes(full_text, "</Clang-repl Test>")
    if len(test_target_open) == 0 or len(test_target_close) == 0 or len(clang_repl_open) == 0 or len(clang_repl_close) == 0:
        return reward, '<Test Target> or <Clang-repl Test> not found'
    if len(test_target_open) != len(test_target_close) or len(clang_repl_open) != len(clang_repl_close):
        return reward , '<Test Target> or <Clang-repl Test> pair not match'
    if not all(x < y for x, y in zip(test_target_open, test_target_close)):
        return reward, '<Test Target> not closed properly'
    if not all(x < y for x, y in zip(clang_repl_open, clang_repl_close)):
        return reward, '<Clang-repl Test> not closed properly' 
    target_text = get_tag_start_end(-1, test_target_open, test_target_close, "<Test Object>", full_text)
    target_text = remove_comments(target_text)
    target_text = ">>> "+target_text.replace('\n', '')

    for idx in range(len(clang_repl_open)):
        clang_repl_test = get_tag_start_end(idx, clang_repl_open, clang_repl_close, "<Clang-repl Test>", full_text)
        clang_repl = ClangReplInterface()
        test_case_with_target = target_text+'\n'+clang_repl_test
        #print(test_case_with_target)
        result, response = clang_repl.run_verify(test_case_with_target)
        reward = 0.0
        if result == 'ok':
            reward = 2.0
        elif result == 'fail':
            reward = 1.0
        elif result == 'error':
            reward = 0.0
        else:
            assert False
        return reward, response
    else:
        return reward, ''

In [5]:
# Load sample dataset
# ------------------------------------------------

def load_sample_dataset(pk_file):
    with open(config.dataset_file, "rb") as f:
        global_samples = pickle.load(f)
        sample_dataset = []
        for sample in global_samples:
            sample_dataset.append({"content": sample + "<|endoftext|>"})
        return sample_dataset

train_data_sample = load_sample_dataset(config.dataset_file)

In [6]:
# Load reasoning dataset
# ------------------------------------------------

def get_test_target_content(full_text):
    test_target_open = find_all_tag_indexes(full_text, "<Test Target>")
    test_target_close = find_all_tag_indexes(full_text, "</Test Target>")
    target_text = get_tag_start_end(-1, test_target_open, test_target_close, "<Test Target>", full_text)
    return target_text

def load_reasoning_dataset(test_target_object_file):
    with open(test_target_object_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
        train = []
        val = []
        categories = set()
        data_dic = {}
        for item in data:
            categories.add(item['category'])
        for cat in categories:
            data_dic[cat] = []
        for item in data:
            data_dic[item['category']].append(item['content'])
        for cat in categories:
            for idx, item in enumerate(data_dic[cat]):
                content = f"### Instruction\n\nn<Test Target>\n{get_test_target_content(item)}\n</Test Target>\nWrtie a Clang-repl Test\n### Response\n"
                if idx >=14:
                    val.append(({"content":content}))
                else:
                    train.append(({"content":content}))

        return train, val

test_target_object_file = "manual_data_set/ReasoningTestTarget.json"
reasoning_dataset, val_reasoning_dataset = load_reasoning_dataset(test_target_object_file)

reasoning_dataset

[{'content': '### Instruction\n\nn<Test Target>\nint factorial(int n) {     if (n <= 1) {         return 1;     }     return n * factorial(n - 1); }\n</Test Target>\nWrtie a Clang-repl Test\n### Response\n'},
 {'content': '### Instruction\n\nn<Test Target>\nint fibonacci(int n) {     if (n <= 1) {         return n;     }     return fibonacci(n - 1) + fibonacci(n - 2); }\n</Test Target>\nWrtie a Clang-repl Test\n### Response\n'},
 {'content': '### Instruction\n\nn<Test Target>\nint sumToN(int n) {     if (n <= 0) {         return 0;     }     return n + sumToN(n - 1); }\n</Test Target>\nWrtie a Clang-repl Test\n### Response\n'},
 {'content': '### Instruction\n\nn<Test Target>\nint power(int base, int exponent) {     if (exponent == 0) {         return 1;     }     return base * power(base, exponent - 1); }\n</Test Target>\nWrtie a Clang-repl Test\n### Response\n'},
 {'content': '### Instruction\n\nn<Test Target>\nint gcd(int a, int b) {     if (b == 0) {         return a;     }     retu

In [7]:
# Check token length
# ------------------------------------------------

from torchtext.data.utils import get_tokenizer

# Instantiate torchtext's basic English tokenizer
tokenizer = get_tokenizer("basic_english")

# Iterate over the dataset, tokenize the content, and print token lengths
for data in reasoning_dataset:
    content = data['content']
    # Tokenize the content using the basic English tokenizer
    tokens = tokenizer(content)
    print(f"Token length: {len(tokens)}\n")

Token length: 39

Token length: 44

Token length: 39

Token length: 44

Token length: 42

Token length: 61

Token length: 43

Token length: 39

Token length: 57

Token length: 42

Token length: 95

Token length: 63

Token length: 42

Token length: 41

Token length: 67

Token length: 74

Token length: 65

Token length: 214

Token length: 137

Token length: 166

Token length: 109

Token length: 196

Token length: 232

Token length: 83

Token length: 104

Token length: 150

Token length: 69

Token length: 131

Token length: 57

Token length: 57

Token length: 69

Token length: 112

Token length: 147

Token length: 53

Token length: 69

Token length: 103

Token length: 105

Token length: 108

Token length: 52

Token length: 57

Token length: 53

Token length: 84

Token length: 42

Token length: 42

Token length: 39

Token length: 42

Token length: 46

Token length: 55

Token length: 45

Token length: 49

Token length: 54

Token length: 46

Token length: 44

Token length: 52

Token length: 



In [8]:
# Check val dataset
# ------------------------------------------------

val_reasoning_dataset

[{'content': '### Instruction\n\nn<Test Target>\nint sumEvenElements(int arr[], int n) {     if (n <= 0) {         return 0;     }          int sum = sumEvenElements(arr, n - 1);     if (arr[n-1] % 2 == 0) {         sum += arr[n-1];     }     return sum; }\n</Test Target>\nWrtie a Clang-repl Test\n### Response\n'},
 {'content': '### Instruction\n\nn<Test Target>\nvoid pigeonholeSort(int arr[], int n) {     int min = arr[0], max = arr[0];          for (int i = 1; i < n; i++) {         if (arr[i] < min)             min = arr[i];         if (arr[i] > max)             max = arr[i];     }          int range = max - min + 1;     int pigeonholes[range];          for (int i = 0; i < range; i++)         pigeonholes[i] = 0;          for (int i = 0; i < n; i++)         pigeonholes[arr[i] - min]++;          int index = 0;     for (int i = 0; i < range; i++) {         while (pigeonholes[i] > 0) {             arr[index++] = i + min;             pigeonholes[i]--;         }     } }\n</Test Target>\nWr