In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [None]:
import torchvision.transforms as tt
import torchvision.models as models

In [None]:
device = torch.device('cuda')

In [None]:
device

In [None]:
batch_size = 64
stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tsfm = tt.Compose([
    tt.Resize(224),
    tt.CenterCrop(224),
    tt.ToTensor(),
    tt.Normalize(*stats)])

In [None]:
import glob
train_files = glob.glob("/kaggle/input/tpu-getting-started/tfrecords-jpeg-224x224/train/*.tfrec")
val_files = glob.glob("/kaggle/input/tpu-getting-started/tfrecords-jpeg-224x224/val/*.tfrec")
test_files = glob.glob("/kaggle/input/tpu-getting-started/tfrecords-jpeg-224x224/test/*.tfrec")

In [None]:
train_feature_description = {
    'class': tf.io.FixedLenFeature([], tf.int64),
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
    return tf.io.parse_single_example(example_proto, train_feature_description)

train_ids = []
train_class = []
train_images = []
for i in train_files:
    train_image_dataset = tf.data.TFRecordDataset(i)
train_image_dataset = train_image_dataset.map(_parse_image_function)
ids = [str(id_features['id'].numpy())[2:-1] for id_features in train_image_dataset]
train_ids = train_ids + ids
classes = [int(class_features['class'].numpy()) for class_features in train_image_dataset]
train_class = train_class + classes
images = [image_features['image'].numpy() for image_features in train_image_dataset]
train_images = train_images + images

In [None]:
train_array_images = []
train_array_class = []
for x in train_images:
    x = np.array(Image.open(io.BytesIO(x))).reshape(3,224,224).astype('float32')
    train_array_images.append(x)
train_array_images = torch.Tensor(train_array_images)
train_array_class = torch.tensor(train_class)

In [None]:
val_feature_description = {
    'class': tf.io.FixedLenFeature([], tf.int64),
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

val_ids = []
val_class = []
val_images = []
for i in val_files:
    val_image_dataset = tf.data.TFRecordDataset(i)
val_image_dataset = val_image_dataset.map(_parse_image_function)
ids = [str(id_features['id'].numpy())[2:-1] for id_features in val_image_dataset]
val_ids = val_ids + ids
classes = [int(class_features['class'].numpy()) for class_features in val_image_dataset]
val_class = val_class + classes
images = [image_features['image'].numpy() for image_features in val_image_dataset]
val_images = val_images + images

In [None]:
val_array_images = []
val_array_class = []
for x in val_images:
    x = np.array(Image.open(io.BytesIO(x))).reshape(3,224,224).astype('float32')
    val_array_images.append(x)
val_array_images = torch.Tensor(val_array_images)
val_array_class = torch.tensor(val_class)

In [None]:
train_set = TensorDataset(train_array_images, train_array_class)
val_set = TensorDataset(val_array_images, val_array_class)

In [None]:
class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [None]:
train_loader = DeviceDataLoader(DataLoader(train_set, batch_size, shuffle=True), device)
val_loader = DeviceDataLoader(DataLoader(val_set, batch_size), device)

In [None]:
epochs = 15
max_lr = 0.0001
opt_func = torch.optim.Adam

In [None]:
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))


@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))
    
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [None]:
class ResNetModel(ImageClassificationBase):
  def __init__(self, num_classes):
        super().__init__()
        self.network = models.resnet101()
        self.network.fc = nn.Linear(self.network.fc.in_features, num_classes)
        
  def forward(self, xb):
        return self.network(xb)


model = to_device(ResNetModel(104), device)

In [None]:
evaluate(model, val_loader)

In [None]:
history = []

In [None]:
history += fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, opt_func=opt_func)

In [None]:
print('done')