<a href="https://colab.research.google.com/github/zhangou888/NN/blob/main/Full%20code_80_10_10_customized_NN_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Full Code with 80/10/10 Custom Split for Neural Network using PyTorch

## Summary of Key Changes:
- Custom 80/10/10 split: Replaces the built-in test set with split from full dataset
- random_split : used	Ensures samples are randomized and balanced
- Test Evaluation:	Now runs on new held-out 10% test split

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt

In [3]:
# Step 2: Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [4]:
# Step 3: Define CNN
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [5]:
# Step 4: Load MNIST and create 80/10/10 split from the full train set
transform = transforms.ToTensor()
full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

total_size = len(full_dataset)  # 60,000
train_size = int(0.8 * total_size)              # 48,000
val_size = int(0.1 * total_size)                # 6,000
test_size = total_size - train_size - val_size  # 6,000

train_data, val_data, test_data = random_split(full_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1000)
test_loader = DataLoader(test_data, batch_size=1000)


100%|██████████| 9.91M/9.91M [00:00<00:00, 138MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 31.2MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 91.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.01MB/s]


In [6]:
# Step 5: Initialize model
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Step 6: Training with validation
epochs = 5
for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    # Validation
    model.eval()
    total_val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    val_acc = 100 * correct / total
    print(f"Epoch {epoch+1} | Train Loss: {total_train_loss:.4f} | Val Loss: {total_val_loss:.4f} | Val Acc: {val_acc:.2f}%")


In [None]:
# Step 7: Final test evaluation (on held-out test split)
model.eval()
correct = 0
total = 0
all_preds = []
all_images = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        all_preds.extend(predicted.cpu().numpy())
        all_images.extend(images.cpu().numpy())

print(f"Test Accuracy on held-out test set: {100 * correct / total:.2f}%")

In [None]:
# Step 8: Save model
torch.save(model.state_dict(), "mnist_cnn_80_10_10.pth")
print("Model saved as mnist_cnn_80_10_10.pth")

# Step 9: Reload model (optional)
model_loaded = CNN().to(device)
model_loaded.load_state_dict(torch.load("mnist_cnn_80_10_10.pth"))
model_loaded.eval()

In [None]:
# Step 10: Show predictions
def show_predictions(images, preds, n=6):
    plt.figure(figsize=(12, 2))
    for i in range(n):
        plt.subplot(1, n, i+1)
        plt.imshow(images[i].squeeze(), cmap="gray")
        plt.title(f"Pred: {preds[i]}")
        plt.axis("off")
    plt.show()

show_predictions(all_images, all_preds)