In [16]:
deps <- c("ISLR", "ndjson", "ggplot2", "dplyr", "caret", "tidyr", "gridExtra", "data.table", "repr", "reshape2", "grid", "ggpubr", "patchwork", "cowplot")
for (p in deps) {
    if (!requireNamespace(p, quietly = TRUE)) {
        install.packages(p, repos = "https://cran.rstudio.com")
    }
    library(p, character.only = TRUE)
}
options(scipen=999)

df <- ndjson::stream_in("./baseline_cifar10_subset.jsonl")

# make sure subset is correct
init_len <- nrow(df)
df <- df %>% filter(dataset == "cifar10" & mask_sides == 3 & mask_num_concentric == 2 & mask_colors == TRUE)
df <- df %>% select(-c("dataset", "mask_sides", "mask_num_concentric", "mask_colors"))
df <- df %>% select(-c("acc_0", "acc_2", "acc_4", "acc_8")) # we only look aat the same plots as the self-ensembled resnet
df <- df %>% select(-c("acc")) # we don't care about plain accuracy in this plot
stopifnot(nrow(df) == init_len)
# print(names(df))

# plot
# options(repr.plot.width=11, repr.plot.height=3)
options(repr.plot.width=11, repr.plot.height=13)

dataset <- "cifar10"
mask_sides <- 3
mask_num_concentric <- 2
mask_colors <- TRUE

plot_grid_rows <- list()
mask_row_cols <- c(2, 4, 10)

for (row_col in mask_row_cols) {
    
    plot_grid_cols <- list()
    mask_opacities <- c(16, 32, 64, 128) # same for train and eval

    for (opacity in mask_opacities) {

        ratio_range <- c(0.0, 0.5, 1.0)
        epochs_range <- c(0, 2, 6)

        p <- ggplot(df, aes_string(x="train_hcaptcha_ratio", y=factor(df$train_epochs), fill=paste0("acc_", opacity))) +
            geom_tile() +
            scale_fill_gradient() +
            scale_x_continuous(breaks=ratio_range) +
            scale_y_discrete() +
            theme_minimal() +
            labs(
                title=paste0("Opacity: ", opacity),
                x="Train AdvX Ratio",
                y="Train Epochs",
                fill="Accuracy"
            )

        plot_grid_cols[[length(plot_grid_cols) + 1]] <- p
    }

    title <- textGrob(
        paste("Attack: Geometric Mask\n(3 sides,", row_col,"per row/col, 2 concentric shapes, colors enabled)"), gp = gpar(fontsize = 15, fontface = "bold"),
        vjust = 0.25,
        y = 0.2
    )
    p <- wrap_plots(
        ncol = 4,
        plotlist = plot_grid_cols,
        guides = "collect"
    )

    plot_grid_rows[[length(plot_grid_rows) + 1]] <- title
    plot_grid_rows[[length(plot_grid_rows) + 1]] <- p
}

twidth <- 0.2
pwidth <- 5
p <- wrap_plots(
    ncol = 1,
    plotlist = plot_grid_rows,
) + plot_annotation(
    title = paste("Dataset: cifar10"),
    theme = theme(
        plot.title = element_text(size = 20, hjust = 0, face = "bold", margin = margin(b = 2, t = 2)),
        plot.margin = unit(c(0.1, 0.1, 0.1, 0.1), "cm")
    )
)

ggsave("baseline_cifar10.pdf", p, width = 11, height = 13, units = "in")
