In [1]:
import torch

from torch.utils.data import DataLoader

from transformers import AutoTokenizer
from datasets import load_dataset
from transformers import AutoModelForCausalLM

from accelerate import Accelerator

from matplotlib import pyplot as plt

In [2]:
model_name="tiiuae/falcon-rw-1b"
dataset_name="yelp_review_full"
accelerator = Accelerator()
max_input_seq_length = 17

In [3]:
# Load the dataset.  Just load the train split.  Different datasets have different splits, but
# having a train split is common.
dataset = load_dataset(dataset_name, split="train")

In [4]:
# We only want to use a small subset of the dataset for this example.
# Note that dataset[:1000] would seem to work, but it doesn't.  In particular,
# dataset[:1000] is not a dataset object, but a dictionary.  So, we use the
# select method to get a dataset object.
dataset = dataset.shuffle(42).select(range(1000))

In [5]:
# This converts the dataset to a format that the model can understand.
# In particlar, it takes the words and converts them to numbers/tokens.
# Note, the pdding side is left since that is that the CausalLM model expects.
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
# NOTE: the tokenizer.pad_token is a special token that is used to pad sequences to the same length.
tokenizer.pad_token = tokenizer.eos_token

# NOT TESTED: I think this gets a batch of samples as defined by the map function.
# So, the longest refers to the longest sequence in the batch.
def tokenize_function(examples):
    return tokenizer(examples["text"], padding='longest', truncation=True, max_length=max_input_seq_length)

# NOTE: the map function does some fancy caching.  I.e., the first time you run it, it will
# take a while.  But, the second time you run it, it will be much faster.
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# We don't need the labels anymore, so we remove them.
tokenized_datasets = tokenized_datasets.remove_columns(["label", "text"])
# From https://huggingface.co/docs/datasets/v2.15.0/en/package_reference/main_classes#datasets.Dataset.set_format
#     Set __getitem__ return format using this transform. The transform is applied on-the-fly on batches when __getitem__ is called. 
#     type (str, optional) — Either output type selected in [None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']. None means __getitem__ returns python objects (default).
tokenized_datasets.set_format("torch")

train_dataloader = DataLoader(tokenized_datasets, shuffle=True, batch_size=8)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [6]:
# A little sanity check.
print('example text')
print(dataset[0])
print('example tokenized text')
print(tokenized_datasets[0])
print('example decoded tokenized text')
print(tokenizer.decode(tokenized_datasets[0]['input_ids']))

example text
{'label': 4, 'text': "I stalk this truck.  I've been to industrial parks where I pretend to be a tech worker standing in line, strip mall parking lots, and of course the farmer's market.  The bowls are so so absolutely divine.  The owner is super friendly and he makes each bowl by hand with an incredible amount of pride.  You gotta eat here guys!!!"}
example tokenized text
{'input_ids': tensor([   40, 31297,   428,  7779,    13,   220,   314,  1053,   587,   284,
         7593, 14860,   810,   314, 16614,   284,   307]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
example decoded tokenized text
I stalk this truck.  I've been to industrial parks where I pretend to be


In [7]:
class Spy(torch.nn.Module):
        def __init__(self, model, debug=False):
            super().__init__()
            self.model = model
            self.debug = debug
            self.inputs = []
            self.outputs = []
            self.last_size = 0

        def forward(self, *args, **kwargs):
            self.inputs.append(args)
            output = self.model(*args, **kwargs)
            self.outputs.append(output)
            if self.debug:
                print(f'args {args}')
                print(f'kwargs {kwargs}')
                print(f'output {output}')
            return output

        def print_last_input(self):
            """prints the shapes of all the inputs that have not been printed yet
            """
            print(f'{self.last_size} {len(self.inputs)}')
            for i in range(self.last_size, len(self.inputs)):
                print(f'{i} {self.inputs[i][0].shape}')
            self.last_size = len(self.inputs)


In [8]:
model = AutoModelForCausalLM.from_pretrained(model_name)
print(model)

FalconForCausalLM(
  (transformer): FalconModel(
    (word_embeddings): Embedding(50304, 2048)
    (h): ModuleList(
      (0-23): 24 x FalconDecoderLayer(
        (self_attention): FalconAttention(
          (query_key_value): FalconLinear(in_features=2048, out_features=6144, bias=True)
          (dense): FalconLinear(in_features=2048, out_features=2048, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): FalconMLP(
          (dense_h_to_4h): FalconLinear(in_features=2048, out_features=8192, bias=True)
          (act): GELU(approximate='none')
          (dense_4h_to_h): FalconLinear(in_features=8192, out_features=2048, bias=True)
        )
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2048, out_feature

In [9]:
# Add a spy to the embedding layer.
embedding_spy = Spy(model.transformer.word_embeddings)
model.transformer.word_embeddings = embedding_spy

# Add a spy to each of the transformer layers.
transformer_layer_spies = []
for i, layer in enumerate(model.transformer.h):
    transformer_layer_spies.append(Spy(layer))
    model.transformer.h[i] = transformer_layer_spies[i]

# Add a spy to the final layer norm.
layer_norm_spy = Spy(model.transformer.ln_f)
model.transformer.ln_f = layer_norm_spy

# Add a spy to the output layer.
output_spy = Spy(model.lm_head)
model.lm_head = output_spy

In [10]:
prompt = "This is a review of a restaurant.  The food was"
input = tokenizer(prompt, return_tensors="pt").input_ids
print(input)

tensor([[1212,  318,  257, 2423,  286,  257, 7072,   13,  220,  383, 2057,  373]])


In [11]:
model = accelerator.prepare(model)
# Note, the accelerator is cool, but only handles dataloaders.  So, for this example, we need to do it ourselves.
input = input.to(accelerator.device)

In [12]:
# We are now ready for the most simple exampl of running the model.
output = model(input)

In [13]:
embedding_spy.inputs

[(tensor([[1212,  318,  257, 2423,  286,  257, 7072,   13,  220,  383, 2057,  373]],
         device='cuda:0'),)]

In [14]:
print(embedding_spy.outputs[0].shape)
embedding_spy.outputs

torch.Size([1, 12, 2048])


[tensor([[[ 0.0102,  0.0081, -0.0007,  ..., -0.0074,  0.0126, -0.0124],
          [ 0.0084,  0.0072,  0.0198,  ..., -0.0066, -0.0183,  0.0156],
          [-0.0005, -0.0045,  0.0157,  ...,  0.0152, -0.0175, -0.0374],
          ...,
          [ 0.0065,  0.0034,  0.0212,  ...,  0.0381,  0.0072, -0.0047],
          [-0.0339, -0.0381,  0.0215,  ...,  0.0017, -0.0064,  0.0088],
          [ 0.0311,  0.0049,  0.0199,  ..., -0.0162,  0.0062,  0.0064]]],
        device='cuda:0', grad_fn=<EmbeddingBackward0>)]

In [15]:
transformer_layer_spies[0].inputs[0][0]

tensor([[[ 0.0102,  0.0081, -0.0007,  ..., -0.0074,  0.0126, -0.0124],
         [ 0.0084,  0.0072,  0.0198,  ..., -0.0066, -0.0183,  0.0156],
         [-0.0005, -0.0045,  0.0157,  ...,  0.0152, -0.0175, -0.0374],
         ...,
         [ 0.0065,  0.0034,  0.0212,  ...,  0.0381,  0.0072, -0.0047],
         [-0.0339, -0.0381,  0.0215,  ...,  0.0017, -0.0064,  0.0088],
         [ 0.0311,  0.0049,  0.0199,  ..., -0.0162,  0.0062,  0.0064]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

In [16]:
transformer_layer_spies[0].outputs[0][0]

tensor([[[-0.5418, -0.0461,  0.2198,  ...,  0.0378,  0.3550, -0.3398],
         [-0.3592,  0.0346,  0.0228,  ..., -0.0202, -0.0376, -0.1926],
         [-0.0531,  0.0472, -0.0923,  ...,  0.0454,  0.2463, -0.0051],
         ...,
         [ 0.0854,  0.0670,  0.0658,  ...,  0.1739, -0.0161, -0.1793],
         [-0.1771, -0.1400,  0.3323,  ...,  0.4190, -0.2321,  0.0036],
         [-0.1761,  0.0589, -0.1985,  ...,  0.0311,  0.0280,  0.0314]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [17]:
transformer_layer_spies[1].inputs[0][0]

tensor([[[-0.5418, -0.0461,  0.2198,  ...,  0.0378,  0.3550, -0.3398],
         [-0.3592,  0.0346,  0.0228,  ..., -0.0202, -0.0376, -0.1926],
         [-0.0531,  0.0472, -0.0923,  ...,  0.0454,  0.2463, -0.0051],
         ...,
         [ 0.0854,  0.0670,  0.0658,  ...,  0.1739, -0.0161, -0.1793],
         [-0.1771, -0.1400,  0.3323,  ...,  0.4190, -0.2321,  0.0036],
         [-0.1761,  0.0589, -0.1985,  ...,  0.0311,  0.0280,  0.0314]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [18]:
transformer_layer_spies[22].outputs[0][0]

tensor([[[-0.5607,  3.2704,  1.5336,  ...,  0.1166,  0.9541, -3.9348],
         [ 1.3911,  0.5455,  3.2251,  ...,  1.8272,  1.5874, -2.5345],
         [ 0.3320, -0.5699,  2.4946,  ...,  1.6166,  1.9061,  0.1447],
         ...,
         [-0.3299, -0.7877,  2.4874,  ...,  1.0747, -0.7235,  1.0040],
         [ 1.5230,  0.3460,  4.8440,  ...,  1.5909,  0.9555,  0.8341],
         [-1.7289, -1.2445,  2.1949,  ...,  1.5398,  1.5457,  0.0147]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [19]:
transformer_layer_spies[23].inputs[0][0]

tensor([[[-0.5607,  3.2704,  1.5336,  ...,  0.1166,  0.9541, -3.9348],
         [ 1.3911,  0.5455,  3.2251,  ...,  1.8272,  1.5874, -2.5345],
         [ 0.3320, -0.5699,  2.4946,  ...,  1.6166,  1.9061,  0.1447],
         ...,
         [-0.3299, -0.7877,  2.4874,  ...,  1.0747, -0.7235,  1.0040],
         [ 1.5230,  0.3460,  4.8440,  ...,  1.5909,  0.9555,  0.8341],
         [-1.7289, -1.2445,  2.1949,  ...,  1.5398,  1.5457,  0.0147]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [20]:
transformer_layer_spies[23].outputs[0][0]

tensor([[[ 1.9322,  1.7793,  2.3793,  ...,  1.7575,  0.6972, -0.0619],
         [ 2.3636,  2.3833,  5.2714,  ...,  4.1948, -1.3329, -5.2314],
         [ 1.7173,  0.3608,  4.3707,  ...,  3.4157, -0.2490, -2.8648],
         ...,
         [ 2.2031,  1.2303,  4.7532,  ...,  3.7923, -3.5564, -1.9627],
         [ 2.9893,  2.0272,  6.1852,  ...,  3.1881, -1.5452, -1.3266],
         [-0.4662,  0.7712,  4.5110,  ...,  3.3452, -0.3409, -2.7155]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [21]:
layer_norm_spy.inputs[0][0]

tensor([[[ 1.9322,  1.7793,  2.3793,  ...,  1.7575,  0.6972, -0.0619],
         [ 2.3636,  2.3833,  5.2714,  ...,  4.1948, -1.3329, -5.2314],
         [ 1.7173,  0.3608,  4.3707,  ...,  3.4157, -0.2490, -2.8648],
         ...,
         [ 2.2031,  1.2303,  4.7532,  ...,  3.7923, -3.5564, -1.9627],
         [ 2.9893,  2.0272,  6.1852,  ...,  3.1881, -1.5452, -1.3266],
         [-0.4662,  0.7712,  4.5110,  ...,  3.3452, -0.3409, -2.7155]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [22]:
layer_norm_spy.outputs[0][0]

tensor([[ 1.6132,  1.4738,  3.1858,  ...,  1.3919, -1.6163, -3.5076],
        [ 1.2990,  1.5379,  4.2701,  ...,  3.0219, -2.1128, -5.5462],
        [ 0.9190, -0.1588,  3.9960,  ...,  2.7389, -1.2626, -3.8932],
        ...,
        [ 0.9481,  0.4120,  3.4575,  ...,  2.3792, -3.9171, -2.6639],
        [ 2.1224,  1.4004,  5.8361,  ...,  2.5096, -2.6661, -2.4685],
        [-1.0004,  0.2386,  3.6133,  ...,  2.3228, -1.1924, -3.2386]],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [23]:
output_spy.inputs[0][0]

tensor([[[ 1.6132,  1.4738,  3.1858,  ...,  1.3919, -1.6163, -3.5076],
         [ 1.2990,  1.5379,  4.2701,  ...,  3.0219, -2.1128, -5.5462],
         [ 0.9190, -0.1588,  3.9960,  ...,  2.7389, -1.2626, -3.8932],
         ...,
         [ 0.9481,  0.4120,  3.4575,  ...,  2.3792, -3.9171, -2.6639],
         [ 2.1224,  1.4004,  5.8361,  ...,  2.5096, -2.6661, -2.4685],
         [-1.0004,  0.2386,  3.6133,  ...,  2.3228, -1.1924, -3.2386]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

In [24]:
output_spy.outputs[0].shape

torch.Size([1, 12, 50304])

In [25]:
print(prompt)
logits = output_spy.outputs[0]

for i in range(logits.shape[1]):
    highest = torch.argmax(output_spy.outputs[0][0, i, :])
    print(f'{i} {highest} {tokenizer.decode([highest])}')

This is a review of a restaurant.  The food was
0 318  is
1 257  a
2 24659  continuation
3 286  of
4 262  the
5 1492  book
6 326  that
7 198 

8 198 

9 2057  food
10 373  was
11 922  good


In [26]:
tokenizer.decode(highest)

' good'