In [1]:
import os
import torch
import json
import gc
import sys

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

device = 'cuda'
model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
tokenizer = AutoTokenizer.from_pretrained(model_id)

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [2]:
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
with open('shots.jsonl', 'r') as infile:
    data = json.load(infile)

domains, commands, solutions = [], [], []

for i in range(len(data['shots'])):
    domains.append(data['shots'][i]['domain'])
    commands.append(data['shots'][i]['command'])
    solutions.append(data['shots'][i]['solution'])

In [8]:
generated_knowledge = 'instance tars robot|instance charging_station zone|instance unload_zone zone|instance shelf_1 zone|instance shelf_2 zone|instance shelf_3 zone|instance shelf_4 zone|predicate robot_availabletars|predicate is_recharge_zone charging_station|predicate is_unload_zone unload_zone|predicate is_shelf_zone shelf_1|predicate is_shelf_zone shelf_2|predicate is_shelf_zone shelf_3|predicate is_shelf_zone shelf_4|'
system_prompt = f'You are the robot tars, an automatic forklift that will move pallets around a warehouse. Your task is to outline the available instances, predicates, and goals based on the provided domain and command. Answer in the format shown after ### Output ###. This is the domain: {domains[0]}.'

messages = [{'role': 'user', 'content': system_prompt + f'At all times, these instances and predicates are true: {generated_knowledge}. You do not have to repeat them in your output.' + f' Here is a command {commands[0]}. ### Output ### {solutions[0]}.'},
            {'role': 'assistant', 'content': 'Understood. Awaiting new domain and command.'},
]

for i in range(1, len(data['shots'])):
    messages.append({'role': 'user', 'content': f'Command: {commands[i]}'})
    messages.append({'role': 'assistant', 'content': f'### Output ### {solutions[i]}'})

command = 'Move the three new pallets from the unload zone to shelf 4.'
messages.append({'role': 'user', 'content': f'Command: {command}'})

In [39]:
encodeds = tokenizer.apply_chat_template(messages, return_tensors='pt')
inputs = encodeds.to(device)
generated_ids = model.generate(inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)

In [40]:
start_token = '[/INST]'
start_tag_index = decoded[0].rfind(start_token)
decoded[0] = decoded[0][start_tag_index+len(start_token):]

end_token = '</s>'
end_tag_index = decoded[0].rfind(end_token)
decoded[0] = decoded[0][:end_tag_index]

delimiter = '|'
last_delimiter_index = decoded[0].rfind(delimiter)
decoded[0] = decoded[0][:last_delimiter_index+len(delimiter)]

output_token = '### Output ###'
output_tag_index = decoded[0].find(output_token)
decoded[0] = decoded[0][output_tag_index+len(output_token)+1:]