In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import optuna
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from model import Iris2LayerClassifier

device = "cuda" if torch.cuda.is_available() else "cpu"

In [80]:
class VAE(nn.Module):
    def __init__(
        self,
        input_dim_w1,
        input_dim_w2,
        hidden_dims_encoder,
        hidden_dims_decoder,
        z_dim=32,
        use_batchnorm=True,
    ):
        super(VAE, self).__init__()

        encoder_layers = []
        in_dim = input_dim_w2
        for out_dim in hidden_dims_encoder:
            encoder_layers.append(nn.Linear(in_dim, out_dim))
            if use_batchnorm:
                encoder_layers.append(nn.BatchNorm1d(out_dim))
            encoder_layers.append(nn.ReLU())
            in_dim = out_dim
        self.encoder_nn1 = nn.Sequential(*encoder_layers)

        encoder_layers = []
        in_dim = input_dim_w1 + hidden_dims_encoder[-1]
        for out_dim in hidden_dims_encoder:
            encoder_layers.append(nn.Linear(in_dim, out_dim))
            if use_batchnorm:
                encoder_layers.append(nn.BatchNorm1d(out_dim))
            encoder_layers.append(nn.ReLU())
            in_dim = out_dim
        self.encoder_nn2 = nn.Sequential(*encoder_layers)

        self.fc_mu = nn.Linear(hidden_dims_encoder[-1], z_dim)
        self.fc_logvar = nn.Linear(hidden_dims_encoder[-1], z_dim)

        decoder_layers = []
        in_dim = z_dim
        for out_dim in hidden_dims_decoder:
            decoder_layers.append(nn.Linear(in_dim, out_dim))
            if use_batchnorm:
                decoder_layers.append(nn.BatchNorm1d(out_dim))
            decoder_layers.append(nn.ReLU())
            in_dim = out_dim
        self.decoder_nn3 = nn.Sequential(*decoder_layers)
        self.decoder_nn3.add_module(
            "output", nn.Linear(hidden_dims_decoder[-1], input_dim_w1)
        )

        decoder_layers = []
        in_dim = z_dim + input_dim_w1
        for out_dim in hidden_dims_decoder:
            decoder_layers.append(nn.Linear(in_dim, out_dim))
            if use_batchnorm:
                decoder_layers.append(nn.BatchNorm1d(out_dim))
            decoder_layers.append(nn.ReLU())
            in_dim = out_dim
        self.decoder_nn4 = nn.Sequential(*decoder_layers)
        self.decoder_nn4.add_module(
            "output", nn.Linear(hidden_dims_decoder[-1], input_dim_w2)
        )

    def encode(self, w1, w2):
        h1 = self.encoder_nn1(w2)
        h2 = self.encoder_nn2(torch.cat((w1, h1), dim=1))
        mu = self.fc_mu(h2)
        logvar = self.fc_logvar(h2)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        weight1_hat = self.decoder_nn3(z)
        weight2_hat = self.decoder_nn4(torch.cat((z, weight1_hat), dim=1))
        return weight1_hat, weight2_hat

    def forward(self, w1, w2):
        mu, logvar = self.encode(w1, w2)
        z = self.reparameterize(mu, logvar)
        weight1_hat, weight2_hat = self.decode(z)
        return weight1_hat, weight2_hat, mu, logvar


def loss_function(weight1_hat, weight2_hat, weight1, weight2, mu, logvar):
    recon_loss1 = F.mse_loss(weight1_hat, weight1, reduction="sum")
    recon_loss2 = F.mse_loss(weight2_hat, weight2, reduction="sum")
    recon_loss = recon_loss1 + recon_loss2
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

In [81]:
import pickle
import io
from torch.utils.data import DataLoader, Dataset


class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == "torch.storage" and name == "_load_from_bytes":
            return lambda b: torch.load(io.BytesIO(b), map_location="cpu")
        else:
            return super().find_class(module, name)


with open("2_layer_real_models.pickle", "rb") as f:
    real_models = CPU_Unpickler(f).load()


class WeightsDataset(Dataset):
    def __init__(self, state_dicts):
        self.data = [self.flatten_state_dict(sd) for sd in state_dicts]
        self.data = self.global_normalize(self.data)

    def flatten_state_dict(self, state_dict):
        weight1 = state_dict["classifier.0.weight"].view(-1)
        bias1 = state_dict["classifier.0.bias"].view(-1)
        weight2 = state_dict["classifier.2.weight"].view(-1)
        bias2 = state_dict["classifier.2.bias"].view(-1)
        flattened_w1 = torch.cat([weight1, bias1])
        flattened_w2 = torch.cat([weight2, bias2])
        return flattened_w1, flattened_w2

    def global_normalize(self, data):
        all_w1 = torch.cat([w1 for w1, _ in self.data])
        all_w2 = torch.cat([w2 for _, w2 in self.data])
        all = torch.cat([all_w1, all_w2])
        min, max = all.min(), all.max()
        data = [
            (self.minmax(w1, min, max), self.minmax(w2, min, max))
            for w1, w2 in self.data
        ]
        mean, std = all.mean(), all.std()
        data = [
            (self.normalize(w1, mean, std), self.normalize(w2, mean, std))
            for w1, w2 in self.data
        ]
        return data

    def minmax(self, data, min, max):
        return (2 * (data - min)) / (max - min) - 1

    def normalize(self, data, mean, std):
        return (data - mean) / std

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


train_dataset = WeightsDataset(real_models)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

sample = train_dataset[0]
w1_size, w2_size = sample[0].size()[0], sample[1].size()[0]

In [82]:
iris = load_iris()
X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=42, test_size=0.33
)
X_train, X_test, y_train, y_test = (
    torch.tensor(X_train, device=device, dtype=torch.float32),
    torch.tensor(X_test, device=device, dtype=torch.float32),
    torch.tensor(y_train, device=device, dtype=torch.long),
    torch.tensor(y_test, device=device, dtype=torch.long),
)

In [83]:
def objective(trial):
    hidden_dims_encoder = [
        trial.suggest_int(f"hidden_dim_enc_{i}", 16, 128)
        for i in range(trial.suggest_int("num_layers_encoder", 1, 4))
    ]
    hidden_dims_decoder = [
        trial.suggest_int(f"hidden_dim_dec_{i}", 16, 128)
        for i in range(trial.suggest_int("num_layers_decoder", 1, 4))
    ]
    z_dim = trial.suggest_int("z_dim", 8, 64)
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-1, log=True)
    epochs = trial.suggest_int("epochs", 50, 500)
    use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])

    vae = VAE(
        input_dim_w1=w1_size,
        input_dim_w2=w2_size,
        hidden_dims_encoder=hidden_dims_encoder,
        hidden_dims_decoder=hidden_dims_decoder,
        z_dim=z_dim,
        use_batchnorm=use_batchnorm,
    ).to(device)
    optimizer = optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)

    losses = []
    for epoch in range(epochs):
        vae.train()
        train_loss = 0
        for i, batch in enumerate(train_loader):
            batch_w1, batch_w2 = batch[0], batch[1]
            weight1_hat, weight2_hat, mu, logvar = vae(batch_w1, batch_w2)
            loss = loss_function(
                weight1_hat, weight2_hat, batch_w1, batch_w2, mu, logvar
            )
            optimizer.zero_grad()
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        avg_loss = train_loss / len(train_loader.dataset)
        losses.append(avg_loss)
        trial.report(avg_loss, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    NUM_OF_MODELS = 500
    with torch.no_grad():
        z = torch.randn(NUM_OF_MODELS, z_dim).to(device)
        vae.eval()
        weight1_hats, weight2_hats = vae.decode(z)

    temp_model = Iris2LayerClassifier()
    weight1_shape = temp_model.classifier[0].weight.shape
    bias1_shape = temp_model.classifier[0].bias.shape
    weight2_shape = temp_model.classifier[2].weight.shape
    bias2_shape = temp_model.classifier[2].bias.shape
    w1_len = torch.prod(torch.tensor(weight1_shape)).item()
    w2_len = torch.prod(torch.tensor(weight2_shape)).item()
    generated_models = []
    for i in range(NUM_OF_MODELS):
        weight1_hat = weight1_hats[i]
        weight2_hat = weight2_hats[i]
        w1 = weight1_hat[:w1_len].view(weight1_shape)
        b1 = weight1_hat[w1_len:].view(bias1_shape)
        w2 = weight2_hat[:w2_len].view(weight2_shape)
        b2 = weight2_hat[w2_len:].view(bias2_shape)
        new_state_dict = {
            "classifier.0.weight": w1,
            "classifier.0.bias": b1,
            "classifier.2.weight": w2,
            "classifier.2.bias": b2,
        }
        generated_models.append(new_state_dict)

    accuracies = []
    for i, state_dict in enumerate(generated_models):
        model = Iris2LayerClassifier().to(device)
        model.load_state_dict(state_dict)
        model.eval()
        with torch.inference_mode():
            y_pred = model(X_test)
            _, labels = torch.max(y_pred, 1)
            accuracy = accuracy_score(y_test.cpu().numpy(), labels.cpu().numpy())
            accuracies.append(accuracy)

    mean_accuracy = torch.tensor(accuracies).mean().item()
    std_accuracy = torch.tensor(accuracies).std().item()

    return mean_accuracy - std_accuracy

In [84]:
import optuna

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print(f"  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

[I 2024-09-13 14:36:59,124] A new study created in memory with name: no-name-e12117d2-b67b-42dd-9b96-66d4ee853b2c
[I 2024-09-13 14:37:10,048] Trial 0 finished with value: 0.6037428773421186 and parameters: {'num_layers_encoder': 2, 'hidden_dim_enc_0': 56, 'hidden_dim_enc_1': 66, 'num_layers_decoder': 2, 'hidden_dim_dec_0': 91, 'hidden_dim_dec_1': 42, 'z_dim': 16, 'lr': 0.004239166411079838, 'weight_decay': 4.945092711569469e-05, 'epochs': 369, 'use_batchnorm': False}. Best is trial 0 with value: 0.6037428773421186.
[I 2024-09-13 14:37:24,874] Trial 1 finished with value: 0.5871435642026036 and parameters: {'num_layers_encoder': 2, 'hidden_dim_enc_0': 114, 'hidden_dim_enc_1': 82, 'num_layers_decoder': 3, 'hidden_dim_dec_0': 49, 'hidden_dim_dec_1': 25, 'hidden_dim_dec_2': 45, 'z_dim': 41, 'lr': 0.00025590262983184615, 'weight_decay': 0.04920107967176251, 'epochs': 282, 'use_batchnorm': True}. Best is trial 0 with value: 0.6037428773421186.
[I 2024-09-13 14:37:48,851] Trial 2 finished wit

Best trial:
  Value: 0.6037428773421186
  Params: 
    num_layers_encoder: 2
    hidden_dim_enc_0: 56
    hidden_dim_enc_1: 66
    num_layers_decoder: 2
    hidden_dim_dec_0: 91
    hidden_dim_dec_1: 42
    z_dim: 16
    lr: 0.004239166411079838
    weight_decay: 4.945092711569469e-05
    epochs: 369
    use_batchnorm: False
