In [4]:
import time
import math
import os

import jax
import jax.numpy as jnp
import flax
import matplotlib.pyplot as plt
%matplotlib inline

from matplotlib.colors import to_rgba
from IPython.display import set_matplotlib_formats
import seaborn as sns
from tqdm.auto import tqdm

In [5]:
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [2]:
!pip install jax

Collecting jax
  Downloading jax-0.4.6.tar.gz (1.2 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m0:01[0m:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting opt_einsum
  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.5/65.5 kB[0m [31m883.5 kB/s[0m eta [36m0:00:00[0m31m92.3 MB/s[0m eta [36m0:00:01[0m
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25ldone
[?25h  Created wheel for jax: filename=jax-0.4.6-py3-none-any.whl size=1432714 sha256=2ddeefcf18e5bcf9f0ecd02be51b2f95f4230b7c988d88ea074cad47b39f926d
  Stored in directory: /home/iiserbpr/.cache/pip/wheels/68/2c/93/17deec4d117dc0675ed79e8e2af1e62fb1c41ed3955c540de0
Successfully built jax
Installing collected packages: opt_einsum, jax
Successfully installed jax-0.4.6 opt_einsu

In [3]:
import jax
import jax.numpy as jnp

ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.

In [4]:
!pip install --upgrade "jax[cpu]"

Collecting jaxlib==0.4.6
  Downloading jaxlib-0.4.6-cp39-cp39-manylinux2014_x86_64.whl (62.0 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
[?25hInstalling collected packages: jaxlib
Successfully installed jaxlib-0.4.6


In [8]:
import jax
import jax.numpy as jnp
rng = jax.random.PRNGKey(42)


random1 = jax.random.normal(rng)
random2 = jax.random.normal(rng)
random3 = jax.random.normal(rng)
print("Random Number: 1", random1)
print("Random Number: 2", random2)
print("Random Number: 3", random3)

Random Number: 1 -0.18471177
Random Number: 2 -0.18471177
Random Number: 3 -0.18471177


In [9]:
def simple_graph(x):
    x = x + 2
    x = x ** 2
    x = x + 3
    y = x.mean()
    return y

inp = jnp.arange(3, dtype=jnp.float32)
print('Input', inp)
print('Output', simple_graph(inp))

Input [0. 1. 2.]
Output 12.666667


In [13]:
def another_graph(x):
    x = x**2
    x = x+3
    x = x.mean()
    y = float(x)
    return y

inpt = jnp.arange(4, dtype = jnp.float32)
print('Input', inpt)
print('Output', another_graph(inpt))
    

Input [0. 1. 2. 3.]
Output 6.5


In [15]:
jax.make_jaxpr(another_graph)(inpt)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
The error occurred while tracing the function another_graph at /tmp/ipykernel_4263/1389991900.py:1 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument x.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [16]:
global_list = []

# Invalid function with side-effect
def norm(x):
    global_list.append(x)
    x = x ** 2
    n = x.sum()
    n = jnp.sqrt(n)
    return n

jax.make_jaxpr(norm)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = integer_pow[y=2] a
    c[35m:f32[][39m = reduce_sum[axes=(0,)] b
    d[35m:f32[][39m = sqrt c
  [34m[22m[1min [39m[22m[22m(d,) }

In [17]:
grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print('Gradient', gradients)

Gradient [1.3333334 2.        2.6666667]


# DataLoaders for Generative Models in Torch

In [None]:
import torch
from torch.utils.data import Data
from torch.utils.data import DataLoader

In [1]:
from torchvision import datasets


In [None]:
MNIST_dataset = torchvision.datasets.MNIST()
CIFAR_10_dataset = torchvision.datasets.CIFAR10()

In [18]:
def load_data(data:str, batch_size, shuffle= True, num_workers:int =0):
    '''
    Arguments:
    data_set: Give the name of the dataset 
    options available: MNIST or CIFAR10 or custom
    batch_size: Number of mini-batches to create
    shuffle: True, if shuffling of the dataset is enabled. False if not
    Default shuffle = true
    num_workers: Number of threads to use for preparing the dataset
    returns Pytorch dataloader for the dataset
    '''
    if data == 'MNIST':
        data_loader = torch.utils.data.DataLoader(MNIST_dataset,
                                                 batch_size = batch_size,
                                                 shuffle = shuffle,
                                                 num_workers = args.nThreads)
    elif data == 'CIFAR10':
        data_loader = torch.utils.data.DataLoader(CIFAR_10_dataset,
                                                 batch_size = batch_size,
                                                 shuffle = shuffle,
                                                 num_workers = args.nThreads)
    elif data == 'CelebA64':
        data_loader = torch.utils.data.DataLoader(celebA64_dataset,
                                                 batch_size = batch_size,
                                                 shuffle = shuffle,
                                                 num_workers = args.nThreads)
    elif data == 'CelebA128':
        data_loader = torch.utils.data.DataLoader(celebA128_dataset,
                                                 batch_size = batch_size,
                                                 shuffle = shuffle,
                                                 num_workers = args.nThreads)
    elif data == 'custom':
        data_loader = torch.utils.data.DataLoader(data,
                                                 batch_size = batch_size,
                                                 shuffle = shuffle,
                                                 num_workers = args.nThreads)
                  
    return data_loader

#### Pro Tip:
Implement some basic training of models on standard datasets
Examples:
MNIST
CIFAR100


In [None]:
import os
import functools import partial
import jax
import jax.numpy as jnp
from typing import Any,Tuple

PRNGKey = jnp.ndarray

num_devices = jax.device_count()



## Configuration settings for individual model components

In [None]:
args = {
    'z_dim': 64,
    'seed': 41,
    'batch_size':,
    'epochs':,
    'num_layers':,
    
}

data_args= {
    
    
}

model_args= {
    
}

util_args = {
    
    
    
}

hyper_parameters= {
    
}


### Device configuration: CPU/ GPU

In [7]:
import torch

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"***************{device} found on the device******************" .format(device))
print(f"************Running the Model Processes on {device}******************" .format(device))

***************cpu found on the device******************
************Running the Model Processes on cpu******************


SEND TRAINING DATA TO GPU USING THE CODE:

'''X_train = X_train.to(device)'''

Testing and validation will take place on CPU

MODEL DECLARATION ON GPU
model MyAwesomeneuralnetwork()
model.to_device()



## Vanilla-GAN Model Architecture

## Model Declaration

In [None]:
import haiku as hk


class Generator_network(hk.module):
    """Model declaration of Generator Network"""
    def __init__(self, output_channels = output_channels, name = None):
        super().__init__(name=name)
        self.output_channels = output_channels
    
    def __call__(self,x):
        x = hk.Linear

## GAN Loss Function

In [None]:
import optax

@partial(jax.pmap, axis_name = 'num_devices')

def generator_step(generator_state:TrainState, 
                   discriminator_state:TrainState,
                   key: PRNGKey):
    r"""
    """
    input_noise = jax.random.normal(key, (args['batch_size'], args['z_dim']))
    
    
    def loss_fn(params):
        
        
                                    

In [16]:
import jax
print(f"Number of devices: {jax.device_count()}")
print("Device:", jax.devices()[0].device_kind)
print("")

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Number of devices: 1
Device: cpu

