In [None]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Read tfrec

In [None]:
import glob

train_files = glob.glob('../input/tpu-getting-started/tfrecords-jpeg-331x331/train/*.tfrec')
val_files = glob.glob('../input/tpu-getting-started/tfrecords-jpeg-331x331/val/*.tfrec')
test_files = glob.glob('../input/tpu-getting-started/tfrecords-jpeg-331x331/test/*.tfrec')

In [None]:
import tensorflow as tf

tf.random.set_seed(3407)

tf.config.set_visible_devices([], 'GPU')

In [None]:
train_feature_description = {
    'class': tf.io.FixedLenFeature([], tf.int64),
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
    return tf.io.parse_single_example(example_proto, train_feature_description)

train_ids, train_class, train_images = [], [], []

for i in train_files:
    train_image_dataset = tf.data.TFRecordDataset(i)

    train_image_dataset = train_image_dataset.map(_parse_image_function)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in train_image_dataset]
    train_ids = train_ids + ids

    classes = [int(class_features['class'].numpy()) for class_features in train_image_dataset]
    train_class = train_class + classes

    images = [image_features['image'].numpy() for image_features in train_image_dataset]
    train_images = train_images + images
    
val_ids, val_class, val_images = [], [], []

for i in val_files:
    val_image_dataset = tf.data.TFRecordDataset(i)

    val_image_dataset = val_image_dataset.map(_parse_image_function)

    ids = [str(image_features['id'].numpy())[2:-1] for image_features in val_image_dataset]
    val_ids += ids

    classes = [int(image_features['class'].numpy()) for image_features in val_image_dataset]
    val_class += classes 

    images = [image_features['image'].numpy() for image_features in val_image_dataset]
    val_images += images
    
test_feature_description = {
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function_test(example_proto):
    return tf.io.parse_single_example(example_proto, test_feature_description)

test_ids, test_images = [], []
for i in test_files:
    test_image_dataset = tf.data.TFRecordDataset(i)
    
    test_image_dataset = test_image_dataset.map(_parse_image_function_test)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in test_image_dataset]
    test_ids = test_ids + ids

    images = [image_features['image'].numpy() for image_features in test_image_dataset]
    test_images = test_images + images

In [None]:
import IPython.display as display

display.display(display.Image(data=val_images[1]))

# Augmentation

In [None]:
!pip install -qU albumentations

In [None]:
import albumentations as A

crop_h, crop_w = 224, 224

transform = A.Compose([
    A.RandomCrop(crop_h, crop_w),
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.Rotate(10),
    A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.25),
    A.CoarseDropout(max_holes=10, max_height=20, max_width=20, p=0.2),
])

val_transform = A.Compose([
    A.CenterCrop(crop_h, crop_w),
])

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
import numpy as np

nrows, ncol = 3, 3
_, axs = plt.subplots(nrows, ncol, figsize=(14, 14))

img = np.array(Image.open(BytesIO(train_images[-1])))

for i in range(nrows):
    for j in range(ncol):
        transformed = transform(image=img)["image"]
        axs[i][j].imshow(transformed)
        axs[i][j].axis("off")

# Dataset

In [None]:
import torch

torch.manual_seed(3407)
torch.cuda.manual_seed_all(3407)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms 

class FlowerDS(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images, self.labels = images, labels
        self.transform = transform
        self.to_tensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        img = np.array(Image.open(BytesIO(img)))
        img = self.transform(image=img)["image"]
        img = self.to_tensor(img)
        return img, self.labels[idx]
    
train_ds = FlowerDS(train_images, train_class, transform)
val_ds = FlowerDS(val_images, val_class, val_transform)
test_ds = FlowerDS(test_images, test_ids, val_transform)

batch_size = 8
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

# Model

In [None]:
!pip install -q vistrans

In [None]:
from vistrans import VisionTransformer

VisionTransformer.list_pretrained()

In [None]:
import torch.nn as nn

class ViT(nn.Module):
    def __init__(self, num_classes=104):
        super().__init__()
        
        self.model = VisionTransformer.create_pretrained("vit_l16_224", num_classes=num_classes)

        for param in self.model.parameters():
            param.require_grad = True

    def forward(self, x):
        return self.model(x)
    
model = ViT()
model.to(device)

# Training

In [None]:
!pip -q install madgrad

In [None]:
import torch.optim as optim
from madgrad import MADGRAD

epochs = 20

optimizer = MADGRAD(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-5)

criterion = nn.CrossEntropyLoss()

In [None]:
import copy
import gc

best_val_loss, best_state_dict = float("inf"), None
train_losses, val_losses = [], []
for e in range(epochs):
    print(f"Epoch {e + 1}")
    
    model.train()
    running_loss, running_acc = 0, 0
    for x, y in train_dl:
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        out = model(x)
        
        loss = criterion(out, y)
        running_loss += loss.item()
    
        _, pred = out.max(dim=1)
        running_acc += torch.sum(pred == y.data) / len(pred)
    
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
    scheduler.step()
    
    print(f"Train loss {running_loss / len(train_dl)} accuracy {running_acc / len(train_dl)}")
    train_losses.append(running_loss / len(train_dl))
    
    model.eval()
    running_loss, running_acc = 0, 0
    for x, y in val_dl:
        x, y = x.to(device), y.to(device)
                
        with torch.no_grad():
            out = model(x)
            
            loss = criterion(out, y)
            running_loss += loss.item()

            _, pred = out.max(dim=1)
            running_acc += torch.sum(pred == y.data) / len(pred)

    print(f"Validation loss {running_loss / len(val_dl)} accuracy {running_acc / len(val_dl)}")
    val_losses.append(running_loss / len(val_dl))
    
    if running_loss < best_val_loss:
        best_val_loss = running_loss
        best_state_dict = copy.deepcopy(model.state_dict())
        
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
n = np.arange(0, epochs)

plt.figure(figsize=(8, 8))
plt.plot(n, train_losses, label="Train loss")
plt.plot(n, val_losses, label="Val loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend();

In [None]:
torch.save(best_state_dict, "best_model.pt")
model.load_state_dict(torch.load("best_model.pt"))

# Inference

In [None]:
test_id, test_label = np.array([]), np.array([])
model.eval()
for x, y in test_dl:
    x = x.to(device) 
    
    with torch.no_grad():
        pred = model(x)

    test_label = np.append(test_label, pred.argmax(dim=1).cpu().detach().numpy()) 
    test_id = np.append(test_id, y)
test_label = list(map(int, test_label))
test_label[:10]

In [None]:
import pandas as pd

pd.DataFrame({"id": test_id, "label": test_label}).to_csv("submission.csv", index=False)

In [None]:
from IPython.display import FileLink

FileLink("submission.csv")

In [None]:
!ls

In [None]:
pd.read_csv("submission.csv")