# PyTorch TPU + TF Records

In [None]:
PROJECT_ID = 'hybrid-vertex'  # <--- TODO: CHANGE THIS
LOCATION = 'us-central1' 
!gcloud config set project {PROJECT_ID}

In [None]:
import sys
if 'google.colab' in sys.modules:
  from google.colab import auth
  auth.authenticate_user()

In [None]:
if 'google.colab' in sys.modules:
  USER_FLAG = ''
else:
  USER_FLAG = '--user'

### pip install

In [None]:
!pip install cloud-tpu-client==0.10 torch==1.11.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl

! pip install tensorflow

! pip -q install google-cloud-storage==1.44.0

import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

### Import packages

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, IterableDataset, DataLoader

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms

import os
import sys
from itertools import cycle, islice, chain, count
import random 
import time

import numpy as np

from google.cloud import storage

import warnings
warnings.filterwarnings("ignore")

### TF vs Torch Tensors

* N - batch size
* H - height of image
* W - width of image
* C - numberof channels (usually 3 for RGB)

Tensorflow 
* `shape=(N, H, W, C)`

PyTorch
* `torch.Size([N, C, H, W])`

[source](https://towardsdatascience.com/convert-images-to-tensors-in-pytorch-and-tensorflow-f0ab01383a03)

# Read TF Records

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
batch_size = 128
test_batch_size = 64

## ImageNet

In [None]:
def read_tfrecord(self, data_):

  features = {
      'image/class/label': tf.io.FixedLenFeature([], tf.int64),
      'image/encoded': tf.io.FixedLenFeature([], tf.string),
  }
    # decode the TFRecord
  tf_record = tf.io.parse_single_example(data_, features)
    
  # Typical code for decoding compressed images
  image = tf.io.decode_jpeg(tf_record['image/encoded'], channels=3)
  # image_ = tf.io.decode_jpeg(tf_record['image/encoded'], channels=3)

  #  --> (128, 128, 3)resize tensor to 128 x 128 
  image = tf.image.resize(image, [128, 128])
  # image_1 = tf.image.resize(image, [128, 128])
  
  # --> # (1, 224, 224, 3)batch_size in front
  image = tf.expand_dims(image, axis=0) 
  # image_2 = tf.expand_dims(image_1, axis=0)           
  
  # convert tensor to torch format: 
  # [N, H, W, C] --> [N, C, H, W]
  image = tf.transpose(image, perm=[0, 3, 1, 2])
  # print(f'image_3 TF tensor: {tf.shape(image_3)}')

  label = tf_record['image/class/label']

  return image, label #, pid

In [None]:
filenames = [
             'gs://imagenet-jt/train/train-00000-of-01024',
]

dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
dataset = dataset.with_options(ignore_order)

# decoding a tf.data.TFRecordDataset
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)

for image, label in dataset.take(1):
  # image_torch = torch.from_numpy(image.numpy())
  # print(image_torch)
  print("Image shape {}, label={}".format(image.numpy().shape, label))

### Torch Iterable Dataset

Create a custom iterable dataset for PyTorch `DataLoader`

In [None]:
class My_TF_Torch_Dataset(IterableDataset):
    
    def __init__(self, data_file_list, batch_size, length): # batch_size
        self.data_file_list = data_file_list # list of filenames
        self.batch_size = batch_size
        self.length = length
        
    def __len__(self, length):
      return self.length
        
    def read_tfrecord(self, data_):

      features = {
          'image/class/label': tf.io.FixedLenFeature([], tf.int64),
          'image/encoded': tf.io.FixedLenFeature([], tf.string),
      }
        # decode the TFRecord
      tf_record = tf.io.parse_single_example(data_, features)
        
      # Typical code for decoding compressed images
      image = tf.io.decode_jpeg(tf_record['image/encoded'], channels=3)
      # image_ = tf.io.decode_jpeg(tf_record['image/encoded'], channels=3)

      #  --> (128, 128, 3)resize tensor to 128 x 128 
      image = tf.image.resize(image, [128, 128])
      # image_1 = tf.image.resize(image, [128, 128])
      
      # --> # (1, 224, 224, 3)batch_size in front
      image = tf.expand_dims(image, axis=0) 
      # image_2 = tf.expand_dims(image_1, axis=0)           
      
      # convert tensor to torch format: 
      # [N, H, W, C] --> [N, C, H, W]
      image = tf.transpose(image, perm=[0, 3, 1, 2])
      # print(f'image_3 TF tensor: {tf.shape(image_3)}')

      label = tf_record['image/class/label']

      return image, label #, pid

        
    def create_dataset(self, data_file_list):
      
      # set configs
      AUTOTUNE = tf.data.AUTOTUNE
      ignore_order = tf.data.Options()
      ignore_order.experimental_deterministic = False

      dataset = tf.data.Dataset.list_files(
          data_file_list,
          shuffle=False,
      )
      dataset = tf.data.TFRecordDataset(
          dataset, 
          num_parallel_reads=AUTOTUNE,
      )
      dataset = dataset.with_options(
          ignore_order
      )
      # dataset = dataset.batch(
      #     batch_size=self.batch_size,
      #     drop_remainder=True,
      #     num_parallel_calls=AUTOTUNE,
      #     deterministic=False,
      #     name="batch_for_vector_mapping",
      # )
      dataset = dataset.map(
          self.read_tfrecord,
          num_parallel_calls=AUTOTUNE,
      )
      # dataset = dataset.prefetch(
      #     buffer_size=AUTOTUNE,
      #     name="prefetch_data_b4_unbatch",
      # )
      # dataset = dataset.unbatch()

      return dataset


    def get_torch_tensors(self, train_files):

      tf_ds = self.create_dataset(train_files) 
      
      for image, label in tf_ds:
        # print("Image shape {}, label={}".format(image.numpy().shape, label))   

        image_torch = torch.from_numpy(image.numpy())
        label_torch = int(label.numpy())
        # print("image_torch shape {}, label={}".format(image_torch.size(), label_torch))

        yield image_torch, label_torch

    # def get_stream(self, data_list):
    #     return chain.from_iterable(map(self.process_data, cycle(data_list)))
    #     # return chain.from_iterable(map(self.process_data, cycle( ? )))
    
    # def get_streams(self, train_files_this):
      # return zip(*[self.get_torch_tensors(self.train_files_this) for _ in range(self.batch_size)]) # split_size

    def __iter__(self):
      
      train_files_this = self.data_file_list
      worker_info = torch.utils.data.get_worker_info()

      if worker_info is not None:
        wid = worker_info.id
        num_workers = worker_info.num_workers
        worker_train_files = train_files_this[wid::num_workers] 
      else:
        worker_train_files = train_files_this

      print(f"In __iter__ : Worker_id: {wid} of {num_workers}; files: {worker_train_files}")

      return self.get_torch_tensors(worker_train_files)

In [None]:
filenames = [
             'gs://imagenet-jt/train/train-00000-of-01024',
]
epoch_length = len(filenames) * 1252

dataset = My_TF_Torch_Dataset(filenames, batch_size=2, length=epoch_length)

loader = DataLoader(dataset, batch_size=None, num_workers=2)

In [None]:
train_features, train_labels = next(iter(loader))

print(f"Feature batch shape: {train_features.size()}")
# print(f"Labels batch shape: {len(train_labels)}")

print(f"Label: {train_labels}")
print(f"Image: {train_features}")

In [None]:
for data in loader:
  print(f"Batch size: {len(data)}")
  print(data)
  break