In [None]:
!pip install transformers
!pip install trl
!pip install peft
!pip install wandb

Collecting trl
  Downloading trl-0.9.6-py3-none-any.whl (245 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m245.8/245.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate (from trl)
  Downloading accelerate-0.32.1-py3-none-any.whl (314 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m314.1/314.1 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets (from trl)
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tyro>=0.5.11 (from trl)
  Downloading tyro-0.8.5-py3-none-any.whl (103 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.4/103.4 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.4.0->trl)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting n

In [None]:
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments, Trainer
from transformers import AutoModelForCausalLM

from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments

from trl import DPOTrainer, DPOConfig

In [None]:
# Mount gdrive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# adapted from https://github.com/huggingface/trl/blob/main/examples/dpo.py

def extract_codon_prompt(prompt_and_response):
    """Extract the start codon prompt from a prompt and response pair."""
    # search_term = "ATG"
    search_length = 15
    search_term = prompt_and_response[:search_length]
    search_term_idx = prompt_and_response.find(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}' in {prompt_and_response}"
    return prompt_and_response[: search_term_idx + len(search_term)]

from typing import Dict, Optional

def split_codon_prompt_and_responses(sample) -> Dict[str, str]:
    prompt = extract_codon_prompt(sample["chosen"])
    return {
        "prompt": prompt,
        "chosen": sample["chosen"][len(prompt) :],
        "rejected": sample["rejected"][len(prompt) :],
    }

In [None]:
# adapted from https://github.com/huggingface/trl/blob/ca0af3944d4ce53f0db8d48fdd7c4459c41fd437/trl/trainer/utils.py
# removed addition of EOS token id, causes an error

import os
import random
import warnings
from collections import deque
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback

@dataclass
class DPODataCollatorWithPadding:
    r"""
    DPO DataCollator class that pads the inputs to the maximum length of the batch.
    Args:
        tokenizer (`PreTrainedTokenizerBase`):
            The tokenizer used for encoding the data.
        model (Optional[`PreTrainedModel`]):
            The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
            prepare the *decoder_input_ids*.
        padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
            padding_strategy to pass to the tokenizer.
        max_length (`Optional[int]`, `optional`, defaults to `None`):
            The maximum length of the sequence to be processed.
        max_prompt_length (`Optional[int]`, `optional`, defaults to `None`):
            The maximum length of the prompt to be processed.
        label_pad_token_id (`int`, defaults to -100):
            The label used for masking.
        padding_value (`int`, defaults to 0):
            The value used for padding.
        is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
            Whether or not you model has an encoder_decoder architecture.
        max_target_length (`Optional[int]`, `optional`, defaults to `None`):
            The maximum length of the target to be processed. Only useful for encoder-decoder architectures.
        truncation_mode: (`str`, defaults to "keep_end"):
            The truncation mode to use when truncating the prompt.
    """
    tokenizer: PreTrainedTokenizerBase
    model: Optional[PreTrainedModel] = None
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_prompt_length: Optional[int] = None
    label_pad_token_id: int = -100
    padding_value: int = 0
    truncation_mode: str = "keep_end"
    is_encoder_decoder: Optional[bool] = False
    max_target_length: Optional[int] = None

    def tokenize_batch_element(
        self,
        prompt: str,
        chosen: str,
        rejected: str,
    ) -> Dict:
        """Tokenize a single batch element.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
            in case the prompt + chosen or prompt + rejected responses is/are too long. First
            we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected response, with
            label_pad_token_id  for the prompt tokens.
        """
        batch = {}

        if not self.is_encoder_decoder:
            chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
            rejected_tokens = self.tokenizer(rejected, add_special_tokens=False)
            prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

            eos_token_id = self.tokenizer.eos_token_id
            # print(f"eos_token_id: {eos_token_id}")
            # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0)
            eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id]
            # attention mask these indices to eos_token_id
            new_attention_mask = [
                0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"])
            ]
            prompt_tokens["attention_mask"] = new_attention_mask

            # do the same for chosen and rejected
            eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id]
            new_attention_mask_c = [
                0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"])
            ]
            chosen_tokens["attention_mask"] = new_attention_mask_c

            eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id]
            new_attention_mask_r = [
                0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"])
            ]
            rejected_tokens["attention_mask"] = new_attention_mask_r

            # # add EOS token to end of prompt
            # chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
            # chosen_tokens["attention_mask"].append(1)

            # rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
            # rejected_tokens["attention_mask"].append(1)

            longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
            # print(f"longer_response_length: {longer_response_length}")

            # if combined sequence is too long, truncate the prompt
            # print(f'prompt_tokens_input_ids: {prompt_tokens["input_ids"]}')
            # print(f'chosen_tokens_input_ids: {chosen_tokens["input_ids"]}')
            if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
                if self.truncation_mode == "keep_start":
                    prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
                elif self.truncation_mode == "keep_end":
                    prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
                else:
                    raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

            # if that's still too long, truncate the response
            if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
                chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
                rejected_tokens = {
                    k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()
                }

            # Create labels
            chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
            rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
            chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
            chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
                prompt_tokens["input_ids"]
            )
            rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
            rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
                prompt_tokens["input_ids"]
            )

            for k, toks in {
                "chosen": chosen_sequence_tokens,
                "rejected": rejected_sequence_tokens,
                "prompt": prompt_tokens,
            }.items():
                for type_key, tokens in toks.items():
                    if type_key == "token_type_ids":
                        continue
                    batch[f"{k}_{type_key}"] = tokens

        else:
            chosen_tokens = self.tokenizer(
                chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
            )
            rejected_tokens = self.tokenizer(
                rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
            )
            prompt_tokens = self.tokenizer(
                prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
            )

            batch["chosen_labels"] = chosen_tokens["input_ids"]
            batch["rejected_labels"] = rejected_tokens["input_ids"]
            batch["prompt_input_ids"] = prompt_tokens["input_ids"]
            batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

            if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
                batch["rejected_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels(
                    labels=batch["rejected_labels"]
                )
                batch["chosen_decoder_input_ids"] = self.model.prepare_decoder_input_ids_from_labels(
                    labels=batch["chosen_labels"]
                )

        batch["prompt"] = prompt
        batch["chosen"] = prompt + chosen
        batch["rejected"] = prompt + rejected
        batch["chosen_response_only"] = chosen
        batch["rejected_response_only"] = rejected

        return batch

    def collate(self, batch):
        # first, pad everything to the same length
        padded_batch = {}
        for k in batch[0].keys():
            if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
                # print(f"k: {k}")
                if self.is_encoder_decoder:
                    to_pad = [torch.LongTensor(ex[k]) for ex in batch]

                    if (k.startswith("prompt")) and (k.endswith("input_ids")):
                        padding_value = self.tokenizer.pad_token_id
                    elif k.endswith("_attention_mask"):
                        padding_value = 0
                    elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
                        padding_value = self.label_pad_token_id
                    else:
                        raise ValueError(f"Unexpected key in batch '{k}'")
                    padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
                else:
                    # adapted from https://stackoverflow.com/questions/73256206
                    if "prompt" in k:
                        to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
                    else:
                        # print(f"k: {k}")
                        # for ex in batch:
                            # print(f"ex[k]: {ex[k]}")
                        to_pad = [torch.LongTensor(ex[k]) for ex in batch]
                    if k.endswith("_input_ids"):
                        padding_value = self.tokenizer.pad_token_id
                    elif k.endswith("_labels"):
                        padding_value = self.label_pad_token_id
                    elif k.endswith("_attention_mask"):
                        padding_value = self.padding_value
                    else:
                        raise ValueError(f"Unexpected key in batch '{k}'")

                    padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
                    # for the prompt, flip back so padding is on left side
                    if "prompt" in k:
                        padded_batch[k] = padded_batch[k].flip(dims=[1])
            else:
                padded_batch[k] = [ex[k] for ex in batch]

        return padded_batch

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        tokenized_batch = []

        for feature in features:
            prompt = feature["prompt"]
            chosen = feature["chosen"]
            rejected = feature["rejected"]

            batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
            tokenized_batch.append(batch_element)

        # return collated batch
        return self.collate(tokenized_batch)

In [None]:
# Utility functions for sequences

# Compute GC content
def batch_gc_values(codon_seqs):
    gc_values = []
    gc_counts = []
    for seq in codon_seqs:
      gc_val = gc_score(seq)
      gc_cnt = gc_count(seq)
      gc_values.append(gc_val)
      gc_counts.append(gc_cnt)
    return gc_values, gc_counts

# GC content value
def gc_score(sequence):
  gc = dict()
  for n in sequence:
    if n in gc:
      gc[n] += 1
    else:
      gc[n] = 1

  gc_value = float(gc['G'] + gc['C'])/len(sequence)*100
  return gc_value

# Count of G and C
def gc_count(sequence):
  gc = dict()
  for n in sequence:
    if n in gc:
      gc[n] += 1
    else:
      gc[n] = 1

  gc_cnt = float(gc['G'] + gc['C'])
  return gc_cnt

def codon_tokenizable_seqs(seqs):
    """ Convert seqs to codon tokenizable form """
    codon_seqs = []
    for s in seqs:
        codons = []
        for i in range(0,len(s),3):
            codons.append(s[i:i+3])
        codon_seqs.append(" ".join(codons))

    return codon_seqs

def read_fasta(fasta_path):
    """ Read sequences from specified fasta file """
    from Bio import SeqIO
    fasta_seqs = []
    for record in SeqIO.parse(fasta_path, "fasta"):
        fasta_seqs.append(str(record.seq).upper())

    return fasta_seqs

def read_txt(filepath):
    """ Read sequences from specified txt file """
    import pandas as pd
    data = pd.read_csv(filepath, sep='\t')
    print(f'Seqs data:\n {data.head()}')

    return data

In [None]:
from datetime import datetime
import yaml

import argparse
from pathlib import Path

In [None]:
def load_sequences(seqs_path):
    """ Load sequences from the specified path """
    import json
    with open(seqs_path, 'r') as f:
        data = json.load(f)

    sequences = data['mdh_given_natural_sequences']
    print(f"Number of sequences: {len(sequences)}")
    print(sequences[0])

    return sequences

def filter_sequences(sequences):
    """ Filter sequences with constraints """
    # 1. Check if "ATG" is not in start seq
    rem = []
    for s in sequences:
        if s[:3] != 'ATG':
            rem.append(s)
    print(len(rem))

    # 2. Remove seqs that don't have ATG as start seq
    filt_seqs = list(set(sequences) - set(rem))
    print(len(filt_seqs))

    # 3. Check if "ATG" is not in start seq; there should be 0 at this point
    rem = []
    for s in filt_seqs:
        if s[:3] != 'ATG':
            rem.append(s)
    print(len(rem))

    # 4. Check if len of seq is divisible by 3
    rem_l = []
    for s in filt_seqs:
        if len(s) % 3 != 0:
            rem_l.append(s)
    print(len(rem_l))

    # 5. Remove seqs with len not divisible by 3
    filt_l_seqs = list(set(filt_seqs) - set(rem_l))
    print(len(filt_l_seqs))

    # 6. Check if "ATG" is not start seq or len of seq is not divisible by 3; should be 0 at this point
    check = []
    for s in filt_l_seqs:
        if (s[:3] != 'ATG') or (len(s)%3 != 0):
            check.append(s)
    print(len(check))

    return filt_l_seqs

def split_chosen_rejected(filtered_seqs, run_params):
    """ Split sequences to 'chosen' and 'rejected' """

    # Split mdh seqs to 'chosen' and 'rejected' based on GC content
    print(f'Chosen gc min.: {run_params["chosen_range_min"]}')
    print(f'Chosen gc max.: {run_params["chosen_range_max"]}')
    chosen = []
    rejected = []
    for s in filtered_seqs:
        if gc_score(s) > run_params["chosen_range_min"] and gc_score(s) < run_params["chosen_range_max"]:
            chosen.append(s)
        else:
            rejected.append(s)

    print(f"Number of chosen, rejected: {len(chosen)}, {len(rejected)}")

    num_samples = min(len(chosen), len(rejected))
    print(f'Min. number of sampels: {num_samples}')

    return chosen[:num_samples], rejected[:num_samples]

def train_eval_split(seqs, test_size=0.05):
    """ Split given list of sequences to train and eval sets """
    import numpy as np
    from sklearn.model_selection import train_test_split
    seqs_train, seqs_eval = train_test_split(np.array(seqs),
                                        random_state=104,
                                        test_size=test_size,
                                        shuffle=True)

    return seqs_train, seqs_eval

def prepare_dataset(seqs_path, run_params):
    # Prepare train, eval datasets

    # Load training sequences
    sequences = load_sequences(seqs_path)

    # Process sequences
    filtered_sequences = filter_sequences(sequences)

    # Split seqs to 'chosen' and 'rejected'
    chosen, rejected = split_chosen_rejected(filtered_sequences, run_params)

    # Convert chosen and rejected sets to codon tokenizable form (spaced as codons)
    chosen_codon = codon_tokenizable_seqs(chosen)
    rejected_codon = codon_tokenizable_seqs(rejected)

    # Sanity check
    _, gc_counts_chosen = batch_gc_values(chosen)
    _, gc_counts_chosen_codon = batch_gc_values(chosen_codon)
    _, gc_counts_rej = batch_gc_values(rejected)
    _, gc_counts_rej_codon = batch_gc_values(rejected_codon)
    assert gc_counts_chosen == gc_counts_chosen_codon, f"GC counts do not match 'chosen' set"
    assert gc_counts_rej == gc_counts_rej_codon, f"GC counts do not match 'rejected' set"

    # Split chosen and rejected to train and eval set with given test_size
    chosen_train, chosen_eval = train_eval_split(chosen_codon, test_size=0.05)
    rejected_train, rejected_eval = train_eval_split(rejected_codon, test_size=0.05)

    print(f"Number of chosen train, eval, (total): {len(chosen_train)}, {len(chosen_eval)}, ({len(chosen_train)+len(chosen_eval)})")
    print(f"Number of rejected train, eval, (total): {len(rejected_train)}, {len(rejected_eval)}, ({len(rejected_train)+len(rejected_eval)})")

    # Prepare train dataset
    train_dict = dict()
    train_dict['chosen'] = list(chosen_train)
    train_dict['rejected'] = list(rejected_train[:len(chosen_train)]) # require same number of entries in each category
    train_ds = Dataset.from_dict(train_dict)
    print(f"Train dataset: {train_ds}")

    # Prepare eval dataset
    eval_dict = dict()
    eval_dict['chosen'] = list(chosen_eval)
    eval_dict['rejected'] = list(rejected_eval[:len(chosen_eval)]) # require same number of entries in each category
    eval_ds = Dataset.from_dict(eval_dict)
    print(f"Eval dataset: {eval_ds}")

    # Add prompt feature
    train_dataset = train_ds.map(split_codon_prompt_and_responses)
    eval_dataset = eval_ds.map(split_codon_prompt_and_responses)

    return train_dataset, eval_dataset

def get_data_collator(tokenizer):
    """ Get data collator """
    data_collator = DPODataCollatorWithPadding(
        tokenizer,
        max_length=512,
        max_prompt_length=128,
        label_pad_token_id=-100,
        padding_value=0,
        truncation_mode='keep_start',
        is_encoder_decoder=None,
        max_target_length=128,
    )

    return data_collator

def load_tokenizer(tokenizer_file):
    """ Load pretrained tokenizer from given path """
    from transformers import AutoTokenizer
    from transformers import PreTrainedTokenizerFast
    from tokenizers import Tokenizer
    tokenizer = PreTrainedTokenizerFast(
                tokenizer_object=Tokenizer.from_file(str(tokenizer_file))
            )
    # tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    return tokenizer

def load_model(saved_model_path):
    """ Load pretrained model from the model path """
    from transformers import GPTNeoXForCausalLM, AutoModelForCausalLM

    # Load model
        # config_file_path = ''
        # config_mdh = AutoConfig.from_pretrained(config_file_path, max_position_embeddings=512)
    model = AutoModelForCausalLM.from_pretrained(saved_model_path)
    model_ref = AutoModelForCausalLM.from_pretrained(saved_model_path)

    return model, model_ref

def print_run_params(run_params):
    """ Print run params """
    print(f'beta: {run_params["beta"]}')
    print(f'learning rate: {run_params["lr"]}')
    print(f'training steps: {run_params["steps"]}')
    print(f'batch size: {run_params["batch_size"]}')
    print(f'chosen property: {run_params["chosen_property"]}')
    if run_params["chosen_property"] == 'gc':
        print(f'gc range: [{run_params["chosen_range_min"]}, {run_params["chosen_range_max"]}]')

In [None]:
def perform_training(run_params):

    seqs_path = '/content/drive/MyDrive/mdh_data/mdh_given_natural_sequences.json'
    tokenizer_file = '/content/drive/MyDrive/codon_wordlevel_100vocab.json'

    saved_model_path = run_params['ref_model']
    # saved_model_path = '/lus/eagle/projects/RL-fold/gdharuman/sc23/test_models/mdh-r_checkpoints/model_gptneox_25M_mdh-r_epochs-10.pt'

    out_dir = '/content/runs'
    run_tag = f'{run_params["chosen_property"]}_{int(run_params["chosen_range_min"])}-{int(run_params["chosen_range_max"])}'

    # Prepare dataset
    train_dataset, eval_dataset = prepare_dataset(seqs_path, run_params)
    print(f"final train dataset: {train_dataset}")
    print(f"final eval dataset: {eval_dataset}")
    print(f"train_dataset[0]: {train_dataset[0]}")
    print(f"eval_dataset[0]: {eval_dataset[0]}")

    # Load tokenizer
    tokenizer = load_tokenizer(tokenizer_file)
    print(f"[PAD] token_id: {tokenizer.pad_token_id}")

    # Get data collator and sanity check
    data_collator = get_data_collator(tokenizer)
    _ = data_collator(eval_dataset)

    # Get device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    # Load fine-tuned model
    model, model_ref = load_model(saved_model_path)
    print(f"model.device, model_ref.device: {model.device}, {model_ref.device}")

    # Get date time
    now = datetime.now()
    dt_string = now.strftime("%d%m%y-%H%M%S")

    # Create run directory
    if run_params["epochs"] is not None:
        run_dir = out_dir + f'/beta-{run_params["beta"]}_lr-{run_params["lr"]}_bs-{run_params["batch_size"]}_steps-{run_params["epochs"]}_{run_tag}_{dt_string}'
    else:
        run_dir = out_dir + f'/beta-{run_params["beta"]}_lr-{run_params["lr"]}_bs-{run_params["batch_size"]}_steps-{run_params["steps"]}_{run_tag}_{dt_string}'
    Path(run_dir).mkdir(parents=True, exist_ok=True)

    # Update and write run params to yaml file
    run_params["run_tag"] = dt_string # add datetime to run_params
    with open(run_dir+'/run_config.yml', 'w') as yaml_file:
        yaml.dump(run_params, yaml_file, default_flow_style=False)

    # Initialize training arguments
    # training_args = TrainingArguments(
    #     per_device_train_batch_size=run_params["batch_size"],
    #     # max_steps=run_params["steps"],
    #     num_train_epochs=run_params["epochs"],
    #     remove_unused_columns=False,
    #     gradient_accumulation_steps=1,
    #     learning_rate=run_params["lr"],
    #     evaluation_strategy="steps",
    #     output_dir=run_dir,
    #     report_to="wandb",
    # )

    training_args = DPOConfig(
    output_dir='/content/runs',
    beta=0.1,
    per_device_train_batch_size=run_params["batch_size"],
    num_train_epochs=run_params["epochs"],
    remove_unused_columns=False,
    gradient_accumulation_steps=1,
    learning_rate=run_params["lr"],
    evaluation_strategy="steps",
    report_to="wandb",
    )

    # Initialize the DPO trainer
    dpo_trainer = DPOTrainer(
        model,
        model_ref,
        args=training_args,
        data_collator=data_collator,
        # beta=run_params["beta"],
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_length=512,
        # max_target_length=128,
        max_prompt_length=128,
    )

    # Train
    dpo_trainer.train()

In [None]:
run_params = dict()
run_params["beta"] = 0.1
run_params["lr"] = 5e-5
run_params["steps"] = 100
run_params["epochs"] = 10
run_params["chosen_property"] = 'gc'
run_params["chosen_range_min"] = 45.0
run_params["chosen_range_max"] = 55.0
run_params["batch_size"] = 8
run_params["ref_model"] = '/content/drive/MyDrive/mdh_data/model_mdh_last.pt'

# Print run params
print_run_params(run_params)

perform_training(run_params)

beta: 0.1
learning rate: 5e-05
training steps: 100
batch size: 8
chosen property: gc
gc range: [45.0, 55.0]
Number of sequences: 36631
ATGAGCAAGGTCACCGTCGTAGGGGCCGGCAAGTACGGATCAACCACCGCCATGCGGCTCGCTGAAGCCGACATCGTCGACGAGGTCGTCATGACCGACATTGTCGAGGGTCTACCCCAGGGCCTGGCGCTCGATATCAATCAGTCGCGGCCGCTGCTCGGCTACCGGACCGTCATCACCGGTTCGAACGACTACGCCGCCACTGCCGGCAGCGATGTCGTGGTCATCACCGCCGGGTTGCCGCGCAAGCCAGGCATGAGTCGCATGGACCTCCTCGAGGTCAATGCCAAGATCGTCAAGGACGTCACCGTCCAGATAGCCGAGCATTCGCCGGACGCGGTCATCATCAACGTCACGAACCCGCTCGATCACATGACAACCCTGGCGGCGGAAGTATCCGGCTTCGACACCCGCCGGGTCATGGGACAGGCAGGCATGCTCGACTCCGCTCGTTTCGCTCATTTCATCGCAGAAGTGACCGGTGCCGACATCATGGACGTCGAGGCTCTTACCCTCGGCAGCCACGGAGAGACCATGGTCCCGGTTCCGTCACAAACCAAGGTGGGGGGCAAACTCCTCGCCGATCTCGTCGATGCCGACGCCGTCGAGTCGCTCGTCGACCGGACCCGCAAGGGTGGGGCCGAGGTTGTTGCGCTCCTCAAGACCGGCAGCGCCTATTACGCCCCCTCGGCGGCTGCCGCCAAGATGGTCGAAGCCGTCATTGGAGATACCGGCGAGGTGATGCCGGTATGTGCCTGGATGAGTGGCGAGTACGGGATCTCCGACGTATACCTCGGTGTTCCAGCAAGTCTCGGCAAAGAGGGCGTGAAGGAGATCGTCGAACTCCCGCTCACCGACACTGAGG

Map:   0%|          | 0/7763 [00:00<?, ? examples/s]

Map:   0%|          | 0/409 [00:00<?, ? examples/s]

final train dataset: Dataset({
    features: ['chosen', 'rejected', 'prompt'],
    num_rows: 7763
})
final eval dataset: Dataset({
    features: ['chosen', 'rejected', 'prompt'],
    num_rows: 409
})
train_dataset[0]: {'chosen': ' GTC CTC GGC GCT GCT GGT GGT ATT GGC CAG GCG CTT GCC CTA CTA CTG AAA ACC CAA CTG CCT TCA GGC TCA GAA CTC TCC CTG TAC GAT ATT GCT CCG GTA ACC CCA GGT GTG GCG GTT GAC CTG AGC CAC ATC CCA ACC GCG GTG AAA ATT AAA GGC TTC TCT GGC GAA GAT GCA CGT CCA GCG CTG CAA GGT GCT GAC GTG GTG CTC ATC TCT GCG GGC GTC GCA CGT AAG CCG GGT ATG GAT CGT TCT GAC CTG TTT AAC GTC AAC GCT GGC ATC GTC AAA AAT CTG GTT CAA CAG ATT GCT GAA ACC TGC CCG AAA GCG TGC GTG GGT ATC ATC ACC AAC CCG GTG AAT ACC ACG GTG GCC ATT GCG GCA GAA GTG CTG AAA AAA GCC GGT GTT TAC GAT AAG AAC AAG CTG TTT GGC GTG ACC ACG CTG GAT ATT ATC CGC TCC AAT ACC TTT GTT GCT GAA CTG AAA GGC AAG TCA CCT GCT GAG ATC GAG GTT CCG GTT ATC GGA GGC CAC TCA GGC GTG ACC ATT CTG CCT CTG CTG TCT CAG ATC CCA GGC GTT AGC TTC TCC GAG C


Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.


Map:   0%|          | 0/7763 [00:00<?, ? examples/s]

Map:   0%|          | 0/409 [00:00<?, ? examples/s]



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen
500,0.0848,0.00397,-2.528712,-13.351259,1.0,10.82255,-466.884155,-353.844177,-0.907256,-0.91263
1000,0.012,0.002662,-3.434631,-17.82711,1.0,14.392477,-511.642639,-362.903351,-0.909491,-0.928923
1500,0.0002,0.004292,-3.70154,-19.475037,0.997596,15.773498,-528.121887,-365.572449,-0.886685,-0.906602
2000,0.0001,0.004476,-3.85741,-20.309338,0.997596,16.451931,-536.464905,-367.131165,-0.875277,-0.897353
2500,0.0,0.004403,-3.95447,-20.608116,0.997596,16.653645,-539.452698,-368.101746,-0.871456,-0.894028
3000,0.0,0.004544,-4.075648,-20.946306,0.997596,16.870657,-542.834595,-369.313538,-0.868154,-0.89126
3500,0.0,0.004858,-4.193683,-21.392702,0.997596,17.199018,-547.298523,-370.493866,-0.863146,-0.886994
4000,0.0,0.004958,-4.317677,-21.798922,0.997596,17.481245,-551.360779,-371.733826,-0.858676,-0.883212
4500,0.0,0.005298,-4.441584,-22.245014,0.997596,17.803431,-555.821716,-372.97287,-0.853492,-0.878771
5000,0.0,0.005376,-4.553698,-22.558617,0.997596,18.004919,-558.957703,-374.094055,-0.850195,-0.875942


