In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [3]:
class TwistedMNISTModel(nn.Module):
    def __init__(self):
        super(TwistedMNISTModel, self).__init__()
        
        # First path (main)
        self.fc1 = nn.Linear(784, 64)
        self.fc2 = nn.Linear(64, 32)

        # Second path (skip connection)
        self.skip_fc1 = nn.Linear(784, 32)

        # Processing after concatenation
        self.concat_fc = nn.Linear(64, 128)
        self.skip_fc2 = nn.Linear(32, 128)

        # Final output layer
        self.final_fc = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (28x28 -> 784)

        # First path (main)
        z = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(z))

        # Second path (skip connection)
        u = torch.relu(self.skip_fc1(x))

        # Concatenation
        q = torch.cat((h, u), dim=1)

        # Processing concatenated output
        v = torch.relu(self.concat_fc(q))

        # Further processing of skip connection
        k = torch.relu(self.skip_fc2(u))

        # Add the processed outputs
        d = v + k

        # Final output layer
        hat_y = self.final_fc(d)
        
        return hat_y

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
B = 64  # Batch size
H, W = 28, 28  # MNIST image dimensions
reduced_dim = 16  # Reduced row dimension
seq_output_dim = 10  # Final output dimension (for classification)
lr = 0.001
epochs = 30

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=B, shuffle=False)

In [6]:
model = TwistedMNISTModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [8]:
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/len(train_dataset):.4f}")


Epoch 1/30, Loss: 0.0362, Accuracy: 0.9879
Epoch 2/30, Loss: 0.0303, Accuracy: 0.9900
Epoch 3/30, Loss: 0.0265, Accuracy: 0.9912
Epoch 4/30, Loss: 0.0239, Accuracy: 0.9914
Epoch 5/30, Loss: 0.0243, Accuracy: 0.9917
Epoch 6/30, Loss: 0.0197, Accuracy: 0.9936
Epoch 7/30, Loss: 0.0197, Accuracy: 0.9932
Epoch 8/30, Loss: 0.0188, Accuracy: 0.9937
Epoch 9/30, Loss: 0.0167, Accuracy: 0.9945
Epoch 10/30, Loss: 0.0192, Accuracy: 0.9934
Epoch 11/30, Loss: 0.0146, Accuracy: 0.9952
Epoch 12/30, Loss: 0.0157, Accuracy: 0.9949
Epoch 13/30, Loss: 0.0128, Accuracy: 0.9959
Epoch 14/30, Loss: 0.0153, Accuracy: 0.9951
Epoch 15/30, Loss: 0.0173, Accuracy: 0.9944
Epoch 16/30, Loss: 0.0124, Accuracy: 0.9958
Epoch 17/30, Loss: 0.0131, Accuracy: 0.9959
Epoch 18/30, Loss: 0.0145, Accuracy: 0.9952
Epoch 19/30, Loss: 0.0100, Accuracy: 0.9967
Epoch 20/30, Loss: 0.0173, Accuracy: 0.9945
Epoch 21/30, Loss: 0.0108, Accuracy: 0.9968
Epoch 22/30, Loss: 0.0109, Accuracy: 0.9962


KeyboardInterrupt: 

In [9]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        logits = model(inputs)
        _, predicted = torch.max(logits, 1)
        correct += (predicted == targets).sum().item()
        total += targets.size(0)
accuracy = 100 * correct / total

print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 97.92%
