In [2]:
pip install flax einops jax jaxlib optax rarfile

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.0.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
# Import necessary libraries
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from tensorflow.keras.preprocessing import image_dataset_from_directory
import numpy as np
import os

class MlpBlock(nn.Module):
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.Dense(self.mlp_dim)(x)
        y = nn.gelu(y)
        return nn.Dense(x.shape[-1])(y)

class MixerBlock(nn.Module):
    tokens_mlp_dim: int
    channels_mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.LayerNorm()(x)
        y = jnp.swapaxes(y, 1, 2)
        y = MlpBlock(self.tokens_mlp_dim)(y)
        y = jnp.swapaxes(y, 1, 2)
        x = x + y
        y = nn.LayerNorm()(x)
        return x + MlpBlock(self.channels_mlp_dim)(y)



In [6]:
class MlpMixer(nn.Module):
    num_classes: int
    num_blocks: int
    patch_size: int
    hidden_dim: int
    tokens_mlp_dim: int
    channels_mlp_dim: int

    @nn.compact
    def __call__(self, x):
        s = self.patch_size
        x = nn.Conv(self.hidden_dim, (s, s), strides=(s, s))(x)
        x = einops.rearrange(x, 'n h w c -> n (h w) c')
        
        for _ in range(self.num_blocks):
            x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
        
        x = nn.LayerNorm()(x)
        x = jnp.mean(x, axis=1)
        return nn.Dense(self.num_classes)(x)


In [8]:
def load_data(batch_size=32, img_size=(224, 224)):
    dataset_path = "breast-ultrasound-dataset"

    train_dir = os.path.join(dataset_path, "train_dir")
    valid_dir = os.path.join(dataset_path, "valid_dir")
    test_dir = os.path.join(dataset_path, "test_dir")
    
    if not all(os.path.exists(d) for d in [train_dir, valid_dir, test_dir]):
        raise FileNotFoundError("One or more dataset directories not found.")
    
    train_ds = image_dataset_from_directory(
        train_dir,
        label_mode='int',
        image_size=img_size,
        color_mode='grayscale',
        batch_size=batch_size
    )
    valid_ds = image_dataset_from_directory(
        valid_dir,
        label_mode='int',
        image_size=img_size,
        color_mode='grayscale',
        batch_size=batch_size
    )
    test_ds = image_dataset_from_directory(
        test_dir,
        label_mode='int',
        image_size=img_size,
        color_mode='grayscale',
        batch_size=batch_size
    )
    return train_ds, valid_ds, test_ds


In [20]:
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(batch['label'], num_classes=3)).mean()
        return loss, logits

    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    
    state = state.apply_gradients(grads=grads)
    
    accuracy = (jnp.argmax(logits, -1) == batch['label']).mean()
    
    return state, loss, accuracy


def train_model(state, train_ds, num_epochs=10):
    for epoch in range(num_epochs):
        for batch in train_ds:
            batch = {'image': jnp.array(batch[0] / 255.0), 'label': jnp.array(batch[1])}
            
            state, loss, accuracy = train_step(state, batch)
        
        print(f"Epoch {epoch + 1}, Loss: {loss:.4f}, Accuracy: {accuracy * 100:.2f}%")


In [21]:
class TrainState(train_state.TrainState):
    pass 

num_classes = 3
model = MlpMixer(num_classes=num_classes, num_blocks=8, patch_size=4, hidden_dim=128, tokens_mlp_dim=256, channels_mlp_dim=512)

train_ds, valid_ds, test_ds = load_data()

rng = jax.random.PRNGKey(0)
dummy_input = jnp.ones((1, 224, 224, 1))  
params = model.init(rng, dummy_input)['params']



Found 606 files belonging to 3 classes.
Found 187 files belonging to 3 classes.
Found 150 files belonging to 3 classes.


In [22]:
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adam(1e-3),
)

train_model(state, train_ds, num_epochs=10)

Epoch 1, Loss: 0.9502, Accuracy: 60.00%
Epoch 2, Loss: 1.0499, Accuracy: 43.33%
Epoch 3, Loss: 1.0727, Accuracy: 43.33%
Epoch 4, Loss: 0.8539, Accuracy: 70.00%
Epoch 5, Loss: 0.8295, Accuracy: 66.67%
Epoch 6, Loss: 0.9497, Accuracy: 50.00%
Epoch 7, Loss: 0.8348, Accuracy: 63.33%
Epoch 8, Loss: 0.6737, Accuracy: 73.33%
Epoch 9, Loss: 0.5973, Accuracy: 83.33%
Epoch 10, Loss: 0.3570, Accuracy: 90.00%
