# Shape-Net

A simple neural network that can distinguish shapes.

## Binary Shape Net
First, we create a simple nueral net that can distinguish L-shapes from non L-shapes in 3x3 grayscale images.

### 📐 Data Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define dataset
L_shapes = [
    torch.tensor([[1,0,0],
                  [1,0,0],
                  [1,1,0]]),  # L
    torch.tensor([[1,1,1],
                  [1,0,0],
                  [0,0,0]]),  # rotated L
    torch.tensor([[0,0,1],
                  [0,0,1],
                  [0,1,1]]),  # another rotation
    torch.tensor([[0,0,0],
                  [0,0,1],
                  [1,1,1]]),  # another rotation
]

non_L_shapes = [
    torch.tensor([[1,1,0],
                  [1,1,0],
                  [0,0,0]]),  # square
    torch.tensor([[1,0,0],
                  [1,1,0],
                  [0,0,0]]),  # corner
    torch.tensor([[1,1,1],
                  [0,1,0],
                  [0,0,0]]),  # T-shape
    torch.tensor([[0,1,0],
                  [0,1,0],
                  [0,1,0]]),  # column
    torch.tensor([[1,1,1],
                  [0,0,0],
                  [0,0,0]]),  # sideways column
    torch.tensor([[0,0,1],
                  [0,0,1],
                  [0,0,1]]),  # column
    torch.tensor([[1,0,0],
                  [1,1,0],
                  [1,0,0]]),  # T
    torch.tensor([[1,0,0],
                  [0,0,0],
                  [1,1,0]]),  # Near-L
]

X = torch.stack([x.float() for x in L_shapes + non_L_shapes]).unsqueeze(1)  # Add channel dimension (1 for grayscale) -> (B, 1, H, W)
y = torch.tensor([1]*len(L_shapes) + [0]*len(non_L_shapes)).float().unsqueeze(1)

### 🧠 Model

In [None]:
from typing import Any


class ShapeNet(nn.Module):
    def __init__(self):
        super(ShapeNet, self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=2)  # 4 filters of size 2x2. One input channel (grayscale)
        self.fc = nn.Linear(4 * 2 * 2, 1)  # Flattened 4 channels × 2×2 patch  1 output

    def forward(self, x) -> torch.Tensor:
        x = torch.relu(self.conv(x))  # (B, 4, 2, 2)
        x = x.view(x.size(0), -1)  # Flatten to (B, 16)
        x = torch.sigmoid(self.fc(x))
        return x

    # This is only done to force intelliSense to recognize the return type as torch.Tensor
    def __call__(self, *args: Any, **kwds: Any) -> torch.Tensor:
        x = super().__call__(*args, **kwds)
        if isinstance(x, torch.Tensor):
            return x
        raise TypeError(f"Expected torch.Tensor, got {type(x)}")

### ⚙️ Training

In [None]:
model = ShapeNet()
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1000):
    y_pred = model(X)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("Final predictions:", y_pred.detach().round().squeeze())

### 🧪 Create Test Set

In [None]:
# L-shaped (positive class)
test_L = [
    torch.tensor([[1,0,0],
                  [1,0,0],
                  [1,1,0]]),  # L
    torch.tensor([[0,0,0],
                  [0,0,1],
                  [1,1,1]]),  # rotated L
]

# Non-L (negative class)
test_non_L = [
    torch.tensor([[1,1,0],
                  [0,1,1],
                  [0,0,0]]),  # diagonal blob
    torch.tensor([[1,0,0],
                  [1,0,0],
                  [1,0,0]]),  # column
]

# Format test set
X_test = torch.stack([x.float() for x in test_L + test_non_L]).unsqueeze(1)  # Add channel dimension (1 for grayscale) -> (B, 1, H, W)
y_test = torch.tensor([1]*len(test_L) + [0]*len(test_non_L)).float().unsqueeze(1)

### ✅ Test Accuracy

In [None]:
with torch.no_grad():  # Disable gradient tracking for inference
    y_pred = model(X_test)
    print("Test predictions (raw):", y_pred.squeeze())
    y_pred_labels = (y_pred >= 0.5).float()  # Threshold at 0.5
    print("Test predictions:", y_pred_labels.squeeze())
    correct = (y_pred_labels == y_test).sum().item()
    accuracy = correct / len(y_test)

print(f"Accuracy: {accuracy*100:.2f}%")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Convert to CPU lists
y_true = y_test.cpu().tolist()
y_pred = y_pred_labels.cpu().tolist()

# Compute matrix
cm = confusion_matrix(y_true, y_pred)

# Plot
plt.figure(figsize=(4, 3))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Not L", "L"], yticklabels=["Not L", "L"])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()

## Multi Shape Net
Now, we create a simple nueral net that can distinguish different types of shapes in 3x3 grayscale images.

Shapes: "L", "box", "zigzag", "T", "column"

### 📐 Data Setup

In [None]:
SHAPE_CLASSES = ["L", "Box", "Zigzag", "T", "Column"]
idx_from_label = {name: i for i, name in enumerate(SHAPE_CLASSES)}
label_from_idx = {i: name for i, name in enumerate(SHAPE_CLASSES)}

In [None]:
import torch

# Sample images per class (grayscale binary, 3x3)
samples = [
    # L shapes
    (torch.tensor([[1,0,0,0],[1,0,0,0],[1,1,0,0],[0,0,0,0]]), "L"),
    (torch.tensor([[0,1,0,0],[0,1,0,0],[0,1,1,0],[0,0,0,0]]), "L"),
    (torch.tensor([[0,0,0,0],[1,1,1,0],[0,0,1,0],[0,0,0,0]]), "L"),
    (torch.tensor([[0,0,0,0],[0,0,0,0],[1,0,0,0],[1,1,1,0]]), "L"),

    # Boxes (2x2)
    (torch.tensor([[1,1,0,0],[1,1,0,0],[0,0,0,0],[0,0,0,0]]), "Box"),
    (torch.tensor([[0,0,0,0],[1,1,0,0],[1,1,0,0],[0,0,0,0]]), "Box"),
    (torch.tensor([[0,0,0,0], [0,0,0,0],[0,1,1,0],[0,1,1,0]]), "Box"),
    (torch.tensor([[0,0,0,0], [0,1,1,0],[0,1,1,0],[0,0,0,0]]), "Box"),

    # Zigzags
    (torch.tensor([[1,0,0,0],[1,1,0,0],[0,1,0,0],[0,0,0,0]]), "Zigzag"),
    (torch.tensor([[0,0,0,0],[0,1,1,0],[1,1,0,0],[0,0,0,0]]), "Zigzag"),
    (torch.tensor([[0,0,0,0],[1,1,0,0],[0,1,1,0],[0,0,0,0]]), "Zigzag"),
    (torch.tensor([[0,0,0,0],[0,1,0,0],[1,1,0,0],[1,0,0,0]]), "Zigzag"),

    # Ts
    (torch.tensor([[1,0,0,0],[1,1,0,0],[1,0,0,0],[0,0,0,0]]), "T"),
    (torch.tensor([[0,0,0,0],[1,1,1,0],[0,1,0,0],[0,0,0,0]]), "T"),
    (torch.tensor([[0,0,0,0],[1,1,1,0],[0,1,0,0],[0,0,0,0]]), "T"),
    (torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,1,0],[0,1,1,1]]), "T"),

    # Columns
    (torch.tensor([[0,1,0,0],[0,1,0,0],[0,1,0,0],[0,0,0,0]]), "Column"),
    (torch.tensor([[1,0,0,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]]), "Column"),
    (torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[1,1,1,0]]), "Column"),
    (torch.tensor([[0,0,0,0],[0,0,0,0],[1,1,1,0], [0,0,0,0]]), "Column"),
]

def augment_rotations(img: torch.Tensor):
    """Returns 4 rotations (0°, 90°, 180°, 270°) of a 3x3 image tensor."""
    return [img.rot90(k, dims=[0, 1]) for k in range(4)]


def augment_with_flips_and_rotations(img):
    rots = augment_rotations(img)
    flips = [torch.flip(img, [0]), torch.flip(img, [1])]  # vertical, horizontal
    return rots + flips


augmented_X = []
augmented_y = []

for img, label in samples:
    for aug in augment_with_flips_and_rotations(img):
        augmented_X.append(aug.float())
        augmented_y.append(idx_from_label[label])

X_train = torch.stack(augmented_X).unsqueeze(1)  # Add channel dimension (1 for grayscale) -> (B, 1, H, W)
y_train = torch.tensor(augmented_y)  # Shape (B,)
print(f"X_train shape: {X_train.shape}\ny_train shape: {y_train.shape}")

In [None]:
visualize = True
if visualize:
    import matplotlib.pyplot as plt

    cols = 12
    fig, axes = plt.subplots(len(augmented_X) // cols, cols, figsize=(12, 9))
    for i, (img, label) in enumerate(zip(augmented_X, augmented_y)):
        ax = axes[i // cols, i % cols]
        ax.imshow(img.numpy(), cmap='Greens', vmin=0, vmax=1)
        ax.set_title(label_from_idx[label], fontsize=8)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()
    plt.show()


### 🧠 Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class MultiShapeNet(nn.Module):
    def __init__(self):
        super(MultiShapeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3)   # 8 x [1, 3, 3] filters
        self.bn1 = nn.BatchNorm2d(8)                  # normalize all data by channel
        self.pool1 = nn.MaxPool2d(kernel_size=2)      # downsample spatially

        self.conv2 = nn.Conv2d(8, 16, kernel_size=1)  # 16 x [8, 1, 1] filters
        self.bn2 = nn.BatchNorm2d(16)                 # normalize all data by channel
        self.dropout = nn.Dropout(0.2)               # randomly zero out features

        self.fc = nn.Linear(16 * 1 * 1, 5)            # output shape after pool = [B, 16, 1, 1]

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))   # [B, 8, 2, 2]
        x = self.pool1(x)                     # [B, 8, 1, 1]
        x = F.relu(self.bn2(self.conv2(x)))   # [B, 16, 1, 1]
        x = self.dropout(x)
        x = x.view(x.size(0), -1)             # [B, 16]
        return self.fc(x)


### ⚙️ Training

In [None]:
model = MultiShapeNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

losses = []
for epoch in range(250):
    out = model(X_train)
    loss = loss_fn(out, y_train)
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
def visualize_filters(layer: nn.Conv2d, title="Conv Filters"):
    with torch.no_grad():
        filters = layer.weight.cpu().detach().clone()
    
    num_filters = filters.shape[0]  # out_channels
    fig, _ = plt.subplots(1, num_filters, figsize=(num_filters * 2, 2))
    fig.suptitle(title)

    for i, subplot in enumerate(fig.axes):
        filter_img = filters[i][0]  # Assuming in_channels=1
        subplot.imshow(filter_img, cmap='viridis')
        subplot.axis('off')

    plt.tight_layout()
    plt.show()

if visualize:
    plt.plot(losses, label='Loss')
    min_loss_value = min(losses)
    min_loss_epoch = losses.index(min_loss_value)
    plt.scatter(min_loss_epoch, min_loss_value, color='red', label=f'Min Loss: {min_loss_value:.4f}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss with Min Loss Highlighted')
    plt.legend()
    plt.show()
    visualize_filters(model.conv1, title="Conv1 Filters")


### 🧪 Create Test Set

In [None]:
test_samples = [
    # True label: "L"
    (torch.tensor([[1,0,0,0],[1,0,0,0],[1,1,0,0],[0,0,0,0]]), "L"),
    (torch.tensor([[0,0,0,0],[1,0,0,0],[1,1,1,0],[0,0,0,0]]), "L"),
    (torch.tensor([[0,0,1,0],[1,1,1,0],[0,0,0,0],[0,0,0,0]]), "L"),
    # True label: "Box"
    (torch.tensor([[0,1,1,0],[0,1,1,0],[0,0,0,0],[0,0,0,0]]), "Box"),
    (torch.tensor([[0,0,0,0],[0,1,1,0],[0,1,1,0],[0,0,0,0]]), "Box"),
    (torch.tensor([[0,0,0,0],[1,1,0,0],[1,1,0,0],[0,0,0,0]]), "Box"),
    # True label: "Zigzag"
    (torch.tensor([[1,1,0,0],[0,1,1,0],[0,0,0,0],[0,0,0,0]]), "Zigzag"),
    (torch.tensor([[0,0,0,0],[0,1,1,0],[1,1,0,0],[0,0,0,0]]), "Zigzag"),
    (torch.tensor([[0,1,0,0],[0,1,1,0],[0,0,1,0],[0,0,0,0]]), "Zigzag"),
    # True label: "T"
    (torch.tensor([[0,0,1,0],[0,1,1,0],[0,0,1,0],[0,0,0,0]]), "T"),
    (torch.tensor([[0,0,0,0],[0,1,0,0],[1,1,1,0],[0,0,0,0]]), "T"),
    (torch.tensor([[0,0,0,0],[1,1,1,0],[0,1,0,0],[0,0,0,0]]), "T"),
    # True label: "Column"
    (torch.tensor([[0,1,0,0],[0,1,0,0],[0,1,0,0],[0,0,0,0]]), "Column"),
    (torch.tensor([[1,0,0,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]]), "Column"),
    (torch.tensor([[0,0,0,0],[0,0,0,0],[1,1,1,0],[0,0,0,0]]), "Column"),
]

X_test = torch.stack([img.float().unsqueeze(0) for img, _ in test_samples])  # Add channel dimension (1 for grayscale) -> (B, 1, H, W)
y_test = torch.tensor([[idx_from_label[label]] for _, label in test_samples]) # Convert labels to indices -> (B, 1)

In [None]:
if visualize:
    import matplotlib.pyplot as plt

    num_images = len(X_test)
    num_cols = 5
    num_rows = (num_images + num_cols - 1) // num_cols  # Calculate rows dynamically
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 3))

    for i, (img, label) in enumerate(zip(X_test, y_test)):
        ax = axes[i // num_cols, i % num_cols]
        ax.imshow(img.numpy().squeeze(), cmap='Greens', vmin=0, vmax=1)
        ax.set_title(label_from_idx[int(label.item())])
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.show()

### ✅ Test Accuracy

In [None]:
with torch.no_grad():
    logits = model(X_test)
    preds = torch.argmax(logits, dim=1)
    print("Test predictions:", [label_from_idx[int(idx.item())] for idx in preds])
    print("True labels     :", [label_from_idx[int(idx.item())] for idx in y_test.squeeze()])
    correct = (preds == y_test.squeeze()).sum().item()
    accuracy = correct / len(y_test)

print(f"Accuracy: {accuracy*100:.2f}%")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_test.tolist(), preds.tolist())

plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=SHAPE_CLASSES, yticklabels=SHAPE_CLASSES)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn

# Input: 2 images (batch), 3 channels (filters), 2x2 spatial
x = torch.tensor([
    [  # Image 1
        [[1.0, 2.0], [3.0, 4.0]],  # Channel 1
        [[5.0, 6.0], [7.0, 8.0]],  # Channel 2
        [[9.0, 10.0], [11.0, 12.0]]  # Channel 3
    ],
    [  # Image 2
        [[2.0, 3.0], [4.0, 5.0]],
        [[6.0, 7.0], [8.0, 9.0]],
        [[10.0, 11.0], [12.0, 13.0]]
    ]
])  # shape = (2, 3, 2, 2)

bn = nn.BatchNorm2d(3)  # One set of stats per channel

# In eval mode, it uses running stats. In training mode, it uses current batch.
bn.train()
normalized = bn(x)

print("Output after BatchNorm2d:\n", normalized)
print("Shape of normalized output:", normalized.shape)