<a href="https://colab.research.google.com/github/shkim0824/ksh-jax/blob/main/Flow_Matching_Ensemble.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Import package
# jax/flax/optax import
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, lax
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax.training import train_state
import optax

# torch import
import torch
import torch.nn.functional as F
from torch import optim
import torchvision
import torchvision.transforms as transforms

# Other Machine Learning Libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sklearn
from sklearn import preprocessing
from sklearn.model_selection import train_test_split

# Python extension libraires
import math
import time
from functools import partial
import pickle
from typing import Any, Callable, Sequence, List
from collections import defaultdict

In [2]:
#@title Data Augmentations
class TransformChain(object):
    def __init__(self, transforms):
        """
        Apply transform sequentially
        """
        self.transforms = transforms

    def __call__(self, rng, image):
        for _t in self.transforms:
            image = _t(rng, image)
        return image

class ToTensorTransform(object):
    def __init__(self):
        """
        Make image values between 0. ~ 1.
        """

    def __call__(self, rng, image):
        return image / 255.

class RandomHFlipTransform(object):
    def __init__(self, prob=0.5):
        """
        Flip the image horizontally with the given probability

        Inputs:
            prob (float): probability of the flip
        """
        self.prob=prob

    def __call__(self, rng, image):
        return jnp.where(
            jax.random.bernoulli(rng, self.prob), # 1 for prob=self.prob
            jnp.flip(image, axis=1), # Flip image
            image, # Or not
        )

class RandomCropTransform(object):

    def __init__(self, size, padding):
        """
        Crop the image at a random location with given size and padding.
        Inputs:
            size (int): desired output size of the crop.
            padding (int): padding on each border of the image before cropping.
        """
        self.size = size
        self.padding = padding

    def __call__(self, rng, image):
        # Add padding
        image = jnp.pad(
            array           = image,
            pad_width       = ((self.padding, self.padding),
                               (self.padding, self.padding),
                               (           0,            0),), # No pad for RGB channel
            mode            = 'constant',
            constant_values = 0,
        )

        # Random cropping position
        rng1, rng2 = jax.random.split(rng, 2)
        h0 = jax.random.randint(rng1, shape=(1,), minval=0, maxval=2*self.padding+1)[0] # output of randint is [x]. We get item of x
        w0 = jax.random.randint(rng2, shape=(1,), minval=0, maxval=2*self.padding+1)[0]

        # Slice image
        image = jax.lax.dynamic_slice(
            operand       = image,
            start_indices = (h0, w0, 0), # We do not crop rgb channel
            slice_sizes   = (self.size, self.size, image.shape[2]), # We do not crop rgb channel
        )

        return image

In [3]:
# Cross Entropy loss
def evaluate_ce(softmax, one_hot):
    return jnp.mean(-jnp.sum(one_hot * jnp.log(softmax+1e-12), axis=-1))

# Test Accuracy
def evaluate_acc(logits, labels):
    return jnp.mean(jnp.argmax(logits, axis=1) == labels)

In [4]:
#@title Load Data
def buildLoader(images, labels, batch_size, steps_per_epoch, rng=None, shuffle=False, transform=None):
    # Shuffle Indices
    indices = jax.random.permutation(rng, len(images)) if shuffle else jnp.arange(len(images)) # Make shuffled indices
    indices = indices[:steps_per_epoch*batch_size] # Batch size may not be divisor of length of images. We drop left ones.
    indices = indices.reshape((steps_per_epoch, batch_size,))
    for batch_idx in indices:
        batch = {'images': jnp.array(images[batch_idx]), 'labels': jnp.array(labels[batch_idx])}
        if transform is not None:
            if rng is not None:
                _, rng = jax.random.split(rng)
            sub_rng = None if rng is None else jax.random.split(rng, batch['images'].shape[0])
            batch['images'] = transform(sub_rng, batch['images'])
        yield batch

# Hyper parameters
BATCH_SIZE = 128
TEST_SIZE = 10000
VAL_SIZE = 128
n_targets = 100
rng = random.PRNGKey(0)

train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=None)
test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=None)

# Make validation set
train_images, val_images = train_test_split(train_set.data, test_size=1000, random_state=42)
train_labels, val_labels = train_test_split(train_set.targets, test_size=1000, random_state=42) # It very boost training speed

# We use img converted to jnp dtype
train_images = np.array(train_images, dtype=jnp.float32)
val_images = np.array(val_images, dtype=jnp.float32)
test_images = np.array(test_set.data, dtype=jnp.float32)

# Make labels to numpy array
train_labels = np.array(train_labels)
val_labels = np.array(val_labels)
test_labels = np.array(test_set.targets)

# transform for CIFAR-10
transform = TransformChain([RandomHFlipTransform(0.5),
                            RandomCropTransform(size=32, padding=4),
                            ToTensorTransform()])

# Naive data loader. We should put rng later in real-time usage
trn_steps_per_epoch = len(train_images) // BATCH_SIZE
tst_steps_per_epoch = len(test_images) // TEST_SIZE
val_steps_per_epoch = len(val_images) // VAL_SIZE

trn_loader = partial(buildLoader,
                    images=train_images,
                    labels=train_labels,
                    batch_size=BATCH_SIZE,
                    steps_per_epoch=trn_steps_per_epoch,
                    shuffle=True,
                    transform=jit(vmap(ToTensorTransform())))

trn_loader_aug = partial(buildLoader,
                        images=train_images,
                        labels=train_labels,
                        batch_size=BATCH_SIZE,
                        steps_per_epoch=trn_steps_per_epoch,
                         shuffle=True,
                         transform=jit(vmap(transform)))

vl_loader = partial(buildLoader,
                    images=val_images,
                    labels=val_labels,
                    batch_size=VAL_SIZE,
                    steps_per_epoch=val_steps_per_epoch,
                    shuffle=True,
                    transform=jit(vmap(ToTensorTransform())))

tst_loader = partial(buildLoader,
                    images=test_images,
                    labels=test_labels,
                    batch_size=TEST_SIZE,
                    steps_per_epoch=tst_steps_per_epoch,
                    shuffle=False,
                    transform=jit(vmap(ToTensorTransform())))

In [5]:
#@title Denoising MLP
from einops import rearrange

def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = jnp.exp(
        -math.log(max_period) * jnp.arange(0, half, dtype=jnp.float32) / half
    )
    args = timesteps[..., None] * freqs[None]
    embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
    if len(timesteps.shape) == 2:
        embedding = rearrange(embedding, "b n d -> b (n d)")
    if dim % 2:
        embedding = jnp.concatenate(
            [embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
    return embedding

class DenoisingMLP(nn.Module):
    """3개의 ResidualBlock과 폭(채널 수) 1024를 가진 단순 MLP 예시."""
    hidden_dim: int = 1024
    num_blocks: int = 3
    num_classes: int = 100
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x, t, c):
        """
        x: 입력 노이즈(혹은 중간 상태) [batch, features]
        t: time step (예: 스칼라 혹은 [batch])
        z: AR/MAR에서 나오는 condition vector [batch, cond_dim]
        """

        x = jnp.concatenate([x, c], axis=-1)

        for i in range(self.num_blocks):
            x_skip = x

            t = jnp.log((1 - t) + 1e-12) / 4

            # (1) Adaptive Layer Normalization utilizing time-embedding
            t_emb = timestep_embedding(t, dim=64)
            t_ = nn.Dense(2 * x.shape[-1], kernel_init=nn.initializers.constant(0.))(t)
            t_ = nn.silu(t_)
            shift_t, scale_t = jnp.split(t_, 2, axis=-1)

            x = nn.LayerNorm(use_bias=False, use_scale=False)(x)
            x = x * (1 + scale_t) + shift_t

            # (2) Residual Block
            x = nn.Dense(
                    features=self.hidden_dim,
                    dtype=self.dtype,
                    kernel_init=nn.initializers.xavier_uniform(),
                    bias_init=nn.initializers.normal(stddev=1e-6)
                )(x)
            x = nn.silu(x)
            x = nn.Dense(
                    features=self.hidden_dim,
                    dtype=self.dtype,
                    kernel_init=nn.initializers.xavier_uniform(),
                    bias_init=nn.initializers.normal(stddev=1e-6)
                )(x)
            x = x_skip + x if i > 0 else x

            # x = nn.Dropout(rate=self.droprate)(x, deterministic=not kwargs["training"])

        # (3) Final Logit
        x = nn.Dense(
                features=self.num_classes,
                dtype=self.dtype,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6)
            )(x)

        return x

In [6]:
#@title ResNet

class IdentityShortcut(nn.Module):
    channels: int
    strides: int
    expansion: int = 1

    @nn.compact
    def __call__(self, x):
        pad_offset = self.expansion * self.channels - x.shape[-1] # x.shape[-1] is original channel dimension
        return jnp.pad(
            array           = x[:, ::self.strides, ::self.strides, :], # get reduced shape of x
            pad_width       = ((0,0), (0,0), (0,0), (0, pad_offset)), # Add zero padding to channel
            mode            = 'constant',
            constant_values  = 0, # zero padding
        )

class ResidualBlock(nn.Module):
    channels: int
    strides: int = 1
    shortcut: nn.Module = IdentityShortcut # identity or projection

    @nn.compact
    def __call__(self, x, training=False):
        y = nn.Conv(features = self.channels, kernel_size=(3,3), strides=self.strides)(x)
        y = nn.BatchNorm(use_running_average=not training)(y)
        y = nn.relu(y)
        y = nn.Conv(features = self.channels, kernel_size=(3,3))(y)
        y = nn.BatchNorm(use_running_average=not training)(y)

        if self.strides != 1 or x.shape[-1] != self.channels * 1: # We have to modify x to match shape
            y = y + self.shortcut(channels  = self.channels,
                                  strides    = self.strides,
                                  expansion = 1)(x)
        else:
            y = y + x

        y = nn.relu(y)

        return y

class ResNet32(nn.Module):
    """ ResNet-32 Structure for CIFAR-10
    Starting Block: 3 by 3 conv, 16
    Residual Block: 16 channel * n - 32 channel * n - 64 channel * n, n=5 for ResNet-32
    """

    @nn.compact
    def __call__(self, x, training=False):
        # Starting Block: 3 by 3 conv
        x = nn.Conv(features = 16, kernel_size=(3,3))(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)

        # Intermediate Residual Blocks
        for i in range(5):
            x = ResidualBlock(channels=16)(x, training)
            # jax.debug.print("{}", x.shape)

        for channel in [32, 64]:
            x = ResidualBlock(channels=channel, strides=2)(x, training)
            for i in range(4):
                x = ResidualBlock(channels=channel)(x, training)

        # Global Average Pooling
        x = jnp.mean(x, [1, 2])
        representation = x

        # Classifier
        x = nn.Dense(features=100)(x) # Classifier, return logits

        return x, representation

In [7]:
class FlowTrainState(train_state.TrainState):
    pass

def create_dMLP_state(key, feature_dim=1024, cond_dim=64, lr=1e-4):
    model = DenoisingMLP(
        hidden_dim=feature_dim,
        num_blocks=3,
    )

    dummy_x = jnp.zeros((1, 100), dtype=jnp.float32)
    dummy_t = jnp.zeros((1, 1), dtype=jnp.float32)
    dummy_z = jnp.zeros((1, cond_dim), dtype=jnp.float32)

    variables = model.init(key, dummy_x, dummy_t, dummy_z)
    params = variables["params"]

    # (3) 옵티마 + 스케줄 설정
    warmup_epochs = 5
    epochs = 50
    trn_steps_per_epoch = len(train_images) // BATCH_SIZE  # 사용자 정의
    scheduler = optax.join_schedules(
        schedules = [
            # warmup
            optax.linear_schedule(
                init_value       = 0.01 * lr,
                end_value        = lr,
                transition_steps = warmup_epochs * trn_steps_per_epoch
            ),
            # cosine decay
            optax.cosine_decay_schedule(
                init_value       = lr,
                decay_steps      = (epochs - warmup_epochs) * trn_steps_per_epoch
            )
        ],
        boundaries = [warmup_epochs * trn_steps_per_epoch]
    )
    tx = optax.adam(scheduler)

    # (4) TrainState 생성
    state = FlowTrainState.create(
        apply_fn = model.apply,
        params   = params,
        tx       = tx
    )
    return state

def create_res32(rng, params, batch_stats, lr_fn=0.1):
    model = ResNet32()

    class TrainState(train_state.TrainState): # We don't use dropout in ResNet
        batch_stats: Any

    state = TrainState.create(
        apply_fn = model.apply,
        params = params,
        batch_stats = batch_stats, # batch state
        tx = optax.sgd(learning_rate=lr_fn, momentum=0.9, nesterov=True) # learning_rate_fn automatically determine learning rate
    )
    return state

In [8]:
#@title loss, train_step
def flow_matching_loss(params, apply_fn, rng, teacher_logits, teacher_representations):
    """
    teacher_logits: (B, K)
    """
    B, K = teacher_logits.shape

    # 노이즈 x0 ~ N(0,1)
    rng, rng_x0, rng_t = jax.random.split(rng, 3)
    x0 = jax.random.normal(rng_x0, shape=(B, K))

    # t ~ Uniform(0,1)
    t = jax.random.uniform(rng_t, shape=(B, 1))

    # OT 경로: y_t = (1 - t)*x0 + t*x1
    y_t = (1.0 - t)*x0 + t*teacher_logits

    # 목표 벡터: x1 - x0
    target_vec = teacher_logits - x0

    # 모델 출력 v_theta(t, y_t)
    pred_vec = apply_fn({'params': params}, y_t, t, teacher_representations)

    # MSE
    loss = jnp.mean((pred_vec - target_vec)**2)
    return loss, (loss, rng)

@jax.jit
def train_step(teacher, state, rng, images):
    # (1) Teacher Logit, Representation
    logit_t, rep_t = teacher.apply_fn(
        {'params': teacher.params, 'batch_stats': teacher.batch_stats}, batch['images'], training=False)

    # (2) Flow Matching Loss
    def loss_fn(p):
        return flow_matching_loss(p, state.apply_fn, rng, logit_t, rep_t)
    (loss_val, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)

    new_state = state.apply_gradients(grads=grads)
    new_rng = aux[1]
    metrics = {'loss': loss_val}
    return new_state, new_rng, metrics

@jax.jit
def test_step(teacher, state, rng, batch, steps=100):
    # (1) Teacher Logit, Representation
    logit_t, rep_t = teacher.apply_fn(
        {'params': teacher.params, 'batch_stats': teacher.batch_stats}, batch['images'], training=False)

    # (2) Flow matching Sampling
    B = batch['images'].shape[0]

    rng, key = jax.random.split(rng)
    x = jax.random.normal(rng, shape=(B, 100))

    ts = jnp.linspace(0.0, 1.0, steps+1)

    for i in range(steps):
        t0, t1 = ts[i], ts[i+1]
        v = state.apply_fn({'params': state.params}, x, jnp.full((B, 1), t0), rep_t)
        dt = (t1 - t0)
        x = x + v * dt      # Euler Method

    logits = x

    one_hot = jax.nn.one_hot(batch['labels'], n_targets)
    softmax = jax.nn.softmax(logits)

    metrics = {
        'acc': evaluate_acc(logits, batch['labels']),
        'loss': evaluate_ce(softmax, one_hot)
    }
    return rng, metrics


In [9]:
teachers = []

# 파라미터와 배치 정규화 상태 불러오기
for i in range(5):
    with open(f"ResNet32 CIFAR100 params_{i}.pickle", "rb") as fr:
        params = pickle.load(fr)
    with open(f"ResNet32 CIFAR100 batch_stats_{i}.pickle", "rb") as fr:
        batch_stats = pickle.load(fr)
    teachers.append(create_res32(jax.random.PRNGKey(i), params, batch_stats))

In [10]:
state = create_dMLP_state(rng)

In [11]:
metrics_history = defaultdict(list)
for epoch in range(50):
    start_time = time.time()

    # Make data laoder
    rng, *keys = random.split(rng, 4)
    train_loader = trn_loader(rng=keys[0])
    test_loader = tst_loader(rng=keys[1])
    val_loader = vl_loader(rng=keys[2])

    for batch in train_loader:
        # teacher selection
        rng, key = jax.random.split(rng)
        t_idx = jax.random.randint(key, shape=(), minval=0, maxval=len(teachers))
        teacher = teachers[t_idx]

        # train
        state, rng, metrics = train_step(teacher, state, rng, batch['images']) # Use key for training
    for metric, value in metrics.items(): # compute metrics
        metrics_history[f'train_{metric}'].append(value) # record metrics

    tst_metrics = defaultdict(float)
    for batch in test_loader:
        # teacher selection
        rng, key = jax.random.split(rng)
        t_idx = jax.random.randint(key, shape=(), minval=0, maxval=len(teachers))
        teacher = teachers[t_idx]

        rng, metrics = test_step(teacher, state, rng, batch) # In test, we do not need key
        for metric, value in metrics.items(): # compute metrics
            tst_metrics[f'test_{metric}'] += value
    for metric, value in tst_metrics.items():
        metrics_history[metric].append(tst_metrics[metric] / tst_steps_per_epoch)

    epoch_time = time.time() - start_time
    print(f"Epoch {epoch+1} in {epoch_time:.2f} sec")
    print(f"Train loss: {metrics_history['train_loss'][-1]:.3f}\n Test acc: {metrics_history['test_acc'][-1]:.3f}, Test loss: {metrics_history['test_loss'][-1]:.3f}, ")

Epoch 1 in 72.23 sec
Train loss: 5.508
 Test acc: 0.056, Test loss: 6.254, 
Epoch 2 in 1.20 sec
Train loss: 4.140
 Test acc: 0.123, Test loss: 5.514, 
Epoch 3 in 33.38 sec
Train loss: 2.767
 Test acc: 0.271, Test loss: 4.021, 
Epoch 4 in 33.54 sec
Train loss: 1.791
 Test acc: 0.281, Test loss: 4.142, 
Epoch 5 in 33.36 sec
Train loss: 1.458
 Test acc: 0.370, Test loss: 3.125, 
Epoch 6 in 1.27 sec
Train loss: 1.009
 Test acc: 0.439, Test loss: 2.613, 
Epoch 7 in 34.26 sec
Train loss: 0.887
 Test acc: 0.396, Test loss: 2.799, 
Epoch 8 in 1.28 sec
Train loss: 0.706
 Test acc: 0.399, Test loss: 2.708, 
Epoch 9 in 1.30 sec
Train loss: 0.669
 Test acc: 0.376, Test loss: 3.022, 
Epoch 10 in 1.14 sec
Train loss: 0.640
 Test acc: 0.463, Test loss: 2.339, 
Epoch 11 in 1.17 sec
Train loss: 0.598
 Test acc: 0.467, Test loss: 2.297, 
Epoch 12 in 1.20 sec
Train loss: 0.618
 Test acc: 0.383, Test loss: 2.933, 
Epoch 13 in 1.26 sec
Train loss: 0.594
 Test acc: 0.381, Test loss: 2.887, 
Epoch 14 in 1.34