In [1]:
import torch
from torch.utils.data import DataLoader
from dataset.dataset import get_cdiscount_dataset
from model.model import assemble_model
from trainer.trainer import get_trainer

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# redirect print to file
# import sys
# sys.stdout = open("PyTorch-resnet34-log.txt", "w")

In [None]:
# parameters
config = {
    'train_batch_size': 256, 'val_batch_size': 256,
    'arch': 'resnet34',
    'optimizer': 'Adam', 'learning_rate': 1e-4, 'decay_lr_freq': 1e4, 'weight_decay': 5e-4,
    'resume': None,
    'start_epoch': 0, 'epochs': 10,
    'print_freq': 10, 'validate_freq': 3e4, 'save_freq': 1e3,
    'best_val_prec1': 0
}

In [None]:
import torchvision.models as models

# get dataset
print('getting dataset...')
train_dataset = get_cdiscount_dataset(offsets_csv="train_offsets.csv",
                                      images_csv="train_images.csv",
                                      bson_file_path="/mnt/data/cdiscount/train.bson",
                                      with_label=True,
                                      resize=256)
val_dataset = get_cdiscount_dataset(offsets_csv="train_offsets.csv",
                                    images_csv="val_images.csv",
                                    bson_file_path="/mnt/data/cdiscount/train.bson",
                                    with_label=True,
                                    resize=256)

# get data loader
print('getting data loader...')
train_dataloader = DataLoader(train_dataset, batch_size=config['train_batch_size'], shuffle=True, num_workers=6)
val_dataloader = DataLoader(val_dataset, batch_size=config['val_batch_size'], shuffle=True, num_workers=6)

# define model
print("=> using pre-trained model '{}'".format(config['arch']))
model = models.__dict__[config['arch']](pretrained=True)
model = assemble_model(model, -1, 512, 5270)
model = torch.nn.DataParallel(model).cuda()

# define loss function (criterion) and optimizer
criterion = torch.nn.CrossEntropyLoss().cuda()

# get trainer
Trainer = get_trainer(train_dataloader, val_dataloader, model, criterion, config)

# Run!
Trainer.run()

getting dataset...
