In [None]:
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import LearningRateScheduler, ReduceLROnPlateau
from tensorflow.keras.models import  Model
import numpy as np

from data.id_dataloader import load_cifar10, load_intel_image, load_mnist, load_cifar100
from data.classes import cifar10_classes, mnist_classes, intel_image_classes, cifar100_classes

from models.models import resnet50, wideresnet2810, vgg16, inceptionv3, efficientnetb2
from models.pretrained_models import pretrained_resnet50, pretrained_vgg16

from rsnn_functions.budgeting import train_embeddings, fit_gmm, ellipse, overlaps
from rsnn_functions.bf_encoding_gt import groundtruthmod
from rsnn_functions.belief_mass_betp import belief_to_mass, mass_coeff, final_betp
from rsnn_functions.rsnn_loss import BinaryCrossEntropy

from utils.train_utils import lr_schedule, train_val_split, data_generator, lr_callbacks, save_model_and_weights

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    num_gpus = len(gpus)
    print(f"Number of GPUs available: {num_gpus}")
    # Set GPUs to use. For example, limit TensorFlow to use 3 GPUs
    tf.config.experimental.set_visible_devices(gpus[:3], 'GPU')
    
# Create a MirroredStrategy for multi-GPU use
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

In [None]:
# Initializing parameters 
k = 20  #number of number of non-singleton focal sets 
batch_size = 128
epochs = 100

In [None]:
num_classes = {"cifar10": 10, "mnist": 10, "intel_image": 6, "cifar100": 100, "svhn": 10, "fmnist": 10, "kmnist":10}

dataset_loader = {
 "cifar10": load_cifar10, 
 "mnist": load_mnist, 
 "intel_image": load_intel_image, 
 "cifar100": load_cifar100, 
}

models = {
    "resnet50": resnet50, 
    "wideresnet_28_10": wideresnet2810, 
    "vgg16": vgg16,
    "inception_v3": inceptionv3,
    "efficientnet_b2": efficientnetb2
}

pretrained_models = {
    "pretrained_resnet50": pretrained_resnet50, 
    "pretrained_vgg16": pretrained_vgg16,
}

class_list_functions = {
    "cifar10": cifar10_classes,
     "mnist": mnist_classes, 
    "intel_image": intel_image_classes, 
    "cifar100": cifar100_classes, 
}

In [None]:
# Define configurations
selected_dataset = "cifar10"  # Choose the dataset
selected_model = "resnet50"   # Choose the model
batch_size = 128
epochs = 100

# Class list
classes = class_list_functions[selected_dataset]()
print("Classes:", classes)

num_clusters = len(classes)
classes_dict = {c:num for c,num in zip(classes, range(len(classes)))}
classes_dict_inverse = {num:c for c,num in zip(classes, range(len(classes)))}

# Load dataset based on selected_dataset
x_train, y_train, x_test_org, x_test, y_test = dataset_loader[selected_dataset]()

# Infer input_shape based on selected_dataset
input_shape = x_train.shape[1:]

# Train-validation split
x_train, y_train, y_train_one_hot, x_val, y_val, y_val_one_hot = train_val_split(x_train, y_train, num_classes[selected_dataset], val_samples=-10000)

print("Shape of x_train:", x_train.shape)
print("Shape of x_test:", x_test.shape)
print("Shape of x_val:", x_val.shape)

# Learning rate scheduler
callbacks = lr_callbacks(lr_schedule)

# Data augmentation
datagen = data_generator(x_train)

## CNN

In [None]:
# Multi-GPU run
with strategy.scope():      
    # Create the model based on selected_model
    if selected_model in pretrained_models:
        model = pretrained_models[selected_model](input_shape=input_shape,  num_classes=num_classes[selected_dataset], final_activation='softmax')
    else:
        model = models[selected_model](input_shape=input_shape, num_classes=num_classes[selected_dataset], final_activation='softmax')

    # Compile the model 
    model.compile(loss='categorical_crossentropy',
                optimizer="adam",
                metrics=['accuracy'])

model.summary()

In [None]:
history = model.fit(datagen.flow(x_train, y_train_one_hot, batch_size=batch_size),
                    validation_data=(x_val, y_val_one_hot),
                    epochs=epochs, verbose=1, workers=2)

In [None]:
# # Save model and weights
# save_model_and_weights(model, selected_model, selected_dataset, model_type='CNN')

## BUDGETING

In [None]:
# Extracting features from the penultimate layer
aux_model = Model(model.input, model.layers[-2].output)

# 3D feature space respresentation of class embeddings
train_embedded_tsne = train_embeddings(aux_model, x_train, batch_size)

# Fitting Gaussian Mixture Models (GMM) to individual classes
individual_gms = fit_gmm(classes, train_embedded_tsne, y_train)

# Calculating clusters for each class
regions, means, max_len = ellipse(individual_gms, num_classes[selected_dataset])

# Compute the overlap and choose the sets of classes with highest overlap
new_classes = overlaps(k, classes, num_clusters, classes_dict, regions, means, max_len)

# np.save('new_classes.npy', new_classes)
print(new_classes)

In [None]:
## Load saved new_classes
# new_classes = np.load('new_classes.npy', allow_pickle=True)

In [None]:
# Belief-encoding of the ground truth
y_train_modified = groundtruthmod(y_train, classes, new_classes, classes_dict_inverse)
y_val_modified = groundtruthmod(y_val, classes, new_classes, classes_dict_inverse)
y_test_modified = groundtruthmod(y_test, classes, new_classes, classes_dict_inverse)

## RS-NN

In [None]:
# Multi-GPU run
with strategy.scope():      
    # Create the model based on selected_model
    if selected_model in pretrained_models:
        new_model = pretrained_models[selected_model](input_shape=input_shape,  num_classes=len(new_classes), final_activation='sigmoid')
    else:
        new_model = models[selected_model](input_shape=input_shape, num_classes=len(new_classes), final_activation='sigmoid')

    # Compile the model 
    new_model.compile(loss=BinaryCrossEntropy,
                optimizer="adam",
                metrics=['binary_accuracy'])

new_model.summary()

In [None]:
history_new = new_model.fit(datagen.flow(x_train, y_train_modified, batch_size=batch_size),
                    validation_data=(x_val, y_val_modified),
                    epochs=epochs, verbose=1, workers=2,
                   callbacks=callbacks)

In [None]:
history_new = new_model.fit(datagen.flow(x_train, y_train_modified, batch_size=batch_size),
                    validation_data=(x_val, y_val_modified),
                    epochs=epochs, verbose=1, workers=2,
                   callbacks=callbacks)

In [None]:
# # Save model and weights
# save_model_and_weights(new_model, selected_model, selected_dataset, model_type='RSNN')