In [None]:
!pip install git+https://github.com/FlorianMuellerklein/PyTorchTrainer
!pip install cloud-tpu-client==0.10 torch==1.11.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl

# **PyTroch Trainer**

This notebook finetunes the official PyTorch implementation of ConvNext by using the [PyTorchTrainer](https://github.com/FlorianMuellerklein/PyTorchTrainer) package to keep the code simple and readable. 


In [None]:
import os
#assert os.environ['COLAB_TPU_ADDR']#, 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:

import glob

import numpy as np
import pandas as pd

import tensorflow as tf

from typing import List, Optional, Iterable

from PIL import Image

import io
import IPython.display as display

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torch.optim as optim

# imports the torch_xla package
#import torch_xla
#import torch_xla.core.xla_model as xm

import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision.transforms as transforms
from torchvision.utils import make_grid

import matplotlib.pyplot as plt


import warnings
warnings.filterwarnings("ignore");

from pytorchtrainer.trainers import SingleOutputTrainer

# **1. Data Loading**

This section contains code that is used in most PyTorch notebooks for this challenge as the data comes in TFRecord format. Why reinvent the wheel herer?

**Paths**

In [None]:
train_files = glob.glob('../input/tpu-getting-started/*/train/*.tfrec')
val_files = glob.glob('../input/tpu-getting-started/*/val/*.tfrec')
test_files = glob.glob('../input/tpu-getting-started/*/test/*.tfrec')

**Here we read tfrecords files in PyTorch. I recommend** https://medium.com/analytics-vidhya/how-to-read-tfrecords-files-in-pytorch-72763786743f

In [None]:
train_feature_description = {
    'class': tf.io.FixedLenFeature([], tf.int64),
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, train_feature_description)

train_ids = []
train_class = []
train_images = []

for i in train_files:
    train_image_dataset = tf.data.TFRecordDataset(i)

    train_image_dataset = train_image_dataset.map(_parse_image_function)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in train_image_dataset] # [2:-1] is done to remove b' from 1st and 'from last in train id names
    train_ids = train_ids + ids

    classes = [int(class_features['class'].numpy()) for class_features in train_image_dataset]
    train_class = train_class + classes

    images = [image_features['image'].numpy() for image_features in train_image_dataset]
    train_images = train_images + images

In [None]:
val_feature_description = {
    'class': tf.io.FixedLenFeature([], tf.int64),
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, val_feature_description)

val_ids = []
val_class = []
val_images = []

for i in val_files:
    val_image_dataset = tf.data.TFRecordDataset(i)

    val_image_dataset = val_image_dataset.map(_parse_image_function)

    ids = [str(image_features['id'].numpy())[2:-1] for image_features in val_image_dataset]
    val_ids += ids

    classes = [int(image_features['class'].numpy()) for image_features in val_image_dataset]
    val_class += classes 

    images = [image_features['image'].numpy() for image_features in val_image_dataset]
    val_images += images

In [None]:
test_feature_description = {
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function_test(example_proto):
    return tf.io.parse_single_example(example_proto, test_feature_description)

test_ids = []
test_images = []
for i in test_files:
    test_image_dataset = tf.data.TFRecordDataset(i)
    
    test_image_dataset = test_image_dataset.map(_parse_image_function_test)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in test_image_dataset]
    test_ids = test_ids + ids

    images = [image_features['image'].numpy() for image_features in test_image_dataset]
    test_images = test_images + images

In [None]:
display.display(display.Image(data=val_images[1]))

In [None]:
len(train_images), len(val_images), np.unique(train_class), len(np.unique(train_class)), np.unique(val_class), len(np.unique(val_class))

# **2. Data preparation**

**Using the PyTorchTrainer package requires us to use most PyTorch functionality as-is. For this task we'll use the PyTorch Dataset class and the standard dataloaders.**

In [None]:
class FlowerDataset(Dataset):
    
    def __init__(
        self,
        imgs: Iterable = None,
        targets: Iterable = None,
        valid: bool = False,
        tforms: dict = None,
    ):
        self.imgs = imgs
        self.targets = targets
        self.mode = 'valid' if valid else 'train'
        self.tforms = tforms

    def __getitem__(self, idx: int) -> dict:
        # load an augment the image
        img = Image.open(io.BytesIO(self.imgs[idx]))
        targ = self.targets[idx]
        img = tforms[self.mode](img)

        return img, torch.tensor(targ)

    def __len__(self):
        return len(self.imgs)

**Set up the transforms for training and validation/testing.**

In [None]:
# transforms for training and validation
tforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(size=224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]),
    'valid': transforms.Compose([
        transforms.CenterCrop(224),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
}

In [None]:
train_dataset = FlowerDataset(
    imgs = train_images,
    targets = train_class,
    valid = False,
    tforms = tforms
)

valid_dataset = FlowerDataset(
    imgs = val_images, 
    targets = val_class,
    valid = True,
    tforms = tforms
)

**Use either GPU or TPU**

In [None]:
#device = xm.xla_device()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_loader = DataLoader(
    train_dataset, 
    64, 
    num_workers=4, 
    pin_memory=True
)

valid_loader = DataLoader(
    valid_dataset, 
    64, 
    num_workers=4, 
    pin_memory=True
)

# **ConvNext**

**Initialize the official PyTorch ConvNext and replace the final layers**

In [None]:
# set up network
net = torchvision.models.convnext_base(weights='ConvNeXt_Base_Weights.DEFAULT')
net.classifier = nn.Sequential(
    nn.LayerNorm(1024),
    nn.Flatten(start_dim=1, end_dim=-1),
    nn.Linear(in_features=1024, out_features=len(np.unique(train_class)), bias=True)

)
net = net.to(device)

In [None]:
num_epochs = 50

# set up training loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    net.parameters(),
    lr = 0.0001,
)

# drop learning rate by factor of 10 after 80% of epochs and again after 90%
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones = [int(num_epochs * 0.8), int(num_epochs * 0.9)],
    gamma = 0.1,
    verbose = True
)

# **Train with PyTorchTrainer**

In [None]:
# make a custom accuracy metric
def accuracy(targets, preds):
    _, pred_class = preds.max(-1)
    total_correct = (pred_class == targets).sum()
    total = targets.size(0)
    return total_correct / total

# set up our trainer
trainer = SingleOutputTrainer(
    train_loader = train_loader,
    valid_loader = valid_loader,
    net = net,
    crit = criterion,
    device = device,
    optimizer = optimizer,
    epochs = num_epochs,
    scheduler = scheduler,
    metrics = [accuracy],
    checkpoint_every = 1,
    model_name = 'flower_convnext'
)

# train the network
trainer.train_network()

# **4. Submit Preparing**

In [None]:
class TestDataset(Dataset):
    
    def __init__(
        self,
        imgs: Iterable = None,
        img_ids: Iterable = None,
        valid: bool = False,
        tforms: dict = None,
    ):
        self.imgs = imgs
        self.img_ids = img_ids
        self.mode = 'valid' if valid else 'train'
        self.tforms = tforms

    def __getitem__(self, idx: int) -> dict:
        # load an augment the image
        img = Image.open(io.BytesIO(self.imgs[idx]))
        img = tforms[self.mode](img)
        
        img_id = self.img_ids[idx]

        return img, img_id

    def __len__(self):
        return len(self.imgs)

test_dataset = TestDataset(
    imgs = test_images, 
    img_ids = test_ids,
    tforms = tforms,
    valid = True
)

testloader = DataLoader(
    test_dataset, 
    128, 
    num_workers=4, p
    in_memory=True, 
    shuffle=False
)


ensemble_df = submit_df.copy()

In [None]:
predictions = []
prediction_ids = []

trainer.net.eval()
for i, (inputs, img_id) in enumerate(testloader):
    inputs = inputs.to(self.device)
    
    preds = trainer.net(inputs)
    
    _, pred_class = preds.max(-1)
    
    predictions.extend(pred_class)
    prediction_ids.extend(img_id)

In [None]:
# Final prediction
submit_df = pd.DataFrame({'label': predictions, 'id': prediction_ids})
submit_df.head(10)

In [None]:
# Create a submission file
submit_df.to_csv('submission11062021.csv', index=False)

# **5. Learning Visualization**