In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:45<00:00, 215805.22it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 118518.45it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:09<00:00, 173388.43it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6446879.45it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [6]:
print('Trainset size:', len(trainset))
print('Testset size:', len(testset))
print(trainset[0][0].shape)
print(testset[0][0].shape)

Trainset size: 60000
Testset size: 10000
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])


In [23]:
class OriginalModel(nn.Module):
    def __init__(self):
        super(OriginalModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(32 * 14 * 14, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x


In [26]:
class FusedModel(nn.Module):
    def __init__(self, original_model):
        super(FusedModel, self).__init__()
        self.conv1 = self.fuse_conv_bn(original_model.conv1, original_model.bn1)
        self.relu1 = nn.ReLU()
        self.conv2 = self.fuse_conv_bn(original_model.conv2, original_model.bn2)
        self.relu2 = nn.ReLU()
        self.fc1 = original_model.fc1
        self.relu3 = nn.ReLU()
        self.fc2 = original_model.fc2
        
    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def fuse_conv_bn(self, conv, bn):
        with torch.no_grad():
            # Initialize a new convolutional layer with the same parameters
            fused_conv = nn.Conv2d(conv.in_channels,
                                   conv.out_channels,
                                   kernel_size=conv.kernel_size,
                                   stride=conv.stride,
                                   padding=conv.padding,
                                   bias=True)
            
            # Extract the batch normalization parameters
            gamma = bn.weight
            beta = bn.bias
            mean = bn.running_mean
            var = bn.running_var
            eps = bn.eps

            # Extract the convolutional parameters
            W = conv.weight
            if conv.bias is None:
                b = torch.zeros_like(mean)
            else:
                b = conv.bias
            
            print(f"Conv weight shape: {W.shape}")
            print(f"BN gamma shape: {gamma.shape}")
            print(f"BN beta shape: {beta.shape}")
            print(f"BN mean shape: {mean.shape}")
            print(f"BN var shape: {var.shape}")

            # Reshape the batch normalization parameters for broadcasting
            gamma = gamma.view(-1, 1, 1, 1)
            beta = beta.view(-1)
            mean = mean.view(-1, 1, 1, 1)
            var = var.view(-1, 1, 1, 1)

            # Fuse the weights and biases
            W_fused = W * (gamma / torch.sqrt(var + eps))
            b_fused = beta + (b - mean.squeeze()) * (gamma.squeeze() / torch.sqrt(var.squeeze() + eps))

            # Copy the fused parameters to the new convolutional layer
            fused_conv.weight.copy_(W_fused)
            fused_conv.bias.copy_(b_fused)
            print(f"Fused weight shape: {fused_conv.weight.shape}")
            print(f"Fused bias shape: {fused_conv.bias.shape}")
        
        return fused_conv

In [28]:


# Training and evaluation functions
def train(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(trainloader)

def evaluate(model, testloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# Measure inference time
def measure_inference_time(model, input_tensor, iterations=100):
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        for _ in range(iterations):
            _ = model(input_tensor)
        end_time = time.time()
    return (end_time - start_time) / iterations

# Main script
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate models
original_model = OriginalModel().to(device)
fused_model = FusedModel(original_model).to(device)

# Loss function and optimizers
criterion = nn.CrossEntropyLoss()
optimizer_original = optim.Adam(original_model.parameters(), lr=0.001)
optimizer_fused = optim.Adam(fused_model.parameters(), lr=0.001)

# Training loop
num_epochs = 5

print("Training original model...")
for epoch in range(num_epochs):
    train_loss = train(original_model, trainloader, criterion, optimizer_original, device)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}")

print("Training fused model...")
for epoch in range(num_epochs):
    train_loss = train(fused_model, trainloader, criterion, optimizer_fused, device)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}")

# Evaluate models
original_accuracy = evaluate(original_model, testloader, device)
fused_accuracy = evaluate(fused_model, testloader, device)
print(f"Original model accuracy: {original_accuracy:.2f}%")
print(f"Fused model accuracy: {fused_accuracy:.2f}%")

# Generate random input for inference time measurement
input_tensor = torch.randn(1, 1, 28, 28).to(device)

# Measure inference time for original model
original_time = measure_inference_time(original_model, input_tensor)
print(f"Original model inference time: {original_time:.6f} seconds")

# Measure inference time for fused model
fused_time = measure_inference_time(fused_model, input_tensor)
print(f"Fused model inference time: {fused_time:.6f} seconds")

# Verify outputs are the same on a batch
input_tensor, _ = next(iter(testloader))
input_tensor = input_tensor.to(device)
original_output = original_model(input_tensor)
fused_output = fused_model(input_tensor)



Conv weight shape: torch.Size([16, 1, 3, 3])
BN gamma shape: torch.Size([16])
BN beta shape: torch.Size([16])
BN mean shape: torch.Size([16])
BN var shape: torch.Size([16])
Fused weight shape: torch.Size([16, 1, 3, 3])
Fused bias shape: torch.Size([16])
Conv weight shape: torch.Size([32, 16, 3, 3])
BN gamma shape: torch.Size([32])
BN beta shape: torch.Size([32])
BN mean shape: torch.Size([32])
BN var shape: torch.Size([32])
Fused weight shape: torch.Size([32, 16, 3, 3])
Fused bias shape: torch.Size([32])
Training original model...


Epoch 1/5, Loss: 0.1508
Epoch 2/5, Loss: 0.0489
Epoch 3/5, Loss: 0.0329
Epoch 4/5, Loss: 0.0254
Epoch 5/5, Loss: 0.0206
Training fused model...
Epoch 1/5, Loss: 0.0738
Epoch 2/5, Loss: 0.0226
Epoch 3/5, Loss: 0.0145
Epoch 4/5, Loss: 0.0114
Epoch 5/5, Loss: 0.0088
Original model accuracy: 98.66%
Fused model accuracy: 98.98%
Original model inference time: 0.000458 seconds
Fused model inference time: 0.000372 seconds
