In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tensorflow.keras.datasets import cifar10
from ndlinear import NdLinear

# Load CIFAR-10 dataset
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = torch.tensor(X_train).permute(0, 3, 1, 2).float() / 255.0
y_train = torch.tensor(y_train).squeeze().long()
X_test = torch.tensor(X_test).permute(0, 3, 1, 2).float() / 255.0
y_test = torch.tensor(y_test).squeeze().long()

# Creating fake bounding boxes
def simulate_boxes(labels):
    boxes = []
    for label in labels:
        x = 0.1 + 0.05 * label + 0.05 * torch.rand(1)
        y = 0.1 + 0.05 * label + 0.05 * torch.rand(1)
        w = 0.2 + 0.1 * torch.rand(1)
        h = 0.2 + 0.1 * torch.rand(1)
        boxes.append(torch.tensor([x.item(), y.item(), w.item(), h.item()]))
    return torch.stack(boxes)

# NdLinear-based Model 
class CNNWithNdLinear(nn.Module):
    def __init__(self):
        super(CNNWithNdLinear, self).__init__()
        
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # NdLinear Layers
        self.ndlayer1 = NdLinear(input_dims=(128, 32, 32), hidden_size=(256, 32))
        self.ndlayer2 = NdLinear(input_dims=(256, 32), hidden_size=(128, 32))

        
        self.flatten = nn.Flatten()
        self.class_head = nn.Linear(128 * 32, 10) 
        self.box_head = nn.Linear(128 * 32, 4)   

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = F.max_pool2d(F.relu(self.conv3(x)), 2)

        x = F.relu(self.ndlayer1(x))
        x = F.relu(self.ndlayer2(x))

        x = self.flatten(x)
        class_out = self.class_head(x)
        bbox_out = self.box_head(x)
        
        return class_out, bbox_out

# Initialize model
model = CNNWithNdLinear().cuda()

# Define optimizer and loss functions
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn_class = nn.CrossEntropyLoss()
loss_fn_box = nn.MSELoss()

# Training
num_epochs = 50
best_accuracy = 0

for epoch in range(num_epochs):
    model.train()
    total_class_loss, total_box_loss = 0, 0

    for i in range(0, len(X_train), 128):
        imgs = X_train[i:i+128].to(device)
        labels = y_train[i:i+128].to(device)
        boxes = simulate_boxes(labels).to(device)

        optimizer.zero_grad()
        class_logits, box_preds = model(imgs)
        loss_class = loss_fn_class(class_logits, labels)
        loss_box = loss_fn_box(box_preds, boxes)
        loss = loss_class + loss_box
        loss.backward()
        optimizer.step()

        total_class_loss += loss_class.item()
        total_box_loss += loss_box.item()

    print(f"Epoch {epoch}: Class Loss = {total_class_loss:.4f}, Box Loss = {total_box_loss:.4f}")

    # Save model at intervals (every 10 epochs)
    if epoch % 10 == 0 or epoch == num_epochs - 1:
        torch.save(model.state_dict(), f'saved_models/model_epoch_{epoch}.pt')
        print(f"Model saved at epoch {epoch}")

    # Model eval
    if epoch % 10 == 0 or epoch == num_epochs - 1:
        model.eval()
        all_preds, all_labels, all_boxes_true, all_boxes_pred = [], [], [], []

        with torch.no_grad():
            for i in range(0, len(X_test), 128):
                imgs = X_test[i:i+128].to(device)
                labels = y_test[i:i+128].to(device)
                boxes = simulate_boxes(labels).to(device)
                class_logits, box_preds = model(imgs)
                preds = class_logits.argmax(dim=1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_boxes_true.extend(boxes.cpu().numpy())
                all_boxes_pred.extend(box_preds.cpu().numpy())

        acc = accuracy_score(all_labels, all_preds)
        bbox_mse = mean_squared_error(all_boxes_true, all_boxes_pred)

        print(f"\nEpoch {epoch} | Test Accuracy: {acc*100:.2f}% | Test BBox MSE: {bbox_mse:.4f}")

        # Save the best model
        if acc > best_accuracy:
            best_accuracy = acc
            torch.save(model.state_dict(), f'saved_models/best_model.pt')
            print(f"Best model saved with accuracy {acc*100:.2f}%")

print("Training Complete!")



KeyboardInterrupt

