In [None]:
%load_ext autoreload
%autoreload 2

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import time
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

from collections import defaultdict

import scipy
import numpy as onp
onp.set_printoptions(precision=3,suppress=True)

import jax
import jax.numpy as np
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn
from jax.scipy.stats import dirichlet
import optax

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.tri as tri
import matplotlib.patches as mpl_patches

# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'DejaVu Sans'
mpl.rcParams['axes.linewidth'] = 3
cmap = plt.cm.get_cmap('bwr')

from tabulate import tabulate
from functools import partial
import copy

from setup_convgp import *
from plt_utils import *
from gpax import *
from dataset import *

jax_status()

In [None]:
crop_by_bbox = True
dir_cub200 = './data/CUB_200_2011'
dir_cub200_images = os.path.join(dir_cub200, 'images')

if crop_by_bbox:
    dir_train = os.path.join(dir_cub200, 'cropped_images_train')
    dir_test  = os.path.join(dir_cub200, 'cropped_images_test')
else:
    dir_train = os.path.join(dir_cub200, 'images_train')
    dir_test  = os.path.join(dir_cub200, 'images_test')

In [None]:
from torchvision import transforms

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

def transform_normalize(im, mean, std):
    """ Assumes `im` has dimension (H, W, C) with dtype of `np.uint8` """
    assert(im.ndim == 3)
    mean = np.array(mean, dtype=np.float32)*255
    std = np.array(std, dtype=np.float32)*255
    im = im.astype(np.float32)
    im = (im - mean) / std
    return im


def transform_normalize_undo(im, mean, std):
    """ Un-normalize by `mean` & `std` but also  normalize 
        s.t. `im` has range of [0, 1] for visualization purposes """
    assert(im.ndim == 3)
    mean = np.array(mean, dtype=np.float32)*255
    std = np.array(std, dtype=np.float32)*255
    im = im*std + mean
    im = im/255
    return im

normalize = lambda im: transform_normalize(np.asarray(im), mean, std)

transform_train = transforms.Compose([transforms.Resize(256),
                                      transforms.RandomCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      normalize])
transform_test = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     normalize])
dataset_train = torchvision.datasets.ImageFolder(dir_train, transform_train)
dataset_test = torchvision.datasets.ImageFolder(dir_test, transform_test)
idx_to_class = {v:k for k,v in dataset_train.class_to_idx.items()}

dataset_train, dataset_test


In [None]:
im, y = dataset_test[0]

fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(transform_normalize_undo(im, mean, std))
ax.set_title(f'{idx_to_class[y]}')
