In [1]:
import argparse
import os

import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import IterableDataset, Dataset
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    logging,
    set_seed,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CodeDataset(Dataset):
    def __init__(self, data_dir, tokenizer, prompt):
        self.data_dir = data_dir
        self.example_folders = os.listdir(self.data_dir)

        self.data = []

        # Number of prompts for in each folder
        prompt_lengths = []
        self.total_length = 0

        for d in self.example_folders:
            prompts_dir = os.path.join(self.data_dir, d, "prompts.txt")

            with open(prompts_dir, 'r') as f:
                prompts = f.read()

            prompt_length = len(prompts.split("\n"))
            prompt_lengths.append(prompt_length)

            self.total_length += prompt_length

        self.prompt_indexes = [0]
        for i in prompt_lengths[:-1]:
            self.prompt_indexes.append(self.prompt_indexes[-1] + i)

        self.DSL = prompt
        self.tokenizer = tokenizer

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        folder_idx = 0
        for idx, val in enumerate(self.prompt_indexes):
            if idx >= val:
                folder_idx = idx

        prompt_idx = idx - self.prompt_indexes[folder_idx]

        code_dir = os.path.join(
            self.data_dir, self.example_folders[folder_idx], "code.txt")
        prompts_dir = os.path.join(
            self.data_dir, self.example_folders[folder_idx], "prompts.txt")

        with open(code_dir, 'r') as f:
            code = f.read()

        with open(prompts_dir, 'r') as f:
            prompts = f.read().split("\n")

        code = "\n\t".join(code.split("\n"))

        final_prompt = self.DSL + prompts[prompt_idx] + "\n\t" + code
        input_ids = self.tokenizer(
            final_prompt, padding="max_length", max_length=256)["input_ids"]
        return {"input_ids": input_ids, "labels": input_ids}

In [3]:
device = "cuda"

In [4]:
loaded_model = AutoModelForCausalLM.from_pretrained(
        "./checkpoints/checkpoint-300/",
        trust_remote_code=True
    ).to(device)

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


In [5]:
tokenizer = AutoTokenizer.from_pretrained(
        "bigcode/santacoder", use_auth_token=True)

tokenizer.pad_token = tokenizer.eos_token_id
tokenizer.padding_side = "left"

In [6]:
with open("DSL.txt", 'r') as f:
    prompt = f.read()

valid_dataset = CodeDataset("./test", tokenizer, prompt)

In [11]:
with open("example_input.txt", 'r') as f:
    example_input = f.read()

print(example_input)

def go_to(location : str)
def find(object : str)
def pick_up(object : str)
def put_down(object : str)
def find(object : str)
def ask(person : str, question : str, options: Optional[List[str]])
def say(message : str)
'''
def main():
    # Using the functions defined above, write a script to do the following: Go to the kitchen. Then go to the living room. Repeat this 4 times.


In [12]:
inputs = tokenizer.encode(example_input, return_tensors="pt", padding="max_length", max_length=256).to(device)

In [13]:
outputs = loaded_model.generate(inputs, max_new_tokens=128)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [14]:
print(tokenizer.decode(outputs[0]))

<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|