<a href="https://colab.research.google.com/github/vjhawar12/FreshNET-A-mobileNET-adaptation/blob/main/FreshNET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install lightning-bolts

In [None]:
import torch
import torch.nn as nn
from torch import optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

In [None]:
!gcloud auth application-default login

In [None]:
!gcloud config set project freshnet-466505

In [None]:
!cd /content/dataset && gcloud storage cp --recursive gs://fruit-images-freshnet .

In [None]:
path_to_train_imgs = ""
path_to_test_imgs = ""

BATCH_SIZE = 512
RANDOM_SEED = 42
TRAIN_SIZE, VAL_SIZE = 0.8, 0.2
IMG_SIZE = 320 # Original images are ~400*400 px so resizing them to 320 retains detail while reducing computational cost
EPOCHS = 50
WARMUP_DUR = 20 # final epoch number for linear warmup and cosine decay
LR_MIN = 0.0001 # minimum learning rate
MILD_DROPOUT_RATE = 0.1

In [None]:
train_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE, IMG_SIZE),

    """
    Mild preprocessing only. No MixUp, CutMix, RandomCrop, or ColorJitter because a lightweight model like this one
    is less likely to overfit. Also, this model needs to perform fine-grained classification,
    so aggressive augmentations could distort the small image regions (like signs of fungi or discoloration)
    that are crucial for accurate prediction.
    """

    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

])

test_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE, IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
full_train_data = ImageFolder(path_to_train_imgs, transform=train_transform)
test_data = ImageFolder(path_to_test_imgs, transform=test_transform)
train_data, val_data = random_split(full_train_data, [TRAIN_SIZE, VAL_SIZE], generator=torch.Generator().manual_seed(RANDOM_SEED))

print(f"{full_train_data.class_to_idx} \n {train_data.class_to_idx} \n {test_data.class_to_idx}")

In [None]:
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
device = torch.device("cuda" if torch.device.cuda.is_available() else "cpu")

if torch.device.cuda.is_available():
  torch.backends.cuda.matmul.allow_tf32 = True
  torch.backends.cudnn.allow_tf32 = True
  torch.backends.cuda.enable_flash_sdp(True)
  torch.backends.cuda.enable_mem_efficient_sdp(True)
  torch.backends.cuda.enable_math_sdp(True)

In [None]:
"""
Helper class for FreshNET. This is the implementation of the "Depthwise seperable convolution" as outlined in the MobileNET paper. Instances of this class are stacked
together and each stack constitutes a hidden layer in the NN. A DepthwiseSeperableConvolution instance is structured as follows:

- Pointwise convolution (applied along all channels at a single pixel) for channel expansion.
- Depthwise convolution (applies across each channel individually) for the model to efficiently learn features in a high dimensional space.
- Pointwise convolution to compress channels slightly.

In between the convolutional layers are batch normalizations and after the final layer is an activation function.
"""

class DepthwiseSeperableConvolution(nn.Module):
  def __init__(self, in_channels, out_channels, exp_factor, downsample_factor, kernel_size=3):
    super().__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.exp_factor = exp_factor
    self.downsample_factor = downsample_factor
    self.kernel_size = kernel_size

    out = self.in_channels * self.exp_factor

    self.preserve = self.in_channels == self.out_channels

    self.block = nn.Sequential(
        nn.Conv2d(in_channels=self.in_channels, out_channels=out, kernel_size=1, stride=1, groups=1), # expansion

        """
        batchnorm normalizes inputs which has shown to improve training. It's typically applied after the convolutional layer output and before the activation function
        The MobileNET paper also mentions, with the exception of the final fully connected linear layer, "all layers are followed by a batchnorm and ReLU nonlinearity"
        """
        nn.BatchNorm2d(out),

        nn.Conv2d(in_channels=out, out_channels=out, kernel_size=self.kernel_size, stride=self.downsample_factor, groups=out), # depthwise
        nn.BatchNorm2d(out),
        nn.Conv2d(in_channels=out, out_channels=self.out_channels, kernel_size=1, stride=1, groups=1), # pointwise
        nn.BatchNorm2d(self.out_channels),

        nn.ReLU6(), # output ∈ [0, 6] ==> more efficient than regular ReLU
    )

  def forward(self, x):
    x = self.block(x)
    return x

In [None]:
# a simple skip connection implementation

"""
FreshNET only applies skip connections between DepthwiseSeperableConvolution instances which preserve channel dimension in order to adhere to the laws of vector addition.
"""
class SkipConnection(nn.Module):
  def __init__(self, module_list): # module_list is an nn.ModuleList instance
    super().__init__()

    self.module_list = module_list

    for module in module_list:
      if not module.preserve:
        raise Exception("Cannot apply skip connection between layers of different channel dimensions")

    self.module_list = module_list

  def forward(self, x):
    self.gradient = x

    for depthwise_sep_conv in self.module_list:
      x = depthwise_sep_conv(x)

    return x + self.gradient

In [None]:
# A MobileNET adaptation
class FreshNET(nn.Module):

  """
  FreshNET applies the concept of depthwise seperable convolutions as mentioned in the MobileNet paper, but the image dimensions are slightly increased
  to better suit the dataset and leverage higher image quality. Specifically, FreshNET decouples spatial filtering and channel mixing,
  the two operations traditionally combined in standard convolutions, and performs them as separate, more efficient operations. As the authors of MobileNet noted,
  it "drastically reduc[es] computation and model size". Despite this computational efficiency, accuracy is largely preserved -- the original MobileNet paper reports
  only about a 1% drop in accuracy compared to standard convolutions.

  This separation is implemented in DepthwiseSeperableConvolution class. Each instance consists of:
  - A pointwise 1*1 convolution for channel expansion, increasing feature dimensionality.
  - A depthwise 3*3 convolution applied independently per channel for spatial filtering.
  - Another pointwise convolution to project back to a lower-dimensional space.

  This design significantly reduces computational cost while maintaining high representational power. There are 10 layers, of which 7 are DepthwiseSeperableConvolution layers.
  Each DepthwiseSeperableConvolution layer consists of 1-4 DepthwiseSeperableConvolution instances. Within the DepthwiseSeperableConvolution layer, there are non-linear activation functions.

  Skip connections are also applied where possible to improve gradient flow and preserve fine-grained information. For a task like fresh/rotten classification,
  fine-grained information can be very valuable.
  """

  def __init__(self):
    super().__init__()

    # SiLU is good for vanishing gradients and can be applied in the later layers where ReLU might not be as effective (dead neurons)
    silu = nn.SiLU()

    # Regularization
    """
    The original paper notes "we use less regularization and data augmentation techniques because small models have less trouble with overfitting".
    This is why I chose to use dropout with 10% probability since that's on the lower end of the probability spectrum.
    """
    dropout_mild = nn.Dropout(MILD_DROPOUT_RATE)

    # initial regular convolution
    initial_conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2) # 320*320*3 --> 160*160*32

    # inverted residual block 1: 160*160*32 --> 160*160*16, channel expansion = 1
    block_1 = nn.ModuleList([
        DepthwiseSeperableConvolution(32, 16, 1, 1)]
    )

    # inverted residual block 2: 160*160*16 --> 80*80*24, channel expansion = 6
    block_2 = nn.ModuleList(
        [DepthwiseSeperableConvolution(16, 24, 6, 2),
        DepthwiseSeperableConvolution(24, 24, 6, 1)]
    )

    # inverted residual block 3: 80*80*24 --> 40*40*32, channel expansion = 6
    block_3 = nn.ModuleList([
        DepthwiseSeperableConvolution(24, 32, 6, 2),
        SkipConnection(nn.ModuleList([
              DepthwiseSeperableConvolution(32, 32, 6, 1),
              DepthwiseSeperableConvolution(32, 32, 6, 1),
          ])
        )]
    )

    # inverted residual block 4: 40*40*32 --> 20*20*64, channel expansion = 6
    block_4 = nn.ModuleList([
        DepthwiseSeperableConvolution(32, 64, 6, 2),
        SkipConnection(nn.ModuleList([
              DepthwiseSeperableConvolution(64, 64, 6, 1),
              DepthwiseSeperableConvolution(64, 64, 6, 1),
              DepthwiseSeperableConvolution(64, 64, 6, 1),
            ])
        )]
    )

    # inverted residual block 5: 20*20*64 --> 20*20*96, channel expansion = 6
    block_5 = nn.ModuleList([
        DepthwiseSeperableConvolution(64, 96, 6, 1),
        SkipConnection(nn.ModuleList([
              DepthwiseSeperableConvolution(96, 96, 6, 1),
              DepthwiseSeperableConvolution(96, 96, 6, 1),
            ])
        )]
    )

    # inverted residual block 6: 20*20*96 --> 10*10*160, channel expansion = 6
    block_6 = nn.ModuleList([
        DepthwiseSeperableConvolution(96, 160, 6, 2),
        SkipConnection(nn.ModuleList([
              DepthwiseSeperableConvolution(160, 160, 6, 1),
              DepthwiseSeperableConvolution(160, 160, 6, 1),
            ])
        )]
    )

    # inverted residual block 7: 10*10*160 --> 10*10*320, channel expansion = 6
    block_7 = nn.ModuleList([
        DepthwiseSeperableConvolution(160, 320, 6, 1)
    ])

    # final regular convolution
    final_conv = nn.Conv2d(in_channels=320, out_channels=1280, kernel_size=1, stride=1) # 10*10*320 --> 10*10*1280

    """
      MaxPooling is often preferred for highlighting dominant features, but it can be too aggressive in lightweight models
      like MobileNet or EfficientNet, which already have low spatial resolution. It can discard valuable spatial
      information. Average Pooling computes the average of nearby pixels, preserving more context. This makes it more
      suitable for compact architectures like this one.
    """

    avg_pool = nn.AvgPool2d(10) # 10*10*1280 --> 1*1*1280

    # fully connected layers
    """
    The MobileNET paper mapped from 1280 directly to 1000 since it was being trained on ImageNET. My dataset, however, only has 2
    output classes: fresh (0) or rotten (1). Mapping directly from 1280 to 2 is an abrupt jump which could limit the model's ability to
    distinguish between classes, so I introduced an additional layer to map from 1280 to 500 then applied nonlinearity and
    mild dropout before going from 500 to 2.
    """
    fc_1 = nn.Linear(in_features=1280, out_features=500)
    fc_2 = nn.Linear(in_features=500, out_features=2)

    """
    The sequence of layers is as follows:

      - intial regular convolution
      - 4 stacks of DepthwiseSeperableConvolution
      - mild dropout for regularization
      - 3 more stacks of DepthwiseSeperableConvolution
      - 1 final regular convolution
      - average pooling
      - activation function
      - first fully connected linear layer
      - activation function
      - mild dropout for regularization
      - final fully connected linear layer

    There are skip connections in blocks 3-6.
    """

    self.layers = nn.ModuleList([
        initial_conv,

        block_1,
        block_2,
        block_3,
        block_4,
        dropout_mild,

        block_5,
        block_6,
        block_7,

        final_conv,
        avg_pool,
        silu,
        dropout_mild,

        fc_1,
        silu,

        fc_2
    ])


  def forward(self, x):
    for layer in self.layers:
      x = layer(x)

    return x

In [None]:
cnn = torch.compile(FreshNET())
cnn.to(device)

loss_fn = nn.CrossEntropyLoss() # std loss function for classification

# builds upon AdaGRAD but prevents lr from decreasing too much. Used in the original paper.
optimizer = optim.RMSprop(cnn.parameters())

"""
Not mentioned in the original paper, but this scheduler was used because linear warmup reduces volatility in the earlier epochs.
Cosine annealing can improve training stability and convergence and declining LR steadily means the model will rely more on the features
it learns early on -- the key differentiators -- rather than picking up potential noise and overfitting.
"""
scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=WARMUP_DUR, eta_min=LR_MIN, max_epochs=EPOCHS)

In [None]:
def train(train_dataloader):
  running_loss = 0
  accuracy = 0

  for batch in train_dataloader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    predicted = cnn(images)
    loss = loss_fn(predicted, labels)
    predicted = torch.argmax(predicted, dim=1)
    running_loss += loss
    accuracy += torch.sum(predicted == labels)

    loss.backward()
    scheduler.step()

  avg_loss = running_loss / len(train_dataloader)
  avg_acc = accuracy / len(train_dataloader)

  return avg_loss, avg_acc

In [None]:
def validate(val_dataloader):
  accuracy = 0

  for batch in val_dataloader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)

    predicted = torch.argmax(cnn(images), dim=1)
    accuracy += torch.sum(predicted == labels)

  avg_acc = accuracy / len(val_dataloader)

  return avg_acc

In [None]:
def test(test_dataloader):
  accuracy = 0

  for batch in test_dataloader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)

    predicted = torch.argmax(cnn(images), dim=1)
    accuracy += torch.sum(predicted == labels)

  avg_acc = accuracy / len(test_dataloader)

  return avg_acc

In [None]:
loop = tqdm(range(EPOCHS)) # progress bar

for epoch in loop:
  cnn.train()

  with torch.autocast(device_type="cuda", dtype=torch.float16):
    train_loss, train_acc = train(train_dataloader)

  cnn.eval()
  with torch.no_grad():
    val_acc = validate(val_dataloader)

  loop.set_description(f"Epoch {epoch + 1} \t Train Loss: {train_loss} \t Train acc: {100 * train_acc}% \t Val acc: {100 * val_acc}%")

In [None]:
cnn.eval()

with torch.no_grad():
  test_acc = test(test_dataloader)
  print(f"Final test accuracy: {100 * test_acc}%")