CPU

In [None]:
!pip install --upgrade pip
!pip install --upgrade "jax[cpu]"

GPU

In [None]:
!pip install --upgrade "jax[cuda]"

## **Autodiff (jax.grad): temel aritmetik işlemlerden yararlanarak bir fonksiyonun türevini değerlendirmek için bir teknik.**



In [None]:
from jax import grad

def func(x):
  return x**2

d_func = grad(func)

In [None]:
d2_func = grad(d_func)

In [None]:
x = 2.0
result = d_func(x)

In [None]:
print(result)

4.0


##**JIT derlemesi ( jax.jit): Hızlandırıcı işlemleri (GPU ve TPU)**








In [None]:
from jax import jit

def funct(x):
  return x*(2+x)

c_funct = jit(funct)

In [None]:
x = 2.0
result = c_funct(x)

In [None]:
print(result)

8.0


##**Paralelleştirme (jax.pmap): Kodu birden çok hızlandırıcı arasında otomatik olarak paralel hale getirir (CPU, GPU ve TPU)**


In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [None]:
from jax import numpy as jnp
from jax import pmap
from jax import random

key = random.PRNGKey(42)
a = random.normal(key, shape=(3000,5000))
b = random.normal(key, shape=(5000,3000))
matrix_mul = lambda a, b: jnp.matmul(a, b)
matrix_mul(a, b).shape

(3000, 3000)

## Birden fazla çekirdekte büyük veri kümeleriyle çalışırken, mevcut kaynaklardan en iyi şekilde yararlanmak için verileri paralelleştirmek önemlidir.

In [None]:
n_devices = jax.local_device_count()
a = random.normal(key, shape=(n_devices, 3000, 5000))
b = random.normal(key, shape=(n_devices, 5000, 3000))
parallel_matrix_mul = pmap(matrix_mul)
parallel_matrix_mul(a, b).shape

(8, 3000, 3000)

## Vektörleştirme (jax.vmap), vektörleştirme işlemlerini paralel olarak hızlandırmak ve tekrarlayan işlemleri daha verimli hale getirmek için kullanılır.


In [None]:
import jax
import jax.numpy as jnp

def v_func(x):
    return x ** 2

matrix = jnp.array([[1, 2, 3],
                    [4, 5, 6]])
result = jax.vmap(v_func)(matrix)

print(result)

[[ 1  4  9]
 [16 25 36]]


## NumPy TEST

In [None]:
import numpy as np

def fn(x):
  return x + x*x + x*x*x

x = np.random.randn(10000,10000).astype(dtype='float32')
%timeit -n5 fn(x)

446 ms ± 16.9 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [None]:
from jax import jit, random
import jax.numpy as jnp

jax_fn = jit(fn)

key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000), dtype=jnp.float32)

%timeit -n5 jax_fn(x).block_until_ready()

4.22 ms ± 802 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)


## **FLAX**

kaynak: https://github.com/google/flax

In [None]:
!pip install flax

In [None]:
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])

batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

In [None]:
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

In [None]:
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.PRNGKey(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)