In [2]:
import torch
import torch.nn as nn
from model_splitter import ModelSplitter

# --- Models ---

class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(20, 30),
            nn.ReLU(),
            nn.Linear(30, 10)
        )
    def forward(self, x):
        return self.net(x)

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 4, 3, padding=1),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(4, 5)
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class SimpleRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(10, 20, batch_first=True)
        self.fc = nn.Linear(20, 5)
    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(out[:, -1])

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads=2, mlp_ratio=2):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(dim * mlp_ratio, dim),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class MiniViT(nn.Module):
    def __init__(self, image_size=28, patch_size=7, dim=64, depth=2, num_classes=10):
        super().__init__()
        self.patch_embed = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (image_size // patch_size) ** 2

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))

        self.blocks = nn.Sequential(*[TransformerBlock(dim) for _ in range(depth)])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)  # [B, dim, H', W']
        x = x.flatten(2).transpose(1, 2)  # [B, N, dim]

        cls = self.cls_token.expand(B, -1, -1)  # [B, 1, dim]
        x = torch.cat([cls, x], dim=1) + self.pos_embed  # [B, N+1, dim]

        x = self.blocks(x)
        x = self.norm(x[:, 0])  # take cls token
        return self.head(x)

# --- Inputs ---
nn_in = torch.randn(5, 20)
cnn_in = torch.randn(5, 1, 8, 8)
rnn_in = torch.randn(5, 7, 10)
vit_in = torch.randn(5, 1, 28, 28)  # corrected for 28x28 patching

# --- Test function ---
def test_model(model, sample_input, name):
    print(f"\nTesting {name}:")
    splitter = ModelSplitter(model, sample_input)
    print("Legal split points:", sorted(splitter.legal_splits))

    safe_points = splitter.get_safe_split_points()
    print("Safe split points:", sorted(safe_points))

    for point in safe_points:
        print(f"  Splitting at {point} ... ", end="")
        try:
            _, _ = splitter.split_from_param(point)
            print("Success.")
        except Exception as e:
            print(f"Failed. Error: {e}")


# --- Run tests ---
test_model(SimpleNN(), nn_in, "SimpleNN")
test_model(SimpleCNN(), cnn_in, "SimpleCNN")
test_model(SimpleRNN(), rnn_in, "SimpleRNN")
test_model(MiniViT(), vit_in, "MiniViT")



Testing SimpleNN:
Legal split points: ['net.0.bias', 'net.0.weight', 'net.2.bias', 'net.2.weight']
Safe split points: ['net.0.bias', 'net.0.weight', 'net.2.bias', 'net.2.weight']
  Splitting at net.0.weight ... Success.
  Splitting at net.2.weight ... Success.
  Splitting at net.0.bias ... Success.
  Splitting at net.2.bias ... Success.

Testing SimpleCNN:
Legal split points: ['conv.0.bias', 'conv.0.weight', 'conv.2.bias', 'conv.2.weight', 'fc.bias', 'fc.weight']
Safe split points: ['fc.bias', 'fc.weight']
  Splitting at fc.bias ... Success.
  Splitting at fc.weight ... Success.

Testing SimpleRNN:
Legal split points: ['fc.bias', 'fc.weight', 'lstm.bias_hh_l0', 'lstm.bias_ih_l0', 'lstm.weight_hh_l0', 'lstm.weight_ih_l0']
Safe split points: []

Testing MiniViT:
Legal split points: ['blocks.0.mlp.0.bias', 'blocks.0.mlp.0.weight', 'blocks.0.mlp.2.bias', 'blocks.0.mlp.2.weight', 'blocks.0.norm1.bias', 'blocks.0.norm1.weight', 'blocks.0.norm2.bias', 'blocks.0.norm2.weight', 'blocks.1.mlp.0

In [3]:
import torch
import torch.nn as nn
from model_splitter_vit import ViT32BModelSplitter

# Example minimal ViT32B model for testing
class MiniViT32B(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = nn.Linear(16, 64)  # e.g. 4 patches x 16 dim -> 64 dim
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 64))
        self.blocks = nn.ModuleList([nn.Sequential(
            nn.LayerNorm(64),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        ) for _ in range(4)])
        self.norm = nn.LayerNorm(64)
        self.head = nn.Linear(64, 10)

    def forward(self, x):
        x = self.patch_embed(x)
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        x = x[:, 0]
        x = self.head(x)
        return x


# Setup model and sample input
model = MiniViT32B().eval()
sample_input = torch.randn(20, 4, 16)  # batch 20, 4 patches, dim 16

# Initialize splitter
splitter = ViT32BModelSplitter(model, sample_input)

print("Available split points:")
for sp in splitter.available_split_points():
    print(sp)

# Test splitting at each split point
for sp in splitter.available_split_points():
    print(f"\nTesting split at: {sp}")
    part1_out, part2_model = splitter.split_from_param(sp)
    print(f"Split successful. Output shape part1: {part1_out.shape}")


Available split points:
patch_embed.weight
patch_embed.bias
blocks.0.0.weight
blocks.0.0.bias
blocks.0.1.weight
blocks.0.1.bias
blocks.0.3.weight
blocks.0.3.bias
blocks.1.0.weight
blocks.1.0.bias
blocks.1.1.weight
blocks.1.1.bias
blocks.1.3.weight
blocks.1.3.bias
blocks.2.0.weight
blocks.2.0.bias
blocks.2.1.weight
blocks.2.1.bias
blocks.2.3.weight
blocks.2.3.bias
blocks.3.0.weight
blocks.3.0.bias
blocks.3.1.weight
blocks.3.1.bias
blocks.3.3.weight
blocks.3.3.bias
norm.weight
norm.bias

Testing split at: patch_embed.weight
Split successful. Output shape part1: torch.Size([20, 5, 64])

Testing split at: patch_embed.bias
Split successful. Output shape part1: torch.Size([20, 5, 64])

Testing split at: blocks.0.0.weight
Split successful. Output shape part1: torch.Size([20, 5, 64])

Testing split at: blocks.0.0.bias
Split successful. Output shape part1: torch.Size([20, 5, 64])

Testing split at: blocks.0.1.weight
Split successful. Output shape part1: torch.Size([20, 5, 64])

Testing split at