In [97]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

In [98]:
class VAE(nn.Module):
    def __init__(self, input_dim_w1, input_dim_w2, hidden_dim=32):
        super(VAE, self).__init__()

        # Encoder
        self.encoder_nn1 = nn.Sequential(
            nn.Linear(input_dim_w2, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
        )
        self.encoder_nn2 = nn.Sequential(
            nn.Linear(input_dim_w1 + hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(hidden_dim, hidden_dim)
        self.fc_logvar = nn.Linear(hidden_dim, hidden_dim)

        # Decoder
        self.decoder_nn3 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, input_dim_w1),
            # nn.ReLU(),
        )
        self.decoder_nn4 = nn.Sequential(
            nn.Linear(hidden_dim + input_dim_w1, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, input_dim_w2),
            # nn.Sigmoid(),
        )

    def encode(self, w1, w2):
        h1 = self.encoder_nn1(w2)
        h2 = self.encoder_nn2(torch.cat((w1, h1), dim=1))
        mu = self.fc_mu(h2)
        logvar = self.fc_logvar(h2)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        weight1_hat = self.decoder_nn3(z)
        weight2_hat = self.decoder_nn4(torch.cat((z, F.relu(weight1_hat)), dim=1))
        return weight1_hat, weight2_hat

    def forward(self, w1, w2):
        mu, logvar = self.encode(w1, w2)
        z = self.reparameterize(mu, logvar)
        weight1_hat, weight2_hat = self.decode(z)
        return weight1_hat, weight2_hat, mu, logvar


def loss_function(weight1_hat, weight2_hat, weight1, weight2, mu, logvar):
    # Reconstruction loss (MSE)
    recon_loss1 = F.mse_loss(weight1_hat, weight1, reduction="sum")
    recon_loss2 = F.mse_loss(weight2_hat, weight2, reduction="sum")
    recon_loss = recon_loss1 + recon_loss2

    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Total loss
    return recon_loss + kl_loss

In [99]:
import pickle
import io
from torch.utils.data import DataLoader, Dataset

class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == "torch.storage" and name == "_load_from_bytes":
            return lambda b: torch.load(io.BytesIO(b), map_location="cpu")
        else:
            return super().find_class(module, name)


with open('2_layer_real_models.pickle', 'rb') as f:
    real_models = CPU_Unpickler(f).load()

class WeightsDataset(Dataset):
    def __init__(self, state_dicts):
        self.data = [self.flatten_state_dict(sd) for sd in state_dicts]
        self.mean_w1, self.std_w1, self.mean_w2, self.std_w2 = (
            self.calculate_statistics()
        )
        self.data = [
            (
                self.normalize(w1, self.mean_w1, self.std_w1),
                self.normalize(w2, self.mean_w2, self.std_w2),
            )
            for w1, w2 in self.data
        ]

    def flatten_state_dict(self, state_dict):
        flattened_params = []
        for key, param in state_dict.items():
            if isinstance(param, torch.Tensor):
                if "weight" in key:
                    bias_key = key.replace("weight", "bias")
                    if bias_key in state_dict:
                        bias_param = state_dict[bias_key]
                        concatenated = torch.cat([param.view(-1), bias_param.view(-1)])
                        flattened_params.append(concatenated)
        return tuple(flattened_params)

    def calculate_statistics(self):
        all_w1 = torch.cat([data[0] for data in self.data])
        all_w2 = torch.cat([data[1] for data in self.data])
        mean_w1, std_w1 = all_w1.mean(), all_w1.std()
        mean_w2, std_w2 = all_w2.mean(), all_w2.std()
        return mean_w1, std_w1, mean_w2, std_w2

    def normalize(self, data, mean, std):
        return (data - mean) / std

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

train_dataset = WeightsDataset(real_models)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

sample = train_dataset[0]
w1_size, w2_size = sample[0].size()[0], sample[1].size()[0]
# # Normalize the data
# data_min = data.min()
# data_max = data.max()
# data_normalized = (data - data_min) / (data_max - data_min)
# data -= data.mean()

In [100]:
hidden_dim = 32
vae = VAE(input_dim_w1=w1_size, input_dim_w2=w2_size, hidden_dim=hidden_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

epochs = 1000
for epoch in range(epochs):
    vae.train()
    train_loss = 0
    for i, batch in enumerate(train_loader):
        batch_w1, batch_w2 = batch[0], batch[1]
        weight1_hat, weight2_hat, mu, logvar = vae(batch_w1, batch_w2)
        loss = loss_function(weight1_hat, weight2_hat, batch_w1, batch_w2, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(
        "Epoch: {}, Loss: {:.4f}".format(epoch, train_loss / len(train_loader.dataset))
    )

Epoch: 0, Loss: 44.5811
Epoch: 1, Loss: 43.2903
Epoch: 2, Loss: 42.7145
Epoch: 3, Loss: 42.0227
Epoch: 4, Loss: 40.7362
Epoch: 5, Loss: 38.4143
Epoch: 6, Loss: 35.0550
Epoch: 7, Loss: 30.8955
Epoch: 8, Loss: 27.8461
Epoch: 9, Loss: 25.8368
Epoch: 10, Loss: 24.7073
Epoch: 11, Loss: 25.0961
Epoch: 12, Loss: 23.9774
Epoch: 13, Loss: 23.6926
Epoch: 14, Loss: 23.6619
Epoch: 15, Loss: 23.7744
Epoch: 16, Loss: 23.4359
Epoch: 17, Loss: 23.2717
Epoch: 18, Loss: 23.3277
Epoch: 19, Loss: 23.0831
Epoch: 20, Loss: 22.7938
Epoch: 21, Loss: 22.7719
Epoch: 22, Loss: 22.8663
Epoch: 23, Loss: 22.5914
Epoch: 24, Loss: 22.8374
Epoch: 25, Loss: 22.8059
Epoch: 26, Loss: 22.7514
Epoch: 27, Loss: 22.3906
Epoch: 28, Loss: 22.4318
Epoch: 29, Loss: 22.4061
Epoch: 30, Loss: 21.8470
Epoch: 31, Loss: 21.7666
Epoch: 32, Loss: 21.9967
Epoch: 33, Loss: 21.4856
Epoch: 34, Loss: 21.8796
Epoch: 35, Loss: 21.6825
Epoch: 36, Loss: 21.6938
Epoch: 37, Loss: 21.1708
Epoch: 38, Loss: 21.5960
Epoch: 39, Loss: 21.4467
Epoch: 40,

In [101]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris = load_iris()
X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, test_size=0.33)
X_train, X_test, y_train, y_test = (
    torch.tensor(X_train, device=device, dtype=torch.float32),
    torch.tensor(X_test, device=device, dtype=torch.float32),
    torch.tensor(y_train, device=device, dtype=torch.long),
    torch.tensor(y_test, device=device, dtype=torch.long),
)

In [102]:
from sklearn.metrics import accuracy_score
from model import Iris2LayerClassifier


NUM_OF_MODELS = 500


# Generate weight matrices from the VAE model
generated_tensors = []
with torch.no_grad():
    z = torch.randn(NUM_OF_MODELS, hidden_dim).to(device)
    vae.eval()
    weight1_hats, weight2_hats = vae.decode(z)

temp_model = Iris2LayerClassifier()
weight1_shape = temp_model.classifier[0].weight.shape
bias1_shape = temp_model.classifier[0].bias.shape
weight2_shape = temp_model.classifier[2].weight.shape
bias2_shape = temp_model.classifier[2].bias.shape
w1_len = torch.prod(torch.tensor(weight1_shape)).item()
w2_len = torch.prod(torch.tensor(weight2_shape)).item()
generated_models = []
for i in range(NUM_OF_MODELS):
    weight1_hat = weight1_hats[i]
    weight2_hat = weight2_hats[i]
    w1 = weight1_hat[:w1_len].view(weight1_shape)
    b1 = weight1_hat[w1_len:].view(bias1_shape)
    w2 = weight2_hat[:w2_len].view(weight2_shape)
    b2 = weight2_hat[w2_len:].view(bias2_shape)
    new_state_dict = {
        "classifier.0.weight": w1,
        "classifier.0.bias": b1,
        "classifier.2.weight": w2,
        "classifier.2.bias": b2,
    }
    generated_models.append(new_state_dict)


for state_dict in generated_models:
    model = Iris2LayerClassifier().to(device)
    model.load_state_dict(state_dict)
    model.eval()
    with torch.inference_mode():
        y_pred = model(X_test)
        _, labels = torch.max(y_pred, 1)
        accuracy = accuracy_score(y_test.cpu().numpy(), labels.cpu().numpy())
        print(accuracy)

0.66
0.96
0.68
0.54
0.98
0.7
0.42
0.92
0.52
0.94
0.7
0.7
0.7
0.98
0.7
0.4
0.9
0.74
0.86
0.68
0.7
0.8
0.7
0.7
0.7
0.84
0.86
0.72
0.38
0.38
0.7
0.9
0.8
0.56
0.98
0.68
0.7
0.92
0.7
0.38
0.7
0.78
0.7
0.7
0.7
0.94
0.9
0.7
0.68
0.7
0.7
0.78
0.7
0.7
0.7
0.7
0.74
0.7
0.38
0.84
0.78
0.44
0.98
0.46
0.7
0.7
0.68
0.76
0.4
0.38
0.32
0.7
0.7
0.38
0.98
0.7
0.32
0.7
0.7
0.7
0.72
0.98
0.88
0.7
0.88
0.7
0.64
0.68
0.7
0.34
0.7
0.68
0.96
0.64
0.7
0.84
0.8
0.38
0.38
0.7
0.64
0.7
0.7
0.96
0.92
0.7
0.7
0.64
0.4
0.52
0.88
0.68
0.7
0.94
0.7
0.7
0.72
0.98
0.7
0.72
0.38
0.7
0.82
0.98
0.76
0.82
0.7
0.98
0.84
0.52
0.78
0.86
0.4
0.7
0.74
0.7
0.7
0.66
0.8
0.3
0.94
0.7
0.82
0.7
0.7
0.7
0.7
0.7
0.6
0.88
0.84
0.7
0.86
0.7
0.4
0.68
0.7
0.98
0.3
0.7
0.66
0.9
0.38
0.98
0.92
0.38
0.7
0.7
0.52
0.32
0.7
0.54
0.96
0.7
0.32
0.8
0.7
0.68
0.7
0.84
0.68
0.7
0.7
0.7
0.7
0.68
0.7
0.7
0.98
0.82
0.82
0.7
0.7
0.7
0.7
0.7
0.7
0.7
0.84
0.86
0.38
0.98
0.82
0.7
0.7
0.58
0.7
0.32
0.92
0.9
0.7
0.78
0.72
0.7
0.7
0.8
0.68
0.68
0.32
0.58
0.7
0

In [103]:
with open("2_layer_generated_models.pickle", "wb") as f:
    pickle.dump(generated_models, f)