In [1]:
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.random as four

In [2]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import torchvision.transforms.v2 as transforms
from torchvision import tv_tensors
import matplotlib.pyplot as plt
import matplotlib.patches as plt_patches
from PIL import Image

In [3]:
import mnist_training



In [4]:
torch.multiprocessing.set_start_method('forkserver')

# Basic Training Example on MNIST

Now we will look at an actual traing script with `FiftyOneTorchDataset`

In [5]:
mnist = foz.load_zoo_dataset("mnist")
mnist.persistent = True

Split 'train' already downloaded
Split 'test' already downloaded
Loading existing dataset 'mnist'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


In [6]:
fo.launch_app(mnist, auto=False)

Session launched. Run `session.show()` to open the App in a cell output.


Dataset:          mnist
Media type:       image
Num samples:      70000
Selected samples: 0
Selected labels:  0
Session URL:      http://localhost:5151/

Now let's say that for our training, we want to define some random subset of our trainset to be a validation set. We can easily do this with FiftyOne.

In [7]:
# remove existing 'train' or 'validation' tags if they exist
mnist.untag_samples(['train', 'validation'])

# create a random split, just on the samples not tagged 'test'
not_test = mnist.match_tags('test', bool=False)
four.random_split(not_test, {'train' : 0.9, 'validation' : 0.1})
print(mnist.count_sample_tags())

{'train': 54000, 'test': 10000, 'validation': 6000}


In [8]:
# subset if we want it
samples = []
samples += mnist.match_tags('train').take(1000).values('id')
for tag in ['test', 'validation']:
    samples += mnist.match_tags(tag).values('id')

subset = mnist.select(samples)

In [9]:
mnist_training.main(subset, 10, 10, 'cuda:1', '/home/jacobs/fiftyone/docs/source/recipes/torch-dataset-examples/mnist_weights')

  torch.nn.init.xavier_uniform(linear_head.weight)
Average Train Loss =   4.497025: 100%|██████████| 63/63 [00:01<00:00, 56.22it/s]
Average Validation Loss =   1.276664: 100%|██████████| 375/375 [00:01<00:00, 215.46it/s]


New best lost achieved : 1.2778225633303324. Saving model...


Average Train Loss =   1.999823: 100%|██████████| 63/63 [00:00<00:00, 121.53it/s]
Average Validation Loss =   0.297409: 100%|██████████| 375/375 [00:01<00:00, 301.09it/s]


New best lost achieved : 0.2974911735057831. Saving model...


Average Train Loss =   0.279180: 100%|██████████| 63/63 [00:00<00:00, 118.09it/s]
Average Validation Loss =   0.955379: 100%|██████████| 375/375 [00:01<00:00, 292.05it/s]
Average Train Loss =   0.648509: 100%|██████████| 63/63 [00:00<00:00, 112.46it/s]
Average Validation Loss =   0.205903: 100%|██████████| 375/375 [00:01<00:00, 293.48it/s]


New best lost achieved : 0.20948163237112263. Saving model...


Average Train Loss =   0.453988: 100%|██████████| 63/63 [00:00<00:00, 113.36it/s]
Average Validation Loss =   0.364783: 100%|██████████| 375/375 [00:01<00:00, 286.46it/s]
Average Train Loss =   0.099397: 100%|██████████| 63/63 [00:00<00:00, 134.04it/s]
Average Validation Loss =   0.167176: 100%|██████████| 375/375 [00:01<00:00, 289.95it/s]


New best lost achieved : 0.17174978681871048. Saving model...


Average Train Loss =   0.714478: 100%|██████████| 63/63 [00:00<00:00, 107.72it/s]
Average Validation Loss =   0.514132: 100%|██████████| 375/375 [00:01<00:00, 294.92it/s]
Average Train Loss =   0.788760: 100%|██████████| 63/63 [00:00<00:00, 114.11it/s]
Average Validation Loss =   0.155441: 100%|██████████| 375/375 [00:01<00:00, 296.51it/s]


New best lost achieved : 0.15452848815266043. Saving model...


Average Train Loss =   0.274881: 100%|██████████| 63/63 [00:00<00:00, 112.40it/s]
Average Validation Loss =   0.131823: 100%|██████████| 375/375 [00:01<00:00, 296.25it/s]


New best lost achieved : 0.13296589381527155. Saving model...


Average Train Loss =   0.020280: 100%|██████████| 63/63 [00:00<00:00, 122.89it/s]
Average Validation Loss =   0.239573: 100%|██████████| 375/375 [00:01<00:00, 299.08it/s]
Average Validation Loss =   0.119549: 100%|██████████| 625/625 [00:09<00:00, 68.25it/s]


Final Test Results:
Loss = 0.12128328085292596
              precision    recall  f1-score   support

    0 - zero       0.95      0.99      0.97       980
     1 - one       0.99      0.99      0.99      1135
     2 - two       0.95      0.97      0.96      1032
   3 - three       0.97      0.95      0.96      1010
    4 - four       0.96      0.98      0.97       982
    5 - five       0.97      0.95      0.96       892
     6 - six       0.99      0.95      0.97       958
   7 - seven       0.99      0.95      0.97      1028
   8 - eight       0.95      0.93      0.94       974
    9 - nine       0.94      0.98      0.96      1009

    accuracy                           0.96     10000
   macro avg       0.96      0.96      0.96     10000
weighted avg       0.96      0.96      0.96     10000

