In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

import seaborn as sns
from matplotlib import pyplot as plt

from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import cluster, pairwise_distances, normalized_mutual_info_score, silhouette_score

from tqdm import tqdm

from FlagRep0 import FlagRep, truncate_svd, chordal_distance



def set_seed(seed):
    np.random.seed(seed)                   
    torch.manual_seed(seed)                
    torch.cuda.manual_seed(seed)           
    torch.cuda.manual_seed_all(seed)       
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False  


def purity_score(y_true, y_pred):
    # compute contingency matrix (also called confusion matrix)
    contingency_matrix = cluster.contingency_matrix(y_true, y_pred)
    # return purity
    return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)


def make_Bs(fl_type):
    Bs = [np.arange(fl_type[0])]
    for i in range(1,len(fl_type)):
        Bs.append(np.arange(fl_type[i-1],fl_type[i]))
    return Bs

In [2]:
# Define the neural network
class NeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NeuralNetwork, self).__init__()
        # Input layer to hidden layer 1
        self.fc1 = nn.Linear(input_size, hidden_size)
        # Hidden layer 1 to hidden layer 2
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        # Hidden layer 2 to hidden layer 3
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        # Hidden layer 3 to output layer (10 classes)
        self.fc4 = nn.Linear(hidden_size, output_size)
        # Activation function
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)  # No activation on the output (for logits)
        return x


In [3]:

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Set up hyperparameters
input_size = 28 * 28  # Flattened image size
hidden_size = 128     # Same size for all hidden layers
output_size = 10      # 10 output classes (digits 0-9)
num_epochs = 10
learning_rate = 0.001

In [4]:
# Initialize the model, loss function, and optimizer
model0 = NeuralNetwork(input_size, hidden_size, output_size)
model0.load_state_dict(torch.load('../models/mnist_model0.pth'))

model1 = NeuralNetwork(input_size, hidden_size, output_size)
model1.load_state_dict(torch.load('../models/mnist_model1.pth'))

  model0.load_state_dict(torch.load('../models/mnist_model0.pth'))
  model1.load_state_dict(torch.load('../models/mnist_model1.pth'))


<All keys matched successfully>

In [17]:
model1.fc2.weight.detach().cpu().numpy()
np.linalg.matrix_rank(model1.fc2.weight.detach().cpu().numpy())


np.int64(128)