# block-internals

> An experiment to examine the internals of a self-attention block


In [None]:
#| default_exp experiments.block_internals

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
from fastcore.test import *

In [None]:
#| export
import argparse
from dataclasses import dataclass
from pathlib import Path
import tempfile
from typing import Iterator, Tuple

In [None]:
#| export
import click
import torch    
from tqdm.auto import tqdm

In [None]:
#| export
from transformer_experiments.common.databatcher import DataBatcher
from transformer_experiments.dataset_split import split_text_dataset
from transformer_experiments.datasets.tinyshakespeare import (
    TinyShakespeareDataSet,
)
from transformer_experiments.models.transformer import (
    n_embed,
    n_layer,
    TransformerLanguageModel
)
from transformer_experiments.models.transformer_helpers import (
    EncodingHelpers,
    TransformerAccessors
)
from transformer_experiments.trained_models.tinyshakespeare_transformer import (
    create_model_and_tokenizer, 
)


In [None]:
#| export
@dataclass
class BlockInternalsResult:
    substring: str
    heads_output: torch.Tensor
    proj_output: torch.Tensor
    ffwd_output: torch.Tensor

In [None]:
# | export
class BlockInternalsExperiment:
    """An experiment to run a bunch of inputs through a block and save the
    intermediate values produced for each token."""

    def __init__(
        self,
        eh: EncodingHelpers,
        accessors: TransformerAccessors,
        block_idx: int,
        results_folder: Path,
    ):
        assert block_idx >= 0 and block_idx < n_layer

        self.eh = eh
        self.accessors = accessors
        self.block_idx = block_idx
        self.results_folder = results_folder

    def _filename_stem(self):
        return f'block{self.block_idx}_internals'

    def _input_filename(self, batch: int):
        return f'{self._filename_stem()}_input_{batch:03d}.pt'

    def _head_output_filename(self, batch: int):
        return f'{self._filename_stem()}_head_output_{batch:03d}.pt'

    def _proj_output_filename(self, batch: int):
        return f'{self._filename_stem()}_proj_output_{batch:03d}.pt'

    def _ffwd_output_filename(self, batch: int):
        return f'{self._filename_stem()}_ffwd_output_{batch:03d}.pt'

    def run(self, data_batcher: DataBatcher):
        for batch_idx, batch in tqdm(enumerate(data_batcher)):
            x = self.accessors.embed_tokens(batch)

            # Run the encoded batch through the blocks up to the one we're interested in
            for i in range(self.block_idx):
                x = self.accessors.m.blocks[i](x)

            # Copy the block we're interested in
            block, io_accessor = self.accessors.copy_block_from_model(
                block_idx=self.block_idx
            )
            _ = block(x)  # Run the block

            # Grab the outputs of interest
            heads_output = io_accessor.input('sa.proj')
            proj_output = io_accessor.output('sa.proj')
            ffwd_output = io_accessor.output('ffwd')

            torch.save(
                batch.clone(), self.results_folder / self._input_filename(batch_idx)
            )
            torch.save(
                heads_output,
                self.results_folder / self._head_output_filename(batch_idx),
            )
            torch.save(
                proj_output, self.results_folder / self._proj_output_filename(batch_idx)
            )
            torch.save(
                ffwd_output, self.results_folder / self._ffwd_output_filename(batch_idx)
            )

    def load(
        self,
    ) -> Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
        input_files = sorted(
            self.results_folder.glob(f'{self._filename_stem()}_input_*.pt')
        )
        head_output_files = sorted(
            self.results_folder.glob(f'{self._filename_stem()}_head_output_*.pt')
        )
        proj_output_files = sorted(
            self.results_folder.glob(f'{self._filename_stem()}_proj_output_*.pt')
        )
        ffwd_output_files = sorted(
            self.results_folder.glob(f'{self._filename_stem()}_ffwd_output_*.pt')
        )

        assert (
            len(input_files)
            == len(head_output_files)
            == len(proj_output_files)
            == len(ffwd_output_files)
        )

        for input_file, head_output_file, proj_output_file, ffwd_output_file in zip(
            input_files, head_output_files, proj_output_files, ffwd_output_files
        ):
            assert input_file.exists()
            assert head_output_file.exists()
            assert proj_output_file.exists()
            assert ffwd_output_file.exists()

            yield torch.load(input_file), torch.load(head_output_file), torch.load(
                proj_output_file
            ), torch.load(ffwd_output_file)

    def raw_results(self) -> Iterator[BlockInternalsResult]:
        for inputs, head_output, proj_output, ffwd_output in self.load():
            n_samples, s_len = inputs.shape
            for i in range(n_samples):
                for j in range(s_len):
                    substring = self.eh.stringify_tokens(inputs[i][: j + 1])
                    yield BlockInternalsResult(
                        substring,
                        head_output[i][j],
                        proj_output[i][j],
                        ffwd_output[i][j],
                    )

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device is {device}")

device is cpu


In [None]:
ts = TinyShakespeareDataSet(cache_file='../artifacts/input.txt')
m, tokenizer = create_model_and_tokenizer(
    saved_model_filename='../artifacts/shakespeare.pt',
    dataset=ts,
    device=device,
)
_, val_data = split_text_dataset(ts.text, tokenizer, train_pct=0.9)
encoding_helpers = EncodingHelpers(m, tokenizer, device)
accessors = TransformerAccessors(m, device)

In [None]:
# Test for BlockInternalsExperiment
max_batch_size = 64
s_len = 3
n_full_batches = 2
last_batch_size = 12
data_batcher = DataBatcher(
    data=val_data[
        0 : n_full_batches * max_batch_size * s_len + last_batch_size * s_len
    ],
    sample_len=s_len,
    max_batch_size=max_batch_size,
    stride=s_len,
)

# Create a temp directory to store the results that will get
# cleaned up at the end.
with tempfile.TemporaryDirectory() as tmpdirname:
    test_exp = BlockInternalsExperiment(
        eh=encoding_helpers,
        accessors=accessors,
        block_idx=0,
        results_folder=Path(tmpdirname),
    )
    test_exp.run(data_batcher=data_batcher)

    expected_batch_shapes = [
        (max_batch_size, s_len),
        (max_batch_size, s_len),
        (last_batch_size, s_len),
    ]
    expected_output_shapes = [
        (max_batch_size, s_len, n_embed),
        (max_batch_size, s_len, n_embed),
        (last_batch_size, s_len, n_embed),
    ]

    for i, (inputs, heads_outputs, proj_outputs, ffwd_outputs) in enumerate(
        test_exp.load()
    ):
        test_eq(inputs.shape, expected_batch_shapes[i])
        test_eq(heads_outputs.shape, expected_output_shapes[i])
        test_eq(proj_outputs.shape, expected_output_shapes[i])
        test_eq(ffwd_outputs.shape, expected_output_shapes[i])

0it [00:00, ?it/s]

In [None]:
# | export
@click.command()
@click.argument('model_weights_filename', type=click.Path(exists=True))
@click.argument('dataset_cache_filename', type=click.Path(exists=True))
@click.argument('output_folder', type=click.Path(exists=True))
@click.option('-b', '--block_idx', required=True, type=click.IntRange(min=0, max=n_layer, max_open=True))
def run(
    model_weights_filename: str,
    dataset_cache_filename: str,
    output_folder: str,
    block_idx: int,
):
    click.echo(f"Running block internals experiment for block {block_idx}")
    
    # Instantiate the model, tokenizer, and dataset
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    click.echo(f"device is {device}")

    ts = TinyShakespeareDataSet(cache_file=dataset_cache_filename)
    m, tokenizer = create_model_and_tokenizer(
        saved_model_filename=model_weights_filename,
        dataset=ts,
        device=device,
    )
    _, val_data = split_text_dataset(ts.text, tokenizer, train_pct=0.9)

    encoding_helpers = EncodingHelpers(m, tokenizer, device)
    accessors = TransformerAccessors(m, device)

    # Create the experiment
    exp = BlockInternalsExperiment(
        encoding_helpers, accessors, block_idx, Path(output_folder)
    )

    # Run the experiment
    data_batcher = DataBatcher(
        data=val_data,
        sample_len=3,
        max_batch_size=64,
        stride=96,
    )
    exp.run(data_batcher=data_batcher)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()