In [1]:
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

from src.data import CIFAR10
from src.models import LeNet
from src.training.config import TrainingConfig
from src.training import ModelTrainer
import matplotlib.pyplot as plt

In [2]:
training_config = TrainingConfig(batch_size=100, lr=0.0001)

In [3]:
model = LeNet()

In [4]:
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
training_data = CIFAR10(train=True, root='./data', transform=transform)
testing_data = CIFAR10(train=False, root='./data', transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


82.4%

In [None]:
len(training_data), len(testing_data)

In [None]:
model_trainer = ModelTrainer(
    model,
    optim.Adam,
    nn.CrossEntropyLoss(),
    training_config,
    training_data,
    cuda=True,
    validation_data=testing_data
)

In [None]:
training_metrics, testing_metrics = model_trainer.train(10, True)

In [None]:
plt.plot(training_metrics.losses, label='training loss')
plt.plot(testing_metrics.losses, label='testing loss')
plt.grid()
plt.legend()

In [None]:
plt.plot(training_metrics.accuracies, label='training accuracy')
plt.plot(testing_metrics.accuracies, label='testing accuracy')
plt.grid()
plt.legend()