# PyTorch TPU + TF Records

In [None]:
PROJECT_ID = 'hybrid-vertex'  # <--- TODO: CHANGE THIS
LOCATION = 'us-central1' 
!gcloud config set project {PROJECT_ID}

In [None]:
import sys
if 'google.colab' in sys.modules:
  from google.colab import auth
  auth.authenticate_user()

In [None]:
if 'google.colab' in sys.modules:
  USER_FLAG = ''
else:
  USER_FLAG = '--user'

### pip install

In [None]:
!pip install cloud-tpu-client==0.10 torch==1.11.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl

! pip install tensorflow

! pip -q install google-cloud-storage==1.44.0

import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

### Import packages

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds 

# import webdataset as wds

import torchvision
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
import torch_xla
import os

import sys
from itertools import cycle, islice, chain, count
import random 
import time

import numpy as np

import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 

### TF vs Torch Tensors

* N - batch size
* H - height of image
* W - width of image
* C - numberof channels (usually 3 for RGB)

Tensorflow 
* `shape=(N, H, W, C)`

PyTorch
* `torch.Size([N, C, H, W])`

[source](https://towardsdatascience.com/convert-images-to-tensors-in-pytorch-and-tensorflow-f0ab01383a03)

# Read TF Records

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
batch_size = 128
test_batch_size = 64

## ImageNet

**TODO:** convert `tf tensor` to `torch tensor`

In [None]:
def read_tfrecord(data):
    features = {
        # tf.string = byte string (not text string)
        'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
    }
    
    # decode the TFRecord
    tf_record = tf.io.parse_single_example(data, features)
    
    # Typical code for decoding compressed images
    image = tf.io.decode_jpeg(tf_record['image/encoded'], channels=3)
    # print(image)
    image = tf.image.resize(image, [128, 128])
    # image = tf.expand_dims(image, axis=0) # add another dimension at the front to get batch_size in front

    # torch tensor
    # image = image.numpy()
    # image = torch.from_numpy(image)
    
    label = tf_record['image/class/label']
    
    # print(image)
    
    return image, label

In [None]:
filenames = [
             'gs://imagenet-jt/train/train-00000-of-01024',
]

dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
dataset = dataset.with_options(ignore_order)

# decoding a tf.data.TFRecordDataset
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)

for image, label in dataset.take(1):
  # image_torch = torch.from_numpy(image.numpy())
  # print(image_torch)
  print("Image shape {}, label={}".format(image.numpy().shape, label))

### Torch Iterable Dataset

Create a custom iterable dataset for PyTorch `DataLoader`

In [None]:
class MyIterable_tf_Dataset(IterableDataset):
    
    def __init__(self, data_list, batch_size, length):
        self.data_list = data_list
        self.batch_size = batch_size
        self.length = length
        
    def __len__(self, length):
        return self.length
        
    def read_tfrecord(self, data):
      features = {
          'image/class/label': tf.io.FixedLenFeature([], tf.int64),
          'image/encoded': tf.io.FixedLenFeature([], tf.string),
      }
        # decode the TFRecord
      tf_record = tf.io.parse_single_example(data, features)
        
      # Typical code for decoding compressed images
      image = tf.io.decode_jpeg(tf_record['image/encoded'], channels=3)

      image = tf.image.resize(image, [128, 128])
      # image = tf.expand_dims(image, axis=0) # add another dimension at the front to get batch_size in front

      label = tf_record['image/class/label']
  
    
      return image, label
        
    def create_dataset(self, data_list):
        dataset = tf.data.Dataset.list_files(data_list)
        dataset = tf.data.TFRecordDataset(dataset)
        dataset = dataset.with_options(ignore_order)
        dataset = dataset.map(self.read_tfrecord)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset
    
#     property
#     def shuffled_data_list(self):
#         return random.sample(list(self.create_dataset(self.data_list)), self.length)

    def process_data(self, data):
        for x in data:
            yield x

    # def process_data(self, data):
    #     for image, label in data:
    #       image_torch = torch.from_numpy(image.numpy())
    #       yield image_torch, label
    
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def get_streams(self):
        return zip(*[self.get_stream(self.create_dataset(self.data_list)) for _ in range(self.batch_size)])
    
    def __iter__(self):
        return self.get_streams()

In [None]:
filenames = [
             'gs://imagenet-jt/train/train-00000-of-01024',
]

dataset = MyIterable_tf_Dataset(filenames, batch_size=2, length=4)

loader = DataLoader(dataset, batch_size=None, num_workers=1)

In [None]:
for data in loader:
  print(data)