In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from ndlinear import NdLinear
import torch.nn.functional as F


In [None]:
# Normalize + convert to tensor
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Fixed to 3 channels
])

# Training data
train_set = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

# Test data
test_set = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)


In [None]:
# Flatten the CIFAR-10 images: Input size (3, 32, 32) -> Flatten to 3072
class NdLinearCIFAR10(nn.Module):
    def __init__(self):
        super(NdLinearCIFAR10, self).__init__()
        
        # Define the layers using NdLinear
        self.layer1 = NdLinear(input_dims=(3072,), hidden_size=(512,))  # Flatten to 3072 (3*32*32)
        self.layer2 = NdLinear(input_dims=(512,), hidden_size=(256,))
        self.layer3 = NdLinear(input_dims=(256,), hidden_size=(10,))  # Output classes = 10

    def forward(self, x):
        # Flatten the input images from (B, 3, 32, 32) -> (B, 3072)
        x = x.view(x.size(0), -1)  # Flatten input
        x = self.layer1(x)  # First layer
        x = F.relu(x)        # Apply ReLU activation
        x = self.layer2(x)  # Second layer
        x = F.relu(x)        # Apply ReLU activation
        x = self.layer3(x)  # Final output layer
        return x

# Create model
model = NdLinearCIFAR10()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:
# Training loop
for epoch in range(10):  # You can increase the epochs for more accuracy
    model.train()
    total, correct = 0, 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Removed redundant flattening here

        optimizer.zero_grad()
        outputs = model(inputs)  # Forward pass through the model
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()  # Backpropagate gradients
        optimizer.step()  # Update weights

        preds = outputs.argmax(dim=1)  # Get predictions
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    print(f"Epoch {epoch + 1}: Accuracy = {acc:.2%}")

torch.save(model.state_dict(), "ndlinear_cifar_model.pth")