In [None]:
import random
import sys
import numpy as np
from tensorflow.keras.datasets import mnist

In [None]:
# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()


x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

full_x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
full_x_train /= 255
x_test /= 255
x_valid = full_x_train[-10000:]
print('x_train shape:', full_x_train.shape)
print(full_x_train.shape[0], 'train samples')
print(x_valid.shape[0], 'valid samples')
print(x_test.shape[0], 'test samples')


full_y_train = y_train
y_valid = y_train[-10000:]

In [None]:
import sys
sys.path.append('/media/nas/pgonzalez/DLquantification')
from histnet.histnet import HistNet
import torch
from histnet.featureextraction.fullyconnected import FCFeatureExtractionModule
from histnet.utils.utils import QLibPriorShiftBagGenerator
from torch.utils.data import TensorDataset
from histnet.utils.lossfunc import MRAE

model_files = []
device=torch.device('cuda:0')

loss_mrae = MRAE(eps=1.0 / (2 * 500)).MRAE

for seed in range(0,100,10):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    for model_idx,train_set_size in enumerate([30000]):
        x_train = full_x_train[:train_set_size] 
        y_train = full_y_train[:train_set_size]
        print(y_train.shape)
        print(y_valid.shape)
        train_dset = TensorDataset(
            torch.cat((torch.from_numpy(x_train),torch.from_numpy(x_valid))),
            torch.cat((torch.from_numpy(y_train),torch.from_numpy(y_valid)))
        )

  
        print("Using %d for training and %d for validation"%(len(y_train),len(y_valid)))
        train_bag_generator = QLibPriorShiftBagGenerator(device, method="Dirichlet", alphas=1)
        val_bag_generator = QLibPriorShiftBagGenerator(device, method="Dirichlet", alphas=1)

        model_file = "model_quant_mnist_mse_set-"+str(train_set_size)+"_mse_seed-"+str(seed)+".h5"
        model_files.append(model_file)
        fe = FCFeatureExtractionModule(input_size=784, output_size=128, hidden_sizes=[256], dropout=0, activation="relu",flatten=True)
        model = HistNet(train_epochs = 1000, test_epochs = 1, batch_size=100, n_classes = 10, start_lr = 0.001, 
            end_lr = 0.00001, n_bags = 5000, histogram="softrbf", bag_size=1000,n_bins=32, random_seed=seed,linear_sizes=[512],
            feature_extraction_module=fe,bag_generator=train_bag_generator,patience=10,
            quant_loss=torch.nn.MSELoss(),
            dropout=0.5,
            val_bag_generator=val_bag_generator,
            val_split=(range(0,train_set_size),range(train_set_size,train_set_size+10000)),verbose=10,dataset_name="minst",device=device,
            save_model_path=model_file
            )
        model.fit(train_dset)