In [1]:
pip install flax einops jax jaxlib optax rarfile

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.0.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from tensorflow.keras.preprocessing import image_dataset_from_directory
import numpy as np
import os 

In [6]:
class MlpBlock(nn.Module):
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.Dense(self.mlp_dim)(x)
        y = nn.gelu(y)
        return nn.Dense(x.shape[-1])(y)

In [7]:
class MixerBlock(nn.Module):
    tokens_mlp_dim: int
    channels_mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.LayerNorm()(x)
        y = jnp.swapaxes(y, 1, 2)
        y = MlpBlock(self.tokens_mlp_dim)(y)
        y = jnp.swapaxes(y, 1, 2)
        x = x + y
        y = nn.LayerNorm()(x)
        return x + MlpBlock(self.channels_mlp_dim)(y)

In [8]:
class MlpMixer(nn.Module):
    num_classes: int
    num_blocks: int
    patch_size: int
    hidden_dim: int
    tokens_mlp_dim: int
    channels_mlp_dim: int

    @nn.compact
    def __call__(self, x):
        s = self.patch_size
        x = nn.Conv(self.hidden_dim, (s, s), strides=(s, s))(x)
        x = einops.rearrange(x, 'n h w c -> n (h w) c')
        
        for _ in range(self.num_blocks):
            x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
        
        x = nn.LayerNorm()(x)
        x = jnp.mean(x, axis=1)
        return nn.Dense(self.num_classes)(x)

In [9]:
def load_data(batch_size=32, img_size=(224, 224)):
    dataset_path = "fold1_separated_2_classes"  

    train_dir = os.path.join(dataset_path, "train")
    test_dir = os.path.join(dataset_path, "test")
    
    if not os.path.exists(train_dir):
        raise FileNotFoundError(f"Train directory not found: {train_dir}")
    if not os.path.exists(test_dir):
        raise FileNotFoundError(f"Test directory not found: {test_dir}")

    train_ds = image_dataset_from_directory(
        train_dir,  
        label_mode='int',
        image_size=img_size,
        color_mode='grayscale',
        batch_size=batch_size
    )
    test_ds = image_dataset_from_directory(
        test_dir,  
        label_mode='int',
        image_size=img_size,
        color_mode='grayscale',
        batch_size=batch_size
    )
    return train_ds, test_ds

In [15]:

class TrainState(train_state.TrainState):
    batch_stats: dict = None

In [11]:
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(batch['label'], num_classes=2)).mean()
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)  # Unpack loss here
    state = state.apply_gradients(grads=grads)
    accuracy = (jnp.argmax(logits, -1) == batch['label']).mean()
    return state, loss, accuracy




In [12]:
def train_model(state, train_ds, test_ds, num_epochs=10):
    for epoch in range(num_epochs):
        for batch in train_ds:
            batch = {'image': jnp.array(batch[0] / 255.0), 'label': jnp.array(batch[1])}
            state, loss, accuracy = train_step(state, batch)
        
        print(f"Epoch {epoch + 1}, Loss: {loss:.4f}, Accuracy: {accuracy * 100:.2f}%")

In [13]:
num_classes = 2
model = MlpMixer(num_classes=num_classes, num_blocks=8, patch_size=4, hidden_dim=128, tokens_mlp_dim=256, channels_mlp_dim=512)

train_ds, test_ds = load_data()

rng = jax.random.PRNGKey(0)
dummy_input = jnp.ones((1, 224, 224, 1))  
params = model.init(rng, dummy_input)['params']





Found 5005 files belonging to 4 classes.
Found 2904 files belonging to 4 classes.


In [14]:
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adam(1e-3),
)

train_model(state, train_ds, test_ds, num_epochs=10)

KeyboardInterrupt: 