In [62]:
%load_ext autoreload
%autoreload 2
import datasets
from datasets import inspect_dataset
import torch
from torch.utils.data import DataLoader
from torchvision.transforms.v2 import Compose, Normalize, ToImage, ToDtype
from fmnist_models import MLP, CNN, train
from transformers import AutoImageProcessor, TrainingArguments

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


TODO:
- Figure out how to normalize images.

In [70]:
dataset = datasets.load_dataset("fashion_mnist")
transform = Compose([ToImage(), ToDtype(torch.float32, scale=True), Normalize(mean=[0.2860], std=[0.3530])])

def transforms(examples):
    examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]
    del examples["image"]
    return examples

dataset.set_transform(transforms)

In [71]:
train_args = TrainingArguments(
    output_dir="fashion_mnist",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=128,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=128,
    num_train_epochs=10,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    use_mps_device=True,
)
trainer = train(MLP, dataset, train_args, d_in=28*28, d_hidden=28*28, d_out=10)
# trainer = train(CNN, dataset, train_args)

Total trainable parameters: 106506


  0%|          | 0/1170 [00:00<?, ?it/s]

  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 2.3019, 'grad_norm': 0.18692916631698608, 'learning_rate': 4.273504273504274e-06, 'epoch': 0.09}
{'loss': 2.3015, 'grad_norm': 0.17187148332595825, 'learning_rate': 8.547008547008548e-06, 'epoch': 0.17}
{'loss': 2.2992, 'grad_norm': 0.2027575969696045, 'learning_rate': 1.282051282051282e-05, 'epoch': 0.26}
{'loss': 2.2957, 'grad_norm': 0.19074523448944092, 'learning_rate': 1.7094017094017095e-05, 'epoch': 0.34}
{'loss': 2.2902, 'grad_norm': 0.20945096015930176, 'learning_rate': 2.1367521367521368e-05, 'epoch': 0.43}
{'loss': 2.2833, 'grad_norm': 0.23835505545139313, 'learning_rate': 2.564102564102564e-05, 'epoch': 0.51}
{'loss': 2.2731, 'grad_norm': 0.28182336688041687, 'learning_rate': 2.9914529914529915e-05, 'epoch': 0.6}
{'loss': 2.2547, 'grad_norm': 0.3428078591823578, 'learning_rate': 3.418803418803419e-05, 'epoch': 0.68}
{'loss': 2.2285, 'grad_norm': 0.4427930414676666, 'learning_rate': 3.846153846153846e-05, 'epoch': 0.77}
{'loss': 2.1879, 'grad_norm': 0.546143293380737

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 2.0552117824554443, 'eval_accuracy': 0.489, 'eval_runtime': 1.3952, 'eval_samples_per_second': 7167.251, 'eval_steps_per_second': 56.621, 'epoch': 1.0}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 2.078, 'grad_norm': 0.7238411903381348, 'learning_rate': 4.985754985754986e-05, 'epoch': 1.02}
{'loss': 2.018, 'grad_norm': 0.6510031223297119, 'learning_rate': 4.938271604938271e-05, 'epoch': 1.11}
{'loss': 1.9579, 'grad_norm': 0.5378130674362183, 'learning_rate': 4.890788224121557e-05, 'epoch': 1.19}
{'loss': 1.9196, 'grad_norm': 0.4793243110179901, 'learning_rate': 4.8433048433048433e-05, 'epoch': 1.28}
{'loss': 1.8939, 'grad_norm': 0.5457066297531128, 'learning_rate': 4.7958214624881294e-05, 'epoch': 1.36}
{'loss': 1.8776, 'grad_norm': 0.6576974987983704, 'learning_rate': 4.7483380816714154e-05, 'epoch': 1.45}
{'loss': 1.857, 'grad_norm': 0.6575126051902771, 'learning_rate': 4.700854700854701e-05, 'epoch': 1.54}
{'loss': 1.8491, 'grad_norm': 0.548642098903656, 'learning_rate': 4.653371320037987e-05, 'epoch': 1.62}
{'loss': 1.831, 'grad_norm': 0.5755109190940857, 'learning_rate': 4.605887939221273e-05, 'epoch': 1.71}
{'loss': 1.8272, 'grad_norm': 0.474594384431839, 'learnin

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.7909363508224487, 'eval_accuracy': 0.71, 'eval_runtime': 1.4251, 'eval_samples_per_second': 7017.165, 'eval_steps_per_second': 55.436, 'epoch': 2.0}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 1.791, 'grad_norm': 0.6842477321624756, 'learning_rate': 4.415954415954416e-05, 'epoch': 2.05}
{'loss': 1.7951, 'grad_norm': 0.462130606174469, 'learning_rate': 4.368471035137702e-05, 'epoch': 2.13}
{'loss': 1.7825, 'grad_norm': 0.4745323657989502, 'learning_rate': 4.3209876543209875e-05, 'epoch': 2.22}
{'loss': 1.7759, 'grad_norm': 0.6558297872543335, 'learning_rate': 4.2735042735042735e-05, 'epoch': 2.3}
{'loss': 1.7738, 'grad_norm': 0.5929506421089172, 'learning_rate': 4.2260208926875595e-05, 'epoch': 2.39}
{'loss': 1.764, 'grad_norm': 0.5352249145507812, 'learning_rate': 4.1785375118708455e-05, 'epoch': 2.47}
{'loss': 1.7702, 'grad_norm': 0.5660427212715149, 'learning_rate': 4.131054131054131e-05, 'epoch': 2.56}
{'loss': 1.7617, 'grad_norm': 0.49604523181915283, 'learning_rate': 4.083570750237417e-05, 'epoch': 2.64}
{'loss': 1.7565, 'grad_norm': 0.5551466941833496, 'learning_rate': 4.036087369420703e-05, 'epoch': 2.73}
{'loss': 1.7567, 'grad_norm': 0.5086965560913086, 'lea

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.7427347898483276, 'eval_accuracy': 0.7415, 'eval_runtime': 1.4065, 'eval_samples_per_second': 7110.089, 'eval_steps_per_second': 56.17, 'epoch': 2.99}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 1.7432, 'grad_norm': 0.4468410313129425, 'learning_rate': 3.846153846153846e-05, 'epoch': 3.07}
{'loss': 1.7463, 'grad_norm': 0.6533463597297668, 'learning_rate': 3.798670465337132e-05, 'epoch': 3.16}
{'loss': 1.7481, 'grad_norm': 0.4890691041946411, 'learning_rate': 3.7511870845204176e-05, 'epoch': 3.24}
{'loss': 1.7433, 'grad_norm': 0.4925867021083832, 'learning_rate': 3.7037037037037037e-05, 'epoch': 3.33}
{'loss': 1.7355, 'grad_norm': 0.6657668948173523, 'learning_rate': 3.65622032288699e-05, 'epoch': 3.41}
{'loss': 1.739, 'grad_norm': 0.5627389550209045, 'learning_rate': 3.608736942070276e-05, 'epoch': 3.5}
{'loss': 1.7444, 'grad_norm': 0.4726216495037079, 'learning_rate': 3.561253561253561e-05, 'epoch': 3.58}
{'loss': 1.7362, 'grad_norm': 0.4759889543056488, 'learning_rate': 3.513770180436847e-05, 'epoch': 3.67}
{'loss': 1.728, 'grad_norm': 0.5613358616828918, 'learning_rate': 3.466286799620133e-05, 'epoch': 3.75}
{'loss': 1.7304, 'grad_norm': 0.4005991816520691, 'learni

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.7229762077331543, 'eval_accuracy': 0.7537, 'eval_runtime': 1.3839, 'eval_samples_per_second': 7225.899, 'eval_steps_per_second': 57.085, 'epoch': 4.0}
{'loss': 1.7186, 'grad_norm': 0.47856467962265015, 'learning_rate': 3.323836657169991e-05, 'epoch': 4.01}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 1.727, 'grad_norm': 0.5356930494308472, 'learning_rate': 3.2763532763532764e-05, 'epoch': 4.09}
{'loss': 1.7254, 'grad_norm': 0.5568737387657166, 'learning_rate': 3.2288698955365625e-05, 'epoch': 4.18}
{'loss': 1.7138, 'grad_norm': 0.5096388459205627, 'learning_rate': 3.181386514719848e-05, 'epoch': 4.26}
{'loss': 1.7268, 'grad_norm': 0.5396839380264282, 'learning_rate': 3.133903133903134e-05, 'epoch': 4.35}
{'loss': 1.7189, 'grad_norm': 0.48535335063934326, 'learning_rate': 3.08641975308642e-05, 'epoch': 4.43}
{'loss': 1.727, 'grad_norm': 0.5055564641952515, 'learning_rate': 3.0389363722697055e-05, 'epoch': 4.52}
{'loss': 1.7257, 'grad_norm': 0.6019004583358765, 'learning_rate': 2.9914529914529915e-05, 'epoch': 4.61}
{'loss': 1.7097, 'grad_norm': 0.4594396948814392, 'learning_rate': 2.9439696106362775e-05, 'epoch': 4.69}
{'loss': 1.7151, 'grad_norm': 0.5734356045722961, 'learning_rate': 2.8964862298195632e-05, 'epoch': 4.78}
{'loss': 1.7175, 'grad_norm': 0.5260999202728271, '

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.7110164165496826, 'eval_accuracy': 0.7644, 'eval_runtime': 1.4225, 'eval_samples_per_second': 7029.773, 'eval_steps_per_second': 55.535, 'epoch': 5.0}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 1.7147, 'grad_norm': 0.4632071852684021, 'learning_rate': 2.754036087369421e-05, 'epoch': 5.03}
{'loss': 1.7124, 'grad_norm': 0.5628951191902161, 'learning_rate': 2.706552706552707e-05, 'epoch': 5.12}
{'loss': 1.7155, 'grad_norm': 0.548023521900177, 'learning_rate': 2.6590693257359926e-05, 'epoch': 5.2}
{'loss': 1.7111, 'grad_norm': 0.6006879210472107, 'learning_rate': 2.611585944919278e-05, 'epoch': 5.29}
{'loss': 1.707, 'grad_norm': 0.6153220534324646, 'learning_rate': 2.564102564102564e-05, 'epoch': 5.37}
{'loss': 1.7125, 'grad_norm': 0.4452708065509796, 'learning_rate': 2.51661918328585e-05, 'epoch': 5.46}
{'loss': 1.7174, 'grad_norm': 0.5169723629951477, 'learning_rate': 2.4691358024691357e-05, 'epoch': 5.54}
{'loss': 1.7113, 'grad_norm': 0.4908588230609894, 'learning_rate': 2.4216524216524217e-05, 'epoch': 5.63}
{'loss': 1.7054, 'grad_norm': 0.537900984287262, 'learning_rate': 2.3741690408357077e-05, 'epoch': 5.71}
{'loss': 1.7126, 'grad_norm': 0.5564411282539368, 'learn

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.7029484510421753, 'eval_accuracy': 0.774, 'eval_runtime': 1.4201, 'eval_samples_per_second': 7041.648, 'eval_steps_per_second': 55.629, 'epoch': 6.0}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 1.7074, 'grad_norm': 0.4416179955005646, 'learning_rate': 2.184235517568851e-05, 'epoch': 6.06}
{'loss': 1.7096, 'grad_norm': 0.522028923034668, 'learning_rate': 2.1367521367521368e-05, 'epoch': 6.14}
{'loss': 1.7057, 'grad_norm': 0.44702380895614624, 'learning_rate': 2.0892687559354228e-05, 'epoch': 6.23}
{'loss': 1.6962, 'grad_norm': 0.4843115508556366, 'learning_rate': 2.0417853751187084e-05, 'epoch': 6.31}
{'loss': 1.6995, 'grad_norm': 0.46359020471572876, 'learning_rate': 1.9943019943019945e-05, 'epoch': 6.4}
{'loss': 1.7048, 'grad_norm': 0.5041469931602478, 'learning_rate': 1.9468186134852805e-05, 'epoch': 6.48}
{'loss': 1.7119, 'grad_norm': 0.6114515662193298, 'learning_rate': 1.899335232668566e-05, 'epoch': 6.57}
{'loss': 1.6986, 'grad_norm': 0.6555495858192444, 'learning_rate': 1.8518518518518518e-05, 'epoch': 6.65}
{'loss': 1.7029, 'grad_norm': 0.7611258029937744, 'learning_rate': 1.804368471035138e-05, 'epoch': 6.74}
{'loss': 1.6956, 'grad_norm': 0.502175509929657, 

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.6975152492523193, 'eval_accuracy': 0.7777, 'eval_runtime': 1.4424, 'eval_samples_per_second': 6933.117, 'eval_steps_per_second': 54.772, 'epoch': 6.99}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 1.7, 'grad_norm': 0.5099855065345764, 'learning_rate': 1.6144349477682812e-05, 'epoch': 7.08}
{'loss': 1.7009, 'grad_norm': 0.45134732127189636, 'learning_rate': 1.566951566951567e-05, 'epoch': 7.16}
{'loss': 1.6907, 'grad_norm': 0.45179829001426697, 'learning_rate': 1.5194681861348528e-05, 'epoch': 7.25}
{'loss': 1.6967, 'grad_norm': 0.47423309087753296, 'learning_rate': 1.4719848053181388e-05, 'epoch': 7.33}
{'loss': 1.7089, 'grad_norm': 0.4582500159740448, 'learning_rate': 1.4245014245014246e-05, 'epoch': 7.42}
{'loss': 1.7035, 'grad_norm': 0.503612756729126, 'learning_rate': 1.3770180436847105e-05, 'epoch': 7.51}
{'loss': 1.6977, 'grad_norm': 0.4390881061553955, 'learning_rate': 1.3295346628679963e-05, 'epoch': 7.59}
{'loss': 1.6933, 'grad_norm': 0.504296064376831, 'learning_rate': 1.282051282051282e-05, 'epoch': 7.68}
{'loss': 1.6935, 'grad_norm': 0.458760142326355, 'learning_rate': 1.2345679012345678e-05, 'epoch': 7.76}
{'loss': 1.7014, 'grad_norm': 0.4728432893753052, '

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.6938382387161255, 'eval_accuracy': 0.7809, 'eval_runtime': 1.3608, 'eval_samples_per_second': 7348.591, 'eval_steps_per_second': 58.054, 'epoch': 8.0}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 1.703, 'grad_norm': 0.431794673204422, 'learning_rate': 1.0921177587844255e-05, 'epoch': 8.02}
{'loss': 1.7031, 'grad_norm': 0.6875536441802979, 'learning_rate': 1.0446343779677114e-05, 'epoch': 8.1}
{'loss': 1.7016, 'grad_norm': 0.5967649221420288, 'learning_rate': 9.971509971509972e-06, 'epoch': 8.19}
{'loss': 1.7019, 'grad_norm': 0.583902895450592, 'learning_rate': 9.49667616334283e-06, 'epoch': 8.27}
{'loss': 1.6916, 'grad_norm': 0.45443660020828247, 'learning_rate': 9.02184235517569e-06, 'epoch': 8.36}
{'loss': 1.6927, 'grad_norm': 0.5751610398292542, 'learning_rate': 8.547008547008548e-06, 'epoch': 8.44}
{'loss': 1.6972, 'grad_norm': 0.5105316042900085, 'learning_rate': 8.072174738841406e-06, 'epoch': 8.53}
{'loss': 1.6929, 'grad_norm': 0.5470250248908997, 'learning_rate': 7.597340930674264e-06, 'epoch': 8.61}
{'loss': 1.7001, 'grad_norm': 0.5834125876426697, 'learning_rate': 7.122507122507123e-06, 'epoch': 8.7}
{'loss': 1.6873, 'grad_norm': 0.7156885266304016, 'learning

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.6918400526046753, 'eval_accuracy': 0.7838, 'eval_runtime': 1.4409, 'eval_samples_per_second': 6940.01, 'eval_steps_per_second': 54.826, 'epoch': 9.0}


  examples["pixel_values"] = [torch.tensor(transform(image)) for image in examples["image"]]


{'loss': 1.6896, 'grad_norm': 0.5594640970230103, 'learning_rate': 5.223171889838557e-06, 'epoch': 9.04}
{'loss': 1.6955, 'grad_norm': 0.5319828987121582, 'learning_rate': 4.748338081671415e-06, 'epoch': 9.13}
{'loss': 1.6956, 'grad_norm': 0.5736562609672546, 'learning_rate': 4.273504273504274e-06, 'epoch': 9.21}
{'loss': 1.6884, 'grad_norm': 0.4775325059890747, 'learning_rate': 3.798670465337132e-06, 'epoch': 9.3}
{'loss': 1.6991, 'grad_norm': 0.47366777062416077, 'learning_rate': 3.3238366571699908e-06, 'epoch': 9.38}
{'loss': 1.6861, 'grad_norm': 0.6080653071403503, 'learning_rate': 2.8490028490028492e-06, 'epoch': 9.47}
{'loss': 1.6938, 'grad_norm': 0.45435380935668945, 'learning_rate': 2.3741690408357077e-06, 'epoch': 9.55}
{'loss': 1.6976, 'grad_norm': 0.701070249080658, 'learning_rate': 1.899335232668566e-06, 'epoch': 9.64}
{'loss': 1.6991, 'grad_norm': 0.5925618410110474, 'learning_rate': 1.4245014245014246e-06, 'epoch': 9.72}
{'loss': 1.6947, 'grad_norm': 0.5760884881019592, '

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 1.6911110877990723, 'eval_accuracy': 0.7833, 'eval_runtime': 1.3585, 'eval_samples_per_second': 7360.946, 'eval_steps_per_second': 58.151, 'epoch': 9.98}
{'train_runtime': 120.0377, 'train_samples_per_second': 4998.429, 'train_steps_per_second': 9.747, 'train_loss': 1.7856239286243407, 'epoch': 9.98}


In [46]:
trainer.state.log_history

[{'loss': 2.3022,
  'grad_norm': 0.020899033173918724,
  'learning_rate': 4.273504273504274e-06,
  'epoch': 0.09,
  'step': 10},
 {'loss': 2.3019,
  'grad_norm': 0.020860455930233,
  'learning_rate': 8.547008547008548e-06,
  'epoch': 0.17,
  'step': 20},
 {'loss': 2.301,
  'grad_norm': 0.02217506244778633,
  'learning_rate': 1.282051282051282e-05,
  'epoch': 0.26,
  'step': 30},
 {'loss': 2.2995,
  'grad_norm': 0.024014804512262344,
  'learning_rate': 1.7094017094017095e-05,
  'epoch': 0.34,
  'step': 40},
 {'loss': 2.2968,
  'grad_norm': 0.030760394409298897,
  'learning_rate': 2.1367521367521368e-05,
  'epoch': 0.43,
  'step': 50},
 {'loss': 2.2921,
  'grad_norm': 0.04679757356643677,
  'learning_rate': 2.564102564102564e-05,
  'epoch': 0.51,
  'step': 60},
 {'loss': 2.2826,
  'grad_norm': 0.07243793457746506,
  'learning_rate': 2.9914529914529915e-05,
  'epoch': 0.6,
  'step': 70},
 {'loss': 2.2612,
  'grad_norm': 0.13746219873428345,
  'learning_rate': 3.418803418803419e-05,
  'epo