# Preliminary Investigations

## Environment preparation

### Imports

In [1]:
import sys
sys.path.append('/home/vincenzoscotti/Projects/transformer_wrappers/src')
# sys.path.append('/app/Projects/transformer_wrappers/src')

In [2]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformer_wrappers.wrappers import TransformerWrapper, CausalLMWrapper

In [4]:
import pickle

### Constants

In [5]:
COLOURS = [f'C{i}' for i in range(10)]
STYLES = ['solid', 'dotted', 'dashed', 'dashdot']

In [6]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [7]:
TOKEN = '...'  # HuggingFace token

In [8]:
# MODEL = 'gpt2-xl'
# MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'  
# MODEL = 'meta-llama/Llama-2-7b-hf'
MODEL = 'google/gemma-7b'
MODEL_CONFIGS = {
    'torch_dtype': torch.bfloat16,
    'attn_implementation': 'eager',
    'device_map': DEVICE,
    'token': TOKEN
}

# TOKENIZER = 'gpt2-xl'
# TOKENIZER = 'mistralai/Mistral-7B-Instruct-v0.2'  
# TOKENIZER = 'meta-llama/Llama-2-7b-hf'
TOKENIZER = 'google/gemma-7b'
TOKENIZER_CONFIGS = {'token': TOKEN}

### Global

In [9]:
input_string = 'The quick brown fox jumps over the lazy dog.'

## Tests

### Transformer

#### Baseline

In [10]:
model = AutoModel.from_pretrained(MODEL, **MODEL_CONFIGS)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER, **TOKENIZER_CONFIGS)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.43it/s]


In [11]:
input_encodings = tokenizer(input_string, return_tensors='pt').to(DEVICE)

In [12]:
output = model(
    **input_encodings, 
    return_dict=True, 
    output_attentions=True, 
    use_cache=True, 
    output_hidden_states=True
)

In [13]:
with open('.tmp_output_transformer.pkl', 'wb') as f:
    pickle.dump(output, f)

#### Wrapper

In [10]:
model = TransformerWrapper.from_pretrained(MODEL, model_kwargs=MODEL_CONFIGS, tokenizer_name_or_path=TOKENIZER, tokenizer_kwargs=TOKENIZER_CONFIGS)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.45it/s]


In [11]:
input_encodings = model.tokenizer(input_string, return_tensors='pt').to(DEVICE)

In [12]:
output_wrapper = model(
    **input_encodings,
    return_dict=True,
    output_attentions=True,
    use_cache=True,
    output_hidden_states=True,
    return_attention_output=True,  # Self-attention layer output
    return_feed_forward_output=True
)

In [13]:
with open('.tmp_output_wrapper_transformer.pkl', 'wb') as f:
    pickle.dump(output_wrapper, f)

#### Comparison

In [14]:
with open('.tmp_output_wrapper_transformer.pkl', 'rb') as f:
    output_wrapper = pickle.load(f)

In [15]:
with open('.tmp_output_transformer.pkl', 'rb') as f:
    output = pickle.load(f)

In [16]:
assert torch.equal(
    output.last_hidden_state, output_wrapper['output_hidden_state']
), '`last_hidden_state` not matching.'

In [17]:
for i, (output_hidden_state, output_wrapper_hidden_state) in enumerate(zip(
        output.hidden_states, output_wrapper['hidden_states']
)):
    if i == 0:
        assert torch.equal(
            output_hidden_state, output_wrapper_hidden_state
        ), 'Initial embedding tensors not matching.'
    if i == len(model.layers):
        assert torch.equal(
            output_hidden_state, model.norm(output_wrapper_hidden_state)
        ), f'`hidden_state` tensors at layer {i} not matching.'
    else:
        assert torch.equal(
            output_hidden_state, output_wrapper_hidden_state
        ), f'`hidden_state` tensors at layer {i} not matching.'

In [18]:
for i, (
        output_hidden_state, prev_output_wrapper_hidden_state, attn_output_wrapper, ffnn_output_wrapper
) in enumerate(zip(
        output.hidden_states[1:],
        output_wrapper['hidden_states'][:-1],
        output_wrapper['attention_outputs'],
        output_wrapper['feed_forward_outputs']
), start=1):
    output_wrapper_hidden_state = prev_output_wrapper_hidden_state + attn_output_wrapper + ffnn_output_wrapper
    if i == len(model.layers):
        assert torch.equal(
            output_hidden_state, model.norm(output_wrapper_hidden_state)
        ), f'Composed `hidden_state` tensors at layer {i} not matching.'
    else:
        assert torch.equal(
            output_hidden_state, output_wrapper_hidden_state
        ), f'Composed `hidden_state` tensors at layer {i} not matching.'

In [19]:
for i, (output_past_key_values, output_wrapper_past_key_values) in enumerate(zip(
        output.past_key_values, output_wrapper['cache'],
), start=1):
    output_past_keys, output_past_values = output_past_key_values
    output_wrapper_past_keys, output_wrapper_past_values = output_wrapper_past_key_values
    assert torch.equal(
        output_past_keys, output_wrapper_past_keys
    ), f'`key` tensors at layer {i} not matching.'
    assert torch.equal(
        output_past_values, output_wrapper_past_values
    ), f'`value` tensors at layer {i} not matching.'

In [20]:
for i, (output_attentions, output_wrapper_attentions) in enumerate(zip(
        output.attentions, output_wrapper['attention_weights']
), start=1):
    assert torch.equal(
        output_attentions, output_wrapper_attentions
    ), f'`attentions` tensors at layer {i} not matching.'

### Language Model

#### Baseline

In [10]:
model = AutoModelForCausalLM.from_pretrained(MODEL, **MODEL_CONFIGS)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER, **TOKENIZER_CONFIGS)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.39it/s]


In [11]:
input_encodings = tokenizer(input_string, return_tensors='pt').to(DEVICE)

In [12]:
output = model.generate(input_encodings.input_ids, do_sample=False, max_length=16)

In [13]:
with open('.tmp_output_lm.pkl', 'wb') as f:
    pickle.dump(output, f)

#### Wrapper

In [10]:
model = CausalLMWrapper.from_pretrained(MODEL, model_kwargs=MODEL_CONFIGS, tokenizer_name_or_path=TOKENIZER, tokenizer_kwargs=TOKENIZER_CONFIGS)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.40it/s]


In [11]:
input_encodings = model.tokenizer(input_string, return_tensors='pt').to(DEVICE)

In [12]:
output_wrapper = model.generate(input_encodings.input_ids, do_sample=False, max_length=16)

AttributeError: 'NoneType' object has no attribute 'shape'

In [13]:
with open('.tmp_output_wrapper_lm.pkl', 'wb') as f:
    pickle.dump(output_wrapper, f)

#### Comparison

In [14]:
with open('.tmp_output_wrapper_lm.pkl', 'rb') as f:
    output_wrapper = pickle.load(f)

In [15]:
with open('.tmp_output_lm.pkl', 'rb') as f:
    output = pickle.load(f)

In [16]:
assert torch.equal(
    output, output_wrapper['input_ids']
), 'Generated sequences not matching.'