In [3]:
import torch
import torch.nn as nn
from model_splitter import ModelSplitter

# --- Models ---

class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(20, 30),
            nn.ReLU(),
            nn.Linear(30, 10)
        )
    def forward(self, x):
        return self.net(x)

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 4, 3, padding=1),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.fc = nn.Linear(4, 5)
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class SimpleRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(10, 20, batch_first=True)
        self.fc = nn.Linear(20, 5)
    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(out[:, -1])

class TinyViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = nn.Linear(4, 32)  # corrected input dim to 4
        self.cls_token = nn.Parameter(torch.zeros(1,1,32))
        self.pos_embed = nn.Parameter(torch.zeros(1, 5, 32))
        encoder_layer = nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.mlp_head = nn.Sequential(nn.LayerNorm(32), nn.Linear(32, 3))

    def forward(self, x):
        B = x.size(0)
        x = x.view(B, 4, 4)  # 4 patches, each patch dim 4
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.transformer(x)
        return self.mlp_head(x[:, 0])

# --- Test function ---

# --- Sample inputs ---

nn_in = torch.randn(5, 20)
cnn_in = torch.randn(5, 1, 8, 8)
rnn_in = torch.randn(5, 7, 10)
vit_in = torch.randn(5, 1, 4, 4)

def test_model(model, sample_input, name):
    print(f"\nTesting {name}:")
    splitter = ModelSplitter(model, sample_input)
    print("Legal split points:", sorted(splitter.legal_splits))
    # We test on all legal splits, regardless of safe check
    for param_name in sorted(splitter.legal_splits):
        print(f"  Trying to split at {param_name} ...", end="")
        try:
            part1_out, model2 = splitter.split_from_param(param_name)
            print(" Success.")
        except Exception as e:
            print(f" Failed. Error: {e}")

# Then run as before:
test_model(SimpleNN(), nn_in, "SimpleNN")
test_model(SimpleCNN(), cnn_in, "SimpleCNN")
test_model(SimpleRNN(), rnn_in, "SimpleRNN")
test_model(TinyViT(), vit_in, "TinyViT")




Testing SimpleNN:
Legal split points: ['net.0.bias', 'net.0.weight', 'net.2.bias', 'net.2.weight']
  Trying to split at net.0.bias ... Success.
  Trying to split at net.0.weight ... Success.
  Trying to split at net.2.bias ... Success.
  Trying to split at net.2.weight ... Success.

Testing SimpleCNN:
Legal split points: ['conv.0.bias', 'conv.0.weight', 'conv.2.bias', 'conv.2.weight', 'fc.bias', 'fc.weight']
  Trying to split at conv.0.bias ... Failed. Error: mat1 and mat2 shapes cannot be multiplied (20x1 and 4x5)
  Trying to split at conv.0.weight ... Failed. Error: mat1 and mat2 shapes cannot be multiplied (20x1 and 4x5)
  Trying to split at conv.2.bias ... Failed. Error: mat1 and mat2 shapes cannot be multiplied (20x1 and 4x5)
  Trying to split at conv.2.weight ... Failed. Error: mat1 and mat2 shapes cannot be multiplied (20x1 and 4x5)
  Trying to split at fc.bias ... Failed. Error: mat1 and mat2 shapes cannot be multiplied (20x1 and 4x5)
  Trying to split at fc.weight ... Failed.