## Import Package

In [None]:
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models

# dataset
import os
import math
import glob
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torch.utils.data import Dataset, Subset, DataLoader
import matplotlib.pyplot as plt

# save result
import pickle

In [None]:
torch.manual_seed(2022)
try:
    device = torch.device("mps")
except:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
device

## Prepare Dataset

### Get All Training Image

In [None]:
img_names = glob.glob("../input/dog-breed-identification/train/*.jpg")

In [None]:
len(img_names)

In [None]:
img_names[0:10]

### Create Custom Dataset

In [None]:
class DogDataset(Dataset):

    def __init__(self, img_path, csv_path):
        self.csv_path = csv_path
        self.transform = None

        self.img_names = glob.glob(f"{img_path}/*.jpg")

        if csv_path:
            label_df = pd.read_csv(csv_path)
            self.label_idx2name = label_df['breed'].unique()
            self.label_name2idx = {}
            for i in range(len(self.label_idx2name)):
                self.label_name2idx[self.label_idx2name[i]] = i
            self.img2label = {}
            for _, row in label_df.iterrows():
                self.img2label[f"{img_path}/{row['id']}.jpg"] = self.label_name2idx[row['breed']]
    
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, index):
        img = self.img_names[index]

        if self.csv_path:
            label = self.img2label[img]
            label = torch.tensor(label)
        else:
            label = -1
        
        img = Image.open(img).convert("RGB")
        img = self.transform(img)
        return (img, label)

In [None]:
# transform_fn for pretrained ViT
channel_mean = torch.Tensor([0.485, 0.456, 0.406])
channel_std = torch.Tensor([0.229, 0.224, 0.225])

vit_train_transform_fn = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),

    transforms.RandomHorizontalFlip(p=0.6),
    transforms.RandomRotation(degrees=(30)),

    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std),
])

vit_valid_transform_fn = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std),
])

In [None]:
dataset = DogDataset(
    img_path="../input/dog-breed-identification/train",
    csv_path="../input/dog-breed-identification/labels.csv",
)

### Split `dataset` into `train_dataset` and `valid_dataset`

In [None]:
indexes = list(range(len(dataset)))
train_indexes, valid_indexes = train_test_split(indexes, test_size=0.1)
train_dataset = Subset(dataset, train_indexes)
valid_dataset = Subset(dataset, valid_indexes)

print(f"number of samples in train_dataset: {len(train_dataset)}")
print(f"number of samples in valid_dataset: {len(valid_dataset)}")

In [None]:
dataset.transform = vit_train_transform_fn
train_dataset.transform = vit_train_transform_fn
valid_dataset.transform = vit_valid_transform_fn

### Create DataLoader

In [None]:
train_valid_dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=32,
    shuffle=True
)

### Show Some Samples in Batch

In [None]:
def show_samples(batch_img, batch_label=None, num_samples=16):

    sample_idx = 0
    total_col = 4
    total_row = math.ceil(num_samples / 4)
    col_idx = 0
    row_idx = 0

    fig, axs = plt.subplots(total_row, total_col, figsize=(15, 15))

    while sample_idx < num_samples:
        img = batch_img[sample_idx]
        img = img.view(3, -1) * channel_std.view(3, -1) + channel_mean.view(3, -1)
        img = img.view(3, 224, 224)
        img = img.permute(1, 2, 0)
        axs[row_idx, col_idx].imshow(img)

        if batch_label != None:
            axs[row_idx, col_idx].set_title(dataset.label_idx2name[(batch_label[sample_idx])])

        sample_idx += 1
        col_idx += 1
        if col_idx == 4:
            col_idx = 0
            row_idx += 1

In [None]:
batch_img, batch_label = next(iter(train_dataloader))

In [None]:
show_samples(batch_img, batch_label, 8)

## Build Model

In [None]:
class PretrainViT(nn.Module):

    def __init__(self):
        super(PretrainViT, self).__init__()
        model = models.vit_b_16(pretrained=True)
        num_classifier_feature = model.heads.head.in_features
        model.heads.head = nn.Sequential(
            nn.Linear(num_classifier_feature, 120)
        )
        self.model = model

        for param in self.model.named_parameters():
            if "heads" not in param[0]:
                param[1].requires_grad = False

    def forward(self, x):
        return self.model(x)

In [None]:
net = PretrainViT()
net.to(device)
print(f"number of paramaters: {sum([param.numel() for param in net.parameters() if param.requires_grad])}")

## Train Model

### Loss Function and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.003, momentum=0.9)

### Training Loop 

In [None]:
def get_accuracy(output, label):
    output = output.to("cpu")
    label = label.to("cpu")

    sm = F.softmax(output, dim=1)
    _, index = torch.max(sm, dim=1)
    return torch.sum((label == index)) / label.size()[0]

In [None]:
def train(model, dataloader):
    model.train()
    running_loss = 0.0
    total_loss = 0.0
    running_acc = 0.0
    total_acc = 0.0

    for batch_idx, (batch_img, batch_label) in enumerate(dataloader):

        batch_img = batch_img.to(device)
        batch_label = batch_label.to(device)

        optimizer.zero_grad()
        output = net(batch_img)
        loss = criterion(output, batch_label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total_loss += loss.item()

        acc = get_accuracy(output, batch_label)
        running_acc += acc
        total_acc += acc

        if batch_idx % 100 == 0 and batch_idx != 0:
            print(f"[step: {batch_idx:4d}/{len(dataloader)}] loss: {running_loss / 100:.3f}")
            running_loss = 0.0
            running_acc = 0.0
    
    return total_loss / len(dataloader), total_acc / len(dataloader)

In [None]:
def validate(model, dataloader):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0

    for batch_idx, (batch_img, batch_label) in enumerate(dataloader):

        batch_img = batch_img.to(device)
        batch_label = batch_label.to(device)

        # optimizer.zero_grad()
        output = net(batch_img)
        loss = criterion(output, batch_label)
        # loss.backward()
        # optimizer.step()

        total_loss += loss.item()
        acc = get_accuracy(output, batch_label)
        total_acc += acc
    
    return total_loss / len(dataloader), total_acc / len(dataloader)

In [None]:
EPOCHS = 3
train_loss_history = []
valid_loss_history = []

train_acc_history = []
valid_acc_history = []

for epoch in range(EPOCHS):
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}")

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "net.pt")

## Predict on Test Dataset and Submit to Kaggle

In [None]:
net = PretrainViT()
net.load_state_dict(torch.load("./net.pt", map_location="cpu"))
net.to(device)
net.eval()

In [None]:
submit_df = pd.read_csv("../input/dog-breed-identification/sample_submission.csv")
test_names = submit_df["id"].values
columns = list(dataset.label_idx2name)

In [None]:
class TestDataset(Dataset):

    def __init__(self, test_names, transform_fn):
        self.test_names = test_names
        self.transform = transform_fn
    
    def __len__(self):
        return len(self.test_names)

    def __getitem__(self, idx):
        name = self.test_names[idx]
        path = os.path.join("../input/dog-breed-identification/test", name + ".jpg")
        img = Image.open(path)
        img = self.transform(img)
        return (img, name)

In [None]:
dataset = TestDataset(
    test_names = test_names,
    transform_fn = vit_valid_transform_fn
)

In [None]:
test_dataloader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=False
)

In [None]:
with torch.no_grad():

    dfs = []

    for batch_idx, (batch_img, batch_name) in enumerate(test_dataloader):
        df = pd.DataFrame(columns=["id"] + columns)
        df["id"] = batch_name

        batch_img = batch_img.to(device)
        output = net(batch_img)
        sm = F.softmax(output, dim=1)
        df[columns] = sm.cpu().numpy()
        dfs.append(df)

        print(f"step: {batch_idx}/{len(test_dataloader)}")

In [None]:
my_submit = pd.concat(dfs)

In [None]:
my_submit

In [None]:
my_submit.to_csv("submit.csv", index=False)