In [1]:
import os
import sys
import pickle
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms

from PIL import Image

from model import MyModel
from utils import score, load_checkpoint, reset, count_parameters

In [2]:
train_data_fp = 'train.npz'
val_data_fp = 'val.npz'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64

In [None]:
class FMnistDataset(Dataset):
    def __init__(self, npz_fp, transform=None):
        with np.load(npz_fp, allow_pickle=True) as data:
            self.data = data["data"]
            self.labels = data["labels"]
        self.transform = transform

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int):
        img_data = self.data[idx].astype("uint8").reshape((28, 28))
        img_label = int(self.labels[idx])

        img_data = Image.fromarray(img_data)

        if self.transform:
            img_data = self.transform(img_data)

        return img_data, img_label

In [None]:
train_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.286], std=[0.353]),
    ]
)

train_dataset = FMnistDataset(train_data_fp, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

val_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.286], std=[0.353]),
    ]
)

val_dataset = FMnistDataset(val_data_fp, transform=val_transforms)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

In [None]:
def train(model, sample):
    model.train()

    