# Run the tuned models

## Set up Python environment
The following libraries are used for this method (`requirements.txt` file):

```
torch
accelerate @ git+https://github.com/huggingface/accelerate.git
bitsandbytes
datasets==2.13.1
transformers @ git+https://github.com/huggingface/transformers.git
peft @ git+https://github.com/huggingface/peft.git
trl @ git+https://github.com/lvwerra/trl.git
scipy
```

Then install and import the installed libraries

In [1]:
import bitsandbytes as bnb
from datasets import load_dataset
from functools import partial
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, AutoPeftModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, BitsAndBytesConfig, DataCollatorForLanguageModeling, Trainer, TrainingArguments
import networkx as nx

## Load the tuned model into memory

The fine-tuned models require the base model as well as the adaptors.

In [2]:
import os

base_dir = '/home/manish/thesis-implementations/quest_generation/llama2/'
model_name = 'meta-llama/llama-2-13b-hf'
model_dir = 'models'
model_path = os.path.join(base_dir, model_dir, model_name)

## Model type and KG depth

To try varying KG depths and model types, we can change the following variables:

In [None]:
# TRAIN_TYPE = 'no_kg'
TRAIN_TYPE = 'text_kg'
# TRAIN_TYPE = 'tree_kg'
KG_DEPTH = 2

In [None]:
output_dir = os.path.join(base_dir, model_dir, 'results', model_name, f'{TRAIN_TYPE}')
print(output_dir, os.path.exists(output_dir))

/home/manish/thesis-implementations/quest_generation/llama2/models/results/meta-llama/llama-2-13b-hf/text_kg True


In [None]:
print(model_path, os.path.exists(model_path))

/home/manish/thesis-implementations/quest_generation/llama2/models/meta-llama/llama-2-13b-hf True


## Create a bitsandbytes configuration and load the model and tokenizer
This will allow us to load our LLM in 4 bits. This way, we can divide the used memory by 4 and import the model on smaller devices. We choose to apply bfloat16 compute data type and nested quantization for memory-saving purposes.

In [None]:
def create_bnb_config():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    return bnb_config

In [3]:
def load_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # Needed for LLaMA tokenizer
    tokenizer.pad_token = tokenizer.eos_token
    
    return tokenizer

In [None]:
def load_pretrained_model(model_path, bnb_config, base_model_path):
    n_gpus = torch.cuda.device_count()
    max_memory = f'{12288}MB'

    model = AutoPeftModelForCausalLM.from_pretrained(
        model_path,
        quantization_config=bnb_config,
        device_map="auto",
        max_memory = {i: max_memory for i in range(n_gpus)},
    )
    
    return model, load_tokenizer(base_model_path)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
bnb_config = create_bnb_config()
model, tokenizer = load_pretrained_model(output_dir, bnb_config, model_path)

## Quest Dataset

Load the qust dataset for training and create prompts accordingly

In [25]:
import json
import numpy as np
import os
import random

dataset_path = os.path.join(base_dir, 'data')
train_file = 'train.jsonl'
val_file = 'val.jsonl'
data_files = {
	"train": train_file, 
	"val": val_file
}

In [11]:
EOS_TOKEN = tokenizer.eos_token
PAD_TOKEN = tokenizer.pad_token
BOS_TOKEN = tokenizer.bos_token

## Load the full game KG

In [12]:
map_game = {
    'TESO': 'TESOblivion_KG.gml',
    'TESS': 'TESSkyrim_KG.gml',
    'TL2': 'Torchlight2_KG.gml',
    'MC': 'Minecraft_KG.gml',
    'BG1': 'BaldursGate1_KG.gml',
    'BG2': 'BaldursGate2_KG.gml'
}

kg_map = {}

In [13]:
kg_base_dir = '/home/manish/thesis-implementations/data/VartinenFormatted/KGs'

for gid, gname in map_game.items():
    kg_path = os.path.join(kg_base_dir, map_game[gid])
    kg = nx.read_gml(kg_path)
    kg_map[gid] = kg

In [31]:
def create_generation_prompt_formats(input):
    """
    Format various fields of the input quest data ('plots', 'kb', 'quest')
    Then concatenate them using two newline characters
    :param input: input dictionary
    """

    BACKGROUND = "### Background:"
    PLOTS_KEY = "### Plots:"
    INTRO_BLURB = "The quest related to the above information is as follows:"
    QUEST = "### Quest:"
    END_KEY = "### End"
    

    blurb = f"{INTRO_BLURB}"  # add intro blurb - model system instruction

    # add background - only if knowledge graph as text
    # create sentence tree from plots using kg - only if knowledge graph as tree
    background = ''
    plots_str = '\n'.join(input['plots'])
    plots = f"{PLOTS_KEY}\n{plots_str}"  # add plots - key plot points
    
    if TRAIN_TYPE == 'text_kg':
        completed_rels = []
        completed_nodes = []
        
        for kb in input['kbs']:
            entity = kb['name']
            e_desc = kb['description']
            e_type = kb['type']
            e_relations = kb['relations']
            
            background += f'{entity} is a {e_type}. '
            
            # add depth of information from KG
            # KG_DEPTH <= 0 will add all information
            # KG_DEPTH = 1 will only add description
            # KG_DEPTH > 1 will randomly add (KG_DEPTH - 1) number
            # of relations and the description
            if KG_DEPTH > 0:
                if entity != e_desc:
                    background+= f'{entity} is {e_desc}. '
                depth = KG_DEPTH - 1
                if depth > len(e_relations):
                    depth = len(e_relations)
                relations = random.sample(e_relations, depth)
            else:
                relations = e_relations
                
            for rel in relations:
                background += f'{entity} is {rel[0]} {rel[1]}. '
                completed_rels.append((entity, rel[1]))
            completed_nodes.append(entity)                     
            background += '\n'
        
        if KG_DEPTH > 0:
            kg = kg_map[input['game']]        
            all_nodes = kg.nodes(data=True)
            for node in all_nodes:
                entity = node[0]
                if entity.lower() in plots.lower():
                    edges = list(nx.dfs_edges(kg, source=entity, depth_limit=KG_DEPTH))
                    for ent1, ent2 in edges:
                        if (ent1, ent2) in completed_rels or (ent2, ent1) in completed_rels:
                            continue
                        e1_type = all_nodes[ent1]['type']
                        e1_desc = all_nodes[ent1]['description']
                        e2_type = all_nodes[ent2]['type']
                        e2_desc = all_nodes[ent2]['description']
                        
                        if ent1 not in completed_nodes:
                            background += f'{ent1} is a {e1_type}. '
                            if e1_desc != ent1:
                                background += f'{ent1} is {e1_desc}. '
                            completed_nodes.append(ent1)
                            background += '\n'
                        
                        if ent2 not in completed_nodes:
                            background += f'{ent2} is a {e2_type}. '
                            if e2_desc != ent2:
                                background += f'{ent2} is {e2_desc}. '
                            completed_nodes.append(ent2)
                            background += '\n'
                        
                        rel = kg[ent1][ent2]['label']
                        if rel == 'connected to':
                            background += f'{ent1} is {rel} {ent2}. '
                        if rel == 'present in':
                            if e1_type == 'location':
                                background += f'{ent2} is {rel} {ent2}. '
                            else:
                                background += f'{ent1} is {rel} {ent2}. '
                        if rel == 'held by':
                            if e1_type == 'character':
                                background += f'{ent2} is {rel} {ent1}. '
                            else:
                                background += f'{ent1} is {rel} {ent2}. '
                        background += '\n'
                        completed_rels.append((ent1, ent2))
                
        background = f"{BACKGROUND}\n{background}"
        
    # add concatenated quest text
    quest_str = ''
    for k,v in input['quest'].items():
        if k == 'description':
            continue
        if k == 'tasks':
            value = '\n ' + '\n '.join(np.char.capitalize(v[:-1]))
        else:
            value = v.capitalize()
        quest_str += f'{k.capitalize()}: {value}\n' 
    quest = f"{QUEST}\n{quest_str}"  # add quest output
    
    end = f"{END_KEY}"  # add end key
    
    if TRAIN_TYPE in ['no_kg', 'tree_kg']:
        parts_p = [part for part in [plots, blurb] if part]
    else:
        parts_p = [part for part in [background, plots, blurb] if part]

    parts_o = [part for part in [quest, end] if part]
    
    formatted_prompt = "\n\n".join(parts_p)
    formatted_output = "\n\n".join(parts_o)
    input['text'] = formatted_prompt
    input['output'] = formatted_output

    return input

In [18]:
# SOURCE https://github.com/databrickslabs/dolly/blob/master/training/trainer.py
def get_max_length(model):
    max_length = None
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_length = getattr(model.config, length_setting, None)
        if max_length:
            print(f"Found max lenth: {max_length}")
            break
    if not max_length:
        max_length = 1024
        print(f"Using default max length: {max_length}")
    return max_length


def preprocess_batch(batch, tokenizer, max_length):
    """
    Tokenizing a batch
    """
    return tokenizer(
        batch["text"],
        max_length=max_length,
        truncation=True,
    )

In [15]:
# SOURCE https://github.com/databrickslabs/dolly/blob/master/training/trainer.py
def preprocess_val_dataset(tokenizer: AutoTokenizer, max_length: int, dataset: str, include_kg: bool = True):
    """Format & tokenize it so it is ready for training
    :param tokenizer (AutoTokenizer): Model Tokenizer
    :param max_length (int): Maximum number of tokens to emit from tokenizer
    :param include_kg (bool): Whether to include knowledge graph in the prompt
    """
    
    # Add prompt to each sample
    print("Preprocessing dataset...")
    dataset = dataset.map(create_generation_prompt_formats)#, batched=True)
    
    # Apply preprocessing to each batch of the dataset & and remove 'instruction', 'context', 'response', 'category' fields
    _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
    dataset = dataset.map(
        _preprocessing_function,
        batched=True,
        remove_columns=["id", "game", "kbs", "plots", "quest"],
    )

    # Filter out samples that have input_ids exceeding max_length
    dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)
    
    return dataset

In [26]:
dataset = load_dataset(dataset_path, data_files=data_files)
train_dataset = dataset['train']
val_dataset = dataset['val']

Found cached dataset json (/home/manish/.cache/huggingface/datasets/json/data-6359e290ba54d2fa/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/2 [00:00<?, ?it/s]

In [19]:
## Preprocess dataset
max_length = get_max_length(model)
val_dataset = preprocess_val_dataset(tokenizer, max_length, val_dataset)

Found max lenth: 4096
Preprocessing dataset...


Map:   0%|          | 0/77 [00:00<?, ? examples/s]

Map:   0%|          | 0/77 [00:00<?, ? examples/s]

Filter:   0%|          | 0/77 [00:00<?, ? examples/s]

## Inference

Get the outputs from the processed inputs

In [23]:
import gc
gc.collect()

5109

In [None]:
from tqdm.auto import tqdm
results = []

for item in tqdm(val_dataset):
	inp_txt_len = len(item['text'])
	print(inp_txt_len)
	# Specify input
	inp = torch.tensor([item['input_ids']])
	attn_mask = torch.tensor([item['attention_mask']])

	# Specify device
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

	# Get answer
	# (Adjust max_new_tokens variable as you wish (maximum number of tokens the model can generate to answer the input))
	outputs = model.generate(
		input_ids=inp.to(device), 
		attention_mask=attn_mask, 
		max_new_tokens=250, 
		eos_token_id=tokenizer('### End')['input_ids'], 
		pad_token_id=tokenizer.eos_token_id,
  	)

	# Decode output & append to outptu list
	output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)[inp_txt_len + 2:]
	result = {'input': item['text'], 'output_gen': output_text, 'output_actual': item['output']}
	results.append(result)


In [54]:
print(TRAIN_TYPE, f'KG Depth = {KG_DEPTH}')

text_kg


In [68]:
out_file = os.path.join(base_dir, 'outputs', f'{TRAIN_TYPE}_{KG_DEPTH}_results.jsonl')
print('Saving results  @ ', out_file)

/home/manish/thesis-implementations/quest_generation/llama2/outputs/text_kg_results.jsonl


In [50]:
with open(out_file, 'w') as outfile:
	for result in results:
		json.dump(result, outfile)
		outfile.write('\n')

## Random examples

In [57]:
def create_input_prompt(background_info, plots):
    text = ''
    if background_info:
        text += '### Background:\n\n'
        text += '\n'.join(background_info) + '\n\n'
    text += '### Plots:\n\n'
    text += "\n".join(plots) + '\n'
    text += '\n### Quest:\n'
    return text

In [76]:
bg_info = [
    # '',
	'Delilah is a character. Delilah is present in Brigmore Manor. Delilah is the leader of the Brigmore Witches.',
	'Brigmore Manor is a location. Brigmore Manor is connected to Empire of the Isles. Brigmore Manor is a lair for the city\'s outlaws.',
 	'Empire of the Isles is a location. Empire of the Isles is connected to Brigmore Manor. Empire of the Isles is a kingdom.',
]
plots = [
	'Delilah and her coven are planning something that threatens everyone across the Empire of the Isles.',
 	'Delilah must be stopped'
	# 'Infiltrate the ruins of Brigmore Manor and stop Delilah'
]

str_inp = create_input_prompt(bg_info, plots)
print(str_inp)

### Background:

Delilah is a character. Delilah is present in Brigmore Manor. Delilah is the leader of the Brigmore Witches.
Brigmore Manor is a location. Brigmore Manor is connected to Empire of the Isles. Brigmore Manor is a lair for the city's outlaws.
Empire of the Isles is a location. Empire of the Isles is connected to Brigmore Manor. Empire of the Isles is a kingdom.

### Plots:

Delilah and her coven are planning something that threatens everyone across the Empire of the Isles.
Delilah must be stopped

### Quest:



In [37]:
token_inp = tokenizer(str_inp)

In [None]:
inp = torch.tensor([token_inp['input_ids']])
attn_mask = torch.tensor([token_inp['attention_mask']])

# Specify device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Get answer
# (Adjust max_new_tokens variable as you wish (maximum number of tokens the model can generate to answer the input))
outputs = model.generate(
	input_ids=inp.to(device), 
	attention_mask=attn_mask, 
	max_new_tokens=250, 
	eos_token_id=tokenizer('### End')['input_ids'], 
	pad_token_id=tokenizer.eos_token_id,
)

In [39]:
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


### Background:
Gilgondorin is a character. Gilgondorin is a keeper of local legends. Gilgondorin is present in Niben Bay. 
Niben Bay is a location. Niben Bay is a sea bay. Niben Bay is connected to Bawnwatch Camp. 
Bawnwatch Camp is a location. Bawnwatch Camp is a camp on the shore of Niben Bay. 
Watchman is a character. Watchman is a harmless, sad ghost of a sailor. Watchman is present in Niben Bay. 


### Plots:
the player asked Gilgondorin about the Watchman
Gilgondorin pinpoints where the Watchman appears on the player's map

The quest related to the above information is as follows:

### Quest:
Title: A watchful eye
Objective: Talk to the watchman
Tasks: 
 Travel to bawnwatch camp.
 Find the watchman


### End


In [67]:
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

### Background:



### Plots:

Delilah and her coven are planning something that threatens everyone across the Empire of the Isles.
Infiltrate the ruins of Brigmore Manor and stop Delilah

### Quest:

Title: The brigmore conspiracy
Objective: Infiltrate the ruins of brigmore manor and stop delilah
Tasks: 
 Find your way to the ruins of brigmore manor


### End


In [None]:
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
with open(os.path.join(base_dir, model_dir, 'results', model_name, f'{TRAIN_TYPE}_{KG_DEPTH}/results.txt'), 'w') as outfile:
	for result in results:
		outfile.write(f'{result}\n{"-"*40}\n')