# Setup

In [None]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging

logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

In [None]:
# If you want to  assign 2 TPU cores per process,
# for both notebook and command line
#https://docs.google.com/document/d/1sbRFVSPePq_8oGBntSOmG0V5gqxyNiuDn-4_ph8eoBc/edit#heading=h.y89aert1620u

import os 
# 2x 2 chips (4 cores) per process:
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,2,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
# Different per process:
os.environ["TPU_VISIBLE_DEVICES"] = "0,1" # Change to "2,3" for the second machine
# Pick a unique port per process
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476"
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476"

print('done')

In [None]:
import os 
cpu_count = os.cpu_count()
print(cpu_count)

# Run jax on multiple CPU cores
# https://github.com/google/jax/issues/5506
# https://stackoverflow.com/questions/72328521/jax-pmap-with-multi-core-cpu
import os 
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'


In [None]:
import jax
print(jax.devices())

In [None]:
%reload_ext autoreload

In [None]:
from shifty.label_shift.labelshift import *

print(nfactors)

# MNIST

In [None]:
import torchvision
mnist_train = torchvision.datasets.MNIST(root="~/data", train=True, download=True)
print(mnist_train.data.shape)
print(mnist_train.targets.shape)
images = np.array(mnist_train.data) / 255.0

In [None]:

import skimage
print(skimage.__version__)

import skimage.util
m = skimage.util.montage(images[:9, :, :])
print(m.shape)
img = plt.imshow(m, cmap=plt.cm.gray);
plt.axis('off');

In [None]:
m = skimage.util.montage(images[:9, :, :])
img = plt.imshow(m, cmap=plt.cm.gray);
plt.axis('off');

In [None]:
from augly import image

def processor(X, angle):
    X_shift = image.aug_np_wrapper(X, image.rotate, degrees=angle)
    size_im = X_shift.shape[0]
    size_pad = (28 - size_im) // 2
    size_pad_mod = (28 - size_im) % 2
    X_shift = np.pad(X_shift, (size_pad, size_pad + size_pad_mod))    
    return X_shift




# Parallel

In [None]:
from math import sqrt
from joblib import Parallel, delayed, parallel_backend
from itertools import repeat
from time import time

In [None]:
import itertools

#iter = itertools.product('ABCD', repeat=2)

arg1 = np.reshape(np.arange(0, 6), (3, 2)) 
arg2 = [1, 2]
arg3 = ['foo', 'bar', 'baz']
arg3 = ['foo']

def make_arg_combo(*args):
    iter = itertools.product(*args)
    combo = list(iter)
    return np.array(combo, dtype=object) # one row per combo

arg_combo = make_arg_combo(arg1, arg2, arg3)
print(arg_combo)

nproc = 2
arg_blocks  = np.array_split(arg_combo, nproc)

print(len(arg_blocks))
print(arg_blocks[0])

In [None]:
args = list(range(10))
print(args)
def f(a):
    return jnp.sqrt(jnp.power(a,2))

out1 = [f(arg) for arg in args]
print(out1)

out2 = Parallel(n_jobs=2)(delayed(f)(arg) for arg in args)
print(out2)

In [None]:


KEY = jr.PRNGKey(42)

def f(arg):
    key, x = arg
    N = 6000
    mat = jr.normal(KEY, (N, N))
    return jnp.max(mat * mat * x)

args = list(zip(repeat(KEY), np.arange(200)))

init_time = time()
out1 = [f(arg) for arg in args]
out1 = np.array(out1)
end_time = time()
print(f"Serial Time elapsed: {end_time - init_time:.2f}s")

init_time = time()
out2 = Parallel(n_jobs=10, prefer="threads", verbose=1)(delayed(f)(arg) for arg in args)
#out2 = Parallel(n_jobs=-2, verbose=1)(delayed(f)(arg) for arg in args)
out2 = np.array(out2)
end_time = time()
print(f"Parallel Time elapsed: {end_time - init_time:.2f}s")

assert jnp.allclose(out1, out2)


In [None]:
import pandas as pd

import itertools

arg1 = np.reshape(np.arange(0, 6), (3, 2)) 
arg2 = [1, 2]
#arg3 = ['foo', 'bar', 'baz']
arg3 = ['foo']

def make_arg_combo(*args):
    iter = itertools.product(*args)
    combo = list(iter)
    return np.array(combo, dtype=object) # one row per combo

arg_combo = make_arg_combo(arg1, arg2, arg3)
print(arg_combo)


df = pd.DataFrame({"data": arg_combo[:,0], "sf": arg_combo[:,1], "str": arg_combo[:,2]})
print(df)

In [None]:
def process_single(arg1, arg2, arg3):
    return arg1  * arg2
    
def process_batch(args_list):
    X_out = []
    for arg_tuple in  args_list:
        output = process_single(*arg_tuple)
        X_out.append(output)
    return np.stack(X_out, axis=0)

def process_all(args_blocks):
    output = []
    for batch in arg_blocks:
        out = process_batch(batch)
        output.append(out)
    return np.concatenate(output, axis=0)
    
output1 = process_batch(arg_combo)
print(len(output1))
print(output1)

output2 = process_all(arg_blocks)
print(len(output2))
print(output2)