In [None]:
import os,sys,copy
file_path = '../src/stacked_hourglass'
sys.path.append(os.path.dirname(file_path))

# https://stackoverflow.com/questions/42212810/tqdm-in-jupyter-notebook-prints-new-progress-bars-repeatedly
from tqdm.notebook import trange, tqdm 

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

from stacked_hourglass import hg2
from stacked_hourglass.utils.logger import Logger
from stacked_hourglass.datasets.csv import CSV
from stacked_hourglass.utils.misc import save_checkpoint, adjust_learning_rate
from stacked_hourglass.csv_train import do_training_epoch, do_validation_epoch

In [None]:
csv_path = '/home/wanglab/Data/image_labels_shuffled.csv'
data_folder = '/home/wanglab/Data/Licking_Data'
checkpoint = 'checkpoint_csv'
input_shape = (256, 256)
arch = 'hg2'

train_batch = 16 #This works with my 12Gb video card, but probably need 16 batches for 8Gb (This is for HG2)
test_batch = 16

workers = 2

lr = 5e-4

start_epoch = 0
epochs = 150
snapshot = 0

best_f1 = 0

device = torch.device('cuda', torch.cuda.current_device())

# Disable gradient calculations by default.
torch.set_grad_enabled(False)

# create checkpoint dir
os.makedirs(checkpoint, exist_ok=True)

if arch == 'hg1':
    model = hg1(pretrained=False, input_channels=1, num_classes=1)
elif arch == 'hg2':
    model = hg2(pretrained=False, input_channels=1, num_classes=1)
elif arch == 'hg8':
    model = hg8(pretrained=False, input_channels=1, num_classes=1)
else:
    raise Exception('unrecognised model architecture: ' + arch)

In [None]:
##Load previous model
filename_to_load = 'checkpoint_csv/checkpoint.pth.tar'
loaded_checkpoint = torch.load(filename_to_load)

state_dict = loaded_checkpoint['state_dict']

if sorted(state_dict.keys())[0].startswith('module.'):
    model = DataParallel(model)
model.load_state_dict(state_dict)


In [None]:
##Load and finetune
if arch == 'hg1':
    model = hg1(pretrained=False)
elif arch == 'hg2':
    model = hg2(pretrained=False)
elif arch == 'hg8':
    model = hg8(pretrained=False)
else:
    raise Exception('unrecognised model architecture: ' + arch)

filename_to_load = 'checkpoint_mpii/checkpoint.pth.tar'
loaded_checkpoint = torch.load(filename_to_load)

state_dict = loaded_checkpoint['state_dict']

if sorted(state_dict.keys())[0].startswith('module.'):
    model = DataParallel(model)
model.load_state_dict(state_dict)

#Freeze all layers
for param in model.parameters():
    param.requires_grad = False

for param in model.module.fc.parameters():
    param.requires_grad = True

for param in model.module.fc_.parameters():
    param.requires_grad = True

score=[]
for i in range(0,len(model.module.score)):
    score.append(nn.Conv2d(256, 1, kernel_size=1, bias=True))
model.module.score = nn.ModuleList(score)
score_ = [nn.Conv2d(1,256,kernel_size=1,bias=True)]
model.module.score_ = nn.ModuleList(score_)

for name, param in model.named_parameters():
    if (param.requires_grad == True):
        print(name)

In [None]:
model = DataParallel(model).to(device)

train_dataset = CSV(csv_path, data_folder, is_train=True, inp_res=input_shape,input_channels=3)
train_loader = DataLoader(
    train_dataset,
    batch_size=train_batch, shuffle=True,
    num_workers=workers, pin_memory=True
)

val_dataset = copy.deepcopy(train_dataset)
val_dataset.is_train = False
val_loader = DataLoader(
    val_dataset,
    batch_size=test_batch, shuffle=False,
    num_workers=workers, pin_memory=True
)

optimizer = Adam(model.parameters(), lr=lr)

writer = SummaryWriter()

print('The total size of the training set is ', len(train_loader)*train_batch)
print('The total size of the validation set is ', len(val_loader)*test_batch)

In [None]:
# train and eval
start_epoch = 150
epochs = 250

for epoch in trange(start_epoch, epochs, desc='Overall', ascii=True):

    # train for one epoch
    train_loss, train_f1, train_PPV, train_sens = do_training_epoch(train_loader, model, device, optimizer)

    # evaluate on validation set
    valid_loss, predictions, valid_f1, valid_PPV, valid_sens = do_validation_epoch(val_loader, model, device)

    # print metrics
    tqdm.write(f'[{epoch + 1:3d}/{epochs:3d}] lr={lr:0.2e} '
                   f'train_loss={train_loss:0.4f} '
                   f'valid_loss={valid_loss:0.4f} ')

    writer.add_scalar('Loss/train', train_loss,epoch)
    writer.add_scalar('Loss/test', valid_loss,epoch)
    writer.add_scalar('F1/train', train_f1,epoch)
    writer.add_scalar('F1/test',valid_f1,epoch)
    writer.add_scalar('PPV/train',train_PPV,epoch)
    writer.add_scalar('PPV/test',valid_PPV,epoch)
    writer.add_scalar('Sensitivity/train',train_sens,epoch)
    writer.add_scalar('Sensitivity/test',valid_sens,epoch)

    # remember best acc and save checkpoint
    is_best = valid_f1 > best_f1
    best_acc = max(valid_f1, best_f1)
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': arch,
        'state_dict': model.state_dict(),
        'best_acc': best_f1,
        'optimizer' : optimizer.state_dict(),
    }, predictions, is_best, checkpoint=checkpoint, snapshot=snapshot)

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
#plt.pcolor(TF.autocontrast(train_dataset[0][0])[0,:,:])
plt.pcolor(train_dataset[0][1][0,:,:])
plt.colorbar()

In [None]:
train_dataset[1][1][0,:,:]

In [None]:
#https://discuss.pytorch.org/t/computing-the-mean-and-std-of-dataset/34949
# !!!I should ONLY do this for training data and not test data !!!

mydataset = CSV(csv_path, data_folder, is_train=False, inp_res=input_shape,training_split=0.0)
myloader = DataLoader(
    mydataset,
    batch_size=train_batch, shuffle=True,
    num_workers=workers, pin_memory=True
)

mean = 0.
std = 0.
for images, _ in myloader:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(myloader.dataset)
std /= len(myloader.dataset)

print('Mean: ', mean)
print('Std: ', std)