In [None]:
# Step 1: Clone the repo and setup environment

import os
import subprocess

repo_dir = "./tamarind-finetune"
repo_url = "https://github.com/smartrics/tamarind-finetune.git"

if os.path.isdir(repo_dir):
    print("Directory 'tamarind-finetune' exists. Pulling latest changes...")
    subprocess.run(["git", "-C", repo_dir, "pull"], check=True)
else:
    print("Directory 'tamarind-finetune' does not exist. Cloning repository...")
    subprocess.run(["git", "clone", repo_url, repo_dir], check=True)
print("finished!")

In [None]:
%cd ./tamarind-finetune

In [None]:
# These are the core libraries: Transformers, Datasets, PEFT (for LoRA), TRL (Trainer), BitsAndBytes (4-bit quant)
%pip install -r requirements.txt


In [None]:
from huggingface_hub import notebook_login

# Authenticate (you'll be prompted)
notebook_login()


In [None]:
# --- 1. Prepare the Data ---

from datasets import Dataset, DatasetDict
import json

def load_and_combine_data(spec_file, wf_file):
    data = []
    try:
        with open(spec_file, 'r') as f_spec, open(wf_file, 'r') as f_wf:
            for line_spec, line_wf in zip(f_spec, f_wf):
                try:
                    spec_obj = json.loads(line_spec)
                    wf_obj = json.loads(line_wf)

                    spec_messages = spec_obj.get("messages", [])
                    wf_messages = wf_obj.get("messages", [])

                    spec_system_content = None
                    spec_user_content = None
                    spec_assistant_content = None

                    wf_system_content = None
                    wf_user_content = None
                    wf_assistant_content = None

                    for message in spec_messages:
                        role = message.get("role")
                        content = message.get("content")
                        if role == "system" and content:
                            spec_system_content = content
                        elif role == "user" and content:
                            spec_user_content = content
                        elif role == "assistant" and content:
                            spec_assistant_content = content

                    for message in wf_messages:
                        role = message.get("role")
                        content = message.get("content")
                        if role == "system" and content:
                            wf_system_content = content
                        elif role == "user" and content:
                            wf_user_content = content
                        elif role == "assistant" and content:
                            wf_assistant_content = content

                    # Combine data from both files (you might need a more specific way to align them based on 'id' or other criteria)
                    if spec_user_content and spec_assistant_content and wf_user_content and wf_assistant_content:
                        # Example: Concatenate the user inputs and assistant outputs
                        combined_input = (spec_system_content + " " if spec_system_content else "") + spec_user_content + \
                                         " [SEP] " + \
                                         (wf_system_content + " " if wf_system_content else "") + wf_user_content
                        combined_output = spec_assistant_content + " [SEP] " + wf_assistant_content
                        data.append({"input": combined_input, "output": combined_output})
                    elif spec_user_content and spec_assistant_content:
                        data.append({"input": (spec_system_content + " " if spec_system_content else "") + spec_user_content,
                                     "output": spec_assistant_content})
                    elif wf_user_content and wf_assistant_content:
                        data.append({"input": (wf_system_content + " " if wf_system_content else "") + wf_user_content,
                                     "output": wf_assistant_content})
                    else:
                        print(f"Warning: Skipping misaligned or incomplete data pair: {line_spec.strip()} - {line_wf.strip()}")

                except json.JSONDecodeError as e:
                    print(f"Warning: Skipping invalid JSON line: {e}")
                except ValueError:
                    print("Warning: Files have different number of lines. Processing will stop at the shorter file.")
                    break
    except FileNotFoundError as e:
        print(f"Error: File not found: {e}")
        return None
    return Dataset.from_dict({"input": [item["input"] for item in data], "output": [item["output"] for item in data]})

# Load and combine data for each split
train_dataset = load_and_combine_data("spec_training_data.jsonl", "wf_training_data.jsonl")
eval_dataset = load_and_combine_data("spec_validation_data.jsonl", "wf_validation_data.jsonl")
test_dataset = load_and_combine_data("spec_test_data.jsonl", "wf_test_data.jsonl")

# Create a DatasetDict
raw_datasets = DatasetDict({
    "train": train_dataset,
    "validation": eval_dataset,
    "test": test_dataset
})

if raw_datasets["train"] is None or raw_datasets["validation"] is None or raw_datasets["test"] is None:
    print("Error loading datasets. Please check file paths and contents.")
    exit()



In [None]:
# --- 2. Load Tokenizer and Model ---
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "google/codet5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

prefix = "" # Adjust prefix if needed

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["input"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["output"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]



In [None]:
# --- 3. Configure Training Arguments ---
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
from huggingface_hub import notebook_login
import os

output_dir = "./codet5-combined-tuned"  # Adjust output directory
learning_rate = 1e-5  # Adjusted for small dataset
batch_size = 8      # Adjusted for small dataset
num_epochs = 20     # Set a higher number of epochs as early stopping will handle it
gradient_accumulation_steps = 2
weight_decay = 0.01

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    num_train_epochs=num_epochs,
    weight_decay=weight_decay,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_dir="./logs",
    fp16=True,
    push_to_hub=True,
    hub_model_id="your_huggingface_username/codet5-combined-tuned", # Update hub name
    load_best_model_at_end=True, # Optional: Load the best model based on the monitored metric
    metric_for_best_model="eval_loss", # Specify the metric to track for the best model
    greater_is_better=False, # Set to False if monitoring loss (lower is better)
)


In [None]:

# --- 4. Define the Trainer with Early Stopping Callback ---

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(patience=3, monitor="eval_loss", min_delta=0.001)],
)


In [None]:
# --- 5. Login to Hugging Face Hub ---
notebook_login()



In [None]:
# --- 6. Train the Model ---
print("Starting training with early stopping...")
trainer.train()
print("Training finished!")



In [None]:
# --- 7. Push the Model to Hugging Face Hub ---
print("Pushing model to Hugging Face Hub...")
trainer.push_to_hub()
print(f"Model pushed to https://huggingface.co/{training_args.hub_model_id}")

print("Fine-tuning complete.")