In [1]:
%load_ext autoreload
%autoreload 2
import datasets
from datasets import inspect_dataset
import torch
from torch.utils.data import DataLoader
from torchvision.transforms.v2 import Compose, ToImage, ToDtype
from fmnist_models import MLP, train
from transformers import AutoImageProcessor, TrainingArguments



In [3]:
dataset = datasets.load_dataset("fashion_mnist", data_dir="fashion_mnist")
transform = Compose([ToImage(), ToDtype(torch.float32, scale=True)])

def transforms(examples):
    examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]
    del examples["image"]
    return examples

dataset.set_transform(transforms)

In [37]:
train_args = TrainingArguments(
    output_dir="fashion_mnist",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=128,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=128,
    num_train_epochs=1000,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    use_mps_device=True,
)
train(MLP, dataset, train_args, d_in=28*28, d_hidden=28*28, d_out=10)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/117000 [00:00<?, ?it/s]

  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 2.3026, 'grad_norm': 0.04654799401760101, 'learning_rate': 4.2735042735042736e-08, 'epoch': 0.09}
{'loss': 2.3028, 'grad_norm': 0.04594659432768822, 'learning_rate': 8.547008547008547e-08, 'epoch': 0.17}
{'loss': 2.3027, 'grad_norm': 0.046945370733737946, 'learning_rate': 1.282051282051282e-07, 'epoch': 0.26}
{'loss': 2.3026, 'grad_norm': 0.04504389315843582, 'learning_rate': 1.7094017094017095e-07, 'epoch': 0.34}
{'loss': 2.3026, 'grad_norm': 0.04321511462330818, 'learning_rate': 2.136752136752137e-07, 'epoch': 0.43}
{'loss': 2.3027, 'grad_norm': 0.04516351968050003, 'learning_rate': 2.564102564102564e-07, 'epoch': 0.51}
{'loss': 2.3026, 'grad_norm': 0.04494961351156235, 'learning_rate': 2.991452991452992e-07, 'epoch': 0.6}
{'loss': 2.3024, 'grad_norm': 0.043484680354595184, 'learning_rate': 3.418803418803419e-07, 'epoch': 0.68}
{'loss': 2.3024, 'grad_norm': 0.045083675533533096, 'learning_rate': 3.846153846153847e-07, 'epoch': 0.77}
{'loss': 2.3024, 'grad_norm': 0.0439924560

  0%|          | 0/79 [00:00<?, ?it/s]

Checkpoint destination directory fashion_mnist/checkpoint-117 already exists and is non-empty. Saving will proceed but saved results may be invalid.


{'eval_loss': 2.3020806312561035, 'eval_accuracy': 0.1275, 'eval_runtime': 1.1665, 'eval_samples_per_second': 8572.625, 'eval_steps_per_second': 67.724, 'epoch': 1.0}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 2.3019, 'grad_norm': 0.04503161087632179, 'learning_rate': 5.128205128205128e-07, 'epoch': 1.02}
{'loss': 2.3021, 'grad_norm': 0.04608452320098877, 'learning_rate': 5.555555555555556e-07, 'epoch': 1.11}
{'loss': 2.3019, 'grad_norm': 0.04588427394628525, 'learning_rate': 5.982905982905984e-07, 'epoch': 1.19}
{'loss': 2.3018, 'grad_norm': 0.04352067783474922, 'learning_rate': 6.41025641025641e-07, 'epoch': 1.28}
{'loss': 2.3016, 'grad_norm': 0.046033281832933426, 'learning_rate': 6.837606837606838e-07, 'epoch': 1.36}


KeyboardInterrupt: 

In [19]:
torch.flatten(torch.ones(4, 1, 28, 28), start_dim=1).shape

torch.Size([4, 784])