In [None]:
import os,sys,copy
file_path = '../src/stacked_hourglass'
sys.path.append(os.path.dirname(file_path))

# https://stackoverflow.com/questions/42212810/tqdm-in-jupyter-notebook-prints-new-progress-bars-repeatedly
from tqdm.notebook import trange, tqdm 

import torch
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from torch.optim import Adam

from stacked_hourglass import hg2
from stacked_hourglass.utils.logger import Logger
from stacked_hourglass.datasets.csv import CSV
from stacked_hourglass.utils.misc import save_checkpoint, adjust_learning_rate
from stacked_hourglass.csv_train import do_training_epoch, do_validation_epoch

In [None]:
csv_path = '/home/wanglab/Data/image_labels.csv'
data_folder = '/home/wanglab/Data/Licking_Data'
checkpoint = 'checkpoint'
input_shape = (256, 256)
arch = 'hg2'

train_batch = 16 #This works with my 12Gb video card, but probably need 16 batches for 8Gb (This is for HG2)
test_batch = 16

workers = 2

lr = 5e-4

start_epoch = 0
epochs = 10
snapshot = 0

best_acc = 0

device = torch.device('cuda', torch.cuda.current_device())

# Disable gradient calculations by default.
torch.set_grad_enabled(False)

# create checkpoint dir
os.makedirs(checkpoint, exist_ok=True)

if arch == 'hg1':
    model = hg1(pretrained=False, input_channels=1, num_classes=1)
elif arch == 'hg2':
    model = hg2(pretrained=False, input_channels=1, num_classes=1)
elif arch == 'hg8':
    model = hg8(pretrained=False, input_channels=1, num_classes=1)
else:
    raise Exception('unrecognised model architecture: ' + arch)

model = DataParallel(model).to(device)


train_dataset = CSV(csv_path, data_folder, is_train=True, inp_res=input_shape)
train_loader = DataLoader(
    train_dataset,
    batch_size=train_batch, shuffle=True,
    num_workers=workers, pin_memory=True
)

val_dataset = copy.deepcopy(train_dataset)
val_dataset.is_train = False
val_loader = DataLoader(
    val_dataset,
    batch_size=test_batch, shuffle=False,
    num_workers=workers, pin_memory=True
)

optimizer = Adam(model.parameters(), lr=lr)

print('The total size of the training set is ', len(train_loader)*train_batch)
print('The total size of the validation set is ', len(val_loader)*test_batch)

In [None]:
# train and eval
#start_epoch = 10
#epochs = 40

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

for epoch in trange(start_epoch, epochs, desc='Overall', ascii=True):

    # train for one epoch
    train_loss = do_training_epoch(train_loader, model, device, optimizer)

    # evaluate on validation set
    valid_loss, predictions = do_validation_epoch(val_loader, model, device)

    # print metrics
    tqdm.write(f'[{epoch + 1:3d}/{epochs:3d}] lr={lr:0.2e} '
                   f'train_loss={train_loss:0.4f} '
                   f'valid_loss={valid_loss:0.4f} ')

    # append logger file
    #logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
    #logger.plot_to_file(os.path.join(checkpoint, 'log.svg'), ['Train Acc', 'Val Acc'])
    writer.add_scalar('Loss/train', train_loss,epoch)
    writer.add_scalar('Loss/test', valid_loss,epoch)

    # remember best acc and save checkpoint
    #save_checkpoint({
    #    'epoch': epoch + 1,
    #    'arch': arch,
    #    'state_dict': model.state_dict(),
    #    'optimizer' : optimizer.state_dict(),
    #}, predictions, checkpoint=checkpoint, snapshot=snapshot)

In [None]:
pcolor(train_dataset[1])