In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.linalg
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision
from cca_core import flatten_weights
from scipy.linalg import subspace_angles

In [16]:

class TwoBranchCNN_CIFAR(nn.Module):
    def __init__(self):
        super(TwoBranchCNN_CIFAR, self).__init__()
        # Initial common layer
        self.conv1 = nn.Conv2d(3, 20, kernel_size=5)  # Adjusted for 3-channel input
        
        # Branch 1
        self.branch1_conv1 = nn.Conv2d(20, 20, kernel_size=5)
        self.branch1_conv2 = nn.Conv2d(20, 40, kernel_size=3, padding=1)
        self.branch1_conv3 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.branch1_drop = nn.Dropout2d()
        self.branch1_fc1 = nn.Linear(2000, 50)  # Adjusted for flattened output size

        # Branch 2
        self.branch2_conv1 = nn.Conv2d(20, 20, kernel_size=5)
        self.branch2_conv2 = nn.Conv2d(20, 40, kernel_size=3, padding=1)
        self.branch2_conv3 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.branch2_drop = nn.Dropout2d()
        self.branch2_fc1 = nn.Linear(2000, 50)  # Adjusted for flattened output size
        
        # Final classifier
        self.final_fc = nn.Linear(100, 10)  # Output for CIFAR-10

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        
        # Branch 1
        b1 = F.relu(F.max_pool2d(self.branch1_conv1(x), 2))
        b1 = F.relu(self.branch1_conv2(b1))
        b1 = F.relu(self.branch1_conv3(b1))
        b1 = self.branch1_drop(b1)
        b1 = b1.view(-1, self.num_flat_features(b1))
        b1 = F.relu(self.branch1_fc1(b1))
        
        # Branch 2
        b2 = F.relu(F.max_pool2d(self.branch2_conv1(x), 2))
        b2 = F.relu(self.branch2_conv2(b2))
        b2 = F.relu(self.branch2_conv3(b2))
        b2 = self.branch2_drop(b2)
        b2 = b2.view(-1, self.num_flat_features(b2))
        b2 = F.relu(self.branch2_fc1(b2))
        
        # Combine branches
        combined = torch.cat((b1, b2), dim=1)
        output = self.final_fc(combined)
        return F.log_softmax(output, dim=1)

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

# Create the model instance
model = TwoBranchCNN_CIFAR()
print(model)


TwoBranchCNN_CIFAR(
  (conv1): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
  (branch1_conv1): Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1))
  (branch1_conv2): Conv2d(20, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (branch1_conv3): Conv2d(40, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (branch1_drop): Dropout2d(p=0.5, inplace=False)
  (branch1_fc1): Linear(in_features=2000, out_features=50, bias=True)
  (branch2_conv1): Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1))
  (branch2_conv2): Conv2d(20, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (branch2_conv3): Conv2d(40, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (branch2_drop): Dropout2d(p=0.5, inplace=False)
  (branch2_fc1): Linear(in_features=2000, out_features=50, bias=True)
  (final_fc): Linear(in_features=100, out_features=10, bias=True)
)


In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize each channel
])


In [5]:
# Load training data
train_data = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)

# Load testing data
test_data = datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

# DataLoader setups
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100.0%


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [6]:
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()  # Set the model to training mode
    total_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()  # Clear the gradients of all optimized variables
        output = model(data)  # Forward pass
        loss = criterion(output, target)  # Calculate loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update parameters
        
        total_loss += loss.item()  # Sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    print(f'\nTraining set: Average loss: {total_loss / len(train_loader):.4f}, Accuracy: {correct}/{len(train_loader.dataset)} ({100. * correct / len(train_loader.dataset):.0f}%)')


In [7]:
def test(model, device, test_loader, criterion):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    correct = 0
    with torch.no_grad():  # No gradients tracked
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # Sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')


In [15]:
# Debugging code to check the output size
# add print statements at layers to check the output size e.g. print("branch1 before drop", b1.size())
# Assuming the input size for CIFAR-10 (3 channels, 32x32 images)
# dummy_input = torch.randn(1, 3, 32, 32)  # Batch size of 1
# output = model(dummy_input)

Branch 1 before dropout: torch.Size([1, 80, 5, 5])
Branch 1 after flattening: torch.Size([1, 2000])
Branch 2 before dropout: torch.Size([1, 80, 5, 5])
Branch 2 before dropout: torch.Size([1, 2000])


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TwoBranchCNN_CIFAR().to(device)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, 11):  # Run for 10 epochs
    train(model, device, train_loader, optimizer, criterion, epoch)

#torch.save(model.state_dict(), 'model_weights_two_stream_CIFAR10.pth')




Training set: Average loss: 1.5937, Accuracy: 20521/50000 (41%)

Training set: Average loss: 1.2360, Accuracy: 27768/50000 (56%)

Training set: Average loss: 1.0821, Accuracy: 30810/50000 (62%)

Training set: Average loss: 0.9840, Accuracy: 32540/50000 (65%)

Training set: Average loss: 0.9123, Accuracy: 33965/50000 (68%)

Training set: Average loss: 0.8527, Accuracy: 35016/50000 (70%)

Training set: Average loss: 0.8040, Accuracy: 35750/50000 (72%)

Training set: Average loss: 0.7628, Accuracy: 36534/50000 (73%)

Training set: Average loss: 0.7205, Accuracy: 37205/50000 (74%)

Training set: Average loss: 0.6903, Accuracy: 37669/50000 (75%)


In [18]:
test(model, device, test_loader, criterion)


Test set: Average loss: 0.8329, Accuracy: 7188/10000 (72%)



In [19]:

torch.save(model.state_dict(), 'model_weights_two_stream_CIFAR10.pth')