# Parallel transformers

## Environment

### Imports

In [ ]:
import sys
sys.path.append('~/Projects/transformer_wrappers/src')

In [ ]:
import torch

In [ ]:
from transformer_wrappers.wrappers import ParallelCausalLMWrapper

### Constants and globals

In [ ]:
TOKEN = None  # TODO: HF token

In [ ]:
MODEL = 'meta-llama/Meta-Llama-3-8B'  #  google/gemma-7b | mistralai/Mistral-7B-v0.3 | meta-llama/Meta-Llama-3-8B  
FINE_TUNED_MODEL = './resources/models/fine_tuned/...'  # TODO: fine-tuned model path
MODEL_CONFIGS = {
    'token': TOKEN,
    'device_map': 'cuda' if torch.cuda.is_available() else 'cpu',
    'load_in_4bit': True, 
    'bnb_4bit_use_double_quant': True, 
    'bnb_4bit_quant_type': 'nf4', 
    'bnb_4bit_compute_dtype': torch.bfloat16
}

In [ ]:
TOKENIZER = MODEL  # Make sure to match it with base model when using fine-tuned models
TOKENIZER_CONFIGS = {
    'token': TOKEN,
    'pad_token': '<|end_of_text|>'  #  mistralai/Mistral-7B-v0.3 -> '</s>' | meta-llama/Meta-Llama-3-8B -> '<|end_of_text|>'
}

In [ ]:
PARALLELISATION_CONFIG_KEYS_MAPPING = {
    'p_rate': 'Parallelisation rate', 
    'block_parallel': 'Layer level parallelisation', 
    'iterative': 'Iterative', 
    'scaling': 'Normalisation'
}
PARALLELISATION_CONFIGS = [
    {k: v for k, v in zip(WRAPPER_CONFIGS_KEYS, configs)}
    for configs in itertools.product([2, 4], [True, False], [True, False], [True, False])
]

In [ ]:
PROMPT = 'Q: What is the capital of Italy? A:'

## Parallelise pre-trained model

Load pre-trained model into parallel wrapper

In [ ]:
model = ParallelCausalLMWrapper.from_pretrained(
    MODEL, 
    model_kwargs=MODEL_CONFIGS,
    tokenizer_kwargs=TOKENIZER_CONFIGS
)

Iterate through configs, generate from the given prompt and print generated output for each config

In [ ]:
print('#' * 32)
for configs in PARALLELISATION_CONFIGS:
    print('\n'.join(f'{PARALLELISATION_CONFIG_KEYS_MAPPING[k]}: {v}' for k, v in config.items()))
    print('-' * 32)
    input_encodings = model.tokenizer(PROMPT, return_tensors='pt').to(DEVICE)
    output = model.generate(input_encodings.input_ids, **configs, do_sample=False, max_length=32)
    text = model.tokenizer.decode(output['input_ids'][0])
    print(repr(text))
    print('#' * 32)
    

## Fine-tuned parallelised model

Load fine-tuned model into parallel wrapper

In [ ]:
model = ParallelCausalLMWrapper.from_pretrained(
    FINE_TUNED_MODEL, 
    model_kwargs=MODEL_CONFIGS,
    peft=True,
    tokenizer_name_or_path=TOKENIZER,
    tokenizer_kwargs=TOKENIZER_CONFIGS
)

Generate from the given prompt and print generated output

In [ ]:
input_encodings = model.tokenizer(PROMPT, return_tensors='pt').to(DEVICE)
output = model.generate(input_encodings.input_ids, do_sample=False, max_length=32)
text = model.tokenizer.decode(output['input_ids'][0])
print(repr(text))