In [None]:
# PlantXMamba/mamba_block/pscan.py
import math

import torch
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from tqdm import tqdm


def npo2(len):
    """
    Returns the next power of 2 above len
    """

    return 2 ** math.ceil(math.log2(len))


def pad_npo2(X):
    """
    Pads input length dim to the next power of 2

    Args:
        X : (B, L, D, N)

    Returns:
        Y : (B, npo2(L), D, N)
    """

    len_npo2 = npo2(X.size(1))
    pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
    return F.pad(X, pad_tuple, "constant", 0)


class PScan(torch.autograd.Function):
    @staticmethod
    def pscan(A, X):
        # A : (B, D, L, N)
        # X : (B, D, L, N)

        # modifies X in place by doing a parallel scan.
        # more formally, X will be populated by these values :
        # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
        # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)

        # only supports L that is a power of two (mainly for a clearer code)

        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        # up sweep (last 2 steps unfolded)
        Aa = A
        Xa = X
        for _ in range(num_steps - 2):
            T = Xa.size(2)
            Aa = Aa.view(B, D, T // 2, 2, -1)
            Xa = Xa.view(B, D, T // 2, 2, -1)

            Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
            Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])

            Aa = Aa[:, :, :, 1]
            Xa = Xa[:, :, :, 1]

        # we have only 4, 2 or 1 nodes left
        if Xa.size(2) == 4:
            Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
            Aa[:, :, 1].mul_(Aa[:, :, 0])

            Xa[:, :, 3].add_(
                Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1]))
            )
        elif Xa.size(2) == 2:
            Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
            return
        else:
            return

        # down sweep (first 2 steps unfolded)
        Aa = A[:, :, 2 ** (num_steps - 2) - 1 : L : 2 ** (num_steps - 2)]
        Xa = X[:, :, 2 ** (num_steps - 2) - 1 : L : 2 ** (num_steps - 2)]
        Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
        Aa[:, :, 2].mul_(Aa[:, :, 1])

        for k in range(num_steps - 3, -1, -1):
            Aa = A[:, :, 2**k - 1 : L : 2**k]
            Xa = X[:, :, 2**k - 1 : L : 2**k]

            T = Xa.size(2)
            Aa = Aa.view(B, D, T // 2, 2, -1)
            Xa = Xa.view(B, D, T // 2, 2, -1)

            Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
            Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])

    @staticmethod
    def pscan_rev(A, X):
        # A : (B, D, L, N)
        # X : (B, D, L, N)

        # the same function as above, but in reverse
        # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
        # it is used in the backward pass

        # only supports L that is a power of two (mainly for a clearer code)

        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        # up sweep (last 2 steps unfolded)
        Aa = A
        Xa = X
        for _ in range(num_steps - 2):
            T = Xa.size(2)
            Aa = Aa.view(B, D, T // 2, 2, -1)
            Xa = Xa.view(B, D, T // 2, 2, -1)

            Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
            Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])

            Aa = Aa[:, :, :, 0]
            Xa = Xa[:, :, :, 0]

        # we have only 4, 2 or 1 nodes left
        if Xa.size(2) == 4:
            Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
            Aa[:, :, 2].mul_(Aa[:, :, 3])

            Xa[:, :, 0].add_(
                Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2])))
            )
        elif Xa.size(2) == 2:
            Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
            return
        else:
            return

        # down sweep (first 2 steps unfolded)
        Aa = A[:, :, 0 : L : 2 ** (num_steps - 2)]
        Xa = X[:, :, 0 : L : 2 ** (num_steps - 2)]
        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
        Aa[:, :, 1].mul_(Aa[:, :, 2])

        for k in range(num_steps - 3, -1, -1):
            Aa = A[:, :, 0 : L : 2**k]
            Xa = X[:, :, 0 : L : 2**k]

            T = Xa.size(2)
            Aa = Aa.view(B, D, T // 2, 2, -1)
            Xa = Xa.view(B, D, T // 2, 2, -1)

            Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
            Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])

    @staticmethod
    def forward(ctx, A_in, X_in):
        """
        Applies the parallel scan operation, as defined above. Returns a new tensor.
        If you can, privilege sequence lengths that are powers of two.

        Args:
            A_in : (B, L, D, N)
            X_in : (B, L, D, N)

        Returns:
            H : (B, L, D, N)
        """

        L = X_in.size(1)

        # cloning is requiered because of the in-place ops
        if L == npo2(L):
            A = A_in.clone()
            X = X_in.clone()
        else:
            # pad tensors (and clone btw)
            A = pad_npo2(A_in)  # (B, npo2(L), D, N)
            X = pad_npo2(X_in)  # (B, npo2(L), D, N)

        # prepare tensors
        A = A.transpose(2, 1)  # (B, D, npo2(L), N)
        X = X.transpose(2, 1)  # (B, D, npo2(L), N)

        # parallel scan (modifies X in-place)
        PScan.pscan(A, X)

        ctx.save_for_backward(A_in, X)

        # slice [:, :L] (cut if there was padding)
        return X.transpose(2, 1)[:, :L]

    @staticmethod
    def backward(ctx, grad_output_in):
        """
        Flows the gradient from the output to the input. Returns two new tensors.

        Args:
            ctx : A_in : (B, L, D, N), X : (B, D, L, N)
            grad_output_in : (B, L, D, N)

        Returns:
            gradA : (B, L, D, N), gradX : (B, L, D, N)
        """

        A_in, X = ctx.saved_tensors

        L = grad_output_in.size(1)

        # cloning is requiered because of the in-place ops
        if L == npo2(L):
            grad_output = grad_output_in.clone()
            # the next padding will clone A_in
        else:
            grad_output = pad_npo2(grad_output_in)  # (B, npo2(L), D, N)
            A_in = pad_npo2(A_in)  # (B, npo2(L), D, N)

        # prepare tensors
        grad_output = grad_output.transpose(2, 1)
        A_in = A_in.transpose(2, 1)  # (B, D, npo2(L), N)
        A = torch.nn.functional.pad(
            A_in[:, :, 1:], (0, 0, 0, 1)
        )  # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)

        # reverse parallel scan (modifies grad_output in-place)
        PScan.pscan_rev(A, grad_output)

        Q = torch.zeros_like(X)
        Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])

        return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]


pscan = PScan.apply


In [None]:
# PlantXMamba/mamba_block/backbone.py
import math
from dataclasses import dataclass
from typing import Union


"""

This file closely follows the mamba_simple.py from the official Mamba implementation, and the mamba-minimal by @johnma2006.
The major differences are :
-the convolution is done with torch.nn.Conv1d
-the selective scan is done in PyTorch

A sequential version of the selective scan is also available for comparison. Also, it is possible to use the official Mamba implementation.

This is the structure of the torch modules :
- A Mamba model is composed of several layers, which are ResidualBlock.
- A ResidualBlock is composed of a MambaBlock, a normalization, and a residual connection : ResidualBlock(x) = mamba(norm(x)) + x
- This leaves us with the MambaBlock : its input x is (B, L, D) and its outputs y is also (B, L, D) (B=batch size, L=seq len, D=model dim).
First, we expand x into (B, L, 2*ED) (where E is usually 2) and split it into x and z, each (B, L, ED).
Then, we apply the short 1d conv to x, followed by an activation function (silu), then the SSM.
We then multiply it by silu(z).
See Figure 3 of the paper (page 8) for a visual representation of a MambaBlock.

"""


@dataclass
class MambaConfig:
    d_model: int  # D
    n_layers: int
    dt_rank: Union[int, str] = "auto"
    d_state: int = 16  # N in paper/comments
    expand_factor: int = 2  # E in paper/comments
    d_conv: int = 4

    dt_min: float = 0.001
    dt_max: float = 0.1
    dt_init: str = "random"  # "random" or "constant"
    dt_scale: float = 1.0
    dt_init_floor = 1e-4

    rms_norm_eps: float = 1e-5
    base_std: float = 0.02

    dropout: float = 0.1

    bias: bool = False
    conv_bias: bool = True
    inner_layernorms: bool = False  # apply layernorms to internal activations

    mup: bool = False
    mup_base_width: float = 128  # width=d_model

    pscan: bool = True  # use parallel scan mode or sequential mode when training
    use_cuda: bool = False  # use official CUDA implementation when training (not compatible with (b)float16)

    def __post_init__(self):
        self.d_inner = self.expand_factor * self.d_model  # E*D = ED in comments

        if self.dt_rank == "auto":
            self.dt_rank = math.ceil(self.d_model / 16)

        # muP
        if self.mup:
            self.mup_width_mult = self.d_model / self.mup_base_width


class Mamba(nn.Module):
    def __init__(self, config: MambaConfig):
        super().__init__()

        self.config = config

        self.layers = nn.ModuleList(
            [ResidualBlock(config) for _ in range(config.n_layers)]
        )

    def forward(self, x):
        # x : (B, L, D)

        # y : (B, L, D)

        for layer in self.layers:
            x = layer(x)

        return x

    def step(self, x, caches):
        # x : (B, L, D)
        # caches : [cache(layer) for all layers], cache : (h, inputs)

        # y : (B, L, D)
        # caches : [cache(layer) for all layers], cache : (h, inputs)

        for i, layer in enumerate(self.layers):
            x, caches[i] = layer.step(x, caches[i])

        return x, caches


class ResidualBlock(nn.Module):
    def __init__(self, config: MambaConfig):
        super().__init__()

        self.mixer = MambaBlock(config)
        self.norm = RMSNorm(config.d_model, config.rms_norm_eps, config.mup)

    def forward(self, x):
        # x : (B, L, D)

        # output : (B, L, D)

        output = self.mixer(self.norm(x)) + x
        return output

    def step(self, x, cache):
        # x : (B, D)
        # cache : (h, inputs)
        # h : (B, ED, N)
        # inputs: (B, ED, d_conv-1)

        # output : (B, D)
        # cache : (h, inputs)

        output, cache = self.mixer.step(self.norm(x), cache)
        output = output + x
        return output, cache


class MambaBlock(nn.Module):
    def __init__(self, config: MambaConfig):
        super().__init__()

        self.config = config

        # projects block input from D to 2*ED (two branches)
        self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)

        self.conv1d = nn.Conv1d(
            in_channels=config.d_inner,
            out_channels=config.d_inner,
            kernel_size=config.d_conv,
            bias=config.conv_bias,
            groups=config.d_inner,
            padding=config.d_conv - 1,
        )

        # projects x to input-dependent delta, B, C
        self.x_proj = nn.Linear(
            config.d_inner, config.dt_rank + 2 * config.d_state, bias=False
        )

        # projects delta from dt_rank to d_inner
        self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)

        # dt initialization
        # dt weights
        dt_init_std = config.dt_rank**-0.5 * config.dt_scale
        if config.dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif config.dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # delta bias
        dt = torch.exp(
            torch.rand(config.d_inner)
            * (math.log(config.dt_max) - math.log(config.dt_min))
            + math.log(config.dt_min)
        ).clamp(min=config.dt_init_floor)
        inv_dt = dt + torch.log(
            -torch.expm1(-dt)
        )  # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        # todo : explain why removed

        # S4D real initialization
        A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(
            config.d_inner, 1
        )
        self.A_log = nn.Parameter(
            torch.log(A)
        )  # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
        self.A_log._no_weight_decay = True

        self.D = nn.Parameter(torch.ones(config.d_inner))
        self.D._no_weight_decay = True

        # projects block output from ED back to D
        self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)

        # used in jamba
        if self.config.inner_layernorms:
            self.dt_layernorm = RMSNorm(
                self.config.dt_rank, config.rms_norm_eps, config.mup
            )
            self.B_layernorm = RMSNorm(
                self.config.d_state, config.rms_norm_eps, config.mup
            )
            self.C_layernorm = RMSNorm(
                self.config.d_state, config.rms_norm_eps, config.mup
            )
        else:
            self.dt_layernorm = None
            self.B_layernorm = None
            self.C_layernorm = None

        if self.config.use_cuda:
            try:
                from mamba_ssm.ops.selective_scan_interface import selective_scan_fn

                self.selective_scan_cuda = selective_scan_fn
            except ImportError:
                print("Failed to import mamba_ssm. Falling back to mamba.py.")
                self.config.use_cuda = False

    def _apply_layernorms(self, dt, B, C):
        if self.dt_layernorm is not None:
            dt = self.dt_layernorm(dt)
        if self.B_layernorm is not None:
            B = self.B_layernorm(B)
        if self.C_layernorm is not None:
            C = self.C_layernorm(C)
        return dt, B, C

    def forward(self, x):
        # x : (B, L, D)

        # y : (B, L, D)

        _, L, _ = x.shape

        xz = self.in_proj(x)  # (B, L, 2*ED)
        x, z = xz.chunk(2, dim=-1)  # (B, L, ED), (B, L, ED)

        # x branch
        x = x.transpose(1, 2)  # (B, ED, L)
        x = self.conv1d(x)[
            :, :, :L
        ]  # depthwise convolution over time, with a short filter
        x = x.transpose(1, 2)  # (B, L, ED)

        x = F.silu(x)
        y = self.ssm(x, z)

        if self.config.use_cuda:
            output = self.out_proj(y)  # (B, L, D)
            return output  # the rest of the operations are done in the ssm function (fused with the CUDA pscan)

        # z branch
        z = F.silu(z)

        output = y * z
        output = self.out_proj(output)  # (B, L, D)

        return output

    def ssm(self, x, z):
        # x : (B, L, ED)

        # y : (B, L, ED)

        A = -torch.exp(self.A_log.float())  # (ED, N)
        D = self.D.float()

        deltaBC = self.x_proj(x)  # (B, L, dt_rank+2*N)
        delta, B, C = torch.split(
            deltaBC,
            [self.config.dt_rank, self.config.d_state, self.config.d_state],
            dim=-1,
        )  # (B, L, dt_rank), (B, L, N), (B, L, N)
        delta, B, C = self._apply_layernorms(delta, B, C)
        delta = self.dt_proj.weight @ delta.transpose(
            1, 2
        )  # (ED, dt_rank) @ (B, L, dt_rank) -> (B, ED, L)
        # here we just apply the matrix mul operation of delta = softplus(dt_proj(delta))
        # the rest will be applied later (fused if using cuda)

        # choose which selective_scan function to use, according to config
        if self.config.use_cuda:
            # these are unfortunately needed for the selective_scan_cuda function
            x = x.transpose(1, 2)
            B = B.transpose(1, 2)
            C = C.transpose(1, 2)
            z = z.transpose(1, 2)

            # "softplus" + "bias" + "y * silu(z)" operations are fused
            y = self.selective_scan_cuda(
                x,
                delta,
                A,
                B,
                C,
                D,
                z=z,
                delta_softplus=True,
                delta_bias=self.dt_proj.bias.float(),
            )
            y = y.transpose(1, 2)  # (B, L, ED)

        else:
            delta = delta.transpose(1, 2)
            delta = F.softplus(delta + self.dt_proj.bias)

            if self.config.pscan:
                y = self.selective_scan(x, delta, A, B, C, D)
            else:
                y = self.selective_scan_seq(x, delta, A, B, C, D)

        return y

    def selective_scan(self, x, delta, A, B, C, D):
        # x : (B, L, ED)
        # Δ : (B, L, ED)
        # A : (ED, N)
        # B : (B, L, N)
        # C : (B, L, N)
        # D : (ED)

        # y : (B, L, ED)

        deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1))  # (B, L, ED, N)

        hs = pscan(deltaA, BX)

        y = (hs @ C.unsqueeze(-1)).squeeze(
            3
        )  # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y

    def selective_scan_seq(self, x, delta, A, B, C, D):
        # x : (B, L, ED)
        # Δ : (B, L, ED)
        # A : (ED, N)
        # B : (B, L, N)
        # C : (B, L, N)
        # D : (ED)

        # y : (B, L, ED)

        _, L, _ = x.shape

        deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1))  # (B, L, ED, N)

        h = torch.zeros(
            x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device
        )  # (B, ED, N)
        hs = []

        for t in range(0, L):
            h = deltaA[:, t] * h + BX[:, t]
            hs.append(h)

        hs = torch.stack(hs, dim=1)  # (B, L, ED, N)

        y = (hs @ C.unsqueeze(-1)).squeeze(
            3
        )  # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y

    # -------------------------- inference -------------------------- #
    """
    Concerning auto-regressive inference

    The cool part of using Mamba : inference is constant wrt to sequence length
    We just have to keep in cache, for each layer, two things :
    - the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
    - the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
      (d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
      (and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)

    Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
    h is (B, ED, N), and inputs is (B, ED, d_conv-1)
    The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.

    The cache object is initialized as follows : (None, torch.zeros()).
    When h is None, the selective scan function detects it and start with h=0.
    The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)

    As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
    """

    def step(self, x, cache):
        # x : (B, D)
        # cache : (h, inputs)
        # h : (B, ED, N)
        # inputs : (B, ED, d_conv-1)

        # y : (B, D)
        # cache : (h, inputs)

        h, inputs = cache

        xz = self.in_proj(x)  # (B, 2*ED)
        x, z = xz.chunk(2, dim=1)  # (B, ED), (B, ED)

        # x branch
        x_cache = x.unsqueeze(2)
        x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[
            :, :, self.config.d_conv - 1
        ]  # (B, ED)

        x = F.silu(x)
        y, h = self.ssm_step(x, h)

        # z branch
        z = F.silu(z)

        output = y * z
        output = self.out_proj(output)  # (B, D)

        # prepare cache for next call
        inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2)  # (B, ED, d_conv-1)
        cache = (h, inputs)

        return output, cache

    def ssm_step(self, x, h):
        # x : (B, ED)
        # h : (B, ED, N)

        # y : (B, ED)
        # h : (B, ED, N)

        A = -torch.exp(
            self.A_log.float()
        )  # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
        D = self.D.float()

        deltaBC = self.x_proj(x)  # (B, dt_rank+2*N)

        delta, B, C = torch.split(
            deltaBC,
            [self.config.dt_rank, self.config.d_state, self.config.d_state],
            dim=-1,
        )  # (B, dt_rank), (B, N), (B, N)
        delta, B, C = self._apply_layernorms(delta, B, C)
        delta = F.softplus(self.dt_proj(delta))  # (B, ED)

        deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(1)  # (B, ED, N)

        BX = deltaB * (x.unsqueeze(-1))  # (B, ED, N)

        if h is None:
            h = torch.zeros(
                x.size(0),
                self.config.d_inner,
                self.config.d_state,
                device=deltaA.device,
            )  # (B, ED, N)

        h = deltaA * h + BX  # (B, ED, N)

        y = (h @ C.unsqueeze(-1)).squeeze(2)  # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)

        y = y + D * x

        return y, h


class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
        super().__init__()

        self.use_mup = use_mup
        self.eps = eps

        # https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
        if not use_mup:
            self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

        if not self.use_mup:
            return output * self.weight
        else:
            return output


In [None]:
# PlantXMamba/mamba_block/head.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class MambaHead(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.0):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.norm(x)
        x = self.dropout(x)
        return x  # (batch_size, seq_len, d_model)


In [None]:
# PlantXMamba/mamba_block/model.py
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

class MambaModule(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.d_model = self.args.d_model
        self.n_layers = self.args.n_layers

        config = MambaConfig(d_model=self.d_model, n_layers=self.n_layers,
                           d_state=self.args.d_state, d_conv=self.args.d_conv,
                           expand_factor=self.args.expand,dropout=self.args.dropout)
        self.backbone = Mamba(config)
        self.head = MambaHead(d_model=self.d_model, dropout=self.args.dropout)

    def forward(self, x):
        sequence_output = self.backbone(x)  # (batch_size, seq_len, d_model)
        output = self.head(sequence_output)  # (batch_size, seq_len, d_model)
        return output

In [None]:
!git clone https://github.com/sakanaowo/PlantXViT

Cloning into 'PlantXViT'...
remote: Enumerating objects: 104825, done.[K
remote: Counting objects: 100% (23/23), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 104825 (delta 6), reused 20 (delta 4), pack-reused 104802 (from 1)[K
Receiving objects: 100% (104825/104825), 2.45 GiB | 58.18 MiB/s, done.
Resolving deltas: 100% (30447/30447), done.
Updating files: 100% (104353/104353), done.


In [None]:
%cd PlantXViT

/content/PlantXViT


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

image_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.0),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import VGG16_Weights

import os
root_dir="./data/raw/embrapa"

In [None]:
train_dataset = datasets.ImageFolder(os.path.join(root_dir, "train"), transform=image_transforms)
val_dataset = datasets.ImageFolder(os.path.join(root_dir, "val"), transform=image_transforms)
test_dataset = datasets.ImageFolder(os.path.join(root_dir, "test"), transform=image_transforms)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
class InceptionBlock(nn.Module):
    def __init__(self, in_channels=128):
        super(InceptionBlock, self).__init__()
        # Nhánh 1: 1x1
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )

        # Nhánh 2: 1x1 -> 3x1 + 1x3
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 128, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )

        # Nhánh 3: 1x1 -> 3x1 + 1x3 -> 3x1 + 1x3
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 96, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 96, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 192, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),
            nn.BatchNorm2d(192),
            nn.Conv2d(192, 192, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(192)
        )

        # Nhánh 4: MaxPool -> 1x1
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(64)
        )

    def forward(self, x):
        b1 = self.branch1x1(x)
        b2 = self.branch3x3(x)
        b3 = self.branch5x5(x)
        b4 = self.branch_pool(x)
        return torch.cat([b1, b2, b3, b4], dim=1)

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size=5, emb_size=16):
        super().__init__()
        self.patch_size = patch_size
        self.emb_size = emb_size
        self.proj = nn.Linear(in_channels * patch_size * patch_size, emb_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
        x = x.view(B, -1, C * self.patch_size * self.patch_size)
        return self.proj(x)  # shape: (b,num patches,emb size)

In [None]:
class PlantXMamba(nn.Module):
    def __init__(self, num_classes=4, patch_size=5, emb_size=16, d_state=64,d_conv=64,expand=4,n_layers=2,num_blocks=4, dropout=0.1):
        super().__init__()

        # VGG16 (2 blocks)
        vgg = models.vgg16(weights=VGG16_Weights.DEFAULT)
        self.vgg_block = nn.Sequential(*list(vgg.features[:10]))

        # Inception-like block → (B, 512, 56, 56)
        self.inception = InceptionBlock(in_channels=128)

        # Patch Embedding → (B, 121, 16)
        self.patch_embed = PatchEmbedding(in_channels=512, patch_size=patch_size, emb_size=emb_size)

        # Mamba blocks
        mamba_args = type('Args', (), {
            'd_model': emb_size,
            'd_state': d_state,
            'd_conv': d_conv,
            'expand': expand,
            'n_layers': n_layers,
            'dropout': dropout
        })()
        self.mamba = nn.Sequential(*[MambaModule(mamba_args) for _ in range(num_blocks)])

        # Classification head
        self.norm = nn.LayerNorm(emb_size)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(emb_size, num_classes)

    def forward(self, x):
        x = self.vgg_block(x)  # (B, 128, 56, 56)
        x = self.inception(x)  # (B, 512, 56, 56)
        x = self.patch_embed(x)  # (B, 121, 16)
        x = self.mamba(x)  # (B, 121, 16)
        x = self.norm(x)  # (B, 121, 16)
        x = x.permute(0, 2, 1)  # (B, 16, 121)
        x = self.global_pool(x).squeeze(-1)  # (B, 16)
        return self.classifier(x)  # (B, num_classes)

In [None]:
model = PlantXMamba(num_classes=93)
criterion=nn.CrossEntropyLoss()

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 202MB/s]


In [None]:
from utils.config_loader import load_config
config=load_config('./configs/config.yaml')
print(config['output']['embrapa']['model_path'])

./outputs/embrapa/models/plantxvit_best.pth


In [None]:
DATA_DIR=root_dir
BATCH_SIZE=16
EPOCHS=50
LR=1e-4
NUM_CLASSES=93
DEVICE=torch.device('cuda')
MODEL_PATH = "./outputs/embrapa/models/plantxvit_best.pth"

# Tạo thư mục nếu chưa tồn tại
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)

In [None]:
model.to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

In [None]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct, total = 0, 0, 0

    for inputs, labels in tqdm(loader, desc="Training"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc

In [None]:
def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Evaluating"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc

In [None]:
best_val_acc = 0
patience,wait=5,0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"✅ Saved best model to {MODEL_PATH}")
        wait=0
    else:
      wait+=1
      if wait>=patience:
        print(f"Early stopping at epoch {epoch+1}")
        break


Epoch 1/50


Training: 100%|██████████| 1851/1851 [02:29<00:00, 12.41it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.51it/s]


Train Loss: 3.3367 | Acc: 0.3167
Val   Loss: 2.5670 | Acc: 0.4563
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 2/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.63it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.38it/s]


Train Loss: 2.2750 | Acc: 0.5058
Val   Loss: 1.8892 | Acc: 0.5426
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 3/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.62it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.32it/s]


Train Loss: 1.7443 | Acc: 0.5951
Val   Loss: 1.4785 | Acc: 0.6376
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 4/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.70it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.08it/s]


Train Loss: 1.4317 | Acc: 0.6518
Val   Loss: 1.2624 | Acc: 0.6795
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 5/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.72it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.02it/s]


Train Loss: 1.2295 | Acc: 0.6939
Val   Loss: 1.0822 | Acc: 0.7175
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 6/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.61it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.34it/s]


Train Loss: 1.0934 | Acc: 0.7191
Val   Loss: 0.9924 | Acc: 0.7390
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 7/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.70it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.47it/s]


Train Loss: 0.9748 | Acc: 0.7434
Val   Loss: 0.9077 | Acc: 0.7553
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 8/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.68it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.58it/s]


Train Loss: 0.9015 | Acc: 0.7593
Val   Loss: 0.8552 | Acc: 0.7729
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 9/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.69it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.06it/s]


Train Loss: 0.8244 | Acc: 0.7769
Val   Loss: 0.7661 | Acc: 0.7952
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 10/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.70it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.19it/s]


Train Loss: 0.7603 | Acc: 0.7921
Val   Loss: 0.7467 | Acc: 0.7928

Epoch 11/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.62it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.23it/s]


Train Loss: 0.7094 | Acc: 0.8036
Val   Loss: 0.6873 | Acc: 0.8133
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 12/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.68it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 29.78it/s]


Train Loss: 0.6610 | Acc: 0.8164
Val   Loss: 0.6930 | Acc: 0.8114

Epoch 13/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.70it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.28it/s]


Train Loss: 0.6234 | Acc: 0.8276
Val   Loss: 0.6464 | Acc: 0.8231
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 14/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.64it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.37it/s]


Train Loss: 0.5852 | Acc: 0.8352
Val   Loss: 0.5944 | Acc: 0.8317
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 15/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.68it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.18it/s]


Train Loss: 0.5529 | Acc: 0.8420
Val   Loss: 0.6060 | Acc: 0.8341
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 16/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.65it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.36it/s]


Train Loss: 0.5154 | Acc: 0.8535
Val   Loss: 0.5776 | Acc: 0.8401
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 17/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.59it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.31it/s]


Train Loss: 0.4965 | Acc: 0.8560
Val   Loss: 0.5532 | Acc: 0.8437
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 18/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.64it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.39it/s]


Train Loss: 0.4754 | Acc: 0.8614
Val   Loss: 0.5669 | Acc: 0.8417

Epoch 19/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.62it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.47it/s]


Train Loss: 0.4471 | Acc: 0.8706
Val   Loss: 0.4946 | Acc: 0.8601
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 20/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.65it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.10it/s]


Train Loss: 0.4221 | Acc: 0.8777
Val   Loss: 0.5320 | Acc: 0.8506

Epoch 21/50


Training: 100%|██████████| 1851/1851 [02:27<00:00, 12.56it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.24it/s]


Train Loss: 0.3978 | Acc: 0.8843
Val   Loss: 0.4912 | Acc: 0.8630
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 22/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.67it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.51it/s]


Train Loss: 0.3933 | Acc: 0.8836
Val   Loss: 0.4824 | Acc: 0.8656
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 23/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.60it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.44it/s]


Train Loss: 0.3681 | Acc: 0.8908
Val   Loss: 0.4734 | Acc: 0.8588

Epoch 24/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.63it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.20it/s]


Train Loss: 0.3496 | Acc: 0.8963
Val   Loss: 0.5244 | Acc: 0.8516

Epoch 25/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.64it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.17it/s]


Train Loss: 0.3330 | Acc: 0.9015
Val   Loss: 0.4472 | Acc: 0.8716
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 26/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.64it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 29.79it/s]


Train Loss: 0.3320 | Acc: 0.8998
Val   Loss: 0.4953 | Acc: 0.8597

Epoch 27/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.62it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.15it/s]


Train Loss: 0.3100 | Acc: 0.9083
Val   Loss: 0.4702 | Acc: 0.8661

Epoch 28/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.64it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 29.89it/s]


Train Loss: 0.2934 | Acc: 0.9139
Val   Loss: 0.4789 | Acc: 0.8640

Epoch 29/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.61it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.19it/s]


Train Loss: 0.2906 | Acc: 0.9134
Val   Loss: 0.5310 | Acc: 0.8511

Epoch 30/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.65it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 29.89it/s]


Train Loss: 0.2735 | Acc: 0.9196
Val   Loss: 0.4331 | Acc: 0.8774
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 31/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.72it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.37it/s]


Train Loss: 0.2591 | Acc: 0.9227
Val   Loss: 0.4545 | Acc: 0.8732

Epoch 32/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.63it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.70it/s]


Train Loss: 0.2596 | Acc: 0.9209
Val   Loss: 0.4568 | Acc: 0.8765

Epoch 33/50


Training: 100%|██████████| 1851/1851 [02:25<00:00, 12.70it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.42it/s]


Train Loss: 0.2435 | Acc: 0.9279
Val   Loss: 0.4961 | Acc: 0.8643

Epoch 34/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.68it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.32it/s]


Train Loss: 0.2367 | Acc: 0.9295
Val   Loss: 0.4604 | Acc: 0.8720

Epoch 35/50


Training: 100%|██████████| 1851/1851 [02:26<00:00, 12.66it/s]
Evaluating: 100%|██████████| 466/466 [00:15<00:00, 30.02it/s]

Train Loss: 0.2218 | Acc: 0.9323
Val   Loss: 0.4513 | Acc: 0.8743
Early stopping at epoch 35



