In [9]:
import torch
from torch.utils.data import DataLoader, random_split

from datasets.CustomImageDataset import CustomImageDataset
from models.cnn import CnnModel
from train import train_classifier
from utils import plot_model_metrics

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

batch_size = 512
dataset = CustomImageDataset('data/DeepHP', class_size=100_000)

train_dataset, test_dataset = random_split(dataset, [0.7, 0.3])
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True, )
test_loader = DataLoader(test_dataset,
                         batch_size=batch_size,
                         shuffle=True, )

model = CnnModel()
print(model)


cuda:0
CnnModel(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=valid)
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=valid)
    (4): ReLU()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=valid)
    (7): ReLU()
    (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=valid)
    (9): ReLU()
    (10): AvgPool2d(kernel_size=2, stride=1, padding=0)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Linear(in_features=512, out_features=256, bias=True)
    (13): ReLU()
    (14): Dropout(p=0.2, inplace=False)
    (15): Linear(in_features=256, out_features=128, bias=True)
    (16): ReLU()
    (17): Dropout(p=0.2, inplace=False)
    (18): Linear(in_features=128, out_features=1, bias=True)
  )
)


In [None]:

model = model.to(device)
model, model_metrics = train_classifier(model, train_loader, test_loader, device,
                                        learning_rate=0.001,
                                        weight_decay=0.0001,
                                        max_epochs=50,
                                        checkpoint_every=10)

In [None]:
for metric in ["accuracy", "precision", "recall", "f1"]:
    print(f"Test {metric}:", model_metrics[f"test_{metric}"][-1])

plot_model_metrics(model_metrics)

In [None]:

torch.save(model.state_dict(), "./model5.bin")
