In [1]:
import numpy as np 
class LengthSampler:
    def __init__(self, min_value, max_value):
        self.values = list(range(min_value, max_value))
    def __call__(self):
        return np.random.choice(self.values)

In [5]:
import torch
config = {
    "txt_in_min_len": 2,
    "txt_in_max_len": 8,
    "txt_out_min_len": 4,
    "txt_out_max_len": 16,
    "model_name": "gpt2",
    "batch_size": 256,
    "steps": 20000,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
input_size = LengthSampler(config["txt_in_min_len"], config["txt_in_max_len"])
output_size = LengthSampler(config["txt_out_min_len"], config["txt_out_max_len"])
print(input_size())

2


In [7]:
# test the 
from trl import AutoModelForCausalLMWithValueHead
gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config['model_name'])

In [8]:
# save the model 

gpt2_model.save_pretrained(".result/")

In [154]:
# traing the trl for query expansion
from transformers import AutoTokenizer
from transformers import GPT2Tokenizer
import pandas as pd 
gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
gpt2_model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

df = pd.read_csv(f"/projects/futhark1/data/wzm289/code/RLSeq2SeqPytorch/scifact/train.source", sep="\t", names=["source"])

gpt2_tokenizer(df["source"][:10].to_list(), truncation=True,max_length=20, padding="max_length")


{'input_ids': [[6030, 352, 39569, 318, 3917, 351, 11800, 22146, 5945, 602, 287, 309, 842, 2478, 13, 50256, 50256, 50256, 50256, 50256], [818, 1007, 38516, 10693, 25451, 278, 4077, 781, 35414, 7532, 739, 262, 1630, 286, 262, 17689, 17, 32972, 11, 517], [38208, 312, 876, 7446, 2465, 36206, 3563, 2751, 45829, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256], [1404, 44, 290, 5325, 18, 3519, 7532, 389, 4688, 329, 34244, 7446, 2465, 13, 50256, 50256, 50256, 50256, 50256, 50256], [8021, 27289, 3513, 34721, 318, 517, 13205, 284, 8668, 3357, 621, 15964, 8027, 10906, 13, 50256, 50256, 50256, 50256, 50256], [45, 28978, 40, 23005, 2728, 6625, 284, 497, 37040, 499, 500, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256], [28780, 934, 29906, 389, 407, 33344, 286, 10255, 5958, 10576, 287, 37483, 513, 51, 18, 4778, 13, 50256, 50256, 50256], [12832, 1133, 450, 7592, 286, 29430, 1921, 5640, 6049, 3349, 23992, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256], 

In [155]:
text = "<|startoftext|>Source:Type 1 Diabetes is associated with subtle perturbations in T reg development.<|pad|>Target:Autoimmune diseases are thought to result from imbalances in normal immune physiology and regulation. Here, we show that autoimmune disease susceptibility and resistance alleles on mouse chromosome 3 (Idd3) correlate with differential expression of the key immunoregulatory cytokine interleukin-2 (IL-2). In order to test directly that an approximately twofold reduction in IL-2 underpins the Idd3-linked destabilization of immune homeostasis, we show that engineered haplodeficiency of Il2 gene expression not only reduces T cell IL-2 production by twofold but also mimics the autoimmune dysregulatory effects of the naturally occurring susceptibility alleles of Il2. Reduced IL-2 production achieved by either genetic mechanism correlates with reduced function of CD4+ CD25+ regulatory T cells, which are critical for maintaining immune homeostasis.<|endoftext|>"
gpt2_tokenizer(text, truncation=True, max_length=20, padding="max_length")


{'input_ids': [27, 91, 9688, 1659, 5239, 91, 29, 7416, 25, 6030, 352, 39569, 318, 3917, 351, 11800, 22146, 5945, 602, 287], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [162]:
# load dataset
from torch.utils.data import Dataset

import torch
class S2Sdataset(Dataset):
    def __init__(self, max_length=20, data_type="train"):
        self.input_ids = []
        self.attn_masks = []

        df_source = pd.read_csv(
            f"/projects/futhark1/data/wzm289/code/RLSeq2SeqPytorch/scifact/{data_type}.source", sep="\t", names=["source"]
        )
        df_target = pd.read_csv(
            f"/projects/futhark1/data/wzm289/code/RLSeq2SeqPytorch/scifact/{data_type}.target", sep="\t", names=["target"]
        )
        # print(df_source)
        self.df_train = "Source:" + df_source["source"]
        

        for item in self.df_train:
            # tokenize

            encodings_dict = gpt2_tokenizer(item, truncation=True, max_length=max_length, padding="max_length")

            self.input_ids.append(torch.tensor(encodings_dict["input_ids"]))
            self.attn_masks.append(torch.tensor(encodings_dict["attention_mask"]))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, index):
        return {"query":self.df_train [index], "input_ids":self.input_ids[index], "attn_masks":self.attn_masks[index]}

In [163]:
train_dataset = S2Sdataset(data_type = "train")
next(iter(train_dataset))

{'query': 'Source:Type 1 Diabetes is associated with subtle perturbations in T reg development.',
 'input_ids': tensor([ 7416,    25,  6030,   352, 39569,   318,  3917,   351, 11800, 22146,
          5945,   602,   287,   309,   842,  2478,    13, 50256, 50256, 50256]),
 'attn_masks': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0])}

In [164]:
from torch.utils.data import DataLoader
dataloader = DataLoader(
    train_dataset, batch_size=config["batch_size"]
)
next(iter(dataloader))

{'query': ['Source:Type 1 Diabetes is associated with subtle perturbations in T reg development.',
  'Source:In transgenic mice harboring green florescent protein under the control of the Sox2 promoter, more than 50 percent of the cells with green florescent colocalize with cell proliferation markers.',
  'Source:Oxidative DNA damage activates STING signalling.',
  'Source:ATM and Rad3 related protein are critical for sensing DNA damage.',
  'Source:Assessing treatment adherence is more beneficial to clinical practice than measuring routine outcomes.',
  'Source:N348I mutations cause resistance to nevirapine.',
  'Source:Cellular clocks are not predictive of mitosis timing in NIH 3T3 cells.',
  'Source:Acute ablation of KRAS causes severe growth impairment.',
  "Source:The World Health Organization's (WHO) data collection process is biased downward by unequal selection of larger outbreaks.",
  'Source:CD44v6 is not associated with constitutive and reprogrammed cancer stem cells driving

# training the model 

In [165]:
from trl.ppo import PPOTrainer
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config)

total_ppo_epochs = int(np.ceil(config["steps"] / config["batch_size"]))

In [166]:
gen_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": gpt2_tokenizer.eos_token_id,
}


In [167]:
generated = gpt2_tokenizer("hello world", return_tensors="pt").input_ids
gpt2_model.generate(generated, max_new_tokens=3, **gen_kwargs)

tensor([[31373,   995, 13428,    13,  2864]])

In [169]:
from tqdm import tqdm
import time
for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(dataloader))):
    logs, timing = dict(), dict()
    t0 = time.time()
    query_tensors = batch["input_ids"]
    #### Get response from gpt2
    t = time.time()
    response_tensors = []
    for i in range(config["batch_size"]):
        gen_len = 10
        response = gpt2_model.generate(
            query_tensors[i].unsqueeze(dim=0), max_new_tokens=gen_len, **gen_kwargs
        )
        response_tensors.append(response.squeeze()[-gen_len:])
    batch["response"] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors]
    print(batch)
    timing["time/get_response"] = time.time() - t

    #### Compute sentiment score
    t = time.time()
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    # pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    # rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device)
    rewards = torch.ones(1,len(batch))
    timing["time/get_sentiment_preds"] = time.time() - t

    #### Run PPO step
    t = time.time()
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    timing["time/optimization"] = time.time() - t

    #### Log everything
    timing["time/epoch"] = time.time() - t0
    table_rows = [
        list(r) for r in zip(batch["query"], batch["response"], rewards.cpu().tolist())
    ]
    # logs.update(
    #     {
    #         "game_log": wandb.Table(
    #             columns=["query", "response", "reward"], rows=table_rows
    #         )
    #     }
    # )
    logs.update(timing)
    logs.update(stats)
    logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy()
    print(torch.mean(rewards).cpu().numpy())
    logs["env/reward_std"] = torch.std(rewards).cpu().numpy()
    logs["env/reward_dist"] = rewards.cpu().numpy()
    # wandb.log(logs)

0it [00:00, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set