In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import time

# Synthetic 2D classification dataset
def generate_data(n_samples=1000):
    X = torch.randn(n_samples, 2)
    y = (X[:, 0] * X[:, 1] > 0).long()
    return X, y

# Simple MLP for binary classification
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 2)
        )

    def forward(self, x):
        return self.net(x)

# Prepare data
X, y = generate_data()
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32)


In [2]:
def train(model, dataloader, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    for _ in range(epochs):
        for xb, yb in dataloader:
            pred = model(xb)
            loss = criterion(pred, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

model = SimpleNet()
train(model, dataloader)


In [3]:
# Put model in eval mode
model.eval()

# Disable gradient computation for inference
with torch.no_grad():
    start = time.time()
    outputs = model(X)
    end = time.time()

# Process outputs
logits = outputs
preds = torch.argmax(logits, dim=1)

# Baseline metrics
print(f"Inference time: {end - start:.6f} seconds")
print(f"Sample logits:\n{logits[:5]}")
print(f"Sample predictions:\n{preds[:5]}")


Inference time: 0.000901 seconds
Sample logits:
tensor([[ 6.9300, -5.3727],
        [-0.1109,  0.0112],
        [-4.4105,  4.7811],
        [-0.7808,  0.7385],
        [ 1.3083, -1.2047]])
Sample predictions:
tensor([0, 1, 1, 1, 0])


In [9]:
# Helper: get modules in forward order
def get_forward_modules(model):
    if isinstance(model, nn.Sequential):
        return list(model._modules.items())
    else:
        raise NotImplementedError("This prototype assumes nn.Sequential-based models.")
        
modules = get_forward_modules(model.net)
print("Ordered modules:")
for name, layer in modules:
    print(name, "->", layer)


Ordered modules:
0 -> Linear(in_features=2, out_features=16, bias=True)
1 -> ReLU()
2 -> Linear(in_features=16, out_features=2, bias=True)


In [10]:

# See what parameters are available
for name, param in model.named_parameters():
    print(name)


net.0.weight
net.0.bias
net.2.weight
net.2.bias


In [11]:
# Find split point from parameter name
def split_model_at_param(model, param_name):
    modules = get_forward_modules(model.net)
    
    # Flatten layers into parameter lookup
    param_to_module_idx = {}
    for idx, (mod_name, layer) in enumerate(modules):
        for pname, _ in layer.named_parameters():
            full_name = f"net.{mod_name}.{pname}"
            param_to_module_idx[full_name] = idx

    # Get split index
    split_idx = param_to_module_idx[param_name]

    # Create submodels
    layers1 = nn.Sequential(*[layer for _, layer in modules[:split_idx]])
    layers2 = nn.Sequential(*[layer for _, layer in modules[split_idx:]])

    return layers1, layers2

# Example split
param_to_split = "net.2.weight"
model1, model2 = split_model_at_param(model, param_to_split)


In [12]:
model.eval()
model1.eval()
model2.eval()

with torch.no_grad():
    full_out = model(X)
    part1_out = model1(X)
    part2_out = model2(part1_out)

# Check match
print(torch.allclose(full_out, part2_out, atol=1e-6))  # Should be True


True


In [1]:
# In a Jupyter cell
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from model_splitter import ModelSplitter

# Define toy model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 2)
        )

    def forward(self, x):
        return self.net(x)

# Generate toy data
def generate_data(n=1000):
    X = torch.randn(n, 2)
    y = (X[:, 0] * X[:, 1] > 0).long()
    return X, y

X, y = generate_data()
model = SimpleNet()

# Quick train loop
def train(model, X, y, epochs=3):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(epochs):
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

train(model, X, y)

# Initialize splitter
splitter = ModelSplitter(model, X)

# Split from a layer (e.g., after first Linear)
param_name = "net.2.weight"  # start of second Linear layer
part1_out, model2 = splitter.split_from_param(param_name)

# Optional: demonstrate new usage
with torch.no_grad():
    out_new = model2(part1_out)
    print(torch.allclose(out_new, splitter.baseline_output, atol=1e-6))  # True


True


In [2]:
import torch
import torch.nn as nn
import time
from model_splitter import ModelSplitter

# Define a deeper model
class LargeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, x):
        return self.net(x)

# Generate input
X_large = torch.randn(10_000, 2)  # Large batch

# Initialize and warm up model
model_large = LargeNet()
model_large.eval()

# Warm-up pass
with torch.no_grad():
    _ = model_large(X_large)

# Baseline full inference time
with torch.no_grad():
    start = time.time()
    baseline_output = model_large(X_large)
    end = time.time()

print(f"Full model inference time: {end - start:.6f} sec")

# Split and time only second half
splitter = ModelSplitter(model_large, X_large)
part1_out, model2 = splitter.split_from_param("net.4.weight")  # Midpoint layer

# Time only model2 inference
with torch.no_grad():
    start = time.time()
    output_partial = model2(part1_out)
    end = time.time()

print(f"Model2 (second half) inference time: {end - start:.6f} sec")
print("Output identical to baseline:", torch.allclose(output_partial, baseline_output, atol=1e-6))


Full model inference time: 0.011029 sec
Model2 (second half) inference time: 0.007837 sec
Output identical to baseline: True


In [None]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader
from model_splitter import ModelSplitter
import time

# --- 1. Minimal ViT block ---
class TinyViT(nn.Module):
    def __init__(self, img_size=28, patch_size=7, emb_dim=64, depth=2, n_heads=4, n_classes=10):
        super().__init__()
        assert img_size % patch_size == 0
        self.n_patches = (img_size // patch_size) ** 2
        self.patch_dim = 1 * patch_size * patch_size  # for grayscale input

        self.patch_embed = nn.Linear(self.patch_dim, emb_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, emb_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, n_classes)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        p = int((H * W) // self.n_patches)**0.5
        x = x.unfold(2, int(p), int(p)).unfold(3, int(p), int(p))  # (B, C, nH, nW, p, p)
        x = x.contiguous().view(B, C, -1, int(p), int(p)).permute(0, 2, 1, 3, 4)
        x = x.reshape(B, self.n_patches, -1)  # (B, n_patches, patch_dim)

        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x += self.pos_embed

        x = self.transformer(x)
        cls_out = x[:, 0]
        return self.mlp_head(cls_out)

# --- 2. Load MNIST ---
transform = Compose([ToTensor()])
mnist = MNIST(root="./data", train=False, download=True, transform=transform)
loader = DataLoader(mnist, batch_size=64, shuffle=False)
images, labels = next(iter(loader))

# --- 3. Create and run model ---
model = TinyViT()
model.eval()

with torch.no_grad():
    start = time.time()
    full_out = model(images)
    end = time.time()
print(f"Full TinyViT inference time: {end - start:.6f} sec")

# --- 4. Use ModelSplitter ---
splitter = ModelSplitter(model, images)

# Check available parameter names
# for name, _ in model.named_parameters(): print(name)

part1_out, model2 = splitter.split_from_param("mlp_head.1.weight")

# --- 5. Partial inference ---
with torch.no_grad():
    start = time.time()
    part_out = model2(part1_out)
    end = time.time()
print(f"TinyViT head-only inference time: {end - start:.6f} sec")
print("Output matches:", torch.allclose(part_out, full_out, atol=1e-6))


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to C:\Users\wwden/.cache\torch\hub\checkpoints\vit_b_16-c867db91.pth
  9%|▊         | 28.8M/330M [02:05<21:56, 240kB/s]


KeyboardInterrupt: 