In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
%pip install einops
import math
from math import sqrt, ceil



Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [None]:


# First we code the path embedding module

class PatchEmbed(nn.Module):
    # converts image into patch embeddings based on total number of non-overlapping crops.
    # For each image containing n patches, there should be n embedding vectors per image, so a n x embedding_vector matrix.
    def __init__(self,img_size,patch_size,in_channels=3, embed_dim=256):
        super().__init__()
        self.img_size       = img_size
        self.patch_size     = patch_size
        self.in_channels    = in_channels
        self.n_patches      = (img_size // patch_size)**2
        self.project        = nn.Conv2d(
                                    in_channels     =in_channels,
                                    out_channels    = embed_dim,
                                    kernel_size     = patch_size,
                                    stride          = patch_size,
                                    )

    def forward(self,x):
        # x has input a tensor of shape B, C, H, W (batch, channel, height, width)

        x = self.project(x)     # [Batch X Embedding Dim X sqrt(N_patches) X sqrt(N_patches)]
        x = x.flatten(2)        # [Batch X Embedding Dim X N_patches]
        x = x.transpose(1,2)    # [Batch X N_patches X Embedding Dim]

        return x

# next we code the multi-layer perceptron block


class MultiLayerPerceptron(nn.Module):
    def __init__(self, input_size, hidden_size_ratio, out_features, dropout=0., num_layers=3):
        super().__init__()
        # Assuming layer normalization is to be applied on the output of each layer, you need to specify the shape dynamically;
        # however, it's not straightforward without knowing the specific architecture's requirements.
        # This example does not dynamically add LayerNorm due to the complexity of handling shapes.
        self.norm = nn.LayerNorm(input_size)  # Placeholder; the actual size might need to be adjusted.
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.mlp = self._create_mlp(input_size, hidden_size_ratio, out_features, num_layers)

    def _create_mlp(self, input_size, hidden_size_ratio, output_size, num_layers=3):
        hidden_size = int(input_size * hidden_size_ratio)
        mlp_layers = [nn.Linear(input_size, hidden_size), self.activation, self.dropout]
        for _ in range(num_layers - 2):
            mlp_layers += [nn.Linear(hidden_size, hidden_size), self.activation]
        mlp_layers.append(nn.Linear(hidden_size, output_size))
        return nn.Sequential(*mlp_layers)

    def forward(self, x):  # x :: batches, tokens, in features
        x = self.mlp(x)
        return x


# We now include an attention block here, with our aim being to interleave Attention,
# Mamba, and MLP blocks

class Attention(nn.Module):

    def __init__(self, embed_dim, n_heads, qkv_bias = False, attn_dropout = 0., projection_dropout=0.):
        super().__init__()
        self.embed_dim          = embed_dim
        self.n_heads            = n_heads
        self.head_dim           = embed_dim // n_heads
        self.scale              = self.head_dim ** -0.5 # From vaswani paper
        self.qkv                = nn.Linear(embed_dim, 3* embed_dim) # convert input to query, key and value
        self.project            = nn.Linear(embed_dim,embed_dim)
        self.project_dropout    = nn.Dropout(projection_dropout)
        self.attention_dropout  = nn.Dropout(attn_dropout)

    def forward(self,x):

        batches, tokens, embed_dim = x.shape # tokens = total patches plus 1 class token

        QueryKeyValue = self.qkv(x) # it is like a neural form of repmat function.
        QueryKeyValue = QueryKeyValue.reshape(batches, tokens, 3, self.n_heads,self.head_dim)
        # Above has following dim: batches, tokens, [Query  Key Value], num_heads, head_dim
        QueryKeyValue = QueryKeyValue.permute(      2,      0, 3,             1,           4)
        # Above has following dim: QKV, batches, num_heads, tokens, head_dim
        Query, Key, Value    = QueryKeyValue[0], QueryKeyValue[1], QueryKeyValue[2]
        # Above has following dim: batches, num_heads, tokens, head_dim
        Attn_dot_product     = (Query @ Key.transpose(-2, -1)) * self.scale
        # Above has following dim: batches, num_heads, tokens, tokens
        Attention_mechanism  = Attn_dot_product.softmax(dim=-1)
        # Above has following dim: batches, num_heads, tokens, tokens
        Attention_mechanism  = self.attention_dropout(Attention_mechanism)
        # Applying the mask (from Values)
        Masking_mechanism    = (Attention_mechanism @ Value).transpose(1,2)
        # Above has following dim: batches, tokens, num_heads, head_dimension
        Masking_mechanism    = Masking_mechanism.flatten(2)
        # Above has following dim: batches, tokens, (num_heads*head_dimension), or, batches, tokens, embedding_dim
        Projection_operation = self.project(Masking_mechanism)
        Projection_operation = self.project_dropout(Projection_operation)

        return Projection_operation





# We will now add a squeeze and attend block process


class GeneralizedDownsample(nn.Module):
    def __init__(self, embed_dim, original_side_length, target_side_length):
        super().__init__()
        self.original_side_length = original_side_length
        self.embed_dim = embed_dim
        # Define the 2D convolution with parameters to adjust from original to target grid sizes
        kernel_size = original_side_length - target_side_length + 1  # Kernel size based on the reduction needed
        self.conv2d = nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(kernel_size, kernel_size),
                                stride=1, padding=0, groups=embed_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.transpose(1, 2)  # Transpose to [batch, embedding, patches]
        x = x.view(batch_size, self.embed_dim, self.original_side_length, self.original_side_length)  # Reshape to spatial format
        x = self.conv2d(x)  # Apply convolution to reduce spatial dimensions
        x = x.view(batch_size, self.embed_dim, -1).transpose(1, 2)  # Flatten and transpose back
        return x


class GeneralizedUpsample(nn.Module):
    def __init__(self, embed_dim, num_patches_in, num_patches_out):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_patches_in = num_patches_in
        self.num_patches_out = num_patches_out
        original_side_length = int(math.sqrt(num_patches_in))
        target_side_length = int(math.sqrt(num_patches_out))

        # Upsampling layer
        self.upsample = nn.Upsample(size=(target_side_length, target_side_length), mode='bilinear', align_corners=True)

        # Convolution layer to refine the upsampled output
        self.conv2d = nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim,
                                kernel_size=3, stride=1, padding=1, groups=embed_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        # Transpose and unroll
        x = x.transpose(1, 2)  # [batch, embedding, num_patches_in]
        x = x.view(batch_size, self.embed_dim, int(math.sqrt(self.num_patches_in)), int(math.sqrt(self.num_patches_in)))  # [batch, embedding, sqrt(num_patches), sqrt(num_patches)]

        # Upsample and apply convolution
        x = self.upsample(x)  # [batch, embedding, sqrt(num_patches_out), sqrt(num_patches_out)]
        x = self.conv2d(x)    # [batch, embedding, sqrt(num_patches_out), sqrt(num_patches_out)]

        # Reshape and transpose back
        x = x.view(batch_size, self.embed_dim, self.num_patches_out).transpose(1, 2)  # [batch, num_patches_out, embedding]

        return x

class RelativeAttention(nn.Module):
    def __init__(self, inp, oup, image_size, patch_size, heads=8, projection_dropout=0., attn_dropout = 0.):
        super().__init__()


        self.embed_dim          = inp
        self.n_heads            = heads
        self.head_dim           = self.embed_dim // self.n_heads
        self.scale              = self.head_dim ** -0.5 # From vaswani paper
        self.qkv                = nn.Linear(self.embed_dim, 3* self.embed_dim) # convert input to query, key and value
        self.project            = nn.Linear(self.embed_dim,oup)
        self.projection_dropout = nn.Dropout(projection_dropout)
        self.attention_dropout  = nn.Dropout(attn_dropout)


        # parameter table of relative position bias
        # (comes from the window attention module in swin transformers)
        self.ih, self.iw   = image_size
        self.ih = int(self.ih/patch_size)
        self.iw = int(self.iw/patch_size)
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]
        relative_coords = relative_coords.permute(1,2,0).contiguous()

        relative_coords[:,:,0] += self.ih - 1
        relative_coords[:,:,1] += self.iw - 1 # these lines make relative coordinates position positive, as prior subtraction makes them negative
        relative_coords[:,:,0] *= 2 * self.iw - 1 # this scales the distance of the pixel (patch) positions in a new row of the image so that it is more than the last pixel (patch) in the previous row
        relative_index = relative_coords.sum(-1) # will be num patches^2 X num patches^2
        self.register_buffer("relative_index", relative_index)



    def forward(self, x):

        batches, tokens, embed_dim = x.shape # tokens = total patches plus 1 class token

        QueryKeyValue = self.qkv(x) # it is like a neural form of repmat function.
        QueryKeyValue = QueryKeyValue.reshape(batches, tokens, 3, self.n_heads,self.head_dim)
        # Above has following dim: batches, tokens, [Query  Key Value], num_heads, head_dim
        QueryKeyValue = QueryKeyValue.permute(      2,      0, 3,             1,           4)
        # Above has following dim: QKV, batches, num_heads, tokens, head_dim
        Query, Key, Value    = QueryKeyValue[0], QueryKeyValue[1], QueryKeyValue[2]
        # Above has following dim: batches, num_heads, tokens, head_dim
        Attn_dot_product     = (Query @ Key.transpose(-2, -1)) * self.scale


        #Estimate the relative position bias and add it to the attention operation
        relative_position_bias = self.relative_position_bias_table[self.relative_index.view(-1)].view(
            self.ih * self.iw,  self.ih * self.iw, -1
        )
        relative_position_bias = relative_position_bias.permute(2,0,1).contiguous()

        # What confuses me is the following should be done:
        # 1) simply add the relative position matrix to the query matrix
        # 2) perform a dot product as in the equation 5 of vaswani's paper.
        # point number 2 is important because in the windowed attention operation
        # of swin transformers, where the following line comes from,
        # the dot product has not been performed.
        Attn_dot_product +=  relative_position_bias.unsqueeze(0)


        # Above has following dim: batches, num_heads, tokens, tokens
        Attention_mechanism  = Attn_dot_product.softmax(dim=-1)
        Attention_mechanism  = self.attention_dropout(Attention_mechanism)

        Attention_mechanism  = Attn_dot_product.softmax(dim=-1)
        # Above has following dim: batches, num_heads, tokens, tokens


        # Applying the mask (from Values)
        Masking_mechanism    = (Attention_mechanism @ Value).transpose(1,2)

        # Above has following dim: batches, tokens, num_heads, head_dimension
        Masking_mechanism    = Masking_mechanism.flatten(2)
        # Above has following dim: batches, tokens, (num_heads*head_dimension), or, batches, tokens, embedding_dim
        Projection_operation = self.project(Masking_mechanism)
        Projection_operation = self.projection_dropout(Projection_operation)



        return Projection_operation



class SqueezeAndAttend(nn.Module):
    def __init__(self, embed_dim, target_patch_count_reduction, image_size, patch_size, heads=8, projection_dropout=0., attn_dropout=0.):
        super().__init__()
        # Initialize pre-norm layers
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ln3 = nn.LayerNorm(embed_dim)
        # Initialize downsampling module
        original_patch_count = int(image_size[0]/patch_size * image_size[1]/patch_size) # ASSUMING SQUARE. PLEASE USE ONLY SQUARES :')
        target_patch_count   = int(math.sqrt(original_patch_count) - target_patch_count_reduction)**2
        self.downsample_patches = GeneralizedDownsample(
                                      embed_dim             = embed_dim,
                                      original_side_length  = int(math.sqrt(original_patch_count)),
                                      target_side_length    = int(math.sqrt(original_patch_count) - target_patch_count_reduction)
                                      )
        # Initialize relative attention module
        self.relative_attention = RelativeAttention(
                                      inp                 = embed_dim,
                                      oup                 = embed_dim,
                                      image_size          = (int(math.sqrt(target_patch_count)), int(math.sqrt(target_patch_count))),
                                      patch_size          = int(math.sqrt(target_patch_count)),
                                      heads               = heads,
                                      projection_dropout  = projection_dropout,
                                      attn_dropout        = attn_dropout
                                      )
        # Initialize upsampling module
        self.upsample_patches   = GeneralizedUpsample(
                                      embed_dim         = embed_dim,
                                      num_patches_in    = target_patch_count,
                                      num_patches_out   = original_patch_count
                                      )
        # Choose activation (ReLU, GELU, ...)
        self.activation = nn.GELU()

    def forward(self, x):
        identity = x

        x = self.ln1(x)
        x = self.downsample_patches(x)
        x = self.activation(x)

        #print('Squeeze and attend, downsample:', x.shape)

        x = self.ln2(x)
        x = self.relative_attention(x)
        x = self.activation(x)

        #print('Squeeze and attend, rel attention:', x.shape)

        x = self.ln3(x)
        x = self.upsample_patches(x)
        x = self.activation(x)

        #print('Squeeze and attend, upsample:', x.shape)


        x += identity
        return x



# Next we code the state space model block (based on mamba), this includes the sequence scanning process defined inside the SSM block
# source: https://arxiv.org/pdf/2312.00752.pdf

class MambaBlock(nn.Module):


      def __init__(self, d_inner, conv_bias, inout_bias, conv_kernel, conv_groupval, d_model, d_state, dt_rank):
          super().__init__()

          # Params for convolution
          self.d_inner            = d_inner
          self.conv_bias          = conv_bias
          self.inout_bias         = inout_bias
          self.conv_kernel        = conv_kernel
          self.conv_groupval      = conv_groupval
          self.d_model            = d_model
          self.d_state            = d_state
          self.dt_rank            = dt_rank


          # input projection to mamba block, which splits to a residual and skip pathway (joined by sigmoid gate)
          self.in_proj = nn.Linear(self.d_model, self.d_inner*2, bias = self.inout_bias)

          # convolution inside residual branch
          self.conv1d = nn.Conv1d(
                        in_channels   = self.d_inner,
                        out_channels  = self.d_inner,
                        kernel_size   = self.conv_kernel,
                        bias          = self.conv_bias,
                        groups        = self.conv_groupval,
                        padding       = self.conv_kernel - 1,
                        )

          # fused linear kernel inside the SSM block:
          # uses the sequence input x (from, say, patch embedding), and creates
          # input dependent Δ, B, C matrices of the state-space equations
          self.x_proj = nn.Linear(
                        self.d_inner,
                        self.dt_rank + self.d_state*2, #dt_rank is for Δ, d_state*2 is for B and C
                        bias = False,
                        )

          # Further calculation of Δ via projection to d_in from dt_rank
          self.dt_proj = nn.Linear(
                        self.dt_rank,
                        self.d_inner,
                        bias = True,  # don't know why it is true here
                        )

          # Initialize the independent matrices
          A = torch.arange(1, self.d_state + 1).unsqueeze(0).repeat(self.d_inner, 1)
          self.A_log = nn.Parameter(torch.log(A))
          # chatGPT says: By taking the logarithm of A, the model parameters are
          # now defined in a space where multiplicative changes in the original
          # space become additive changes in the log space. This can make learning
          # more stable and efficient, especially when the true values of A that the
          # model needs to learn may vary across a wide range.
          self.D      = nn.Parameter(torch.ones(self.d_inner))

          # Output projection from a mamba block (to the next maybe)
          self.out_proj = nn.Linear(
                        self.d_inner,
                        self.d_model,
                        bias          = self.inout_bias
                        )


      def selective_scan_seq(self,u, delta, A, B, C, D):#, dim_inner: int, d_state: int):

          """
          Applies the continuous state space equation

          x(t+1) = A x(t) + B u(t)
          y(t)   = C x(t) + D u(t)

          with following interpretations of the A, B, C, and D matrices per mamba paper:
          here, A and D are input independent by construction, but they are not adapted to discrete forms yet
          B and C are input dependent by construction.

          Inputs:

          u (torch.Tensor)        :: batches, tokens, in_features, straight from the patch embedding module (tokens = num_patches, or 'sequence length' viewing images unwrapped as a sequence, like language)
          delta_t (torch.Tensor)  :: batches, tokens, in_features, is an input dependant (S4's contribution) trainable param
          A (torch.Tensor)        :: in_features, N (number of states)
          B (torch.Tensor)        :: batches, tokens, N (number of states),
          C (torch.Tensor)        :: batches, tokens, N (number of states),
          D (torch.Tensor)        :: N (number of states),
          dim_inner (int)         :: inner dimension size,
          d_state (int)           :: number of states


          Output:
          torch.Tensor            :: batches, tokens, in_features

          """

          batch, tokens, emb_dim  =   u.shape
          N                       =   A.shape[1]

          # discretize the continuous matrices:
          # A: discretized using zero order hold, eq. 4 of mamba paper
          # We add a dimension to end of delta  : batch , tokens, in_features, 1
          # We would add two dimensions to start of A : 1     , 1     , in_features, N,
          # but pytorch handles that broadcasting automatically
          deltaA = torch.exp(torch.einsum('b l d, d n -> b l d n', delta, A))
          # this now has batch, tokens, in_features, N dimensions

          # B: discretizing B using a simplified euler discretization, noting no loss
          # in performance relative to the true equation given in eq. 4 of mamba paper
          # delta goes from: batch, tokens, in_features -> batch, token, in_features, 1
          # B goes from: batches, tokens, N -> batches, tokens, 1, N
          deltaB                  =   delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, ED, N)
          # deltaB is then: (batch, token, in_features, 1), (batches, tokens, 1, N) -> (batch, tokens, in_features, N)

          # Now to complete the discretization
          # u goes from: batches, tokens, in_features -> batches, tokens, in_features, 1
          deltaB_u                =   deltaB * (u.unsqueeze(-1))  # (B, L, ED, N)
          # deltaB_u is then: (batch, tokens, in_features, N), (batch, tokens, in_features, 1) -> (batch, tokens, in_features, N)

          # now we perform a selective scan which is as follows:
          x = torch.zeros((batch, emb_dim, N), device = deltaA.device)
          # x is the predicted next state of x
          ys = []
          for i in range(tokens): # loop over all tokens
              # we will now implement the first formula of the
              # discretized state space equation:
              # x(t + 1) = deltaA * x(t) + deltaB_u
              x   =   deltaA[:,i,:,:] * x + deltaB_u[:,i,:,:]
              # x = (batch, tokens, in_features, N dimensions)
              # now compute the second equation, first by looping over C to calculate y = C*x
              y   =   torch.einsum('b d n, b n -> b d', x, C[:, i, :])
              ys.append(y)

          y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
          # complete the second equation of the ssm by adding D*u to C*x, taking into account the necessary
          # matrix shapes
          y = y + u*D

          return y

      def SSM(self,x):
        """
        The state space algorithm is called here, which further calls the selection mechanism proposed in mamba

        Input:
        x         ::    batches, tokens, in_features

        Output:
        y         ::    batches, tokens, in_features

        This algorithm is based on the selective scanning S4 algorithm (S6)
        shown in algorithm 2, section 3.2, of the mamba paper. By understanding
        this algorithm, you understand mamba in general.

        """
        emb_dim, n = self.A_log.shape

        # first gather the input ('x') independent matrices A and D of the state space equations:
        A       =     -torch.exp(self.A_log.float()) # hard to understand why this approach is adopted
        D       =     self.D.float()

        # now gather the input dependent functions (same fused convolutions like in forward in_proj operation)
        x_dbl         =     self.x_proj(x)
        delta, B, C   =     x_dbl.split(split_size = [self.dt_rank, n, n], dim=-1)
        # finish calculating delta using the softplus operation
        delta         =     F.softplus(self.dt_proj(delta))

        # now that A, B, C, D, and delta have been gathered, perform selective scanning

        y             =     self.selective_scan_seq(x, delta, A, B, C, D)

        return y



      def forward(self,x):

        """
        Forward process of the mamba block, based on figure 3 (right) of the mamba paper

        Input:
            x       ::    batch, tokens, embedding dimension

        Output:
            out     ::    batch, tokens, embedding dimension

        """
        batch, tokens, emb_dim = x.shape

        # perform input end projection
        # (both do not have the same convolutional layer, but
        # the convolutions are fused)
        x_and_res = self.in_proj(x)
        # split x and the residual
        x, residual = x_and_res.split(split_size = [self.d_inner, self.d_inner] ,dim=-1)

        # now process the residual pathway where the SSM will operate
        x = torch.permute(x, (0, 2, 1)) # similar to: einops.rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:,:,:tokens]
        x = torch.permute(x, (0, 2, 1)) # similar to: einops.rearrange(x, 'b d_in l -> b l d_in')

        # apply activation
        x = F.silu(x)

        # apply SSM
        y = self.SSM(x)

        # apply gating mechanism
        y = y * F.silu(residual)

        # apply final outer projection
        out = self.out_proj(y)

        return out



# Now we code an attention operation block, bringing it from the transformer implementation we have


# Vision backbone with mamba-former blocks (my interpretation)
# Inspired by source: https://arxiv.org/pdf/2402.04248.pdf and https://arxiv.org/pdf/2401.09417.pdf

class HybridMambaTransformerBlock(nn.Module):

    def __init__(self,
                 embedding_dim,
                 num_heads,
                 d_inner,
                 conv_bias,
                 inout_bias,
                 conv_kernel,
                 conv_groupval,
                 squeeze_and_attend_reduction,
                 d_state,
                 dt_rank,
                 image_size,
                 patch_size,
                 MLP_ratio=4.0,
                 qkv_bias = True,
                 attention_dropout=0.,
                 projection_dropout=0.,
                 ):
        super().__init__()
        self.norm1              = nn.LayerNorm(embedding_dim,eps=1e-6)
        self.norm2              = nn.LayerNorm(embedding_dim,eps=1e-6)
        self.norm3              = nn.LayerNorm(embedding_dim,eps=1e-6)
        self.attention          = Attention(embedding_dim,num_heads,qkv_bias,attention_dropout,projection_dropout)
        self.squeeze_and_attend = SqueezeAndAttend(
                                  embed_dim                       = embedding_dim,
                                  target_patch_count_reduction    = squeeze_and_attend_reduction,
                                  image_size                      = (image_size,image_size),
                                  patch_size                      = patch_size,
                                  heads                           = num_heads,
                                  projection_dropout              = 0.1,
                                  attn_dropout                    = 0.2
                                  )

        self.mlp                = MultiLayerPerceptron(embedding_dim, MLP_ratio, embedding_dim, projection_dropout, num_layers=3)
        self.mamba              = MambaBlock(d_inner, conv_bias, inout_bias, conv_kernel, conv_groupval, embedding_dim, d_state, dt_rank)

    def forward(self,x):
        #print('Input to block:', x.shape)
        x = x + self.mamba(self.norm1(x))     # SSM
        #print('State space modeling:', x.shape)
        # The following line is the vanilla attention, and it will handle a class token
        #x = x + self.attention(self.norm2(x)) # Attention
        # The following line is the relative attention, it will not handle a class token
        #x = x + self.relative_attention(self.norm2(x)) # Attention

        # Apply squeeze and attend
        x = x + self.squeeze_and_attend(self.norm2(x)) # Attention (1 block)
        x = x + self.mlp(self.norm3(x))       # MLP
        #print('Feed forward and exit from block:', x.shape)
        return x




# Now we code a complete encoder


class HybridMambaTransformerEncoder(nn.Module):

    def __init__(self,
                 image_size,            # image_size (int)            : size of the input image
                 patch_size,            # patch_size (int)            : size of the patches to be extracted from the input image
                 in_channels,           # in_channels (int)           : number of input channels
                 embedding_dim,         # embedding_dim (int)         : number of elements of the embedding vector (per patch)
                 feature_size,          # feature_size (int)          : Total size of feature vector
                 n_blocks,              # n_blocks (int)              : total number of sequential transformer blocks (a.k.a. depth)
                 n_heads,               # n_heads (int)               : total number of attention heads per transformer block
                 mlp_ratio,             # mlp_ratio (float)           : the ratio by which embedding dimension expands inside a transformer block (in the MLP layer after attention)
                 qkv_bias,              # qkv_bias (bool)             : whether to add a bias term to the qkv projection layer or not
                 attention_dropout,     # attention_dropout (float)   : dropout in the attention layer
                 projection_dropout,    # projection_dropout (float)  : dropout in the projection layer
                 squeeze_and_attend_reduction,
                 d_inner,
                 conv_bias,
                 inout_bias,
                 conv_kernel,
                 conv_groupval,
                 d_state,
                 dt_rank
                 ):
        super().__init__()
        self.patch_embedding    = PatchEmbed(
                                            img_size        =   image_size,
                                            patch_size      =   patch_size,
                                            in_channels     =   in_channels,
                                            embed_dim       =   embedding_dim
                                            )

        #self.class_token        = nn.Parameter(torch.zeros(1,1,embedding_dim))

        # If you want to account for the class token, take this line, it has the + 1 included
        self.position_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, embedding_dim))

        #self.position_dropout   = nn.Dropout(p = projection_dropout)
        self.blocks             = nn.ModuleList(
                                        [
                                         HybridMambaTransformerBlock(
                                                            image_size          = image_size,
                                                            patch_size          = patch_size,
                                                            embedding_dim       = embedding_dim,
                                                            num_heads           = n_heads,
                                                            MLP_ratio           = mlp_ratio,
                                                            qkv_bias            = qkv_bias,
                                                            attention_dropout   = attention_dropout,
                                                            projection_dropout  = projection_dropout,
                                                            squeeze_and_attend_reduction = squeeze_and_attend_reduction[0],
                                                            d_inner             = d_inner,
                                                            conv_bias           = conv_bias,
                                                            inout_bias          = inout_bias,
                                                            conv_kernel         = conv_kernel,
                                                            conv_groupval       = conv_groupval,
                                                            d_state             = d_state,
                                                            dt_rank             = dt_rank
                                                             )
                                         for _ in range(n_blocks)]
                                        )
        self.norm               = nn.LayerNorm(embedding_dim, eps=1e-6)
        self.head               = nn.Linear(embedding_dim, feature_size)



    def forward(self,x):

        batches, _, W, H = x.shape # B, C, W, H
        #print('Input image stack:', x.shape)
        x                   = self.patch_embedding(x) # convert images to patch embedding
        #print('Patch embedding:', x.shape)

        # We comment out the class token here, we don't want it with our relative attention mechanism
        # This is because the class token is not a physical patch position, and will mess with our
        # current version of the relative attention.
        #class_token         = self.class_token.expand(batches, -1, -1) #
        #x                   = torch.cat((class_token,x), dim=1) # class token is not appended to the patch tokens
        #print('adding class token',x.shape)


        # As we have a relative attention implemented, it carries its own
        # positional encoding relative to the input matrix. We therefore avoid
        # usage of a positional encoding explicitly altogether. Just a simple
        # patch level encoding is deemend sufficient. The following lines
        # are commented out.
        # In classical vision transformers, the position embedding is strictly
        # unchanged in the forward pass, this is because the ViT expects only one
        # image size, none else.
        #x                   = x + self.position_embedding # Add the position embedding mechanism
        # However, we want a variable positional encoding, allowing us to
        # use the same ViT architecture for multiple image sizes.
        #x                   = x + self.pos_embedding_interp(x, H, W)


        #x                   = self.position_dropout(x)
        for block in self.blocks:
            x = block(x)


        x                   = self.norm(x) # add the layer norm mechanism now, giving us n_samples X (class token + patch token) X embedding dim
        x                   = x[:, 1:, :].mean(dim=1)  # global pool without cls token, giving us n_samples X embedding_dim

        # the 1: is done in the second dim because the first entry there is the class token, which we do not need (why do we have it then? lol...)
        #print('Input to final head:', x.shape)
        x                   = self.head(x) # expand feature set to intended feature size
        #print('Output from network:', x.shape)

        return x






In [None]:

# Define the parameters for the Mamba Block and Transformer Encoder
d_inner = 512
conv_bias = True
inout_bias = True
conv_kernel = 3
conv_groupval = 1
d_model = 128  # Note: This should match embedding_dim or be adjusted in the architecture
d_state = 32
dt_rank = 16

# Define the parameters for the Transformer Encoder
image_size = 32
patch_size = 8
in_channels = 3
embedding_dim = 384  # Increased compared to tiny
feature_size = 10
n_blocks = 6  # Similar to the base model
n_heads = 8  # More heads than tiny, fewer than base
mlp_ratio = 4.0
qkv_bias = True
attention_dropout = 0.1
projection_dropout = 0.1
squeeze_and_attend_reduction = 1,
# Instantiate the Hybrid Mamba Transformer Encoder
model = HybridMambaTransformerEncoder(
    image_size=image_size,
    patch_size=patch_size,
    in_channels=in_channels,
    embedding_dim=embedding_dim,
    feature_size=feature_size,
    n_blocks=n_blocks,
    n_heads=n_heads,
    mlp_ratio=mlp_ratio,
    qkv_bias=qkv_bias,
    attention_dropout=attention_dropout,
    projection_dropout=projection_dropout,
    squeeze_and_attend_reduction = squeeze_and_attend_reduction,
    d_inner=d_inner,
    conv_bias=conv_bias,
    inout_bias=inout_bias,
    conv_kernel=conv_kernel,
    conv_groupval=conv_groupval,
    d_state=d_state,
    dt_rank=dt_rank
)

test = torch.rand(4,3,image_size,image_size)
out = model(test)

print('Model param count:',sum(p.numel() for p in model.parameters())/1000000, ' million')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Model param count: 37.170874  million


In [None]:
# Train the code:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

# Assuming HybridMambaTransformerEncoder is defined and imported correctly

# Prepare datasets with transformations and dataloaders
from torchvision import transforms

transform = transforms.Compose([
    #transforms.RandomResizedCrop(32),  # Example target size, adjust as necessary
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),  # Rotates by degrees selected from (-10, 10)
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Adjust normalization parameters as needed
])


BATCH_SIZE = 1300

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)



# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95), weight_decay=0.05)

# Training and Validation Loop
n_epochs = 100  # Number of epochs

for epoch in range(n_epochs):
    # Training Phase
    model.train()
    running_loss = 0.0

    dataloader = tqdm(train_loader,desc=f"epoch {epoch+1}/{n_epochs}")
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        dataloader.set_postfix(loss=loss.item())

    avg_train_loss = running_loss / len(train_loader)

    # Validation Phase
    model.eval()
    running_val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_val_loss = running_val_loss / len(val_loader)
    val_accuracy = 100 * correct / total

    dataloader.set_postfix(avg_val_loss=avg_val_loss,avg_train_loss=avg_train_loss)

    print(f'Epoch [{epoch+1}/{n_epochs}], Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12188230.68it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


epoch 1/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [1/100], Train Loss: 2.0397, Validation Loss: 1.8804, Validation Accuracy: 30.46%


epoch 2/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [2/100], Train Loss: 1.8121, Validation Loss: 1.7492, Validation Accuracy: 36.76%


epoch 3/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [3/100], Train Loss: 1.7049, Validation Loss: 1.6634, Validation Accuracy: 39.97%


epoch 4/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [4/100], Train Loss: 1.6362, Validation Loss: 1.5893, Validation Accuracy: 42.40%


epoch 5/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [5/100], Train Loss: 1.5711, Validation Loss: 1.5737, Validation Accuracy: 42.54%


epoch 6/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [6/100], Train Loss: 1.5267, Validation Loss: 1.5186, Validation Accuracy: 45.07%


epoch 7/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [7/100], Train Loss: 1.4773, Validation Loss: 1.4546, Validation Accuracy: 47.88%


epoch 8/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [8/100], Train Loss: 1.4379, Validation Loss: 1.4193, Validation Accuracy: 49.21%


epoch 9/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [9/100], Train Loss: 1.4145, Validation Loss: 1.4100, Validation Accuracy: 48.52%


epoch 10/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [10/100], Train Loss: 1.3691, Validation Loss: 1.3575, Validation Accuracy: 51.39%


epoch 11/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [11/100], Train Loss: 1.3477, Validation Loss: 1.3328, Validation Accuracy: 52.81%


epoch 12/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [12/100], Train Loss: 1.3185, Validation Loss: 1.3449, Validation Accuracy: 52.18%


epoch 13/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [13/100], Train Loss: 1.2941, Validation Loss: 1.2863, Validation Accuracy: 54.08%


epoch 14/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [14/100], Train Loss: 1.2769, Validation Loss: 1.2885, Validation Accuracy: 53.97%


epoch 15/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [15/100], Train Loss: 1.2481, Validation Loss: 1.2585, Validation Accuracy: 54.68%


epoch 16/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [16/100], Train Loss: 1.2307, Validation Loss: 1.2425, Validation Accuracy: 55.57%


epoch 17/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [17/100], Train Loss: 1.2180, Validation Loss: 1.2333, Validation Accuracy: 55.84%


epoch 18/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [18/100], Train Loss: 1.1917, Validation Loss: 1.2183, Validation Accuracy: 56.61%


epoch 19/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [19/100], Train Loss: 1.1750, Validation Loss: 1.1893, Validation Accuracy: 57.08%


epoch 20/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [20/100], Train Loss: 1.1681, Validation Loss: 1.1891, Validation Accuracy: 58.00%


epoch 21/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [21/100], Train Loss: 1.1487, Validation Loss: 1.1860, Validation Accuracy: 57.69%


epoch 22/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [22/100], Train Loss: 1.1291, Validation Loss: 1.1616, Validation Accuracy: 58.87%


epoch 23/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [23/100], Train Loss: 1.1094, Validation Loss: 1.1535, Validation Accuracy: 58.94%


epoch 24/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [24/100], Train Loss: 1.1026, Validation Loss: 1.1387, Validation Accuracy: 59.57%


epoch 25/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [25/100], Train Loss: 1.0901, Validation Loss: 1.1153, Validation Accuracy: 60.27%


epoch 26/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [26/100], Train Loss: 1.0651, Validation Loss: 1.1298, Validation Accuracy: 60.47%


epoch 27/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [27/100], Train Loss: 1.0593, Validation Loss: 1.0969, Validation Accuracy: 61.55%


epoch 28/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [28/100], Train Loss: 1.0490, Validation Loss: 1.1033, Validation Accuracy: 61.20%


epoch 29/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [29/100], Train Loss: 1.0386, Validation Loss: 1.0699, Validation Accuracy: 62.01%


epoch 30/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [30/100], Train Loss: 1.0315, Validation Loss: 1.0733, Validation Accuracy: 62.11%


epoch 31/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [31/100], Train Loss: 1.0210, Validation Loss: 1.0638, Validation Accuracy: 62.65%


epoch 32/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [32/100], Train Loss: 0.9986, Validation Loss: 1.0599, Validation Accuracy: 62.83%


epoch 33/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [33/100], Train Loss: 0.9939, Validation Loss: 1.0262, Validation Accuracy: 64.08%


epoch 34/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [34/100], Train Loss: 0.9832, Validation Loss: 1.0391, Validation Accuracy: 63.33%


epoch 35/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [35/100], Train Loss: 0.9718, Validation Loss: 1.0178, Validation Accuracy: 64.20%


epoch 36/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [36/100], Train Loss: 0.9623, Validation Loss: 1.0053, Validation Accuracy: 64.46%


epoch 37/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [37/100], Train Loss: 0.9504, Validation Loss: 1.0125, Validation Accuracy: 64.99%


epoch 38/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [38/100], Train Loss: 0.9363, Validation Loss: 0.9939, Validation Accuracy: 64.87%


epoch 39/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [39/100], Train Loss: 0.9288, Validation Loss: 0.9996, Validation Accuracy: 64.33%


epoch 40/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [40/100], Train Loss: 0.9123, Validation Loss: 0.9844, Validation Accuracy: 65.32%


epoch 41/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [41/100], Train Loss: 0.9161, Validation Loss: 0.9840, Validation Accuracy: 65.56%


epoch 42/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [42/100], Train Loss: 0.9094, Validation Loss: 0.9824, Validation Accuracy: 65.46%


epoch 43/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [43/100], Train Loss: 0.8935, Validation Loss: 0.9583, Validation Accuracy: 66.29%


epoch 44/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [44/100], Train Loss: 0.8880, Validation Loss: 0.9649, Validation Accuracy: 66.19%


epoch 45/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [45/100], Train Loss: 0.8731, Validation Loss: 0.9487, Validation Accuracy: 66.87%


epoch 46/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [46/100], Train Loss: 0.8668, Validation Loss: 0.9530, Validation Accuracy: 66.88%


epoch 47/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [47/100], Train Loss: 0.8684, Validation Loss: 0.9383, Validation Accuracy: 66.61%


epoch 48/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [48/100], Train Loss: 0.8550, Validation Loss: 0.9455, Validation Accuracy: 66.62%


epoch 49/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [49/100], Train Loss: 0.8470, Validation Loss: 0.9435, Validation Accuracy: 67.09%


epoch 50/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [50/100], Train Loss: 0.8429, Validation Loss: 0.9211, Validation Accuracy: 67.93%


epoch 51/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [51/100], Train Loss: 0.8279, Validation Loss: 0.9515, Validation Accuracy: 67.43%


epoch 52/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [52/100], Train Loss: 0.8301, Validation Loss: 0.9263, Validation Accuracy: 67.46%


epoch 53/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [53/100], Train Loss: 0.8152, Validation Loss: 0.9133, Validation Accuracy: 68.00%


epoch 54/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [54/100], Train Loss: 0.8099, Validation Loss: 0.9058, Validation Accuracy: 67.99%


epoch 55/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [55/100], Train Loss: 0.8046, Validation Loss: 0.8942, Validation Accuracy: 68.91%


epoch 56/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [56/100], Train Loss: 0.7955, Validation Loss: 0.8929, Validation Accuracy: 68.71%


epoch 57/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [57/100], Train Loss: 0.7955, Validation Loss: 0.8853, Validation Accuracy: 69.06%


epoch 58/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [58/100], Train Loss: 0.7819, Validation Loss: 0.8747, Validation Accuracy: 69.57%


epoch 59/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [59/100], Train Loss: 0.7740, Validation Loss: 0.8935, Validation Accuracy: 68.84%


epoch 60/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [60/100], Train Loss: 0.7732, Validation Loss: 0.8763, Validation Accuracy: 70.07%


epoch 61/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [61/100], Train Loss: 0.7692, Validation Loss: 0.8661, Validation Accuracy: 69.87%


epoch 62/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [62/100], Train Loss: 0.7614, Validation Loss: 0.8798, Validation Accuracy: 69.29%


epoch 63/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [63/100], Train Loss: 0.7586, Validation Loss: 0.8606, Validation Accuracy: 70.06%


epoch 64/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [64/100], Train Loss: 0.7517, Validation Loss: 0.8744, Validation Accuracy: 69.55%


epoch 65/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [65/100], Train Loss: 0.7442, Validation Loss: 0.8591, Validation Accuracy: 69.68%


epoch 66/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [66/100], Train Loss: 0.7333, Validation Loss: 0.8637, Validation Accuracy: 69.85%


epoch 67/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [67/100], Train Loss: 0.7335, Validation Loss: 0.8688, Validation Accuracy: 69.54%


epoch 68/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [68/100], Train Loss: 0.7302, Validation Loss: 0.8529, Validation Accuracy: 70.05%


epoch 69/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [69/100], Train Loss: 0.7208, Validation Loss: 0.8521, Validation Accuracy: 70.13%


epoch 70/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [70/100], Train Loss: 0.7088, Validation Loss: 0.8489, Validation Accuracy: 70.81%


epoch 71/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [71/100], Train Loss: 0.7090, Validation Loss: 0.8384, Validation Accuracy: 70.98%


epoch 72/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [72/100], Train Loss: 0.6991, Validation Loss: 0.8617, Validation Accuracy: 70.14%


epoch 73/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [73/100], Train Loss: 0.6977, Validation Loss: 0.8357, Validation Accuracy: 70.78%


epoch 74/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [74/100], Train Loss: 0.6915, Validation Loss: 0.8426, Validation Accuracy: 70.63%


epoch 75/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [75/100], Train Loss: 0.6900, Validation Loss: 0.8426, Validation Accuracy: 70.98%


epoch 76/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [76/100], Train Loss: 0.6817, Validation Loss: 0.8443, Validation Accuracy: 70.88%


epoch 77/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [77/100], Train Loss: 0.6786, Validation Loss: 0.8245, Validation Accuracy: 71.39%


epoch 78/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [78/100], Train Loss: 0.6747, Validation Loss: 0.8110, Validation Accuracy: 72.34%


epoch 79/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [79/100], Train Loss: 0.6615, Validation Loss: 0.8207, Validation Accuracy: 71.70%


epoch 80/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [80/100], Train Loss: 0.6610, Validation Loss: 0.8190, Validation Accuracy: 72.21%


epoch 81/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [81/100], Train Loss: 0.6628, Validation Loss: 0.8224, Validation Accuracy: 71.61%


epoch 82/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [82/100], Train Loss: 0.6514, Validation Loss: 0.8285, Validation Accuracy: 71.83%


epoch 83/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [83/100], Train Loss: 0.6431, Validation Loss: 0.8016, Validation Accuracy: 72.19%


epoch 84/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [84/100], Train Loss: 0.6446, Validation Loss: 0.8165, Validation Accuracy: 72.03%


epoch 85/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [85/100], Train Loss: 0.6373, Validation Loss: 0.8175, Validation Accuracy: 72.00%


epoch 86/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [86/100], Train Loss: 0.6359, Validation Loss: 0.8170, Validation Accuracy: 72.37%


epoch 87/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [87/100], Train Loss: 0.6290, Validation Loss: 0.7987, Validation Accuracy: 72.85%


epoch 88/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [88/100], Train Loss: 0.6280, Validation Loss: 0.8187, Validation Accuracy: 71.65%


epoch 89/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [89/100], Train Loss: 0.6208, Validation Loss: 0.7845, Validation Accuracy: 72.96%


epoch 90/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [90/100], Train Loss: 0.6101, Validation Loss: 0.7935, Validation Accuracy: 72.43%


epoch 91/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [91/100], Train Loss: 0.6100, Validation Loss: 0.7975, Validation Accuracy: 73.37%


epoch 92/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [92/100], Train Loss: 0.6070, Validation Loss: 0.8001, Validation Accuracy: 72.76%


epoch 93/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [93/100], Train Loss: 0.6041, Validation Loss: 0.7961, Validation Accuracy: 72.61%


epoch 94/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [94/100], Train Loss: 0.6001, Validation Loss: 0.8061, Validation Accuracy: 72.43%


epoch 95/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [95/100], Train Loss: 0.5947, Validation Loss: 0.7918, Validation Accuracy: 73.05%


epoch 96/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [96/100], Train Loss: 0.5890, Validation Loss: 0.7956, Validation Accuracy: 72.69%


epoch 97/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [97/100], Train Loss: 0.5842, Validation Loss: 0.7967, Validation Accuracy: 73.28%


epoch 98/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [98/100], Train Loss: 0.5817, Validation Loss: 0.8119, Validation Accuracy: 72.56%


epoch 99/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [99/100], Train Loss: 0.5722, Validation Loss: 0.7875, Validation Accuracy: 73.42%


epoch 100/100:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [100/100], Train Loss: 0.5801, Validation Loss: 0.7890, Validation Accuracy: 73.11%
