In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import config
from data import OmniMNIST
from  main import train, test
from model import SphereCNN, CNN

In [2]:
np.random.seed(23)

In [3]:
import importlib
importlib.reload(config)
opt = config.Config()
os.environ["CUDA_VISIBLE_DEVICES"] = opt.CUDA_VISIBLE_DEVICES

In [4]:
train_dataset = OmniMNIST(fov=90, h_rotate=True, v_rotate=True, train=True)
test_dataset = OmniMNIST(fov=90, h_rotate=True, v_rotate=True, train=False, fix_aug=True)

In [5]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=False)

In [6]:
device = torch.device('cuda' if opt.use_gpu else 'cpu')

In [7]:
# Train
sphere_model = SphereCNN().to(device)
model = CNN().to(device)
if opt.optimizer == 'adam':
    sphere_optimizer = torch.optim.Adam(sphere_model.parameters(), lr=opt.lr)
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
elif opt.optimizer == 'sgd':
    sphere_optimizer = torch.optim.SGD(sphere_model.parameters(), lr=opt.lr, momentum=opt.momentum)
    optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum)

In [None]:
for epoch in range(1, opt.epochs + 1):
    print('{} Sphere CNN {}'.format('='*10, '='*10))
    train(opt, sphere_model, device, train_loader, sphere_optimizer, epoch)
    test(opt, sphere_model, device, test_loader, epoch)
    torch.save(sphere_model.state_dict(), 'sphere_cnn.pkl')

    print('{} Conventional CNN {}'.format('='*10, '='*10))
    train(opt, model, device, train_loader, optimizer, epoch)
    test(opt, model, device, test_loader, epoch)
    torch.save(model.state_dict(), 'cnn.pkl')