In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
!pip install pytorch-lightning --quiet

In [None]:
import torch
import torchvision
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
import torchvision.transforms as T
from torch import nn, optim
import torch.nn.functional as F

import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import warnings
import cv2
from PIL import Image
from pathlib import Path

import io
import tensorflow as tf
from tqdm.auto import tqdm
import logging


warnings.filterwarnings("ignore")
pd.set_option("display.max_colwidth", None)
logger = logging.getLogger("lightning")
pl.seed_everything(123)

In [None]:
TRAIN_PATH = Path("../input/tpu-getting-started/tfrecords-jpeg-512x512/train/")
VALID_PATH = Path("../input/tpu-getting-started/tfrecords-jpeg-512x512/val/")
TEST_PATH  = Path("../input/tpu-getting-started/tfrecords-jpeg-512x512/test/")

In [None]:
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',      
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', 
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower', 
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']       

In [None]:
train_feature = {
    "class": tf.io.FixedLenFeature([], tf.int64),
    "id"   : tf.io.FixedLenFeature([], tf.string),
    "image": tf.io.FixedLenFeature([], tf.string),
}

test_feature = {
    'id'   : tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}


def parse_train_image(example_proto):
    """parse a single train tfrecord example"""
    return tf.io.parse_single_example(example_proto, train_feature)

def parse_test_image(example_proto):
    """parse a single test tfrecord example"""
    return tf.io.parse_single_example(example_proto, test_feature)

In [None]:
train_files  = glob.glob(str(TRAIN_PATH/"*.tfrec"))
train_class  = []
train_images = []
train_ids    = []

print(f"Reading tfrecords from {TRAIN_PATH}")
for i in tqdm(train_files):
    train_image_dataset = tf.data.TFRecordDataset(i)
    
    train_image_dataset = train_image_dataset.map(parse_train_image)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in train_image_dataset]
    train_ids = train_ids + ids

    classes = [int(class_features['class'].numpy()) for class_features in train_image_dataset]
    train_class = train_class + classes

    images = [image_features['image'].numpy() for image_features in train_image_dataset]
    train_images = train_images + images

In [None]:
val_files  = glob.glob(str(VALID_PATH/"*.tfrec"))
val_ids    = []
val_class  = []
val_images = []

print(f"Reading tfrecords from {VALID_PATH}")
for i in tqdm(val_files):
    val_image_dataset = tf.data.TFRecordDataset(i)
    
    val_image_dataset = val_image_dataset.map(parse_train_image)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in val_image_dataset]
    val_ids = val_ids + ids

    classes = [int(class_features['class'].numpy()) for class_features in val_image_dataset]
    val_class = val_class + classes

    images = [image_features['image'].numpy() for image_features in val_image_dataset]
    val_images = val_images + images

In [None]:
test_files  = glob.glob(str(TEST_PATH/"*.tfrec"))
test_ids    = []
test_images = []

print(f"Reading tfrecords from {VALID_PATH}")
for i in tqdm(test_files):
    test_image_dataset = tf.data.TFRecordDataset(i)
    test_image_dataset = test_image_dataset.map(parse_test_image)
    ids = [str(id_features['id'].numpy())[2:-1] for id_features in test_image_dataset]
    test_ids = test_ids + ids
    images = [image_features['image'].numpy() for image_features in test_image_dataset]
    test_images = test_images + images

In [None]:
assert len(train_ids) == len(train_class) == len(train_images)
assert len(val_ids)   == len(val_class)   == len(val_images)
assert len(test_ids)  == len(test_images)

In [None]:
class FlowersDataset(Dataset):
    def __init__(self, ids, images, augments, classes = None, is_test = False):
        self.ids     = ids
        self.images  = images
        self.classes = classes
        self.augs    = augments
        self.is_test = is_test
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        idx  = self.ids[index]
        image = Image.open(io.BytesIO(self.images[index]))
        image = self.augs(image)
        
        if not self.is_test :
            clas = torch.tensor(int(self.classes[index]))
            return idx, image, clas 
        
        elif self.is_test:
            return idx, image

In [None]:
class FlowersDataModule(pl.LightningDataModule):
    def __init__(self, train_vars:tuple, test_vars:tuple, valid_vars:tuple, batch_size:int = 256, input_dims:int = 512):
        super().__init__()
        self.train_ids, self.train_images, self.train_classes  = train_vars
        self.valid_ids, self.valid_images, self.valid_classes  = valid_vars
        self.test_ids,  self.test_images = test_vars
        
        # Imagenet means and stds
        mean = [0.485, 0.456, 0.406]
        std  = [0.229, 0.224, 0.225]
        
        self.batch_size = batch_size
        
        self.train_augs = T.Compose([
            T.Resize(size=(input_dims, input_dims)),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        
          
        self.valid_augs = T.Compose([
            T.Resize(size=(input_dims, input_dims)),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        
        self.test_augs = T.Compose([
            T.Resize(size=(input_dims, input_dims)),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.flowers_train = FlowersDataset(self.train_ids, self.train_images, self.train_augs, self.train_classes)
            self.flowers_valid = FlowersDataset(self.valid_ids, self.valid_images, self.valid_augs, self.valid_classes)
        
        if stage == 'test' or stage is None:
            self.flowers_test  = FlowersDataset(self.test_ids,  self.test_images, augments=self.test_augs, is_test = True)
        
    def train_dataloader(self):
        return DataLoader(self.flowers_train, shuffle=True, batch_size=self.batch_size, pin_memory=True)
    
    def val_dataloader(self):
        return DataLoader(self.flowers_valid, shuffle=False, batch_size=self.batch_size, pin_memory=True)
    
    def test_dataloader(self):
        return DataLoader(self.flowers_test,  shuffle=False, batch_size=self.batch_size, pin_memory=True)

In [None]:
class FlowerClassifier(pl.LightningModule):
    def __init__(self, output_dims: int, learning_rate:float, weight_decay:float):
        super().__init__()
        self.save_hyperparameters()
        
        self.classifier  = torchvision.models.resnet34(pretrained=True, progress=True)
        base_output_dims = self.classifier.fc.out_features
        
        self.lin1   = nn.Sequential(nn.BatchNorm1d(base_output_dims),  nn.Dropout(0.2), nn.ReLU(inplace=True))
        self.lin2   = nn.Sequential(nn.Linear(base_output_dims, 1024), nn.BatchNorm1d(1024), nn.Dropout(0.5), nn.ReLU())
        self.lin3   = nn.Sequential(nn.Linear(1024, 512),  nn.BatchNorm1d(512),  nn.Dropout(0.5), nn.ReLU())
        self.output = nn.Sequential(nn.Linear(512, self.hparams.output_dims))
        
        self.accuracy = pl.metrics.Accuracy()

        self.results    = pd.DataFrame()
        self.test_idxs  = []
        self.test_preds = []
        
    def forward(self, x):
        out = self.classifier(x)
        out = self.lin3(self.lin2(self.lin1(out)))
        out = self.output(out)
        return out
    
    def training_step(self, batch, batch_idx, *args, **kwargs):
        _, image, clas = batch
        y_hat = self(image)
        loss  = F.cross_entropy(y_hat, clas)
        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
    def validation_step(self, batch, batch_idx, *args, **kwargs):
        _, image, clas = batch
        logits = self(image)
        loss   = F.cross_entropy(logits, clas)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        metric = self.accuracy(logits, clas)
        self.log("accuracy", metric, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
    def test_step(self, batch, batch_idx, *args, **kwargs):
        idx, image = batch
        logits     = self(image)
        # compute the output from the logits
        _, preds   = logits.max(dim=1)
        res        = list(preds.cpu().numpy())
        
        self.test_preds = self.test_preds + res
        self.test_idxs  = self.test_idxs  + list(idx)
        
    def test_epoch_end(self, *args, **kwargs):
        self.results["id"]    = self.test_idxs
        self.results["label"] = self.test_preds
        
    def configure_optimizers(self, *args, **kwargs):
        opt = optim.AdamW(self.parameters(), lr = self.hparams.learning_rate, weight_decay = self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min")
        
        pl_scheduler = {
            "scheduler": lr_scheduler, 
            "interval": "epoch", 
            "frequency": 1, 
            "reduce_on_plateau": True, 
            "monitor": "val_loss", 
            "strict": True
        }
        
        return [opt], [pl_scheduler] 

In [None]:
# Cross-Check DataModule
train_vars = (train_ids, train_images, train_class)
valid_vars = (val_ids, val_images, val_class)
test_vars  = (test_ids, test_images)

fake_dm = FlowersDataModule(train_vars, test_vars, valid_vars, batch_size=8, input_dims=224)

fake_dm.setup("fit")
trn_ds, val_ds = fake_dm.train_dataloader(), fake_dm.val_dataloader()
fake_dm.setup("test")
test_ds = fake_dm.test_dataloader()

trn_batch = next(iter(trn_ds))
_, trn_image, _ = trn_batch

val_batch = next(iter(val_ds))
_, val_image, _ = val_batch

test_batch = next(iter(test_ds))
_, test_image = test_batch

trn_grid  = make_grid(trn_image, normalize=True, nrow=4).permute(1, 2, 0).numpy()
val_grid  = make_grid(val_image, normalize=True, nrow=4).permute(1, 2, 0).numpy()
test_grid = make_grid(test_image, normalize=True, nrow=4).permute(1, 2, 0).numpy()


fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=(15, 15))

ax1.axis("off")
ax1.imshow(trn_grid);

ax2.axis("off")
ax2.imshow(val_grid);


ax3.axis("off")
ax3.imshow(test_grid);

In [None]:
logger = pl.loggers.CSVLogger(save_dir="/kaggle/working/", name="kaggle_tpu_flowers", version="001")

callbacks = [
    pl.callbacks.EarlyStopping(monitor="val_loss", patience=5),
    pl.callbacks.LearningRateMonitor("step"),
]

trainer = pl.Trainer(tpu_cores=8, precision=16, callbacks=callbacks,logger=logger,  gradient_clip_val=0.5,)

In [None]:
dataModule = FlowersDataModule(train_vars, test_vars, valid_vars, batch_size = 8, input_dims = 224)
model = FlowerClassifier(len(CLASSES), learning_rate = 0.002, weight_decay = 0.01)

In [None]:
trainer.fit(model, datamodule=dataModule)

In [None]:
trainer.test(model, datamodule=dataModule)

In [None]:
classifier.results.head()