In [None]:
import trax
trax.fastmath.backend_name()

In [None]:
from pathlib import Path
datadir = Path("../data")

In [None]:
train_stream = trax.data.TFDS(
    'cats_vs_dogs', 
    keys=('image', 'label'), 
    train=True)()

In [None]:
eval_stream = trax.data.TFDS(
    'cats_vs_dogs', 
    keys=('image', 'label'), 
    train=False)()

In [None]:
from trax import layers as tl
train_data_pipeline = trax.data.Serial(
    trax.data.Shuffle(),
    trax.data.Batch(8),
)

train_batches_stream = train_data_pipeline(train_stream)

eval_data_pipeline = trax.data.Batch(8)

eval_batches_stream = eval_data_pipeline(eval_stream)

In [None]:
example_batch = next(train_batches_stream)

In [None]:
print(f'batch shape (image, label) = {[x.shape for x in example_batch]}')

In [None]:
X = example_batch[0]
X.shape

In [None]:
i = 2 # check the i-th image

In [None]:
%matplotlib inline
import seaborn as sns
img = X[i, :, :, 0]
sns.heatmap(img)

In [None]:
model = tl.Serial(
      tl.ToFloat(),

      tl.Conv(32, (3, 3), (1, 1), 'SAME'),
      tl.LayerNorm(),
      tl.Relu(),
      tl.MaxPool(),

      tl.Conv(64, (3, 3), (1, 1), 'SAME'),
      tl.LayerNorm(),
      tl.Relu(),
      tl.MaxPool(),

      tl.Flatten(),
      tl.Dense(10),
)

In [None]:
from trax.supervised import training

train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=100,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CategoryCrossEntropy(), tl.CategoryAccuracy()],
    n_eval_batches=20,
)

In [None]:
training_loop = training.Loop(model, 
                              train_task, 
                              eval_tasks=[eval_task], 
                              output_dir='./cnn_model2')

training_loop.run(1000)

In [None]:
X, y = next(eval_batches_stream)
y

In [None]:
X.shape, y.shape

In [None]:
yhat = model.forward(X)
yhat.shape

In [None]:
import trax.fastmath.numpy as np
np.argmax(yhat, axis=0)

In [None]:
y

In [None]:
from trax.models.resnet import Resnet50

resn = Resnet50(n_output_classes=10)

In [None]:
training_loop = training.Loop(resn, 
                              train_task, 
                              eval_tasks=[eval_task], 
                              output_dir='./resnet50')

training_loop.run(1000)