In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Define CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.fc = nn.Linear(in_features=64*3*3, out_features=10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        x = x.reshape(-1, 64*3*3)
        x = self.fc(x)
        return x

In [None]:
# Set hyperparameters
batch_size = 64
learning_rate = 0.001
num_epochs = 5

In [None]:
# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Initialize model and optimizer
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Train the model
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        # Forward pass
        scores = model(data)
        loss = nn.CrossEntropyLoss()(scores, targets)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')

In [None]:
# Define evaluate function
def evaluate(model, loader):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, targets in loader:
            data = data.to(device)
            targets = targets.to(device)
            scores = model(data)
            _, predictions = torch.max(scores, dim=1)
            total += targets.size(0)
            correct += (predictions == targets).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [None]:
accuracy = evaluate(model, test_loader)
print(f'Test accuracy: {accuracy:.2f}%')

In [None]:
from decomposition.decomposition import cp_decomposition_conv_layer
from decomposition.CPDLayers import CPDLayer

In [None]:
import copy

In [None]:
import tensorly as tl
tl.set_backend("pytorch")

In [None]:
net = copy.deepcopy(model)

In [None]:
print(net)

In [None]:
for name, module in net._modules.items():
    if isinstance(module, nn.Conv2d) and name == 'conv3':
        print(name, module)
        cpd_layer = cp_decomposition_conv_layer(module, rank=2)
        net._modules[name] = cpd_layer
        accuracy = evaluate(net, test_loader)
        print(f'Test accuracy: {accuracy:.2f}%')