# Bridging PyTorch with MAX🧑‍🚀 and Mojo🔥: 

Accompanying code for [workshops/pytorch-max-bridge](https://github.com/modular/workshops/tree/main/pytorch-max-bridge).

**We will learn how to gradually replace any parts of a PyTorch model with MAX and Mojo**

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# no need to copy
model.to(device)
model.eval()
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)

In [3]:
import numpy as np
import torch

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

class DialogueBot:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.chat_history_ids = None

    def generate_response(self, user_input):
        with torch.no_grad():
            set_seed(42)
            new_user_input_ids = self.tokenizer.encode(
                user_input + self.tokenizer.eos_token,
                return_tensors='pt'
            ).to(device)

            if self.chat_history_ids is not None:
                bot_input_ids = torch.cat([self.chat_history_ids, new_user_input_ids], dim=-1).to(device)
            else:
                bot_input_ids = new_user_input_ids

            # Create attention mask (1 for real tokens, 0 for padding)
            # Since we're not using padding here, all tokens are real
            attention_mask = torch.ones_like(bot_input_ids).to(device)

            self.chat_history_ids = self.model.generate(
                bot_input_ids,
                attention_mask=attention_mask,
                max_length=1000,
                pad_token_id=self.tokenizer.eos_token_id,
                repetition_penalty=1.1,
                no_repeat_ngram_size=2,
            )

            response = self.tokenizer.decode(
                self.chat_history_ids[:, bot_input_ids.shape[-1]:][0],
                skip_special_tokens=True
            )
            return response

    def reset_conversation(self):
        self.chat_history_ids = None

In [4]:
test_inputs = [
    "Hello, how are you?",
    "What's your favorite programming language?",
    "Tell me about artificial intelligence",
    "What's the meaning of life?",
    "Tell me a joke",
]

print("=== STAGE 1: Original PyTorch Implementation ===")
bot_original = DialogueBot(model, tokenizer)

print("Original PyTorch Results:")
original_responses = []
for i, user_input in enumerate(test_inputs, 1):
    response = bot_original.generate_response(user_input)
    original_responses.append(response)
    print(f"Turn {i}")
    print(f"User: {user_input}")
    print(f"Bot: {response}")
    print("=" * 80)

=== STAGE 1: Original PyTorch Implementation ===
Original PyTorch Results:
Turn 1
User: Hello, how are you?
Bot: I'm good, thanks. How about you?
Turn 2
User: What's your favorite programming language?
Bot: Java, Python, C, and C.
Turn 3
User: Tell me about artificial intelligence
Bot: It's a thing. It's called AI. I don't know what it is, but it's something. And it can be used to make things. Like computers. Or robots. Whatever. You want it to be. :D
Turn 4
User: What's the meaning of life?
Bot: Life is a computer program that makes things
Turn 5
User: Tell me a joke
Bot: A joke!


## Custom Op with [max.torch.graph_op](https://docs.modular.com/max/api/python/torch/#max.torch.graph_op)

* This is a powerful approach that leverages writting composable custom op using MAX Python Graph [max.graph.ops](https://docs.modular.com/max/api/python/graph/ops/) APIs.
* MAX Graph Compiler performs various optimizations such as kernel fusion.
* `torch.compile` helps with memory planning.

In [5]:
import max.torch
from max.graph import TensorValue, ops

@max.torch.graph_op
def max_layer_norm(x: TensorValue, weight: TensorValue, bias: TensorValue) -> TensorValue:
    return ops.layer_norm(x, weight, bias, epsilon=1e-5)

@torch.compile  
def custom_layer_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    # same as torch.empty with correct dtype and device mapping
    output = x.new_empty(x.shape)
    # NOTE: `output` is added as an argument to enable highly efficient Destination-Passing Style
    max_layer_norm(output, x, weight, bias)
    return output

In [6]:
# quick test
import torch.nn as nn

batch_size, seq_len, hidden_size = 1, 8, 32
test_input = torch.randn(batch_size, seq_len, hidden_size, device=device)
test_weight = torch.randn(hidden_size, device=device)
test_bias = torch.randn(hidden_size, device=device)

pret = nn.functional.layer_norm(test_input, test_weight.shape, test_weight, test_bias)
mret = custom_layer_norm(test_input, test_weight, test_bias)

assert torch.allclose(pret, mret, 1e-4), f"didn't match {pret[0, :3, :3]}, {mret[0, :3, :3]}"

## Perform "Model Surgery"

Iteratively replacing the `LayerNorm` with our custom op and test the model.

In [7]:
def replace_layer_norm_with_max(model):
    """Replace LayerNorm with MAX implementation for inference"""
    replaced_count = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.LayerNorm):
            print(f"Found LayerNorm: {name}, weight shape: {module.weight.shape}, bias: {module.bias is not None}")

            if module.weight.shape[0] == 1024:  # DialoGPT-medium hidden dimension
                original_forward = module.forward

                def max_forward(x, name=name, module=module):
                    try:
                        weight_detached = module.weight.detach()
                        if module.bias is not None:
                            bias_detached = module.bias.detach()
                            return custom_layer_norm(x, weight_detached, bias_detached)
                        else:
                            bias_tensor = torch.zeros_like(weight_detached)
                            return custom_layer_norm(x, weight_detached, bias_tensor)
                    except Exception as e:
                        print(f"Error in MAX LayerNorm for {name}: {e}")
                        return original_forward(x)

                module.forward = max_forward
                replaced_count += 1
                print(f"Replaced {name} with MAX implementation")
            else:
                print(f"Skipping {name} - unexpected shape {module.weight.shape}")

    print(f"Total replaced: {replaced_count} LayerNorm operations")
    model.to(device)
    return replaced_count

print("=== STAGE 2: Replacing LayerNorm with MAX ===")
replaced_count = replace_layer_norm_with_max(model)

=== STAGE 2: Replacing LayerNorm with MAX ===
Found LayerNorm: transformer.h.0.ln_1, weight shape: torch.Size([1024]), bias: True
Replaced transformer.h.0.ln_1 with MAX implementation
Found LayerNorm: transformer.h.0.ln_2, weight shape: torch.Size([1024]), bias: True
Replaced transformer.h.0.ln_2 with MAX implementation
Found LayerNorm: transformer.h.1.ln_1, weight shape: torch.Size([1024]), bias: True
Replaced transformer.h.1.ln_1 with MAX implementation
Found LayerNorm: transformer.h.1.ln_2, weight shape: torch.Size([1024]), bias: True
Replaced transformer.h.1.ln_2 with MAX implementation
Found LayerNorm: transformer.h.2.ln_1, weight shape: torch.Size([1024]), bias: True
Replaced transformer.h.2.ln_1 with MAX implementation
Found LayerNorm: transformer.h.2.ln_2, weight shape: torch.Size([1024]), bias: True
Replaced transformer.h.2.ln_2 with MAX implementation
Found LayerNorm: transformer.h.3.ln_1, weight shape: torch.Size([1024]), bias: True
Replaced transformer.h.3.ln_1 with MAX imp

In [8]:
bot_max_layernorm = DialogueBot(model, tokenizer)

print("MAX LayerNorm Results:")
max_layernorm_responses = []
for i, user_input in enumerate(test_inputs, 1):
    response = bot_max_layernorm.generate_response(user_input)
    max_layernorm_responses.append(response)
    print(f"Turn {i}")
    print(f"User: {user_input}")
    print(f"Bot: {response}")
    print("=" * 80)

MAX LayerNorm Results:
Turn 1
User: Hello, how are you?
Bot: I'm good, thanks. How about you?
Turn 2
User: What's your favorite programming language?
Bot: Java, Python, C, and C.
Turn 3
User: Tell me about artificial intelligence
Bot: It's a thing. It's called AI. I don't know what it is, but it's something. And it can be used to make things. Like computers. Or robots. Whatever. You want it to be. :D
Turn 4
User: What's the meaning of life?
Bot: Life is a computer program that makes things
Turn 5
User: Tell me a joke
Bot: A joke!


## Load Mojo Custom op with [max.torch.CustomOpLibrary](https://docs.modular.com/max/api/python/torch/#max.torch.CustomOpLibrary)

Another approach is to 

1. Write a GPU kernel in Mojo.
2. Package it [mojo package](https://docs.modular.com/mojo/cli/package/).
3. Use `max.torch.CustomOpLibrary` to load the package in our python code.

In [9]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)

## Mojo NewGELU Kernel

Next, we write a simple CPU/GPU for the `NewGELU` activation function and use the jupyter magic `%%mojo` to make a mojo package (`.mojopkg`) for loading later.

In [10]:
import max.support.notebook

In [11]:
%%mojo

@fieldwise_init
struct Foo(Writable):
    var s: String

    fn write_to[W: Writer](self, mut writer: W):
        writer.write("[ ", self.s, " ]")

def main():
    print("Hello, world!")
    foo = Foo("Hello, jupyter!")
    print(foo)

Hello, world!
[ Hello, jupyter! ]



In [12]:
%%mojo package -o ops.mojopkg

import math
import compiler
from algorithm import parallelize, vectorize
from sys import simdwidthof
from layout import LayoutTensor, Layout, UNKNOWN_VALUE
from gpu import thread_idx, block_idx, block_dim
from gpu.host import DeviceContext, DeviceBuffer
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor
from math import ceildiv

alias dtype = DType.float32
alias BLOCK_SIZE = 256

# Core NewGELU computation - can be used in both CPU and GPU contexts
@always_inline
fn new_gelu_computation[dtype: DType](x: Scalar[dtype]) -> Scalar[dtype]:
    """
    Core NewGELU computation for a single scalar value.

    NewGELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
    """
    alias SQRT_2_OVER_PI = Scalar[dtype](0.7978845608028654)  # sqrt(2/pi) for float32
    alias GELU_COEFF = Scalar[dtype](0.044715)

    return 0.5 * x * (1.0 + math.tanh(SQRT_2_OVER_PI * (x + GELU_COEFF * x * x * x)))

# Simple GPU kernel for demonstration
fn new_gelu_gpu_kernel[dtype: DType](
    output_ptr: UnsafePointer[Scalar[dtype]],
    input_ptr: UnsafePointer[Scalar[dtype]],
    num_elements: Int,
):
    idx = block_idx.x * block_dim.x + thread_idx.x

    if idx < num_elements:
        x_val = input_ptr[idx]
        result_val = new_gelu_computation[dtype](x_val)
        output_ptr[idx] = result_val

# MAX custom operation using @compiler.register
@compiler.register("new_gelu")
struct NewGELU:
    @staticmethod
    fn execute[target: StaticString](
        # Outputs
        result: OutputTensor[dtype=DType.float32, rank=3], # (batch_size, seq_len, hidden_size)
        # Inputs
        x: InputTensor[dtype=DType.float32, rank=3],
        # Context
        ctx: DeviceContextPtr,
    ) raises:
        batch_size = x.dim_size(0)
        seq_len = x.dim_size(1)
        hidden_size = x.dim_size(2)

        @parameter
        if target == "cpu":
            # Apply NewGELU element-wise across all dimensions
            for b in range(batch_size):
                for s in range(seq_len):
                    for h in range(hidden_size):
                        var x_val = x[b, s, h]
                        var result_val = new_gelu_computation[dtype](x_val)
                        result[b, s, h] = result_val

        elif target == "gpu":
            gpu_ctx = ctx.get_device_context()
            num_elements = batch_size * seq_len * hidden_size

            # Calculate grid and block dimensions
            grid_size = ceildiv(num_elements, BLOCK_SIZE)
            block_size = BLOCK_SIZE

            # Get tensor data and convert to layout tensors
            x_tensor = x.to_layout_tensor()
            output_tensor = result.to_layout_tensor()
            input_ptr = x_tensor.ptr
            output_ptr = output_tensor.ptr

            # Launch GPU kernel
            gpu_ctx.enqueue_function[new_gelu_gpu_kernel[dtype]](
                output_ptr,
                input_ptr,
                num_elements,
                grid_dim=grid_size,
                block_dim=block_size,
            )
        else:
            raise Error("Unsupported target: " + target)




## Load the Custom Mojo Kernel

* Load our packaged Mojo kernel using [max.torch.CustomOpLibrary](https://docs.modular.com/max/api/python/torch/#max.torch.CustomOpLibrary).
* Perform another surgery to replace the activation.

In [13]:
# Load our custom NewGELU operation from ops package
from max.torch import CustomOpLibrary
from pathlib import Path

try:
    op_library = CustomOpLibrary(Path("ops.mojopkg"))
    new_gelu_op = op_library.new_gelu
    print("Successfully loaded new_gelu operation")
except Exception as e:
    print(f"Error loading custom operations: {e}")
    print("Make sure ops.mojopkg is properly created and contains new_gelu operation")
    raise

class MaxNewGELUActivation(nn.Module):
    """MAX implementation of NewGELUActivation using our custom operation"""

    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = x.new_empty(x.shape)

        # Use our custom NewGELU operation
        # The custom op modifies output in-place (destination-passing style)
        torch.compile(new_gelu_op)(output, x)

        return output

def replace_new_gelu_with_max(model):
    """Replace NewGELUActivation with MAX implementation"""
    replaced_count = 0

    def replace_in_module(module):
        nonlocal replaced_count
        for name, child in module.named_children():
            if hasattr(child, '__class__') and 'NewGELUActivation' in child.__class__.__name__:
                print(f"  Replacing NewGELUActivation in {name}")
                setattr(module, name, MaxNewGELUActivation())
                replaced_count += 1
            else:
                replace_in_module(child)

    replace_in_module(model)
    model.to(device)
    return replaced_count

gelu_replaced_count = replace_new_gelu_with_max(model)
print(f"Successfully replaced {gelu_replaced_count} NewGELUActivation operations with MAX")

Successfully loaded new_gelu operation
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
  Replacing NewGELUActivation in act
Successfully replaced 24 NewGELUActivation operations with MAX


In [14]:
bot_max_newgelu = DialogueBot(model, tokenizer)

print("MAX LayerNorm + MAX NewGELU Results:")
max_full_responses = []
for i, user_input in enumerate(test_inputs, 1):
    response = bot_max_newgelu.generate_response(user_input)
    max_full_responses.append(response)
    print(f"Turn {i}")
    print(f"User: {user_input}")
    print(f"Bot: {response}")
    print("=" * 80)

MAX LayerNorm + MAX NewGELU Results:
Turn 1
User: Hello, how are you?
Bot: I'm good, thanks. How about you?
Turn 2
User: What's your favorite programming language?
Bot: Java, Python, C, and C.
Turn 3
User: Tell me about artificial intelligence
Bot: It's a thing. It's called AI. I don't know what it is, but it's something. And it can be used to make things. Like computers. Or robots. Whatever. You want it to be. :D
Turn 4
User: What's the meaning of life?
Bot: Life is a computer program that makes things
Turn 5
User: Tell me a joke
Bot: A joke!


## Surgically replace the whole `GPT2MLP` layer with MAX

Next, we implement the `GPT2MLP` layer in MAX and replace all the `GPT2MLP` layers from PyTorch model and get benefit of MAX Graph Compiler Optimizations. 

In [15]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): MaxNewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)

In [16]:
import max.torch
import torch.nn as nn
from max.graph import TensorValue, ops

@max.torch.graph_op
def max_gpt2_mlp(
    x: TensorValue,
    c_fc_weight: TensorValue,
    c_fc_bias: TensorValue,
    c_proj_weight: TensorValue,
    c_proj_bias: TensorValue,
):
    """
    GPT2MLP using MAX graph operations to replicate HuggingFace Conv1D behavior
    Architecture: Conv1D(nx=1024, nf=4096) -> NewGELU -> Conv1D(nx=4096, nf=1024) -> Dropout
    Note: HuggingFace Conv1D is actually a linear layer with transposed weights
    """
    # c_fc: Conv1D expansion (nx=1024 -> nf=4096)
    expanded = x @ c_fc_weight + c_fc_bias

    # NewGELU activation (using built-in GELU with tanh approximation)
    gelu_output = ops.gelu(expanded, approximate="tanh")

    # c_proj: Conv1D projection (nx=4096 -> nf=1024)
    return gelu_output @ c_proj_weight + c_proj_bias


class MaxGPT2MLP(nn.Module):
    def __init__(self, c_fc, c_proj):
        super().__init__()

        # Copy weights from original Conv1D layers
        # HuggingFace Conv1D weight shape: (nx, nf) - already correct for matmul
        # For inference, use regular tensors (no gradients needed)

        # c_fc: Conv1D(nx=1024, nf=4096) - expand hidden size
        # No transpose needed! HuggingFace Conv1D weights are already in matmul-ready shape
        self.c_fc_weight = c_fc.weight.clone()  # (1024, 4096)
        self.c_fc_bias = c_fc.bias.clone()

        # c_proj: Conv1D(nx=4096, nf=1024) - project back to hidden size
        self.c_proj_weight = c_proj.weight.clone()  # (4096, 1024)
        self.c_proj_bias = c_proj.bias.clone()

        # No dropout, we're only doing inference here

    def forward(self, x):
        # Create output tensor with same shape as input (1024 -> 4096 -> 1024)
        # Input shape: (batch_size, seq_len, 1024)
        # Output shape: (batch_size, seq_len, 1024)
        output = torch.empty_like(x)

        # Use MAX graph operation with destination-passing style
        torch.compile(max_gpt2_mlp)(
            output,
            x,
            self.c_fc_weight,
            self.c_fc_bias,
            self.c_proj_weight,
            self.c_proj_bias,
        )

        return output


def replace_gpt2_mlp_with_max(model):
    mlp_replaced_count = 0

    @torch.no_grad()
    def _replace_gpt2_mlp_with_max(parent):
        """Replace GPT2MLP layers with MAX Linear implementation that replicates Conv1D behavior"""
        nonlocal mlp_replaced_count
        for name, module in parent.named_children():
            if (
                "mlp" in name.lower()
                and hasattr(module, "c_fc")
                and hasattr(module, "c_proj")
            ):
                print(f"Found GPT2MLP: {name}")
                print(
                    f"c_fc shape: {module.c_fc.weight.shape} (nx={module.c_fc.weight.shape[0]}, nf={module.c_fc.weight.shape[1]})"
                )
                print(
                    f"c_proj shape: {module.c_proj.weight.shape} (nx={module.c_proj.weight.shape[0]}, nf={module.c_proj.weight.shape[1]})"
                )
    
                max_mlp = MaxGPT2MLP(module.c_fc, module.c_proj)
                setattr(parent, name, max_mlp)
                mlp_replaced_count += 1
                print("-> Replaced with MAX implementation")
    
    model.apply(_replace_gpt2_mlp_with_max)
    model.to(device)
    return mlp_replaced_count

mlp_replaced_count = replace_gpt2_mlp_with_max(model)
print(f"Replaced {mlp_replaced_count} GPT2MPL layers")

Found GPT2MLP: mlp
c_fc shape: torch.Size([1024, 4096]) (nx=1024, nf=4096)
c_proj shape: torch.Size([4096, 1024]) (nx=4096, nf=1024)
-> Replaced with MAX implementation
Found GPT2MLP: mlp
c_fc shape: torch.Size([1024, 4096]) (nx=1024, nf=4096)
c_proj shape: torch.Size([4096, 1024]) (nx=4096, nf=1024)
-> Replaced with MAX implementation
Found GPT2MLP: mlp
c_fc shape: torch.Size([1024, 4096]) (nx=1024, nf=4096)
c_proj shape: torch.Size([4096, 1024]) (nx=4096, nf=1024)
-> Replaced with MAX implementation
Found GPT2MLP: mlp
c_fc shape: torch.Size([1024, 4096]) (nx=1024, nf=4096)
c_proj shape: torch.Size([4096, 1024]) (nx=4096, nf=1024)
-> Replaced with MAX implementation
Found GPT2MLP: mlp
c_fc shape: torch.Size([1024, 4096]) (nx=1024, nf=4096)
c_proj shape: torch.Size([4096, 1024]) (nx=4096, nf=1024)
-> Replaced with MAX implementation
Found GPT2MLP: mlp
c_fc shape: torch.Size([1024, 4096]) (nx=1024, nf=4096)
c_proj shape: torch.Size([4096, 1024]) (nx=4096, nf=1024)
-> Replaced with MAX i

In [17]:
bot_max_full = DialogueBot(model, tokenizer)

print("MAX LayerNorm + MAX NewGELU + New GPT2MLP Results:")
max_full_responses = []
for i, user_input in enumerate(test_inputs, 1):
    response = bot_max_full.generate_response(user_input)
    max_full_responses.append(response)
    print(f"Turn {i}")
    print(f"User: {user_input}")
    print(f"Bot: {response}")
    print("=" * 80)

MAX LayerNorm + MAX NewGELU + New GPT2MLP Results:
Turn 1
User: Hello, how are you?
Bot: I'm good, thanks. How about you?
Turn 2
User: What's your favorite programming language?
Bot: Java, Python, C, and C.
Turn 3
User: Tell me about artificial intelligence
Bot: It's a thing. It's called AI. I don't know what it is, but it's something. And it can be used to make things. Like computers. Or robots. Whatever. You want it to be. :D
Turn 4
User: What's the meaning of life?
Bot: Life is a computer program that makes things
Turn 5
User: Tell me a joke
Bot: A joke!


In [18]:
print("Stage 1: Original PyTorch model")
print(f"Stage 2: Replaced {replaced_count} LayerNorm operations with MAX")
print(f"Stage 3: Replaced additional {gelu_replaced_count} NewGELUActivation operations with MAX")
print(f"Stage 4: Replaced additional {mlp_replaced_count} with custom GPT2MLP layer")
print(f"Total replacements with MAX: {replaced_count + gelu_replaced_count + mlp_replaced_count}")

Stage 1: Original PyTorch model
Stage 2: Replaced 49 LayerNorm operations with MAX
Stage 3: Replaced additional 24 NewGELUActivation operations with MAX
Stage 4: Replaced additional 24 with custom GPT2MLP layer
Total replacements with MAX: 97
