# Creation of figures for the IEEE J-STARS paper

## Setup

### Import the required libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np

### Default settings

In [None]:
default_size = 14
label_size = 16
decimals=0

### LaTeX style configuration

In [None]:
plt.rc('text', usetex=True)
plt.rc('font', family='serif', size=default_size)

## Figures

### Training time

In [None]:
total_pretraining_epochs = 500

model_1 = 'SSL Barlow Twins'
model_2 = 'Fully-supervised'

epoch_duration = {
    model_1: [463.38, 480.70, 461.12, 459.01, 457.66, 459.29],
    model_2: [10650.33, 9796.01, 11453.84, 10345.85, 9635.58, 9936.32]                                                       # Update.
}

In [None]:
training_time = {k: np.round(np.mean(v)*total_pretraining_epochs/(3600), decimals).astype(int) for k, v in epoch_duration.items()}
print(training_time)

In [None]:
width = 0.5

# Create a bar plot comparing both training times.
plt.figure(figsize=(6, 2.25))
plt.bar(training_time.keys(), training_time.values(), width=width, color={'blue', 'orange'})

# Add labels to the plot.
plt.ylabel('Training time (h)', labelpad=15, fontsize=label_size)
plt.xlabel('Model', labelpad=15, fontsize=label_size)

# Add the values on top of the bars.
for i, v in enumerate(training_time.values()):
    plt.text(i, v, str(v), ha='center', va='bottom', fontsize=default_size)

# Configure the y-axis.
max_y = max(training_time.values())
margin = 0.25
plt.ylim(0, max_y + max_y*margin)
plt.yticks(np.arange(0, max_y + max_y*margin, 500))

# Adjust the plot margins and show.
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.15)

# Save the plot.
plt.savefig(f'training_time.png', dpi=600, bbox_inches='tight')
plt.savefig(f'training_time.pdf', dpi=600, bbox_inches='tight')
plt.show()


In [None]:
# width = 0.1

# # Create a bar plot comparing both training times.
# fig, ax = plt.subplots(figsize=(6, 2.25))

# positions = (0.45, 0.6)

# ax.bar(positions[0], training_time[model_1], width, color='blue')
# ax.bar(positions[1], training_time[model_2], width, color='orange')

# # Add the values on top of the bars.
# for i, v in enumerate(training_time.values()):
#     plt.text(positions[i], v, str(v), ha='center', va='bottom', fontsize=default_size)

# # Add labels to the plot.
# plt.ylabel('Training time (h)', labelpad=15, fontsize=label_size)
# plt.xlabel('Model', labelpad=15, fontsize=label_size)

# # Now set the ticks and the corresponding labels
# plt.xticks(positions, (model_1, model_2))

# # Configure the y-axis.
# max_y = max(training_time.values())
# margin = 0.25
# plt.ylim(0, max_y + max_y*margin)
# plt.yticks(np.arange(0, max_y + max_y*margin, 500))

# # Adjust the plot margins and show.
# plt.gcf().subplots_adjust(bottom=0.15)
# plt.gcf().subplots_adjust(left=0.15)

# # Save the plot.
# plt.savefig(f'training_time_v2.png', dpi=600, bbox_inches='tight')
# plt.savefig(f'training_time_v2.pdf', dpi=600, bbox_inches='tight')
# plt.show()