In [1]:
"bye"

'bye'

In [2]:
import torch
from torch import nn
import torchvision 
from torchvision import datasets

In [63]:
def create_effnetv2M_model(num_classes:int=3, 
                          seed:int=42):
    """Creates an EfficientNetB2 feature extractor model and transforms.

    Args:
        num_classes (int, optional): number of classes in the classifier head. 
            Defaults to 3.
        seed (int, optional): random seed value. Defaults to 42.

    Returns:
        model (torch.nn.Module): EffNetB2 feature extractor model. 
        transforms (torchvision.transforms): EffNetB2 image transforms.
    """
    # Create EffNetB2 pretrained weights, transforms and model
    weights = torchvision.models.EfficientNet_V2_M_Weights.DEFAULT
    transforms = weights.transforms()
    model = torchvision.models.efficientnet_v2_m(weights=weights)

    # Freeze all layers in base model
    for param in model.parameters():
        param.requires_grad = False

    # Change classifier head with random seed for reproducibility
    torch.manual_seed(seed)
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.3, inplace=True),
        nn.Linear(in_features=1280, out_features=num_classes),
    )
    
    return model, transforms

In [58]:
len(class_names)

101

In [64]:
effnetv2_model, effnetv1_transforms = create_effnetv2M_model(num_classes=101)
effnetv2_model, effnetv1_transforms

(EfficientNet(
   (features): Sequential(
     (0): Conv2dNormActivation(
       (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
       (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
       (2): SiLU(inplace=True)
     )
     (1): Sequential(
       (0): FusedMBConv(
         (block): Sequential(
           (0): Conv2dNormActivation(
             (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
             (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
             (2): SiLU(inplace=True)
           )
         )
         (stochastic_depth): StochasticDepth(p=0.0, mode=row)
       )
       (1): FusedMBConv(
         (block): Sequential(
           (0): Conv2dNormActivation(
             (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
             (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track

In [65]:
# Get a summary of EffNetB2 feature extractor for Food101 with 101 output classes (uncomment for full output)
from torchinfo import summary
summary(effnetv2_model, 
        input_size=(1, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

  action_fn=lambda data: sys.getsizeof(data.storage()),


Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
EfficientNet (EfficientNet)                                  [1, 3, 224, 224]     [1, 101]             --                   Partial
├─Sequential (features)                                      [1, 3, 224, 224]     [1, 1280, 7, 7]      --                   False
│    └─Conv2dNormActivation (0)                              [1, 3, 224, 224]     [1, 24, 112, 112]    --                   False
│    │    └─Conv2d (0)                                       [1, 3, 224, 224]     [1, 24, 112, 112]    (648)                False
│    │    └─BatchNorm2d (1)                                  [1, 24, 112, 112]    [1, 24, 112, 112]    (48)                 False
│    │    └─SiLU (2)                                         [1, 24, 112, 112]    [1, 24, 112, 112]    --                   --
│    └─Sequential (1)                                        [1, 24, 112, 112]    [1, 2

In [66]:
effnetv1_transforms

ImageClassification(
    crop_size=[480]
    resize_size=[480]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [67]:
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.TrivialAugmentWide(),
    effnetv1_transforms
])
train_transforms

Compose(
    TrivialAugmentWide(num_magnitude_bins=31, interpolation=InterpolationMode.NEAREST, fill=None)
    ImageClassification(
    crop_size=[480]
    resize_size=[480]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)
)

In [12]:
from pathlib import Path
data_dir = Path("data")

train_data = datasets.Food101(root=data_dir,
                              split='train',
                              transform=train_transforms,
                              download=True)
test_data = datasets.Food101(root=data_dir,
                             split='test',
                             transform=effnetv1_transforms,
                             download=True)

Downloading https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz to data/food-101.tar.gz


100%|██████████| 4996278331/4996278331 [56:15<00:00, 1479976.78it/s]  


Extracting data/food-101.tar.gz to data


In [68]:
len(train_data), len(test_data)

(75750, 25250)

In [69]:
class_names = train_data.classes
class_names

['apple_pie',
 'baby_back_ribs',
 'baklava',
 'beef_carpaccio',
 'beef_tartare',
 'beet_salad',
 'beignets',
 'bibimbap',
 'bread_pudding',
 'breakfast_burrito',
 'bruschetta',
 'caesar_salad',
 'cannoli',
 'caprese_salad',
 'carrot_cake',
 'ceviche',
 'cheese_plate',
 'cheesecake',
 'chicken_curry',
 'chicken_quesadilla',
 'chicken_wings',
 'chocolate_cake',
 'chocolate_mousse',
 'churros',
 'clam_chowder',
 'club_sandwich',
 'crab_cakes',
 'creme_brulee',
 'croque_madame',
 'cup_cakes',
 'deviled_eggs',
 'donuts',
 'dumplings',
 'edamame',
 'eggs_benedict',
 'escargots',
 'falafel',
 'filet_mignon',
 'fish_and_chips',
 'foie_gras',
 'french_fries',
 'french_onion_soup',
 'french_toast',
 'fried_calamari',
 'fried_rice',
 'frozen_yogurt',
 'garlic_bread',
 'gnocchi',
 'greek_salad',
 'grilled_cheese_sandwich',
 'grilled_salmon',
 'guacamole',
 'gyoza',
 'hamburger',
 'hot_and_sour_soup',
 'hot_dog',
 'huevos_rancheros',
 'hummus',
 'ice_cream',
 'lasagna',
 'lobster_bisque',
 'lobster

In [70]:
from pytorch_modules.modules import data_preprocess
train_data_20perc,_ = data_preprocess.split_dataset(dataset=train_data,
                                                  split_size=0.2)
test_data_20perc,_ = data_preprocess.split_dataset(dataset=test_data,
                                                 split_size=0.2)
len(train_data_20perc), len(test_data_20perc)


[INFO] Splitting dataset of length 75750 into splits of size: 15150 (20%), 60600 (80%)
[INFO] Splitting dataset of length 25250 into splits of size: 5050 (20%), 20200 (80%)


(15150, 5050)

In [71]:
train_data_20perc[1238][1]

91

In [72]:
import os 
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

train_dataloader = torch.utils.data.DataLoader(train_data_20perc,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True,
                                               num_workers=NUM_WORKERS)
test_dataloader = torch.utils.data.DataLoader(test_data_20perc,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True,
                                              num_workers=NUM_WORKERS)
len(train_dataloader), len(test_dataloader)

(474, 158)

In [73]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [74]:
from pytorch_modules.modules import train_engine

optimizer = torch.optim.Adam(params=effnetv2_model.parameters(),
                             lr=1e-3)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
effnetv2_results = train_engine.train(model=effnetv2_model,
                                      train_dataloader=train_dataloader,
                                      test_dataloader=test_dataloader,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=3,
                                      device=device)

  0%|          | 0/3 [00:00<?, ?it/s]

KeyboardInterrupt: 