# Torch Model

In [None]:
# From transformers.models.bert.modeling_bert.BertIntermediate
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py or https://huggingface.co/transformers/v2.5.0/_modules/transformers/modeling_bert.html (find BertIntermidiate here, we have modified it a bit here)

import torch

class BertIntermediate(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = torch.nn.Linear(config.hidden_size, config.intermediate_size)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.nn.functional.gelu(hidden_states)
        return hidden_states


# And finally, the model can be rewritten using functional torch APIs to make the test pass:

In [None]:
# torch_functional_bert.py

def bert_intermediate(hidden_states, *, parameters):
    hidden_states = hidden_states @ parameters.dense.weight
    hidden_states = hidden_states + parameters.dense.bias
    hidden_states = torch.nn.functional.gelu(hidden_states)
    return hidden_states


# Following TDD, the first step is to write a test for the model:
# What is TDD? Need to add acrynoms docs

In [None]:
def bert_intermediate(hidden_states, *, parameters):
    hidden_states = hidden_states @ parameters.dense.weight
    hidden_states = hidden_states + parameters.dense.bias
    hidden_states = torch.nn.functional.gelu(hidden_states)
    return hidden_states

In [2]:
import pytest
import torch
import transformers

import ttnn
import sys
#import os
#current_path = os.getcwd()
#print("the current path is: ", current_path)
sys.path.append("/home/dvartanians/tt-metal-v0.41.0/tt-metal/models/experimental/functional_bert/reference")
import torch_functional_bert # implemented here: https://github.com/tenstorrent-metal/tt-metal/blob/main/models/experimental/functional_bert/reference/torch_functional_bert.py

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32)
    torch_output = model(torch_hidden_states) # Golden output


    # where is this function defined? 
    # must be the following:
    from ttnn.model_preprocessing import preprocess_model_parameters
    parameters = preprocess_model_parameters(
        initialize_model=lambda: model, # Function to initialize the model
        convert_to_ttnn=lambda *_: False, # Keep the weights as torch tensors
    )

    output = torch_functional_bert.bert_intermediate(
        torch_hidden_states,
        parameters=parameters,
    )
    #output = bert_intermediate(
    #    torch_hidden_states,
    #    parameters=parameters,
    #)

    
    assert_with_pcc(torch_output, output, 0.9999)

# Step 2 - Switching to ttnn ops

In [None]:
import pytest
import torch
import transformers

import ttnn
#import ttnn_functional_bert

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(device, model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1)
    torch_output = model(torch_hidden_states)

    parameters = preprocess_model_parameters(
        initialize_model=lambda: model.to(torch.bfloat16),
        device=device, # Device to put the parameters on
    )

    hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16)
    hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)
    hidden_states = ttnn.to_device(hidden_states, device)
    #output = ttnn_functional_bert.bert_intermediate(
    #    hidden_states,
    #    parameters=parameters,
    #)

    output = bert_intermediate(
        hidden_states,
        parameters=parameters,
    )
    output = ttnn.from_device(output)
    output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT)
    output = ttnn.to_torch(output)

    assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.999)

# Then implementing the function using ttnn operations:


In [None]:
# ttnn_functional_bert.py

import ttnn

def bert_intermediate(
    hidden_states,
    *,
    parameters,
):
    output = hidden_states @ parameters.dense.weight
    output = output + parameters.dense.bias
    output = ttnn.gelu(output)
    return output


# Step 3 - Optimizing the model

In [None]:
# ttnn_optimized_functional_bert.py

import ttnn

def bert_intermediate(
    hidden_states,
    *,
    parameters,
    num_cores_x,
):
    batch_size, *_ = hidden_states.shape

    num_cores_x = 12
    output = ttnn.linear(
        hidden_states,
        ff1_weight,
        bias=ff1_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG, # Put the output into local core memory
        core_grid=(batch_size, num_cores_x), # Specify manual core grid to get the best possible performance
        activation="gelu", # Fuse Gelu
    )
    return True

# More examples