In [None]:
%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git

In [None]:
!pip install -q datasets

In [None]:
import jax
jax.local_devices()

# Beans dataset

In [None]:
from transformers import FlaxResNetModel, AutoImageProcessor
from PIL import Image
import requests
from flax.training import train_state
import optax
import jax.numpy as jnp
from datasets import load_dataset

In [None]:
ds = load_dataset('beans')

In [None]:
ds['train'][200]

In [None]:
ds['train'][200]['image']

In [None]:
labels = ds['train'].features['labels']
print(labels)

In [None]:
labels.int2str(ds['train'][200]['labels'])

In [None]:
from transformers.utils.dummy_vision_objects import ImageGPTFeatureExtractor
import random
from PIL import ImageDraw, ImageFont, Image

In [None]:
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):
  w, h = size
  labels = ds['train'].features['labels'].names
  grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
  draw = ImageDraw.Draw(grid)
  font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

  for label_id, label in enumerate(labels):
    ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))
    for i, example in enumerate(ds_slice):
      image = example['image']
      idx = examples_per_class * label_id + i
      box = (idx % examples_per_class * w, idx // examples_per_class * h)
      grid.paste(image.resize(size), box=box)
      draw.text(box, label, (255, 255, 255), font=font)
  return grid

In [None]:
show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)

In [None]:
image_processor = AutoImageProcessor.from_pretrained('microsoft/resnet-50')

In [None]:
image_example = ds['train'][200]['image']

In [None]:
inputs_example = image_processor(image_example)

In [None]:
inputs_example.keys()

While we could call ds.map and apply this to every example at once, this can be very slow, especially if you use a larger dataset. Instead, we'll apply a transform to the dataset. Transforms are only applied to examples as you index them.

In [None]:
def transform(example_batch):
  inputs = image_processor([example for example in example_batch['image']])
  inputs['labels'] = example_batch['labels']
  return inputs

In [None]:
prepared_ds = ds.with_transform(transform)

In [None]:
# dataset = load_dataset('cifar100')

In [None]:
# image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
# preprocessed_dataset = dataset.map(lambda x: image_processor(x['img']), batched=True)

In [None]:
# train_dataset, test_dataset = dataset['train'], dataset['test']
# print(f'train: {len(train_dataset)}, test: {len(test_dataset)}')

In [None]:
# model = FlaxResNetModel.from_pretrained('microsoft/resnet-50')

In [None]:
# # ResNet parameters
# jax.tree_map(lambda x: x.shape, model.params)

In [None]:
# state = train_state.TrainState.create(
#     apply_fn=model.__call__,
#     params=model.params,
#     tx=optax.adam(1e-3),
# )