In [None]:
import torch
torch.cuda.is_available()

In [None]:
!python main.py --model_dir experiments/base_resnet18

In [None]:
import tensorflow as tf


BATCH_SIZE = 64
BIGGER = 160
RESIZE = 128
CENTRAL_FRAC = 0.875

SCHEDULE_LENGTH = 500
SCHEDULE_LENGTH = (SCHEDULE_LENGTH * 512 / BATCH_SIZE)

In [None]:
# Augmentation motivated from here:
# https://github.com/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb.
def augment(image):
    # Resize to a bigger shape, randomly horizontally flip it,
    # and then take the crops. 
    image = tf.image.resize(image, (BIGGER, BIGGER))
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, [RESIZE, RESIZE, 3])
    return image

    
# Function to read the TFRecords, segregate the images and labels.
def read_tfrecord(example, train):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "class": tf.io.FixedLenFeature([], tf.int64)
    }
    
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_jpeg(example["image"], channels=3)
    
    if train:
        image = augment(image)
    else:
        image = tf.image.central_crop(image, central_fraction=CENTRAL_FRAC)
        image = tf.image.resize(image, (RESIZE, RESIZE))
        
    image = tf.reshape(image, (RESIZE, RESIZE, 3))
    image = tf.cast(image, tf.float32) / 255.0  
    class_label = tf.cast(example["class"], tf.int32)
    return (image, class_label)



In [None]:
# Define optimizer and loss

lr = (1e-5 * BATCH_SIZE / 512)

# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[200, 300, 400], 
                                                                   values=[lr, lr*0.1, lr*0.001, lr*0.0001])
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [None]:
# # Extract data from log files and plot

filename = '/content/drive/MyDrive/CS762/experiments/base_resnet18/train.log'

train_acc = []
loss = []
eval_acc = []

with open(filename) as file:
  for line in file:
    l = line.rstrip()
    if 'Train metrics' in l:
      train_acc.append(float(l[-19:-14]))
      loss.append(float(l[-5:]))
    if 'Eval metrics' in l:
      eval_acc.append(float(l[-19:-14]))

import matplotlib.pyplot as plt
plt.plot(train_acc)
plt.plot(eval_acc)
plt.show()


In [None]:
#drive.flush_and_unmount()