In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# Define a simple CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # 10 output classes for MNIST

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [3]:
def structurise_reconstructions(decoder_output, bs=16):
    '''Convert list of flattened batches into list of reshaped images'''
    structurised = []
    for batch in decoder_output:
        structurised.append(batch.reshape(bs, 1, 28, 28))
    return structurised

In [4]:
def train(flav='mnist'):
       # Loading MNIST dataset
    transform = transforms.Compose([transforms.ToTensor()])

    if flav == 'mnist':
        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    elif flav == 'fashion':
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform) 

    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Instantiate the model
    model = CNN()

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    num_epochs = 5

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                # print(f'Epoch: {epoch + 1}, Batch: {i + 1}, Loss: {running_loss / 100}')
                running_loss = 0.0

    # print('Finished Training')

    # Evaluation on test data
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the 10000 test images: {(100 * correct / total):.2f}%')
    return (model, (100 * correct / total))

In [5]:
# Evaluation
def batch_test(reconstructions, model):
    labels = 7 * torch.ones(16)
    model.eval()
    with torch.no_grad():
        batch_accuracy = []
        for batch in reconstructions:
            correct = 0
            total = 0
            outputs = model(batch)
            _, predicted = torch.max(outputs.data, 1)
            print(predicted)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            batch_accuracy.append(100 * correct/total)
    return batch_accuracy

    # print(f"Accuracy on the test set: {100 * correct / total:.2f}%")

## MNIST Cumulants

In [6]:
reconstructions_1_2_3_cumul = torch.load('/Users/as/Desktop/reconstructions/MNIST/bs16/lr_0.0001/2pt-1.0_3pt-0.017/mixed_reconstr_avg_approx_all_1pt_1.0-2pt_0.017-3pt_simpleout')[1000:]
reconstructions_1_2_cumul = torch.load('/Users/as/Desktop/reconstructions/MNIST/bs16/lr_0.0001/2pt-1-3pt_0.0_doover/mixed_reconstr_avg_approx_all_1pt_1.0-2pt_0.0-3pt_simpleout')[1000:]
reconstructions_1_cumul = torch.load('/Users/as/Desktop/reconstructions/MNIST/bs16/lr_0.0001/2pt-0.0_3pt-0.0/mixed_reconstr_avg_approx_all_1pt_0.0-2pt_0.0-3pt_simpleout')[1000:]

reconstructions_1_cumul = structurise_reconstructions(reconstructions_1_cumul)
reconstructions_1_2_cumul = structurise_reconstructions(reconstructions_1_2_cumul)
reconstructions_1_2_3_cumul = structurise_reconstructions(reconstructions_1_2_3_cumul)

In [7]:
accuracies_1_cumul_mnist = []
accuracies_12_cumul_mnist = []
accuracies_123_cumul_mnist = []

for i in range(5):
    model, train_acc = train()
    batch_accuracy_1_2_3 = batch_test(reconstructions_1_2_3_cumul, model)
    accuracies_123_cumul_mnist.append(torch.as_tensor(batch_accuracy_1_2_3).mean())

    batch_accuracy_1_2 = batch_test(reconstructions_1_2_cumul, model)
    accuracies_12_cumul_mnist.append(torch.as_tensor(batch_accuracy_1_2).mean())

    batch_accuracy_1 = batch_test(reconstructions_1_cumul, model)
    accuracies_1_cumul_mnist.append(torch.as_tensor(batch_accuracy_1).mean())

Accuracy of the network on the 10000 test images: 99.24%
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([8, 8, 8, 8, 8, 8, 8, 8,

In [12]:
print('Avg 1 pt reconstr: ', float(sum(accuracies_1_cumul_mnist)/len(accuracies_1_cumul_mnist)))
print('Avg 1 and 2pt reconstr: ', float(sum(accuracies_12_cumul_mnist)/len(accuracies_12_cumul_mnist)))
print('Avg 1,2,3 pt reconst: ', float(sum(accuracies_123_cumul_mnist)/len(accuracies_123_cumul_mnist)))

Avg 1 pt reconstr:  0.34285715222358704
Avg 1 and 2pt reconstr:  20.842857360839844
Avg 1,2,3 pt reconst:  99.19999694824219


## MNIST Moments

In [13]:
reconstructions_1_2_3_moment_mnist = torch.load('/Users/as/Desktop/reconstructions/moments/MNIST/2pt-1_3pt-0.017/mixed_reconstr_avg_approx_all_1pt_1.0-2pt_0.017-3pt_simpleout')[1000:]
reconstructions_1_2_moment_mnist = torch.load('/Users/as/Desktop/reconstructions/moments/MNIST/2pt-1_3pt-0/mixed_reconstr_avg_approx_all_1pt_1.0-2pt_0.0-3pt_simpleout')[1000:]
reconstructions_1_moment_mnist = torch.load('/Users/as/Desktop/reconstructions/moments/MNIST/2pt-0_3pt-0/mixed_reconstr_avg_approx_all_1pt_0.0-2pt_0.0-3pt_simpleout')[1000:]

reconstructions_1_moment_mnist = structurise_reconstructions(reconstructions_1_moment_mnist)
reconstructions_1_2_moment_mnist = structurise_reconstructions(reconstructions_1_2_moment_mnist)
reconstructions_1_2_3_moment_mnist = structurise_reconstructions(reconstructions_1_2_3_moment_mnist)

In [14]:
accuracies_1_moment_mnist = []
accuracies_12_moment_mnist = []
accuracies_123_moment_mnist = []

for i in range(5):
    model, train_acc = train()
    batch_accuracy_1_2_3 = batch_test(reconstructions_1_2_3_moment_mnist, model)
    accuracies_123_moment_mnist.append(torch.as_tensor(batch_accuracy_1_2_3).mean())

    batch_accuracy_1_2 = batch_test(reconstructions_1_2_moment_mnist, model)
    accuracies_12_moment_mnist.append(torch.as_tensor(batch_accuracy_1_2).mean())

    batch_accuracy_1 = batch_test(reconstructions_1_moment_mnist, model)
    accuracies_1_moment_mnist.append(torch.as_tensor(batch_accuracy_1).mean())

Accuracy of the network on the 10000 test images: 99.01%
Accuracy of the network on the 10000 test images: 99.14%
Accuracy of the network on the 10000 test images: 98.88%
Accuracy of the network on the 10000 test images: 98.99%
Accuracy of the network on the 10000 test images: 98.45%


In [15]:
print('Avg 1 pt reconstr: ', float(sum(accuracies_1_moment_mnist)/len(accuracies_1_moment_mnist)))
print('Avg 1 and 2pt reconstr: ', float(sum(accuracies_12_moment_mnist)/len(accuracies_12_moment_mnist)))
print('Avg 1,2,3 pt reconst: ', float(sum(accuracies_123_moment_mnist)/len(accuracies_123_moment_mnist)))

Avg 1 pt reconstr:  0.29333335161209106
Avg 1 and 2pt reconstr:  15.822221755981445
Avg 1,2,3 pt reconst:  11.84889030456543


## FashionMNIST

### FashionMNIST (cumulants)

In [None]:
reconstructions_1_2_3_cumul_fashion = torch.load('/Users/as/Desktop/reconstructions/FashionMNIST/bs16/lr_0.0001/larger_init_batch/2pt-0.8_3pt-0.001_variable/mixed_reconstr_avg_approx_all_1pt_0.8-2pt_0.001-3pt_simpleout')[1000:]
reconstructions_1_2_cumul_fashion = torch.load('/Users/as/Desktop/reconstructions/FashionMNIST/bs16/lr_0.0001/larger_init_batch/2pt-0.80_3pt-0.0_doover/mixed_reconstr_avg_approx_all_1pt_0.8-2pt_0.0-3pt_simpleout')[1000:]
reconstructions_1_cumul_fashion = torch.load('/Users/as/Desktop/reconstructions/FashionMNIST/bs16/lr_0.0001/larger_init_batch/2pt-0.0_3pt-0.0/mixed_reconstr_avg_approx_all_1pt_0.0-2pt_0.0-3pt_simpleout')[1000:]

reconstructions_1_cumul_fashion = structurise_reconstructions(reconstructions_1_cumul_fashion)
reconstructions_1_2_3_cumul_fashion = structurise_reconstructions(reconstructions_1_2_cumul_fashion)
reconstructions_1_2_3_cumul_fashion = structurise_reconstructions(reconstructions_1_2_3_cumul_fashion)

In [17]:
accuracies_1_cumul_fashion = []
accuracies_12_cumul_fashion = []
accuracies_123_cumul_fashion = []

for i in range(5):
    model, train_acc = train()
    batch_accuracy_1_2_3 = batch_test(reconstructions_1_2_3_cumul, model)
    accuracies_123_cumul_fashion.append(torch.as_tensor(batch_accuracy_1_2_3).mean())

    batch_accuracy_1_2 = batch_test(reconstructions_1_2_cumul, model)
    accuracies_12_cumul_fashion.append(torch.as_tensor(batch_accuracy_1_2).mean())

    batch_accuracy_1 = batch_test(reconstructions_1_cumul, model)
    accuracies_1_cumul_fashion.append(torch.as_tensor(batch_accuracy_1).mean())

print(torch.mean(accuracies_123_cumul_fashion), torch.mean(accuracies_12_cumul_fashion), torch.mean(accuracies_1_cumul_fashion))

Accuracy of the network on the 10000 test images: 99.13%
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7,

KeyboardInterrupt: 

### FashionMNIST (moments)

In [None]:
reconstructions_1_2_3_moment_fashion = torch.load('/Users/as/Desktop/reconstructions/moments/FashionMNIST/2pt-0.8_3pt-0.001/mixed_reconstr_avg_approx_all_1pt_0.8-2pt_0.001-3pt_simpleout')[-500:]
reconstructions_1_2_moment_fashion = torch.load('/Users/as/Desktop/reconstructions/moments/FashionMNIST/2pt-0.8_3pt-0/mixed_reconstr_avg_approx_all_1pt_0.8-2pt_0.0-3pt_simpleout')[-500:]
reconstructions_1_moment_fashion = torch.load('/Users/as/Desktop/reconstructions/FashionMNIST/bs16/lr_0.0001/larger_init_batch/2pt-0.0_3pt-0.0/mixed_reconstr_avg_approx_all_1pt_0.0-2pt_0.0-3pt_simpleout')[-500:]

reconstructions_1_moment_fashion = structurise_reconstructions(reconstructions_1_moment_fashion)
reconstructions_1_2_moment_fashion = structurise_reconstructions(reconstructions_1_2_moment_fashion)
reconstructions_1_2_3_moment_fashion = structurise_reconstructions(reconstructions_1_2_3_moment_fashion)

In [None]:
accuracies_1_moment_fashion = []
accuracies_12_moment_fashion = []
accuracies_123_moment_fashion = []

for i in range(5):
    model, train_acc = train()
    batch_accuracy_1_2_3 = batch_test(reconstructions_1_2_3_moment_fashion, model)
    accuracies_123_moment_fashion.append(torch.as_tensor(batch_accuracy_1_2_3).mean())

    batch_accuracy_1_2 = batch_test(reconstructions_1_2_moment_fashion, model)
    accuracies_12_moment_fashion.append(torch.as_tensor(batch_accuracy_1_2).mean())

    batch_accuracy_1 = batch_test(reconstructions_1_moment_fashion, model)
    accuracies_1_moment_fashion.append(torch.as_tensor(batch_accuracy_1).mean())

torch.mean(accuracies_123_moment_fashion, accuracies_12_moment_fashion, accuracies_1_moment_fashion)