<a href="https://colab.research.google.com/github/sidms24/internship/blob/main/internship/june%20/week3_/Thursday.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
import jax
from tqdm import tqdm
import pandas as pd
import numpy as np
import lpips_jax

In [2]:
class Convlayer:
  def __init__(self,din, dout, kernel_size,key, stride = 1, padding=0 ):
    weights_key, bias_key = jax.random.split(key)
    limit = jnp.sqrt(6 / (din + dout))
    stdv = 1. / jnp.sqrt(din * kernel_size * kernel_size)
    self.W = jax.random.uniform(weights_key, (dout, din, kernel_size, kernel_size), minval=-limit, maxval=limit)
    self.b = jax.random.uniform(bias_key, (dout,), minval=-limit, maxval=limit)
  def forward(self, x):
    self.x = x
    dim_num = ('NCHW', 'OIHW', 'NCHW')
    return jax.lax.conv_general_dilated(x, self.W, (self.stride, self.stride), self.padding, dim_num)

  def forward_transpose(self, x):
    self.x = x
    dim_num = ('NCHW', 'OIHw', 'NCHW')
    return jax.lax.conv_transpose(x, self.W, (self.stride, self.stride), self.padding, dim_num)


In [17]:
class Encoder:
  def __init__(self, din, dout,latent_dim, key, stride = 1, padding=0):
    keys = jax.random.split(key, 6)
    self.latent_dim = latent_dim
    self.conv1 = Convlayer(din, dout, 3, keys[0], stride, padding)
    self.conv2 = Convlayer(dout, dout, 3, keys[1], stride, padding)
    self.conv3 = Convlayer(dout, dout, 3, keys[2], stride, padding)
    self.conv4 = Convlayer(dout, dout, 3, keys[3], stride, padding)
    self.conv5 = Convlayer(dout, dout, 3, keys[4], stride, padding)
    self.conv_out = Convlayer(dout, latent_dim * 2, 3, keys[5], stride, padding)

  @property
  def params(self):
    return {'conv1': self.conv1.params, 'conv2': self.conv2.params, 'conv3': self.conv3.params,
            'conv4': self.conv4.params, 'conv5': self.conv5.params, 'conv_out':self.conv_out.params}
  def __call__(self, params, x):
    x = self.conv1(params['conv1'], x)
    x = self.conv2(params['conv2'], x)
    x = self.conv3(params['conv3'], x)
    x = self.conv4(params['conv4'], x)
    x = self.conv5(params['conv5'], x)
    x_out = self.conv_out(params['conv_out'], x)
    mu, logvar = jnp.split(x_out, 2, axis = 1)
    return mu, logvar



In [12]:
class DGAE:
  def __init__(self, din, dout, latent_dim, key):
    enc_key, dec_key = jax.random.split(key, 2)
    self.encoder = Encoder(din, dout,latent_dim, enc_key)
    self.decoder = diffusionUnet(dout, din, dec_key, latent_dim=latent_dim)

    self.noise_schedule = noiseschedule()
    self.key = key

  @property
  def params(self):
    return {'encoder': self.encoder.params, 'decoder': self.decoder.params,
            'noise_schedule': self.noise_schedule.params}

  def reparameterize(self, mu, logvar):
    std = jnp.exp(0.5 * logvar)
    eps = jax.random.normal(self.key, std.shape)
    return mu + eps * std

  def __call__(self, params, x, rng):
    mu, logvar = self.encoder(params['encoder'], x)
    z = self.reparameterize(mu, logvar)

    t = jax.random.randint(rng, (x.shape[0],), 0, self.noise_schedule.num_timesteps)
    noise = jax.random.normal(rng, x.shape)
    x_noisy = self.noise_schedule(rng, x, t)
    x_pred = self.decoder(params['decoder'], x_noisy, t, z)

    return x_pred, noise, mu, logvar



In [5]:
class groupnorm:
  def __init__(self, channels, num_groups = 32, key = None):
    self.gamma = jnp.ones((1, channels, 1, 1))
    self.beta = jnp.zeros((1, channels, 1, 1))
    self.num_groups = min(num_groups, channels)

  @property
  def params(self):
    return {'gamma': self.gamma, 'beta': self.beta}

  def __call__(self, params, x):
    N, C, H, W = x.shape
    G = self.num_groups
    x = x.reshape(N, G, C // G, H, W)
    mean, var = jnp.mean(x, axis=(2, 3, 4), keepdims=True), jnp.var(x, axis=(2, 3, 4), keepdims=True)
    x = (x - mean) / jnp.sqrt(var + 1e-5)
    x = x.reshape(N, C, H, W)
    return x * params['gamma'] + params['beta']

In [6]:
class resblock:
  def __init__(self, din, dout, key,timembed_dim = 512, stride = 1, padding=1):
    keys = jax.random.split(key, 4)
    self.conv1 = Convlayer(din, dout, 3, keys[0], stride, padding)
    self.conv2 = Convlayer(dout, dout, 3, keys[1], stride, padding)
    self.norm1 = groupnorm(dout, keys[2])
    self.norm2 = groupnorm(dout, keys[3])
    self.emb_proj = MLP(timembed_dim, dout, keys[3])

  @property
  def params(self):
    return {'conv1': self.conv1.params, 'conv2': self.conv2.params,
            'norm1': self.norm1.params, 'norm2': self.norm2.params,
            'emb_proj': self.emb_proj.params}

  def __call__(self, params, x, t):
    residual = x
    x = self.norm1(params['norm1'], x)
    x = jax.nn.relu(x)
    x = self.conv1(params['conv1'], x)
    x = self.norm2(params['norm2'], x)
    x = jax.nn.relu(x)
    x = self.conv2(params['conv2'], x)
    return x + self.emb_proj(params['emb_proj'], t) + residual

In [7]:
class timeEmbed:
  def __init__(self, dim, key):
    keys = jax.random.split(key, 2)
    self.linear1 = MLP(dim, dim * 4, keys[0])
    self.linear2 = MLP(dim * 4, dim, keys[1])

  @property
  def params(self):
    return {'linear1': self.linear1.params, 'linear2': self.linear2.params}

  def __call__(self, params, t):
    t = jnp.array(t)
    t = self.linear1(params['linear1'], t)
    t = jax.nn.swish(t)
    t = self.linear2(params['linear2'], t)
    return t

In [13]:
class diffusionUnet:
  def __init__(self, din, dout, latent_dim, key):
    keys = jax.random.split(key, 11)
    self.time_embed = timeEmbed(512, keys[0])
    time_dim = dout *4
    self.latent_dim = latent_dim

    self.first_conv = Convlayer(din + latent_dim, dout, 3, keys[1], padding = 1)

    self.down1 = resblock(dout, dout * 2, keys[2], time_dim)
    self.down2 = resblock(dout * 2, dout * 4, keys[3], time_dim)
    self.down3 = resblock(dout * 4, dout * 8, keys[4], time_dim)

    self.middle1 = resblock(dout * 8, dout * 8, keys[5],time_dim)
    self.middle2 = resblock(dout * 8, dout * 8, keys[6],time_dim)



    self.up1 = resblock(dout * 8 + dout * 8, dout * 4, keys[7], time_dim)
    self.up2 = resblock(dout * 4 + dout * 4, dout * 2, keys[8], time_dim)
    self.up3 = resblock(dout * 2 + dout * 2, dout, keys[9], time_dim)

    self.out_conv = Convlayer(dout, din, 3, keys[10], padding = 1)

  @property
  def params(self):
    return {'time_embed': self.time_embed.params, 'first_conv': self.first_conv.params,
            'down1': self.down1.params, 'down2': self.down2.params, 'down3': self.down3.params,
            'middle1': self.middle1.params, 'middle2': self.middle2.params,
            'up1': self.up1.params, 'up2': self.up2.params, 'up3': self.up3.params,
            'out_conv': self.out_conv.params}
  def __call__(self, params, x, t, z):
    z_up = jax.image.resize(z, x.shape, method= 'nearest')
    x_con = jnp.concatenate([x, z_up], axis = 1)
    t = self.time_embed(params['time_embed'], t)
    x = self.first_conv(params['first_conv'], x_con)



    d1 = self.down1(params['down1'], x, t)
    d2 = self.down2(params['down2'], d1, t)
    d3 = self.down3(params['down3'], d2, t)


    m1 = self.middle1(params['middle1'], d3, t)
    m2 = self.middle2(params['middle2'], m1, t)

    u1_input = jnp.concatenate([m2, d3], axis = 1)
    u1 = self.up1(params['up1'], u1_input, t)

    u2_input = jnp.concatenate([u1, d2], axis = 1)
    u2 = self.up2(params['up2'], u2_input, t)

    u3_input = jnp.concatenate([u2, d1], axis = 1)
    u3 = self.up3(params['up3'], u2, t)

    return self.out_conv(params['out_conv'], u3)

In [9]:
class VAE:
  def __init__(self, din, dout ,encoder_lay, decoder_lay, latent_dim,key):

    # 1.keyes and dim
    enc_key, dec_key, mu_key, logvar_key = jax.random.split(key, 4)
    self.latent_dim = latent_dim
    enc_dim = [din] + encoder_lay
    dec_dim = [latent_dim] + decoder_lay + [din]
    #2.Encoder
    self.encoder_blocks = [MLP(k, i, o) for k, i, o in zip(enc_key, enc_dim[:-1], enc_dim[1:])]
    self.enconder = sequentialNN(self.encoder_blocks)
    #3.mu and logvar
    self.mu = MLP(mu_key, encoder_lay[-1], latent_dim)
    self.logvar = MLP(logvar_key, encoder_lay[-1], latent_dim)
    #4.decoder
    self.decoder_blocks = [MLP(k, i, o) for k, i, o in zip(dec_key, dec_dim[:-1], dec_dim[1:])]
    self.decoder = sequentialNN(self.decoder_blocks)
    self.key = key


  @property
  def params(self):
    return {'encoder':[b.params for b in self.encoder_blocks],
            'decoder':[b.params for b in self.decoder_blocks],
            'mu':self.mu.params,
            'logvar':self.logvar.params}

  def reparameterize(self, mu, logvar):
    std = jnp.exp(0.5 * logvar)
    eps = jax.random.normal(self.key, std.shape)
    return mu + eps * std
  def _call__(self, params, x):
    encoded = self.encoder(params['encoder'], x)
    mu = self.mu(params['mu'], encoded)
    logvar = self.logvar(params['logvar'], encoded)
    z = self.reparameterize(mu, logvar)
    x_pred = self.decoder(params['decoder'], z)
    return x_pred, mu, logvar

In [10]:
class noiseschedule:
  def __init__(self,num_timesteps =1000, beta_start = 0.001, beta_end=0.02):
    self.beta_start = beta_start
    self.beta_end = beta_end
    self.num_timesteps = num_timesteps
    self.betas = jnp.linspace(beta_start, beta_end, num_timesteps)
    self.alphas = 1 - self.betas

  @property
  def params(self):
    return {'betas': self.betas, 'alphas': self.alphas}
  def add_noise(self,rng, x, t):
    sqrt_alphas_cumprod = jnp.sqrt(self.params['alphas'][t])
    sqrt_one_minus_alphas_cumprod = jnp.sqrt(1 - self.params['alphas'][t])
    noise = jax.random.normal(rng, x.shape)
    return sqrt_alphas_cumprod * x + sqrt_one_minus_alphas_cumprod * noise

  def __call__(self, rng, x, t):
    return self.add_noise(rng, x, t)



In [11]:
class MLP:
  def __init__(self, din, dout, key):
    weights_key, bias_key = jax.random.split(key)
    limit = jnp.sqrt(6 / (din + dout))
    self.W = jax.random.uniform(weights_key, (dout, din), minval=-limit, maxval=limit)
    self.b = jax.random.uniform(bias_key, (dout,), minval=-limit, maxval=limit)

  @property
  def params(self):
    return {'W': self.W, 'b': self.b}

  def __call__(self,params, x):
    return jnp.dot(x, params['W'].T) + params['b']


In [12]:
class sequentialNN:
    def __init__(self, blocks: list):
        self.blocks = blocks
        # Only collect params from blocks that have them
        self.parametric_blocks = [b for b in self.blocks if b.params is not None]
        self._params = [b.params for b in self.parametric_blocks]

    @property
    def params(self):
        return self._params

    def __call__(self, params, x):
        param_idx = 0
        for block in self.blocks:
            if hasattr(block, 'params') and block.params is not None:
                x = block(params[param_idx], x)
                param_idx += 1
            else:
                x = block(x)
        return x

In [13]:
# Activation function

class ReLU:
  @property
  def params(self):
      return None

  def __call__(self, x):
      return jnp.maximum(0, x)

In [19]:
# Loss functions


class MSE:
  def forward(self, y_pred, y_true):
    self.y_pred = y_pred
    self.y_true = y_true
    loss = jnp.mean((y_true - y_pred) ** 2)
    return loss


  def __call__(self, y_pred, y_true):
    return self.forward(y_pred, y_true)


class nll():
  def forward(self, y_pred, y_true):
    self.y_pred = y_pred
    self.y_true = y_true
    return -jnp.log(y_pred[jnp.arange(y_pred.shape[0]), y_true])

  def __call__(self, y_pred, y_true):
    return self.forward(y_pred, y_true)


class vaeloss:
  def forward(self, x_pred, x_true, mu, logvar):
    recon_loss= MSE(x_pred, x_true)
    kl_loss =  -0.5 * jnp.sum(1 + logvar - mu**2 - jnp.exp(logvar))
    return recon_loss + kl_loss

  def __call__(self, x_pred, x_true):
    return self.forward(x_pred, x_true)


class DGAEloss:
  def __init__(self, kl_weight = 0.001, lpips_weight = 0.1):
    self.kl_weight = kl_weight
    self.lpips_weight = lpips_weight
    self.LPIPS = lpips_jax.LPIPSEvaluator(net='vgg')

  def forward(self,pred_noise, true_noise,  y_pred, y_true, mu, logvar):

    dsmloss = MSE(pred_noise, true_noise)
    kl_loss = -0.5 * jnp.sum(1 + logvar - mu**2 - jnp.exp(logvar))
    lpips_loss = self.LPIPS(y_pred, y_true)
    return dsmloss + self.kl_weight * kl_loss + self.lpips_weight * lpips_loss

  def __call__(self,pred_noise, true_noise, y_pred, y_true, mu, logvar):
    return self.forward(pred_noise, true_noise, y_pred, y_true, mu, logvar)


In [3]:
#Optimiser

class optimiser:
  def __init__(self, model, lr=0.01):
    self.model = model
    self.lr = lr
  def step(self, params, grads):
    update = lambda p, g: p - self.lr * g
    return jax.tree_map(update, params, grads)

In [11]:
 # Training Functions

def train(model, data, optimiser, loss_function,epochs=10, batch_size=64):
  X = data['X']
  y = data['y']
  trainingloss = []
  current_params = model.params

  @jax.jit
  def train_step(params, batchX, batchY):
    def loss_fn(params):
      pred = model(params, batchX)

      return loss_function(pred, batchY)
  loss_grad_fun = jax.value_and_grad(train_step)



  for epoch in tqdm(range(epochs)):
    batchloss = []
    indices = jax.random.permutation(jax.random.PRNGKey(epoch), X.shape[0])
    X_shuffled = X[indices]
    y_shuffled = y[indices]


    for i in range(0, X.shape[0], batch_size):
      batchX = X_shuffled[i:i+batch_size]
      batchY = y_shuffled[i:i+batch_size]
      loss, grad = loss_grad_fun(current_params)
      batchloss.append(float(loss))
      model._params = optimiser.step(current_params, grad)
      current_params = model.params

      batchloss.append(float(loss))
      model._params = optimiser.step(current_params, grad)

      current_params = model.params
    trainingloss.append(float(np.mean(batchloss)))
  return trainingloss,





def train_dgae(model, data,optimiser, loss_function , epochs=10, batch_size=64, seed = 0):
  X = data['X']
  trainingloss = []
  current_params = model.params
  rng = jax.random.PRNGKey(seed)

  alphas_cumprod = jnp.cumprod(model.noise_schedule.alphas)
  @jax.jit
  def train_step(params, batchX, rng):
    t_rng, noise_rng, reparam_rng = jax.random.split(rng, 3)
    t = jax.random.randint(t_rng, (batchX.shape[0],), 0, model.noise_schedule.num_timesteps)

    sqrt_alphas_cumprod_t = jnp.sqrt(alphas_cumprod[t]).reshape(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = jnp.sqrt(1 - alphas_cumprod[t]).reshape(-1, 1, 1, 1)

    noise = jax.random.normal(noise_rng, batchX.shape)
    x_noisy = sqrt_alphas_cumprod_t * batchX + sqrt_one_minus_alphas_cumprod_t * noise

    mu, logvar = model.encoder(params['encoder'], x_noisy)
    z = model.reparameterize(mu, logvar)

    pred_noise = model.decoder(params['decoder'], x_noisy, t, z)

    recon_image = (x_noisy - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t

    loss = loss_function(pred_noise, noise,recon_image, batchX, mu, logvar)
    return loss
  grad_fn = jax.value_and_grad(train_step)

  for epoch in tqdm(epochs):
    rng, batch_rng = jax.random.split(rng)
    indicies = jax.random.permutation(batch_rng, X.shape[0])
    X_shuffled = X[indicies]
    batchloss = []

    for i in range(0, X.shape[0], batch_size):
      batchX = X_shuffled[i:i+batch_size]
      loss, grad = grad_fn(current_params, batchX, rng)
      batchloss.append(float(loss))
      current_params = optimiser.step(current_params, grad)

    trainingloss.append(jnp.mean(jnp.array(batchloss)))
  return trainingloss, current_params

Array(1.759764, dtype=float32)