In [2]:
import sys
sys.version

'3.10.10 (main, Aug 27 2024, 07:11:23) [Clang 15.0.0 (clang-1500.3.9.4)]'

In [3]:
"""Model Inference."""
import torch
import numpy as np
from PIL import Image

from models.tiny_vit import tiny_vit_21m_224
from data import build_transform, imagenet_classnames
from config import get_config

config = get_config()


# Build model
model = tiny_vit_5m_224(pretrained=True)
model.eval()

# Load Image
fname = "../../dataset/images/10001.jpg"
image = Image.open(fname)
transform = build_transform(is_train=False, config=config)

# (1, 3, img_size, img_size)
batch = transform(image)[None]

with torch.no_grad():
    logits = model(batch)

# print top-5 classification names
probs = torch.softmax(logits, -1)
scores, inds = probs.topk(5, largest=True, sorted=True)
print('=' * 30)
print(fname)
for score, ind in zip(scores[0].numpy(), inds[0].numpy()):
    print(f'{imagenet_classnames[ind]}: {score:.2f}')

../../dataset/images/10001.jpg
miniskirt: 0.60
jeans: 0.23
lampshade: 0.01
swim trunks / shorts: 0.00
knee pad: 0.00


In [18]:
"""Model Inference."""
import torch
import numpy as np
from PIL import Image

from models.tiny_vit import tiny_vit_5m_224
from data import build_transform, imagenet_classnames
from config import get_config

config = get_config()


# Build model
model = tiny_vit_5m_224(pretrained=True)
model.eval()

# Load Image
fname = "../../dataset/images/10001.jpg"
image = Image.open(fname)
transform = build_transform(is_train=False, config=config)

# (1, 3, img_size, img_size)
batch = transform(image)[None]

with torch.no_grad():
    logits = model(batch)

# print top-5 classification names
probs = torch.softmax(logits, -1)
scores, inds = probs.topk(5, largest=True, sorted=True)
print('=' * 30)
print(fname)
for score, ind in zip(scores[0].numpy(), inds[0].numpy()):
    print(f'{imagenet_classnames[ind]}: {score:.2f}')

../../dataset/images/10001.jpg
jeans: 0.28
miniskirt: 0.12
punching bag: 0.09
baby pacifier: 0.01
messenger bag: 0.01


In [19]:
print(model)

TinyViT(
  (patch_embed): PatchEmbed(
    (seq): Sequential(
      (0): Conv2d_BN(
        (c): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): GELU(approximate='none')
      (2): Conv2d_BN(
        (c): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layers): ModuleList(
    (0): ConvLayer(
      (blocks): ModuleList(
        (0-1): 2 x MBConv(
          (conv1): Conv2d_BN(
            (c): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (act1): GELU(approximate='none')
          (conv2): Conv2d_BN(
            (c): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), pa

In [24]:
model.head = torch.nn.Linear(model.head.in_features, 4)

In [25]:
print(model)

TinyViT(
  (patch_embed): PatchEmbed(
    (seq): Sequential(
      (0): Conv2d_BN(
        (c): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): GELU(approximate='none')
      (2): Conv2d_BN(
        (c): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layers): ModuleList(
    (0): ConvLayer(
      (blocks): ModuleList(
        (0-1): 2 x MBConv(
          (conv1): Conv2d_BN(
            (c): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (act1): GELU(approximate='none')
          (conv2): Conv2d_BN(
            (c): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), pa

In [None]:
from torchview import draw_graph


batch_size = 2
# device='meta' -> no memory is consumed for visualization
model_graph = draw_graph(model, input_size=(batch_size,3, 224, 224), device='meta')
model_graph.visual_graph

# Fine Tuning

## Gender prediction

### Reading data

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
from torch.utils.data import DataLoader

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0])+".jpg")
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 2]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label