In [1]:
import torch
import utils as ut
import numpy as np
import pandas as pd
import torch.nn as nn
from torch import optim
from time import time
from torch.utils.tensorboard import SummaryWriter
from rumex_model import RumexNet
from trainer import train, validate
from rumex_dataset import RumexDataset, train_loader, test_loader

# mnasnet: 1e-3, 1e-2
# shufflenet: 5e-3, 5e-2
# mobilenet: 1e-4, 7e-3
# densenet: 1e-4, 1e-3
# resnet: 1e-4, 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = '/u/21/hiremas1/unix/postdoc/rumex/data256_for_training/'
model_name = 'shufflenet'
base_lr = 5e-3
max_lr = 5e-2
log_dir = 'logs/'+model_name 
bs = 32


dstr = RumexDataset(data_dir+'train/', train_flag=True)
dltr = train_loader(dstr, bs)
dsva = RumexDataset(data_dir+'valid/', train_flag=False)
dlva = test_loader(dsva, bs)

n1tr = dltr.dataset.rumex.targets.count(1)
n0tr = dltr.dataset.rumex.targets.count(0) 
n1va = dlva.dataset.rumex.targets.count(1)
n0va = dlva.dataset.rumex.targets.count(0)
print(n0tr/(n0tr+n1tr))
print(n0va/(n0va+n1va))

0.674591381872214
0.6727941176470589


In [2]:
model = RumexNet(model_name)
loss_fn = nn.CrossEntropyLoss(reduction="none")
optimizer = torch.optim.Adam(model.parameters(), lr=base_lr, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CyclicLR(optimizer,
                                        step_size_up=500,
                                        cycle_momentum=False,
                                        base_lr=base_lr,
                                        max_lr=max_lr)
writer = SummaryWriter(log_dir=log_dir)

best_val_loss = np.inf
best_val_acc = 0.5
num_epochs=20
history= np.zeros((num_epochs, 5))
for ep in np.arange(num_epochs):
    start = time()

    #### fit model ##########
    loss = train(model, dltr, optimizer, scheduler, loss_fn, device)

    ##### Model Validation ##########
    predictions, metrics = validate(model, dlva, loss_fn, device)

    history[ep, 0] = loss # training loss
    history[ep, 1] = metrics["loss"] # validation loss
    history[ep, 2] = metrics["acc"] # validation acc
    history[ep, 3] = metrics["f1"] # validation acc
    history[ep, 4] = metrics["auc"] # validation acc    

    ##### checkpoint saving and logging ##########
    if metrics['loss'] < best_val_loss:
        best_val_loss = metrics['loss']
        ckpt_dict = {'ep': ep,
                     'state_dict': model.state_dict(),
                     'optim_dict': optimizer.state_dict(),
                     'predictions': predictions,
                     'metrics': metrics}    
        ut.save_ckpt(ckpt_dict, log_dir)

    # tensorboad logging
    writer.add_scalar('train/loss', loss, ep)
    for key in metrics.keys():
        name = 'val/'+key
        writer.add_scalar(name, metrics[key], ep)   
    
    
    et = time() - start
    print(f"ep:{ep}|et:{et:.3f}|loss_tr:{loss:.5f}|loss: {metrics['loss']:.5f}" +
          f"|acc:{metrics['acc']:.5f}|re:{metrics['pre']:.5f}" +
          f"|pre:{metrics['recall']:.5f}|f1:{metrics['f1']:.5f}|auc:{metrics['auc']:.5f}")

np.save(log_dir+"/history.npy", history)

ep:0|et:2.747|loss_tr:0.58490|loss: 0.71094|acc:0.67279|re:0.00000|pre:nan|f1:nan|auc:0.59141
ep:1|et:1.547|loss_tr:0.51152|loss: 0.64708|acc:0.67279|re:0.00000|pre:nan|f1:nan|auc:0.60142
ep:2|et:1.557|loss_tr:0.44758|loss: 0.63946|acc:0.67279|re:0.00000|pre:nan|f1:nan|auc:0.60428
ep:3|et:1.706|loss_tr:0.46137|loss: 0.63226|acc:0.67279|re:0.00000|pre:nan|f1:nan|auc:0.54686
ep:4|et:1.691|loss_tr:0.48438|loss: 0.63873|acc:0.67279|re:0.00000|pre:nan|f1:nan|auc:0.67087
ep:5|et:1.642|loss_tr:0.46617|loss: 0.69702|acc:0.67279|re:0.00562|pre:0.50000|f1:0.01111|auc:0.78299
ep:6|et:1.539|loss_tr:0.49755|loss: 0.65622|acc:0.73713|re:0.23596|pre:0.85714|f1:0.37004|auc:0.72536
ep:7|et:1.534|loss_tr:0.46115|loss: 0.65489|acc:0.65441|re:0.01685|pre:0.18750|f1:0.03093|auc:0.64511
ep:8|et:1.571|loss_tr:0.48928|loss: 0.59682|acc:0.77022|re:0.48315|pre:0.72269|f1:0.57912|auc:0.78650
ep:9|et:1.538|loss_tr:0.45521|loss: 1.55901|acc:0.68750|re:0.11798|pre:0.61765|f1:0.19811|auc:0.66278
ep:10|et:1.569|loss_

In [None]:
import matplotlib.pyplot as plt
plt.plot(history[:, 0])
plt.plot(history[:, 1])
plt.plot(history[:, 2])