In [None]:
%load_ext autoreload
%autoreload 2
from torch.optim import Adam
from dataloader import *
from torch.utils.data.dataset import random_split
from training_utils import *
from torchvision.transforms import transforms
import torch
import os
import pickle

torch.set_num_threads(os.cpu_count())

Parsing the dataset and perform splits.

In [None]:
trashset = TrashNetDataset("trashnet/data/dataset-resized")
valsize = int(len(trashset) * 0.1)
trainsize = int(len(trashset) * 0.8)
testsize = len(trashset) - valsize - trainsize
torch.manual_seed(0) # Ensure dataset is randomly split the same way each time
train_dataset, val_dataset, test_dataset = random_split(trashset, [trainsize, valsize, testsize])

train_dataloader = DataLoader(train_dataset, batch_size=32)
test_dataloader = DataLoader(test_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32)

Creating transformer to perform data augmentation.

In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
#     transforms.Pad(300, padding_mode='reflect'),
#     transforms.RandomRotation(20,expand=True),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor()
])

Visualizing transformer (running is optional).

In [None]:
x, y = trashset[0]
print("Before Transform Shape:", x.shape)
plot(x, y)
xT = transform(x)
print("After Transform Shape:", xT.shape)
plot(xT, y)

Visualizing some samples (running is optional).

In [None]:
x,y = next(iter(train_dataloader))
transform_batch(x, transform)
for ix in range(len(x)):
    plot(x[ix], y[ix].item())

Initializing the model and optomizer.

In [None]:
# Replace with desired network
from torchvision.models.resnet import *
model = resnet18(pretrained=True)

for param in model.parameters():
    param.requires_grad = False

model.fc = torch.nn.Linear(model.fc.in_features, len(CLASSES), bias=True)
model.fc.requires_grad = True
model.device = 'cpu'
if torch.cuda.is_available():
    model.device = 'cuda'
model = model.to(model.device)
opt = Adam(model.parameters(), lr=0.001)
loss = torch.nn.CrossEntropyLoss()


total_loss = []
total_acc = []
total_learning_rate = []
total_batch_size = []
total_train_full = []
num_epochs = 0
epochs_per_run = 1

Training the model & fine tuning.

In [None]:
# Run this cell to start a new training epoch
num_epochs += epochs_per_run
losses, acc = train_model(model, train_dataloader, val_dataloader, opt, epochs_per_run, transform=transform)
total_loss += losses
total_acc += acc
total_learning_rate += [opt.defaults['lr']] * len(acc)
total_batch_size += [train_dataloader.batch_size] * len(acc)
train_full = True
for param in model.parameters():
    train_full |= param.requires_grad
total_train_full += [train_full] * len(acc)

fig = plt.figure(figsize=(20,3))
plt.subplot(141)
plt.plot(total_loss)
plt.xlabel("Iteration")
plt.ylabel("Training Loss")
plt.subplot(142)
plt.plot(total_acc)
plt.xlabel("Iteration")
plt.ylabel("Validation Accuracy")
plt.subplot(143)
plt.plot(total_batch_size)
plt.xlabel("Iteration")
plt.ylabel("Batch Size")
plt.subplot(144)
plt.plot(torch.log10(torch.tensor(total_learning_rate)))
plt.xlabel("Iteration")
plt.ylabel("Learning Rate")
plt.show()
plt.show()

In [None]:
# Run this cell to fine tune hyperparameters in between epochs
opt = Adam(model.parameters(), lr=0.0001)
train_dataloader = DataLoader(train_dataset, batch_size=8)
for param in model.parameters():
    param.requires_grad = True

Evaluate on test set.

In [None]:
acc = evaluate(model, test_dataloader)
print("Accuracy:", acc)

Saving the weights and training metadata. Make sure to set the model name.

In [None]:
model_name = "test"

if not os.path.isdir("models"):
    os.mkdir("models")
if not os.path.isdir("models/" + model_name):
    os.mkdir("models/" + model_name)
torch.save(model.state_dict(), "models/" + model_name + "/weights.checkpoint")
torch.save(model.state_dict(), "models/" + model_name + "/weights.checkpoint")
training_metadata = {"total_loss" : total_loss, "total_acc" : total_acc, "total_learning_rate" : total_learning_rate, "total_batch_size" : total_batch_size, "total_train_full" : total_train_full, "num_epochs" : num_epochs}
pickle_out = open("models/" + model_name + "/training_metadata.pickle","wb")
pickle.dump(training_metadata, pickle_out)
pickle_out.close()