In [1]:
"""
   This section of the code simply sets up all possible variables we might want to change during training.
"""
import os, sys, random, gzip, optparse
import numpy as np                     # Math and Deep Learning libraries
import torch                
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm                  # Pretty status bars
from collections import defaultdict
np.seterr(divide='ignore')             # Ignore divide by zero errors
np.warnings.filterwarnings('ignore')

# Use a GPU when possible
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

parser = optparse.OptionParser(description='FIND')
parser.add_option('--name',       type=str,            default='')
parser.add_option('--batch-size', type=int,            default=32,    help='batch size')
parser.add_option('--epochs',     type=int,            default=100,   help='number of epochs')
parser.add_option('--hidden-dim', type=int,            default=256,   help='hidden dim')
parser.add_option('--log',        type=str,            default='./logs/')
parser.add_option('--load',       type=str,            default='',    help='Load model')
parser.add_option('--load_test',  type=str,            default='',    help='Load specific test file')
parser.add_option('--binary',     type=str,            default='',    help='Binary classifer based on argument')
parser.add_option('--lr',         type=float,          default=5e-3)
parser.add_option('--nitro',      action='store_true', default=False, help='Train/Test on nitrogenase')
parser.add_option('--bd',         action='store_true', default=False, help='Train/Test on cytochrome bd oxidase')
parser.add_option('--test',       action='store_true', default=False, help='Swap validation for test set')
parser.add_option('--confusion',  action='store_true', default=False, help='If test, generate confusion matrix')
parser.add_option('--pconfusion', action='store_true', default=False, help='sum probabilities for confusion')
parser.add_option('--reweight',   action='store_true', default=False)


sys.argv = sys.argv[3:]  # Remove pykernel launcher
args, _ = parser.parse_args()

# BD is small and performs better with a low learning rate
# Example run below
args.bd = True
args.lr = 1e-3

In [2]:
###############################################################################
##    Infrastructure to process Data into Numpy Arrays of Integers
##    Here we specify the dataset (indicated by args.nitro, args.bd or default HCO)
##    Then we convert letters to numbers via the acids dictionary.
##    We also compute the set of function labels (optionally training with binary)
###############################################################################
DATA_DIR = "data/"
acids = {' ':0, 'A':1, 'C':2, 'E':3, 'D':4, 'G':5, 'F':6, 'I':7, 'H':8, \
         'K':9, 'M':10, 'L':11, 'N':12, 'Q':13, 'P':14, 'S':15, 'R':16, \
         'T':17, 'W':18, 'V':19, 'Y':20, 'X':21 }
ints = {}
for v in acids:
    ints[acids[v]] = v

lbls = {}
ilbls = {}

if args.nitro:
    labels_file = "nitro.labels.txt"
elif args.bd:
    labels_file = "bd.labels.txt"
else:
    labels_file = "hco.labels.txt"

L = [line.strip() for line in open(DATA_DIR + labels_file,'r')]
# For logging we will store details of our training regime in the file name
run = "binary" if args.binary else "multi"
run += ".e{}".format(args.epochs)
run += ".h{}".format(args.hidden_dim)
run += ".b{}".format(args.batch_size)
if args.reweight:
    run += ".reweight"

if len(args.binary) == 0:
    for v in L:
        lbls[v] = len(lbls)
        ilbls[lbls[v]] = v
    num_labels = len(lbls)
else:
    for v in L:
        if v == args.binary:
            lbls[v] = 1
            ilbls[lbls[v]] = v
        else:
            lbls[v] = 0
            ilbls[lbls[v]] = "OTHER"
    print(lbls, ilbls)
    num_labels = 2
    run += "." + args.binary

In [3]:
###############################################################################
##    This block introduces helper functions
##    to_int and to_string convert AAs back and forth between representations
##    sort_data and pad_data help create batches of data of a fixed length to pass
##      to the network.
###############################################################################

def to_int(seq):
    """   Map AA sequence to integers  """
    seq = seq.replace("*","")
    conv = []
    for i in range(len(seq)):
        if seq[i] not in acids:
            print(i, seq)
        conv.append(acids[seq[i]])
    return np.array(conv)

def to_string(seq):
    """  Map ints to AA sequence  """
    return "".join([ints[s] for s in seq])

def sort_data(inputs, outputs, strs=[]):
    """ 
      Sorted by input length and then output length
    """
    if len(strs) == 0:
        strs = [""]*len(inputs)
    v = []
    for i, o, s in zip(inputs, outputs, strs):
        v.append((len(i), i, o, s))
    v.sort(key=lambda x: x[0])

    sorted_inputs = []
    sorted_outputs = []
    sorted_strs= []
    for len_i, i, o, s in v:
        sorted_inputs.append(i)
        sorted_outputs.append(o)
        sorted_strs.append(s)

    return sorted_inputs, sorted_outputs, sorted_strs

def pad_data(inputs):
    max_i = max([len(i) for i in inputs])
  
    padded_i = np.zeros((len(inputs), max_i), dtype=np.int64)
    for i in range(len(inputs)):
        padded_i[i, :len(inputs[i])] = np.copy(inputs[i])

    return padded_i

In [4]:
###############################################################################
##    Data is loaded, Train/Validation/Test, counted, converted to numbers and 
##    stored in numpy arrays.
###############################################################################
""" Data & Parameters """
if args.nitro:
    prefix = "nitro.labeled" 
elif args.bd:
    prefix = "bd.labeled"
else:
    prefix = "hco.labeled"

data = [line.strip().split() for line in open(DATA_DIR + prefix + ".train",'r')]
if args.test:
    val = [line.strip().split() for line in open(DATA_DIR + prefix + ".test",'r')]
elif args.load_test != '':
    val = [line.strip().split() for line in open(args.load_test,'r')]
else:
    val = [line.strip().split() for line in open(DATA_DIR + prefix + ".val",'r')]

for vals in data:
    if len(vals) != 2:
        print(vals)
strs = np.array([sequence for label, sequence in data])
inputs = [to_int(sequence) for label, sequence in data]
outputs = np.array([lbls[label] for label, sequence in data])

print("Training counts\t")
l_c = defaultdict(int) 
for v in outputs:
    l_c[ilbls[v]] += 1
V = [(l_c[v],v) for v in l_c]
V.sort()
V.reverse()
print("    ".join(["{}: {}".format(lbl, cnt) for cnt,lbl in V]))

count = np.zeros(len(ilbls), dtype=np.float32)

for v in range(len(ilbls)):
    if ilbls[v] in l_c:
        count[v] += l_c[ilbls[v]]
distr = np.sum(count)/(np.size(count)*count) #1. - count/np.sum(count)
weight = torch.from_numpy(100*distr).to(device)

inps, outs, strs = sort_data(inputs, outputs, strs)
outs = np.array(outs)
strs = np.array(strs)

t_strs = np.array([sequence for label, sequence in val])
t_inps = [to_int(sequence) for label, sequence in val]
t_outs = np.array([lbls[label] for label, sequence in val])
t_inps, t_outs, t_strs = sort_data(t_inps, t_outs, t_strs)
t_outs = np.array(t_outs)
t_strs = np.array(t_strs)

print("Train Inps: ", len(inputs))
print("Train Outs: ", outputs.shape)
print("Test  Inps: ", len(t_inps))
print("Test  Outs: ", t_outs.shape)
print("Labels\t",lbls)

Training counts	
E1: 373    C: 206    E2: 123    A: 47    E4: 41    B: 33    E3: 26
Train Inps:  849
Train Outs:  (849,)
Test  Inps:  117
Test  Outs:  (117,)
Labels	 {'B': 0, 'E4': 1, 'E3': 2, 'E2': 3, 'C': 4, 'E1': 5, 'A': 6}


In [5]:
###############################################################################
##    Model definition + Helper Functions
###############################################################################
class Net(nn.Module):
    def __init__(self, width=3, RF=19):
        """
           Build a stack of 1D convolutions with batch norm and ReLU activations
           The final two convolutions are simply linear layers, then followed by
           a prediction and attention layer.
        """
        super(Net, self).__init__()
        self.width = width
        self.RF = RF
    
        self.embedding = nn.Embedding(len(acids), args.hidden_dim)
        layers = [
          nn.Conv1d(args.hidden_dim, args.hidden_dim, self.width),
          nn.ReLU(),
          nn.BatchNorm1d(args.hidden_dim),
          nn.Conv1d(args.hidden_dim, args.hidden_dim, self.width*2),
          nn.ReLU(),
          nn.BatchNorm1d(args.hidden_dim),
          nn.Conv1d(args.hidden_dim, args.hidden_dim, self.width*4),
          nn.ReLU(),
          nn.BatchNorm1d(args.hidden_dim),
          nn.Conv1d(args.hidden_dim, args.hidden_dim, 1),
          nn.ReLU(),
          nn.Conv1d(args.hidden_dim, args.hidden_dim, 1),
          nn.ReLU(),
        ]

        self.conv_stack = nn.Sequential(*layers)
    
        self.pred = nn.Conv1d(args.hidden_dim, num_labels, 1)
        self.att = nn.Conv1d(args.hidden_dim, 1, 1)

    def forward(self, x):
        embed = self.embedding(x).permute(0,2,1)
        embed = self.conv_stack(embed)

        # Log probabilities for every class at every substring
        logits = self.pred(embed)
    
        # Un-normalized weight of a given n-gram
        att = self.att(embed)
        # Reshape [b,L] --> [b,1,L]  -- and normalize
        re_att = F.softmax(att.view(x.size()[0],1,-1), dim=-1)
        # Rescale logits by attention weight
        joint = re_att * logits
        # Class distribution
        collapsed = torch.sum(joint, 2)
        return collapsed, att, logits

    def reset_counts(self, epoch):
        self.gold_counts = np.zeros(num_labels)
        self.pred_counts = np.zeros(num_labels)
        self.corr_counts = np.zeros(num_labels)
        self.epoch = epoch

In [6]:
###############################################################################
##    Functions for evaluation and visualization
###############################################################################

def run_evaluation(net, v_inputs, v_outputs, v_strings = [], aggregate=False, verbose=False, showTrain=True):
    net.train(mode=False)
    net.top_predictor = []
    net.predictors = defaultdict(int)
    """
      Run evaluation
    """
    val_loss = 0.0
    val_acc = []
    gold_counts = np.zeros(num_labels)
    pred_counts = np.zeros(num_labels)
    corr_counts = np.zeros(num_labels)

    if args.confusion:
        pairs = np.zeros((num_labels, num_labels))

    v_inps, v_outs, v_strs = sort_data(v_inputs, v_outputs, v_strings)
    v_outs = np.array(v_outs)
    v_strs = np.array(v_strs)
    batches = []
    indices = list(range(len(v_inps)))
    for start in range(0, len(indices), args.batch_size):
        batches.append((start, min(args.batch_size, len(indices)-start)))

    for start, b_size in tqdm(batches, ncols=80):
        vals = indices[start : start + b_size]

        inputs = torch.from_numpy(pad_data(v_inps[indices[start]:indices[start+b_size-1]+1])).to(device)
        labels = torch.from_numpy(v_outs[vals]).to(device)
        logits, att, full = net(inputs)
        att = F.softmax(torch.squeeze(att), dim=-1)
        val_loss += F.cross_entropy(logits, labels).item()
        _, preds = torch.max(logits, 1)

        preds = preds.data.cpu().numpy()
        val_acc.extend(list((preds == v_outs[vals])))

        np.add.at(pred_counts, preds, 1)
        np.add.at(gold_counts, v_outs[vals], 1)
        np.add.at(corr_counts, preds[(preds == v_outs[vals])], 1)

        if args.confusion:
            if not args.pconfusion:
                np.add.at(pairs, [v_outs[vals], preds], 1)
            else:
                dists = F.softmax(logits, -1)
                for i in range(len(vals)):
                    gold = v_outs[vals][i]

                    tmp = [(dists[i,j], j) for j in range(len(ilbls))]
                    tmp.sort()
                    prob, second = tmp[-2]
                    pairs[gold, second] += 1

        if aggregate:
            aggregate_predictors(net, v_strs[vals], v_outs[vals], full, att)

    if verbose:
        if aggregate or not showTrain:
            print_eval((gold_counts, pred_counts, corr_counts))
        else:
            print_eval((net.gold_counts, net.pred_counts, net.corr_counts), 
                           (gold_counts, pred_counts, corr_counts))

    if args.confusion:
        out = open("confusion.csv", 'w')
        out.write("," + ",".join([ilbls[i] for i in range(len(ilbls))]) + "\n")
        for i in range(len(ilbls)):
            out.write("{},".format(ilbls[i]))
            for j in range(len(ilbls)):
                out.write("{},".format(pairs[i,j]))
            out.write("\n")
        out.close()
    return val_loss, 100*np.array(val_acc).mean()

def aggregate_predictors(net, seqs, outs, full, att):
    dists = F.softmax(full.permute(0, 2, 1), dim=-1)
    if dists.shape[0] == 1:
        att = att.unsqueeze(0)  # batch size of 1 needs to be unsqueezed
    vals = dists * att.unsqueeze(2)
    for b in range(len(seqs)):
        max_val = -1e10
        max_predictor = "NONE"
        max_class = -1
        for i in range(len(att[0])):
            predictor = seqs[b][i:i + net.RF]

            for c in range(num_labels):
                net.predictors[(predictor, c)] += vals[b,i,c].item() 

            cval, c = torch.max(vals[b,i,:], 0)
            cval = cval.item()
            if cval > max_val:
                max_val = cval
                max_predictor = predictor
                max_class = c.item()
        net.top_predictor.append((max_val, max_predictor, ilbls[max_class], ilbls[outs[b]], seqs[b].strip()))


def print_predictors(net, epoch):
    start = "nitro" if args.nitro else "bd" if args.bd else "hco"
    rtype = "binary" if args.binary else "multi"
    rewht = "reweight" if args.reweight else "orig"
    fname = "{}.{}.{}.{}.{}.h{}.b{}".format(start, epoch, net.RF, rtype, rewht, args.hidden_dim, args.batch_size)

    g = gzip.open("{}.predictors.joint.gz".format(fname),'wt')
    joint = defaultdict(list)
    for seq, lbl in net.predictors:
        joint[ilbls[lbl]].append((net.predictors[(seq, lbl)], seq))

    for lbl in joint:
        vals = joint[lbl]
        vals.sort()
        vals.reverse()
        for val, seq in vals:
            g.write("{:5} {:30} {}\n".format(lbl, seq, val))
        g.write("\n")
    g.close()


    g = gzip.open("{}.top_predictors.txt.gz".format(fname), 'wt')
    g.write("{:10} {:30} {:5} {:5} {}\n".format("Val", "Predictor", "Pred", "Gold", "Seq"))
    net.top_predictor.sort()
    net.top_predictor.reverse()
    for val, predictor, pred, gold, seq in net.top_predictor:
        g.write("{:10.9f} {:30} {:5} {:5} {}\n".format(val, predictor, pred, gold, seq))
    g.close()


def print_eval(train, test=None): 
    """  Print training performance  """
    gold, pred, corr = train
    p, r = corr / pred, corr / gold
    f = 2*p*r/(p+r)

    if test is not None:
        t_gold, t_pred, t_corr = test
        t_p, t_r = t_corr / t_pred, t_corr / t_gold
        t_f = 2*t_p*t_r/(t_p+t_r)

    gold_counts = [(gold[lab],lab) for lab in range(num_labels)]
    gold_counts.sort(reverse=True)
    for count, i in gold_counts:
        train_str = "{:<10} {:<5} {:5.3f} {:5.3f} {:5.3f}   ".format(ilbls[i], int(count), p[i], r[i], f[i])
        if test is not None:
            test_str = "{:<5} {:5.3f} {:5.3f} {:5.3f}".format(int(t_gold[i]), t_p[i], t_r[i], t_f[i])
        else:
            test_str = ""
        print(train_str + test_str)

In [7]:
###############################################################################
##    Training Loop
###############################################################################
#torch.cuda.manual_seed(20180119)  <-- set a value for consistency
if args.load != '':
    net = torch.load(args.load)                             # Load Saved Model
    net.to(device)
    if args.confusion:
        loss, acc = run_evaluation(net, t_inps, t_outs, t_strs, aggregate=args.load_test != '')
    else:
        loss, acc = run_evaluation(net, t_inps, t_outs, t_strs, aggregate=args.load_test != '')
    print("Acc: {:5.3f}".format(acc))
    if args.load_test != '':
        print_predictors(net, "test")
    sys.exit()
else:
    net = Net()
    net.to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)

"""
  Perform training
"""
prev_acc = 0
for epoch in range(0, args.epochs + 1):

    batches = []
    indices = list(range(len(inps)))
    for start in range(0, len(indices), args.batch_size):
        batches.append((start, min(args.batch_size, len(indices)-start)))
    random.shuffle(batches)

    total_loss = 0.0
    train_acc = []
  
    net.reset_counts(epoch)
    for start, b_size in tqdm(batches, ncols=80):
        r = indices[start : start + b_size]
        
        # Setup
        optimizer.zero_grad()
        inputs = torch.from_numpy(pad_data(inps[indices[start] : indices[start+b_size-1] + 1])).to(device)
        labels = torch.from_numpy(outs[r]).to(device)

        # Predict
        net.train(mode=True)
        logits, att, full = net(inputs)
        if args.reweight:
            ce_loss = F.cross_entropy(logits, labels, weight=weight)
        else:
            ce_loss = F.cross_entropy(logits, labels)
        att = F.softmax(torch.squeeze(att), dim=-1)
    
        # Compute loss and update
        loss = ce_loss
        total_loss += ce_loss.item()
        loss.backward()
        optimizer.step()
    
        # Look at predictions
        _, preds = torch.max(logits, 1)
        dists = full.permute(0,2,1).cpu().data.numpy()
    
        preds = preds.data.cpu().numpy()
        np.add.at(net.pred_counts, preds, 1)
        np.add.at(net.gold_counts, outs[r], 1)
        np.add.at(net.corr_counts, preds[preds == outs[r]], 1)

        train_acc.extend(list(preds == outs[r]))
  
    # Evaluate on validation (during training)
    val_loss, val_acc = run_evaluation(net, t_inps, t_outs, t_strs, verbose=(epoch % 10 == 0 or epoch == args.epochs))

    print("Epoch: {}  Train Loss: {:8.4f}  Acc: {:5.2f}  Val  Loss {:8.4f}  Acc: {:5.2f}".format(epoch, 
          total_loss, 100*np.array(train_acc).mean(), val_loss, val_acc))

    # Save best validation model for optimal generalization
    if val_acc > prev_acc:
        prev_acc = val_acc
        pref = "nitro" if args.nitro else "bd" if args.bd else "hco"
        torch.save(net, "{}.{}.model".format(pref, run))

out = torch.cat((net.embedding.weight.data, torch.ones(len(acids), 1).to(device)), 1)

100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.98it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 88.79it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.28it/s]

E1         373   0.514 0.804 0.627   58    0.972 0.603 0.745
C          206   0.376 0.272 0.315   24    0.296 1.000 0.457
E2         123   0.219 0.057 0.090   17      nan 0.000   nan
A          47      nan 0.000   nan   6       nan 0.000   nan
E4         41    0.048 0.098 0.064   3       nan 0.000   nan
B          33      nan 0.000   nan   1       nan 0.000   nan
E3         26      nan 0.000   nan   8       nan 0.000   nan
Epoch: 0  Train Loss:  42.0500  Acc: 43.23  Val  Loss   5.7246  Acc: 50.43


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 24.13it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 87.46it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.02it/s]

Epoch: 1  Train Loss:  36.9096  Acc: 48.53  Val  Loss   3.9526  Acc: 68.38


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.86it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.97it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 25.10it/s]

Epoch: 2  Train Loss:  30.2535  Acc: 59.01  Val  Loss   3.9170  Acc: 58.12


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.98it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.78it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.74it/s]

Epoch: 3  Train Loss:  30.2013  Acc: 57.71  Val  Loss   3.1095  Acc: 70.09


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.99it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.86it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.63it/s]

Epoch: 4  Train Loss:  23.2013  Acc: 66.55  Val  Loss   2.3916  Acc: 70.94


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.84it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.09it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.04it/s]

Epoch: 5  Train Loss:  16.7196  Acc: 73.62  Val  Loss   2.2541  Acc: 77.78


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.72it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.81it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.09it/s]

Epoch: 6  Train Loss:  18.6337  Acc: 73.62  Val  Loss   1.5663  Acc: 83.76


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.63it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.13it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.99it/s]

Epoch: 7  Train Loss:  18.4645  Acc: 75.27  Val  Loss   2.4827  Acc: 76.92


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.62it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 87.16it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.21it/s]

Epoch: 8  Train Loss:  13.2933  Acc: 80.68  Val  Loss   1.2693  Acc: 83.76


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.73it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 87.27it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 25.01it/s]

Epoch: 9  Train Loss:   8.5742  Acc: 85.63  Val  Loss   2.0158  Acc: 84.62


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.75it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.78it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.03it/s]

E1         373   0.898 0.968 0.932   58    0.983 1.000 0.991
C          206   0.995 1.000 0.998   24    1.000 1.000 1.000
E2         123   0.704 0.772 0.736   17    0.654 1.000 0.791
A          47    0.687 0.979 0.807   6     1.000 0.667 0.800
E4         41    0.833 0.488 0.615   3     0.667 0.667 0.667
B          33    0.143 0.061 0.085   1     0.000 0.000   nan
E3         26      nan 0.000   nan   8       nan 0.000   nan
Epoch: 10  Train Loss:  11.6247  Acc: 85.98  Val  Loss   1.0039  Acc: 89.74


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.68it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.36it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 26.58it/s]

Epoch: 11  Train Loss:   5.0701  Acc: 92.70  Val  Loss   0.6640  Acc: 91.45


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.53it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.19it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.29it/s]

Epoch: 12  Train Loss:   4.1493  Acc: 95.29  Val  Loss   0.5270  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.69it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 87.41it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.29it/s]

Epoch: 13  Train Loss:   3.9656  Acc: 93.88  Val  Loss   0.5509  Acc: 93.16


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.68it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.50it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.99it/s]

Epoch: 14  Train Loss:   2.9894  Acc: 96.94  Val  Loss   0.4651  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.50it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 84.01it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.84it/s]

Epoch: 15  Train Loss:   2.6244  Acc: 96.47  Val  Loss   0.3633  Acc: 92.31


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.69it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.94it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.84it/s]

Epoch: 16  Train Loss:   1.9727  Acc: 97.17  Val  Loss   0.3135  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.13it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.79it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.29it/s]

Epoch: 17  Train Loss:   1.7215  Acc: 98.70  Val  Loss   0.9505  Acc: 91.45


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.25it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 87.52it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.72it/s]

Epoch: 18  Train Loss:  37.7761  Acc: 75.03  Val  Loss   6.9621  Acc: 70.09


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.48it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.64it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 19.91it/s]

Epoch: 19  Train Loss:  11.7668  Acc: 86.45  Val  Loss   1.2610  Acc: 88.89


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.59it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.84it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.51it/s]

E1         373   0.963 0.987 0.975   58    0.934 0.983 0.958
C          206   0.967 0.995 0.981   24    1.000 1.000 1.000
E2         123   0.842 0.951 0.893   17    1.000 0.824 0.903
A          47    0.667 0.979 0.793   6     0.750 1.000 0.857
E4         41    0.550 0.537 0.543   3     0.333 0.667 0.444
B          33    0.500 0.030 0.057   1     0.000 0.000   nan
E3         26    0.400 0.077 0.129   8     0.667 0.250 0.364
Epoch: 20  Train Loss:   7.5138  Acc: 89.63  Val  Loss   0.8394  Acc: 89.74


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.36it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.34it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.92it/s]

Epoch: 21  Train Loss:   4.8344  Acc: 93.40  Val  Loss   0.5728  Acc: 92.31


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.70it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.34it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.50it/s]

Epoch: 22  Train Loss:   3.4308  Acc: 96.00  Val  Loss   0.5690  Acc: 92.31


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.17it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.33it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 25.16it/s]

Epoch: 23  Train Loss:   2.6116  Acc: 96.70  Val  Loss   0.4448  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.24it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.14it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.42it/s]

Epoch: 24  Train Loss:   2.6027  Acc: 96.11  Val  Loss   0.5734  Acc: 93.16


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.14it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.93it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.25it/s]

Epoch: 25  Train Loss:   2.2993  Acc: 97.17  Val  Loss   0.5111  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.39it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.29it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.47it/s]

Epoch: 26  Train Loss:   1.9792  Acc: 97.88  Val  Loss   0.3555  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.21it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 84.87it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.63it/s]

Epoch: 27  Train Loss:   1.8016  Acc: 98.00  Val  Loss   0.5488  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.13it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.91it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 25.30it/s]

Epoch: 28  Train Loss:   2.2394  Acc: 96.94  Val  Loss   0.3930  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.02it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.13it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.98it/s]

Epoch: 29  Train Loss:   1.4302  Acc: 98.59  Val  Loss   0.3406  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 24.75it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 88.47it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 27.44it/s]

E1         373   0.984 0.992 0.988   58    0.935 1.000 0.967
C          206   1.000 1.000 1.000   24    1.000 1.000 1.000
E2         123   0.992 0.992 0.992   17    1.000 0.882 0.938
A          47    1.000 1.000 1.000   6     1.000 1.000 1.000
E4         41    1.000 0.951 0.975   3     0.750 1.000 0.857
B          33    0.943 1.000 0.971   1     0.500 1.000 0.667
E3         26    0.826 0.731 0.776   8     1.000 0.500 0.667
Epoch: 30  Train Loss:   1.4349  Acc: 98.47  Val  Loss   0.6377  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 24.64it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.68it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 25.09it/s]

Epoch: 31  Train Loss:   1.0198  Acc: 99.18  Val  Loss   0.2983  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 24.75it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.21it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 27.11it/s]

Epoch: 32  Train Loss:   0.8379  Acc: 99.18  Val  Loss   0.5294  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 24.28it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.34it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.56it/s]

Epoch: 33  Train Loss:   0.9769  Acc: 98.94  Val  Loss   0.3053  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.57it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.22it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 26.57it/s]

Epoch: 34  Train Loss:   0.8398  Acc: 99.18  Val  Loss   0.3706  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.61it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.51it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 25.33it/s]

Epoch: 35  Train Loss:   0.6447  Acc: 99.41  Val  Loss   0.3435  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.75it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 86.90it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.56it/s]

Epoch: 36  Train Loss:   0.7317  Acc: 99.06  Val  Loss   0.3925  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.29it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.27it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 25.30it/s]

Epoch: 37  Train Loss:   0.4079  Acc: 99.76  Val  Loss   0.3203  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.88it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 83.48it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.71it/s]

Epoch: 38  Train Loss:   0.3628  Acc: 99.41  Val  Loss   0.5148  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.89it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 84.46it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.81it/s]

Epoch: 39  Train Loss:   0.4450  Acc: 99.65  Val  Loss   0.3125  Acc: 97.44


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 23.10it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 85.72it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.56it/s]

E1         373   0.995 0.997 0.996   58    0.983 0.983 0.983
C          206   1.000 1.000 1.000   24    1.000 1.000 1.000
E2         123   1.000 0.992 0.996   17    1.000 0.941 0.970
A          47    1.000 1.000 1.000   6     1.000 0.833 0.909
E4         41    1.000 1.000 1.000   3     0.750 1.000 0.857
B          33    1.000 1.000 1.000   1     0.500 1.000 0.667
E3         26    0.962 0.962 0.962   8     0.875 0.875 0.875
Epoch: 40  Train Loss:   0.3085  Acc: 99.65  Val  Loss   0.2785  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.92it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 84.49it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.74it/s]

Epoch: 41  Train Loss:   0.3859  Acc: 99.53  Val  Loss   0.3006  Acc: 98.29


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.71it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 82.13it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 20.11it/s]

Epoch: 42  Train Loss:   0.4022  Acc: 99.53  Val  Loss   0.6556  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.84it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 78.26it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.69it/s]

Epoch: 43  Train Loss:   0.9096  Acc: 98.94  Val  Loss   0.3594  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.31it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.73it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 25.56it/s]

Epoch: 44  Train Loss:   0.3275  Acc: 99.65  Val  Loss   0.3342  Acc: 97.44


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.48it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 79.07it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 19.83it/s]

Epoch: 45  Train Loss:   0.2978  Acc: 99.76  Val  Loss   0.4344  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.50it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 79.18it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.01it/s]

Epoch: 46  Train Loss:   0.2845  Acc: 99.65  Val  Loss   0.4269  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.42it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 79.19it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.22it/s]

Epoch: 47  Train Loss:   0.2210  Acc: 99.65  Val  Loss   0.4292  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.64it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.27it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.92it/s]

Epoch: 48  Train Loss:   0.3591  Acc: 99.76  Val  Loss   0.2863  Acc: 97.44


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.09it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 79.03it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.80it/s]

Epoch: 49  Train Loss:   0.1859  Acc: 99.76  Val  Loss   0.5325  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.38it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 79.08it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 20.26it/s]

E1         373   0.997 0.995 0.996   58    0.951 1.000 0.975
C          206   1.000 1.000 1.000   24    1.000 1.000 1.000
E2         123   1.000 0.992 0.996   17    1.000 0.882 0.938
A          47    1.000 1.000 1.000   6     1.000 1.000 1.000
E4         41    1.000 1.000 1.000   3     0.750 1.000 0.857
B          33    1.000 1.000 1.000   1     0.500 1.000 0.667
E3         26    0.893 0.962 0.926   8     1.000 0.625 0.769
Epoch: 50  Train Loss:   0.2471  Acc: 99.53  Val  Loss   0.6796  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.64it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 80.17it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 17.93it/s]

Epoch: 51  Train Loss:   0.5513  Acc: 99.65  Val  Loss   0.3997  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.65it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.11it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 17.25it/s]

Epoch: 52  Train Loss:   0.3010  Acc: 99.65  Val  Loss   0.5098  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.63it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 78.99it/s]
 11%|████▉                                       | 3/27 [00:00<00:00, 24.08it/s]

Epoch: 53  Train Loss:   0.2431  Acc: 99.76  Val  Loss   0.2938  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.32it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.72it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.58it/s]

Epoch: 54  Train Loss:   0.3047  Acc: 99.65  Val  Loss   0.6367  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.26it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.57it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 19.98it/s]

Epoch: 55  Train Loss:   0.1938  Acc: 99.76  Val  Loss   0.3798  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.21it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 76.35it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.46it/s]

Epoch: 56  Train Loss:   0.2186  Acc: 99.65  Val  Loss   0.4571  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.19it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.20it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.46it/s]

Epoch: 57  Train Loss:   0.2037  Acc: 99.65  Val  Loss   0.8392  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.60it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 78.49it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.87it/s]

Epoch: 58  Train Loss:   0.2701  Acc: 99.76  Val  Loss   0.4342  Acc: 97.44


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.40it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.27it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.90it/s]

Epoch: 59  Train Loss:   0.5540  Acc: 99.53  Val  Loss   0.4722  Acc: 97.44


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.99it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.36it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.29it/s]

E1         373   0.995 0.997 0.996   58    0.983 1.000 0.991
C          206   1.000 1.000 1.000   24    1.000 1.000 1.000
E2         123   1.000 0.992 0.996   17    1.000 0.882 0.938
A          47    1.000 1.000 1.000   6     1.000 0.833 0.909
E4         41    1.000 1.000 1.000   3     0.750 1.000 0.857
B          33    1.000 1.000 1.000   1     0.333 1.000 0.500
E3         26    0.962 0.962 0.962   8     1.000 0.875 0.933
Epoch: 60  Train Loss:   0.2999  Acc: 99.65  Val  Loss   0.6592  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.06it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 78.64it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.81it/s]

Epoch: 61  Train Loss:   0.6550  Acc: 99.41  Val  Loss   0.6179  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.77it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.61it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 20.33it/s]

Epoch: 62  Train Loss:  22.3530  Acc: 87.40  Val  Loss   6.2048  Acc: 85.47


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.05it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 79.17it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.78it/s]

Epoch: 63  Train Loss:  29.8717  Acc: 69.61  Val  Loss   2.7103  Acc: 73.50


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 22.23it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 76.03it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.79it/s]

Epoch: 64  Train Loss:  12.6776  Acc: 82.69  Val  Loss   1.5909  Acc: 87.18


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.91it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.69it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 19.85it/s]

Epoch: 65  Train Loss:   8.2922  Acc: 89.63  Val  Loss   1.0018  Acc: 93.16


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.91it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.36it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.07it/s]

Epoch: 66  Train Loss:   5.4809  Acc: 92.93  Val  Loss   0.5921  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.89it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 76.94it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.62it/s]

Epoch: 67  Train Loss:   3.6576  Acc: 94.94  Val  Loss   0.5280  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.89it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.05it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.59it/s]

Epoch: 68  Train Loss:   2.6619  Acc: 97.29  Val  Loss   0.6660  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.92it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 75.89it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.35it/s]

Epoch: 69  Train Loss:   2.2299  Acc: 97.88  Val  Loss   0.5126  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.91it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.37it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.36it/s]

E1         373   0.981 0.995 0.988   58    0.905 0.983 0.942
C          206   1.000 1.000 1.000   24    1.000 1.000 1.000
E2         123   1.000 0.984 0.992   17    1.000 0.941 0.970
A          47    1.000 1.000 1.000   6     1.000 0.833 0.909
E4         41    1.000 0.976 0.988   3     0.750 1.000 0.857
B          33    0.868 1.000 0.930   1     0.000 0.000   nan
E3         26    0.895 0.654 0.756   8     0.500 0.250 0.333
Epoch: 70  Train Loss:   1.6867  Acc: 98.35  Val  Loss   0.6842  Acc: 91.45


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.84it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.46it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.11it/s]

Epoch: 71  Train Loss:   1.4390  Acc: 98.70  Val  Loss   0.4810  Acc: 97.44


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.92it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 75.95it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 16.58it/s]

Epoch: 72  Train Loss:   1.6725  Acc: 97.76  Val  Loss   0.6995  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.86it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.63it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.42it/s]

Epoch: 73  Train Loss:   1.2896  Acc: 98.23  Val  Loss   0.7824  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.83it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.34it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.87it/s]

Epoch: 74  Train Loss:   0.8105  Acc: 99.18  Val  Loss   0.7358  Acc: 93.16


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.82it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.39it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 17.39it/s]

Epoch: 75  Train Loss:   0.6896  Acc: 99.41  Val  Loss   0.9933  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.02it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.52it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.30it/s]

Epoch: 76  Train Loss:   0.4736  Acc: 99.53  Val  Loss   0.9344  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.88it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 79.37it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.36it/s]

Epoch: 77  Train Loss:   1.5379  Acc: 98.47  Val  Loss   1.4016  Acc: 93.16


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.10it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.82it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.19it/s]

Epoch: 78  Train Loss:  13.8145  Acc: 89.16  Val  Loss   1.5833  Acc: 90.60


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.87it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.82it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.12it/s]

Epoch: 79  Train Loss:   5.5561  Acc: 94.82  Val  Loss   0.8362  Acc: 93.16


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.97it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.67it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.74it/s]

E1         373   0.987 0.992 0.989   58    0.935 1.000 0.967
C          206   1.000 1.000 1.000   24    1.000 1.000 1.000
E2         123   0.992 0.976 0.984   17    1.000 1.000 1.000
A          47    0.959 1.000 0.979   6     1.000 0.833 0.909
E4         41    0.919 0.829 0.872   3     0.750 1.000 0.857
B          33    0.756 0.939 0.838   1     0.000 0.000   nan
E3         26    0.800 0.615 0.696   8     1.000 0.500 0.667
Epoch: 80  Train Loss:   2.8474  Acc: 97.06  Val  Loss   0.5603  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.93it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 76.07it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 17.59it/s]

Epoch: 81  Train Loss:   1.8780  Acc: 97.88  Val  Loss   0.5661  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.98it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 75.77it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.89it/s]

Epoch: 82  Train Loss:   1.2118  Acc: 98.59  Val  Loss   0.6590  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.91it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.36it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 17.50it/s]

Epoch: 83  Train Loss:   0.8665  Acc: 99.18  Val  Loss   0.7440  Acc: 93.16


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.96it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 76.54it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.32it/s]

Epoch: 84  Train Loss:   0.9579  Acc: 98.47  Val  Loss   0.7999  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.87it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.76it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.13it/s]

Epoch: 85  Train Loss:   0.7049  Acc: 99.41  Val  Loss   0.7685  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.93it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.70it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.44it/s]

Epoch: 86  Train Loss:   0.4351  Acc: 99.65  Val  Loss   0.8386  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.96it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 76.76it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.27it/s]

Epoch: 87  Train Loss:   0.4237  Acc: 99.41  Val  Loss   0.8310  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.89it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.41it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.39it/s]

Epoch: 88  Train Loss:   0.5584  Acc: 99.53  Val  Loss   0.6749  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.90it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.64it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 20.24it/s]

Epoch: 89  Train Loss:   0.4439  Acc: 99.65  Val  Loss   0.3900  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.00it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.11it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 17.24it/s]

E1         373   0.995 1.000 0.997   58    0.935 1.000 0.967
C          206   1.000 1.000 1.000   24    1.000 1.000 1.000
E2         123   1.000 0.992 0.996   17    1.000 1.000 1.000
A          47    1.000 1.000 1.000   6     1.000 0.833 0.909
E4         41    1.000 0.976 0.988   3     1.000 1.000 1.000
B          33    1.000 1.000 1.000   1     0.500 1.000 0.667
E3         26    0.962 0.962 0.962   8     1.000 0.500 0.667
Epoch: 90  Train Loss:   0.2915  Acc: 99.65  Val  Loss   0.5961  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.88it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.06it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.31it/s]

Epoch: 91  Train Loss:   0.2130  Acc: 99.76  Val  Loss   0.6980  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.87it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 76.27it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.25it/s]

Epoch: 92  Train Loss:   0.3306  Acc: 99.65  Val  Loss   1.3250  Acc: 94.02


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.11it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.44it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.47it/s]

Epoch: 93  Train Loss:   0.4340  Acc: 99.76  Val  Loss   0.6485  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.94it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.05it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.94it/s]

Epoch: 94  Train Loss:   0.2643  Acc: 99.76  Val  Loss   0.5922  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.06it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.63it/s]
  7%|███▎                                        | 2/27 [00:00<00:01, 16.17it/s]

Epoch: 95  Train Loss:   0.3256  Acc: 99.65  Val  Loss   0.9464  Acc: 94.87


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.87it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.53it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 22.90it/s]

Epoch: 96  Train Loss:   0.3974  Acc: 99.65  Val  Loss   0.5844  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 21.09it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.49it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.75it/s]

Epoch: 97  Train Loss:   0.2842  Acc: 99.76  Val  Loss   0.4406  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.90it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 76.55it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 23.51it/s]

Epoch: 98  Train Loss:   0.1945  Acc: 99.76  Val  Loss   0.7454  Acc: 95.73


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.99it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.23it/s]
 11%|████▉                                       | 3/27 [00:00<00:01, 21.40it/s]

Epoch: 99  Train Loss:   0.2564  Acc: 99.65  Val  Loss   0.6213  Acc: 96.58


100%|███████████████████████████████████████████| 27/27 [00:01<00:00, 20.87it/s]
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 77.00it/s]

E1         373   0.995 1.000 0.997   58    0.951 1.000 0.975
C          206   1.000 1.000 1.000   24    1.000 1.000 1.000
E2         123   1.000 0.992 0.996   17    1.000 1.000 1.000
A          47    1.000 1.000 1.000   6     1.000 0.833 0.909
E4         41    1.000 1.000 1.000   3     1.000 1.000 1.000
B          33    1.000 1.000 1.000   1     0.500 1.000 0.667
E3         26    1.000 0.962 0.980   8     1.000 0.625 0.769
Epoch: 100  Train Loss:   0.1587  Acc: 99.76  Val  Loss   0.7035  Acc: 96.58





Above is an example run of 100 epochs with training (left) and validation (right) precision, recall and F1 printed for every class.  Additionally, we print the overall  losses and accuracies after every epoch.  Note, these do not decrease monitonically but in general the model does perform better after some backtracking.  For most use cases a  shorter training regime is probably sufficient.

In [8]:
args.load = "bd.multi.e100.h256.b32.model"
net = torch.load(args.load)                             # Load Saved Model
net.to(device)

# Run evaluation with best validation model on training
val = [line.strip().split() for line in open(DATA_DIR + prefix + ".train",'r')]
t_strs = np.array([sequence for label, sequence in val])
t_inps = [to_int(sequence) for label, sequence in val]
t_outs = np.array([lbls[label] for label, sequence in val])

args.confusion = True
args.pconfusion = False
loss, acc = run_evaluation(net, t_inps, t_outs, t_strs, aggregate=True, verbose=True)
print_predictors(net, "final")
print("Acc: {:5.3f}".format(acc))

100%|███████████████████████████████████████████| 27/27 [01:13<00:00,  2.72s/it]


E1         373   0.997 0.997 0.997   
C          206   1.000 1.000 1.000   
E2         123   0.992 0.992 0.992   
A          47    1.000 1.000 1.000   
E4         41    1.000 1.000 1.000   
B          33    1.000 1.000 1.000   
E3         26    0.962 0.962 0.962   
Acc: 99.647


In [9]:
# Print confusion matrix on training data
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
data = [line.strip().replace("[","").replace("]","").split(",") for line in open("confusion.csv")]
labels = data[0][1:]
data = [[float(v) for v in line[1:-1]] for line in data[1:]]
df = pd.DataFrame(data, index = [i for i in labels], columns = [i for i in labels])
df_norm_col = df.div(df.sum(axis=1), axis=0)
plt.figure(figsize = (10,7))
_ = sn.heatmap(df_norm_col, annot=True)

In [10]:
# Run evaluation with best validation model on validation
val = [line.strip().split() for line in open(DATA_DIR + prefix + ".val",'r')]
t_strs = np.array([sequence for label, sequence in val])
t_inps = [to_int(sequence) for label, sequence in val]
t_outs = np.array([lbls[label] for label, sequence in val])
args.confusion = True
args.pconfusion = False
loss, acc = run_evaluation(net, t_inps, t_outs, showTrain=False, verbose=True)
print("Acc: {:5.3f}".format(acc))

100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 73.58it/s]

E1         58    0.967 1.000 0.983   
C          24    1.000 1.000 1.000   
E2         17    1.000 1.000 1.000   
E3         8     1.000 0.750 0.857   
A          6     1.000 1.000 1.000   
E4         3     1.000 1.000 1.000   
B          1     1.000 1.000 1.000   
Acc: 98.291





In [11]:
# Run evaluation with best validation model on test
val = [line.strip().split() for line in open(DATA_DIR + prefix + ".test",'r')]
t_strs = np.array([sequence for label, sequence in val])
t_inps = [to_int(sequence) for label, sequence in val]
t_outs = np.array([lbls[label] for label, sequence in val])
loss, acc = run_evaluation(net, t_inps, t_outs, showTrain=False, verbose=True)
print("Acc: {:5.3f}".format(acc))

100%|█████████████████████████████████████████████| 8/8 [00:00<00:00, 80.60it/s]

E1         118   0.992 1.000 0.996   
C          51    0.981 1.000 0.990   
E2         33    1.000 1.000 1.000   
A          21    1.000 1.000 1.000   
E3         10    1.000 0.900 0.947   
E4         9     0.900 1.000 0.947   
B          7     1.000 0.714 0.833   
Acc: 98.795



