In [1]:
import os
from GetDataLoaders import get_dataloaders, get_short_dataloaders
from architectures.AlexNetFeature import AlexNetFeature
from architectures.AlexNetClassifier import AlexNetClassifier
from architectures.LinearTransformationNorm import LinearTransformationNorm
from architectures.NCEAverage import NCEAverage
from architectures.NCELoss import NCELoss
import torch
import numpy as np
#from torch.utils.tensorboard import SummaryWriter
from torch import nn
import time
from torch import optim

In [2]:
'''
# we skip the probs for now
gama = 2.0
with open(os.path.join("./PUprobs", 'prob.dat'), 'r') as file_input:
    train_prob_str = file_input.readlines()
    train_prob = [float(i_prob_str.rstrip('\n')) for i_prob_str in train_prob_str]
    print(len(train_prob)/4.0)
    train_weight = [1.0 if 0==i%4 else 1-train_prob[i]**gama for i in range(len(train_prob))]
'''

'\n# we skip the probs for now\ngama = 2.0\nwith open(os.path.join("./PUprobs", \'prob.dat\'), \'r\') as file_input:\n    train_prob_str = file_input.readlines()\n    train_prob = [float(i_prob_str.rstrip(\'\n\')) for i_prob_str in train_prob_str]\n    print(len(train_prob)/4.0)\n    train_weight = [1.0 if 0==i%4 else 1-train_prob[i]**gama for i in range(len(train_prob))]\n'

In [3]:
use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"
batch_size = 192
lr = 0.1
LUT_lr = [(90,0.01), (130,0.001), (190,0.0001), (210,0.00001), (230,0.0001), (245,0.00001)]
num_epochs = 245
momentum = 0.9
weight_decay = 5e-4
nesterov = True
Lambdas = {'CE':1.0, 'MSE':1.0, 'NCE':1.0}

loaders = get_dataloaders('imagenet', batch_size=batch_size, num_workers=2)
ndata_train = len(loaders['train_loader'].dataset)
ndata_valid = len(loaders['valid_loader'].dataset)
t = 0.07
m = 4096
gamma = 2

In [4]:
ndata_train, ndata_valid

(80000, 10000)

In [5]:
feature_net = AlexNetFeature().to(device)
classifier_net = AlexNetClassifier().to(device)
transformation_net = LinearTransformationNorm().to(device)

feature_optimizer = optim.SGD(feature_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
classifier_optimizer = optim.SGD(classifier_net.parameters() ,lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
transformation_optimizer = optim.SGD(transformation_net.parameters() ,lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)

Networks = {'feature':feature_net, 'classifier':classifier_net, 'transformation':transformation_net}
Optimizers = {'feature':feature_optimizer, 'classifier':classifier_optimizer, 'transformation':transformation_optimizer}

NCE = {'NCEAverage_train': NCEAverage(outputSize=ndata_train, K=m, T=t).cuda(), 'NCEAverage_valid': NCEAverage(outputSize=ndata_valid, K=m, T=t).cuda() , 'NCECriterion_train': NCELoss(ndata_train), 'NCECriterion_valid': NCELoss(ndata_valid)}
Criterions = {'CE': nn.CrossEntropyLoss(reduction='none'), 'MSE':nn.MSELoss() }

Moving NCE To CUDA.
Moving AliasMethod To CUDA.
Moving NCE To CUDA.
Moving AliasMethod To CUDA.


In [6]:
#for sample in loaders['train_loader']:
    #print(sample[0].size(), sample[1].size(), sample[2].size())

In [6]:
def adjust_learning_rates(epoch):
    # filter out the networks that are not trainable and that do
    # not have a learning rate Look Up Table (LUT_lr) in their optim_params
    lr = next((lr for (max_epoch, lr) in LUT_lr if max_epoch>epoch), LUT_lr[-1][1])
    for key in Optimizers:
        for g in Optimizers[key].param_groups:
            g['lr'] = lr

def AlexNetDecoupling(batch, NCEAverage, NCECriterion, PUWeights=None, train=True):
    data, targets, indices = batch
    
    if train is True:
        Optimizers['feature'].zero_grad()
        Optimizers['classifier'].zero_grad()
        Optimizers['transformation'].zero_grad()
    
    #to cuda
    
    data = data.to(device)
    targets = targets.to(device)
    indices = indices.to(device)

    
    #perform rotations & adjust data shape
    data_90 = torch.flip(torch.transpose(data,2,3),[2])
    data_180 = torch.flip(torch.flip(data,[2]),[3])
    data_270 = torch.transpose(torch.flip(data,[2]),2,3)
    data = torch.stack([data, data_90, data_180, data_270], dim=1)
    batch_size, rotations, channels, height, width = data.size()
    data = data.view([batch_size*rotations, channels, height, width])
    
    #debug for backward
    data.requires_grad = False 
    targets.requires_grad = False
    indices.requires_grad = False
    
    #collect features
    features = Networks['feature'](data)
    features_rot, features_invariance = torch.split(features, 2048, dim=1)
    
    #collect rotation prediction
    pred = Networks['classifier'](features_rot)
    
    
    #average features across 4 rotations
    features_invariance_instance = features_invariance[0::4,:] + features_invariance[1::4,:] + features_invariance[2::4,:] + features_invariance[3::4,:]
    features_invariance_instance = torch.mul(features_invariance_instance, 0.25) #fbar 192x2048
    
    #downsample to 128 & perform normalization of vector
    features_128_norm = Networks['transformation'](features_invariance_instance) #192x128
    
    with torch.no_grad():
        #stack 192x2048 4 times to be  = 192x4x2048 = 768x2048
        features_invariance_instance_mean = torch.unsqueeze(features_invariance_instance,1).expand(-1,4,-1).clone()
        features_invariance_instance_mean = features_invariance_instance_mean.view(4*len(features_invariance_instance),2048)
    
    #calculate rotation loss ignore PU for now
    loss_cls_each = Criterions['CE'](pred, targets)
    loss_cls = torch.sum(loss_cls_each)/loss_cls_each.shape[0]
    
    #calculate rotation invariance by MSE
    loss_mse = Criterions['MSE'](features_invariance, features_invariance_instance_mean)
    
    #calculate instance loss using NCE
    output_nce = NCEAverage(features_128_norm, indices)
    loss_nce = NCECriterion(output_nce, indices)
    
    #loss_nce = torch.FloatTensor([0.0]).cuda() #NCECriterion(features_128_norm, indices)
    
    loss_total = Lambdas['CE']*loss_cls + Lambdas['MSE']*loss_mse + Lambdas['NCE']*loss_nce
    
    if train is True:
        loss_total.backward()
        Optimizers['feature'].step()
        Optimizers['classifier'].step()
        Optimizers['transformation'].step()
    
    #calculate classification accuracy
    pred = pred.argmax(dim=1, keepdim=True)
    correct = pred.eq(targets.view_as(pred)).sum().item()
    
    losses = {'ce':loss_cls.item(), 'mse':loss_mse.item(), 'nce':loss_nce.item(), 'correct':correct}
    
    return losses
    

In [7]:
def train(data_loader, epoch, log_interval=50):
    
    Networks['feature'].train()
    Networks['classifier'].train()
    Networks['transformation'].train()
    
    losses = {'ce':[], 'mse':[], 'nce':[]}
    correct = 0
    train_loss = np.Inf
    train_acc = 0.0
    adjust_learning_rates(epoch)
    start_time = time.time()
    for batch_idx, sample in enumerate(data_loader):
        
        lossesdict = AlexNetDecoupling(sample, NCE['NCEAverage_train'], NCE['NCECriterion_train'] ,train=True)
        
        losses['ce'].append(lossesdict['ce'])
        losses['mse'].append(lossesdict['mse'])
        losses['nce'].append(lossesdict['nce'])
        correct += lossesdict['correct']
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: CE {:.6f}, MSE {:.6f}, NCE {:.6f}'.format(epoch, batch_idx*len(sample[0]), len(data_loader.dataset), 100. * batch_idx / len(data_loader), lossesdict['ce'], lossesdict['mse'], lossesdict['nce']))
    
    end_time = time.time()
    print("Time for epoch pass {}".format(end_time-start_time))
    train_loss = {'ce': float(np.mean(losses['ce'])), 'mse': float(np.mean(losses['mse'])), 'nce':float(np.mean(losses['nce']))}
    train_acc = correct / float(len(data_loader.dataset)*4)
    print('Train set: Average loss: CE {:.4f}, MSE {:.4f}, NCE {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(train_loss['ce'], train_loss['mse'], train_loss['nce'], correct, len(data_loader.dataset)*4, 100*train_acc))
    return train_loss, train_acc

def validate(data_loader, epoch, log_interval=50):
    
    Networks['feature'].eval()
    Networks['classifier'].eval()
    Networks['transformation'].eval()
    
    losses = {'ce':[], 'mse':[], 'nce':[]}
    correct = 0
    valid_loss = np.Inf
    valid_acc = 0.0
    start_time = time.time()
    for batch_idx, sample in enumerate(data_loader): 
        with torch.no_grad():
            lossesdict = AlexNetDecoupling(sample, NCE['NCEAverage_valid'], NCE['NCECriterion_valid'], train=False)
        
        losses['ce'].append(lossesdict['ce'])
        losses['mse'].append(lossesdict['mse'])
        losses['nce'].append(lossesdict['nce'])
        correct += lossesdict['correct']
        if batch_idx % log_interval == 0:
            print('Valid Epoch: {} [{}/{} ({:.0f}%)]\tLoss: CE {:.6f}, MSE {:.6f}, NCE {:.6f}'.format(epoch, batch_idx*len(sample[0]), len(data_loader.dataset), 100. * batch_idx / len(data_loader), lossesdict['ce'], lossesdict['mse'], lossesdict['nce']))
    
    end_time = time.time()
    print("Time for epoch pass {}".format(end_time-start_time))
    valid_loss = {'ce': float(np.mean(losses['ce'])), 'mse': float(np.mean(losses['mse'])), 'nce':float(np.mean(losses['nce']))}
    valid_acc = correct / float(len(data_loader.dataset)*4)
    print('Valid set: Average loss: CE {:.4f}, MSE {:.4f}, NCE {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(valid_loss['ce'], valid_loss['mse'], valid_loss['nce'], correct, len(data_loader.dataset)*4, 100*valid_acc))
    return valid_loss, valid_acc

def run_main_loop(loaders, num_epochs):
    #writer = SummaryWriter('./logs/AlexNet_Unsupervised_Decoupling')
    save_path = "weights/AlexNet_Decoupling.pth"
    best_acc = 0.0
    for epoch in range(num_epochs):
        print("Performing {}th epoch".format(epoch))
        train_loss, train_acc = train(loaders['train_loader'], epoch)
        val_loss, val_acc = validate(loaders['valid_loader'], epoch)
        '''
        writer.add_scalar('CELoss/train', train_loss['ce'], epoch)
        writer.add_scalar('MSELoss/train', train_loss['mse'], epoch)
        writer.add_scalar('NCELoss/train', train_loss['nce'], epoch)
        
        writer.add_scalar('CELoss/Valid', val_loss['ce'], epoch)
        writer.add_scalar('MSELoss/Valid', val_loss['mse'], epoch)
        writer.add_scalar('NCELoss/Valid', val_loss['nce'], epoch)
        
        writer.add_scalar('Accuracy/train', train_acc, epoch)
        writer.add_scalar('Accuracy/Valid', val_acc, epoch)
    
        writer.add_scalar('LR', Optimizers['feature'].param_groups[0]['lr'], epoch)
        '''
        if val_acc > best_acc :
            best_acc = val_acc
            #save model
            states = {
                'epoch': epoch + 1,
                'feature_net':Networks['feature'].state_dict(),
                'classifier_net':Networks['classifier'].state_dict(),
                'transformation_net':Networks['transformation'].state_dict(),
                'feature_optimizer': Optimizers['feature'].state_dict(),
                'classifier_optimizer': Optimizers['classifier'].state_dict(),
                'transformation_optimizer': Optimizers['transformation'].state_dict(),
                'best_accuracy': best_acc
            }
            torch.save(states, save_path)
            print('Model Saved')

In [8]:
run_main_loop(loaders, num_epochs)

Performing 0th epoch
normalization constant Z is set to 175837.7
Time for epoch pass 245.66650128364563
Train set: Average loss: CE 1.3250, MSE 0.0016, NCE 24.9360, Accuracy: 120272/320000 (38%)

normalization constant Z is set to 22520.2
Time for epoch pass 14.504526138305664
Valid set: Average loss: CE 1.3065, MSE 0.0002, NCE 10.1196, Accuracy: 15837/40000 (40%)

Model Saved
Performing 1th epoch
Time for epoch pass 249.9048135280609
Train set: Average loss: CE 1.2990, MSE 0.0001, NCE 25.6782, Accuracy: 125723/320000 (39%)

Time for epoch pass 15.155174970626831
Valid set: Average loss: CE 1.2796, MSE 0.0000, NCE 10.1196, Accuracy: 16338/40000 (41%)

Model Saved
Performing 2th epoch
Time for epoch pass 263.853444814682
Train set: Average loss: CE 1.2851, MSE 0.0001, NCE 21.4281, Accuracy: 128598/320000 (40%)

Time for epoch pass 15.388926982879639
Valid set: Average loss: CE 1.2982, MSE 0.0000, NCE 10.0997, Accuracy: 15890/40000 (40%)

Performing 3th epoch
Time for epoch pass 268.3335

Time for epoch pass 15.062572717666626
Valid set: Average loss: CE 1.2148, MSE 0.0001, NCE 10.1207, Accuracy: 18180/40000 (45%)

Model Saved
Performing 7th epoch
Time for epoch pass 245.7385036945343
Train set: Average loss: CE 1.2047, MSE 0.0001, NCE 8.4611, Accuracy: 146494/320000 (46%)

Time for epoch pass 14.645542860031128
Valid set: Average loss: CE 1.1995, MSE 0.0001, NCE 10.1137, Accuracy: 18434/40000 (46%)

Model Saved
Performing 8th epoch
Time for epoch pass 246.84732389450073
Train set: Average loss: CE 1.1932, MSE 0.0001, NCE 7.0231, Accuracy: 148610/320000 (46%)

Time for epoch pass 15.206623315811157
Valid set: Average loss: CE 1.1889, MSE 0.0001, NCE 10.1303, Accuracy: 18820/40000 (47%)

Model Saved
Performing 9th epoch
Time for epoch pass 248.75216221809387
Train set: Average loss: CE 1.1823, MSE 0.0001, NCE 6.0779, Accuracy: 150616/320000 (47%)

Time for epoch pass 14.989362955093384
Valid set: Average loss: CE 1.1841, MSE 0.0001, NCE 10.1105, Accuracy: 18867/40000 (47

Time for epoch pass 15.293487787246704
Valid set: Average loss: CE 1.1477, MSE 0.0002, NCE 10.1167, Accuracy: 19584/40000 (49%)

Model Saved
Performing 14th epoch
Time for epoch pass 250.38363671302795
Train set: Average loss: CE 1.1325, MSE 0.0002, NCE 3.6185, Accuracy: 160345/320000 (50%)

Time for epoch pass 15.627100467681885
Valid set: Average loss: CE 1.1427, MSE 0.0002, NCE 10.1209, Accuracy: 19921/40000 (50%)

Model Saved
Performing 15th epoch
Time for epoch pass 250.50816249847412
Train set: Average loss: CE 1.1242, MSE 0.0003, NCE 3.3445, Accuracy: 162434/320000 (51%)

Time for epoch pass 15.345122814178467
Valid set: Average loss: CE 1.1344, MSE 0.0003, NCE 10.1170, Accuracy: 19993/40000 (50%)

Model Saved
Performing 16th epoch
Time for epoch pass 249.08281230926514
Train set: Average loss: CE 1.1135, MSE 0.0003, NCE 3.1101, Accuracy: 164142/320000 (51%)

Time for epoch pass 15.322603940963745
Valid set: Average loss: CE 1.1260, MSE 0.0003, NCE 10.1187, Accuracy: 20242/40000

Time for epoch pass 14.368187189102173
Valid set: Average loss: CE 1.0897, MSE 0.0004, NCE 10.1116, Accuracy: 21138/40000 (53%)

Model Saved
Performing 21th epoch
Time for epoch pass 245.7217562198639
Train set: Average loss: CE 1.0479, MSE 0.0005, NCE 2.3369, Accuracy: 176106/320000 (55%)

Time for epoch pass 14.716146230697632
Valid set: Average loss: CE 1.0788, MSE 0.0005, NCE 10.1097, Accuracy: 21460/40000 (54%)

Model Saved
Performing 22th epoch
Time for epoch pass 250.2693693637848
Train set: Average loss: CE 1.0304, MSE 0.0005, NCE 2.2436, Accuracy: 179280/320000 (56%)

Time for epoch pass 15.241191387176514
Valid set: Average loss: CE 1.0740, MSE 0.0005, NCE 10.1149, Accuracy: 21416/40000 (54%)

Performing 23th epoch
Time for epoch pass 250.94348883628845
Train set: Average loss: CE 1.0104, MSE 0.0006, NCE 2.1502, Accuracy: 183023/320000 (57%)

Time for epoch pass 15.183342456817627
Valid set: Average loss: CE 1.0570, MSE 0.0005, NCE 10.1121, Accuracy: 21869/40000 (55%)

Model 

Time for epoch pass 15.302263498306274
Valid set: Average loss: CE 0.9919, MSE 0.0006, NCE 10.0998, Accuracy: 23342/40000 (58%)

Model Saved
Performing 28th epoch
Time for epoch pass 249.23116946220398
Train set: Average loss: CE 0.8947, MSE 0.0007, NCE 1.8171, Accuracy: 202101/320000 (63%)

Time for epoch pass 14.853137493133545
Valid set: Average loss: CE 0.9978, MSE 0.0006, NCE 10.1090, Accuracy: 23515/40000 (59%)

Model Saved
Performing 29th epoch
Time for epoch pass 249.99253869056702
Train set: Average loss: CE 0.8743, MSE 0.0007, NCE 1.7715, Accuracy: 205224/320000 (64%)

Time for epoch pass 14.64015817642212
Valid set: Average loss: CE 0.9599, MSE 0.0006, NCE 10.1171, Accuracy: 23960/40000 (60%)

Model Saved
Performing 30th epoch
Time for epoch pass 250.35735630989075
Train set: Average loss: CE 0.8536, MSE 0.0007, NCE 1.7274, Accuracy: 208271/320000 (65%)

Time for epoch pass 14.899359941482544
Valid set: Average loss: CE 0.9482, MSE 0.0006, NCE 10.1074, Accuracy: 24382/40000 

Time for epoch pass 14.738312482833862
Valid set: Average loss: CE 0.9150, MSE 0.0007, NCE 10.1120, Accuracy: 25035/40000 (63%)

Model Saved
Performing 35th epoch
Time for epoch pass 249.66518664360046
Train set: Average loss: CE 0.7581, MSE 0.0006, NCE 1.5722, Accuracy: 222368/320000 (69%)

Time for epoch pass 14.809694528579712
Valid set: Average loss: CE 0.8971, MSE 0.0006, NCE 10.1065, Accuracy: 25479/40000 (64%)

Model Saved
Performing 36th epoch
Time for epoch pass 280.8803482055664
Train set: Average loss: CE 0.7407, MSE 0.0006, NCE 1.5430, Accuracy: 224611/320000 (70%)

Time for epoch pass 14.768786430358887
Valid set: Average loss: CE 0.9663, MSE 0.0006, NCE 10.1041, Accuracy: 24395/40000 (61%)

Performing 37th epoch
Time for epoch pass 248.54080963134766
Train set: Average loss: CE 0.7217, MSE 0.0006, NCE 1.5250, Accuracy: 227320/320000 (71%)

Time for epoch pass 14.779814004898071
Valid set: Average loss: CE 0.8804, MSE 0.0006, NCE 10.0942, Accuracy: 25936/40000 (65%)

Model

Time for epoch pass 15.053376913070679
Valid set: Average loss: CE 0.8686, MSE 0.0006, NCE 10.1023, Accuracy: 26280/40000 (66%)

Model Saved
Performing 42th epoch
Time for epoch pass 250.13913655281067
Train set: Average loss: CE 0.6298, MSE 0.0006, NCE 1.4406, Accuracy: 240422/320000 (75%)

Time for epoch pass 15.1327965259552
Valid set: Average loss: CE 0.8992, MSE 0.0005, NCE 10.1152, Accuracy: 25723/40000 (64%)

Performing 43th epoch
Time for epoch pass 250.49264216423035
Train set: Average loss: CE 0.6101, MSE 0.0006, NCE 1.4283, Accuracy: 243264/320000 (76%)

Time for epoch pass 16.15111804008484
Valid set: Average loss: CE 0.8651, MSE 0.0006, NCE 10.1085, Accuracy: 26481/40000 (66%)

Model Saved
Performing 44th epoch
Time for epoch pass 250.35830736160278
Train set: Average loss: CE 0.5957, MSE 0.0006, NCE 1.4214, Accuracy: 245258/320000 (77%)

Time for epoch pass 15.327064752578735
Valid set: Average loss: CE 0.8838, MSE 0.0006, NCE 10.1177, Accuracy: 26223/40000 (66%)

Perform

Time for epoch pass 14.84180760383606
Valid set: Average loss: CE 0.8643, MSE 0.0006, NCE 10.1322, Accuracy: 26550/40000 (66%)

Model Saved
Performing 49th epoch
Time for epoch pass 249.1657440662384
Train set: Average loss: CE 0.5020, MSE 0.0006, NCE 1.3804, Accuracy: 258281/320000 (81%)

Time for epoch pass 14.617973327636719
Valid set: Average loss: CE 0.9471, MSE 0.0006, NCE 10.1210, Accuracy: 26236/40000 (66%)

Performing 50th epoch
Time for epoch pass 249.89550352096558
Train set: Average loss: CE 0.4842, MSE 0.0006, NCE 1.3681, Accuracy: 260434/320000 (81%)

Time for epoch pass 14.684577226638794
Valid set: Average loss: CE 0.8896, MSE 0.0005, NCE 10.1130, Accuracy: 26599/40000 (66%)

Model Saved
Performing 51th epoch
Time for epoch pass 249.8493776321411
Train set: Average loss: CE 0.4652, MSE 0.0006, NCE 1.3629, Accuracy: 262746/320000 (82%)

Time for epoch pass 15.770360708236694
Valid set: Average loss: CE 0.9167, MSE 0.0006, NCE 10.1252, Accuracy: 26567/40000 (66%)

Perform

Time for epoch pass 15.23548412322998
Valid set: Average loss: CE 0.9527, MSE 0.0006, NCE 10.1237, Accuracy: 26481/40000 (66%)

Performing 56th epoch
Time for epoch pass 250.81341552734375
Train set: Average loss: CE 0.3805, MSE 0.0006, NCE 1.3379, Accuracy: 274216/320000 (86%)

Time for epoch pass 14.698760747909546
Valid set: Average loss: CE 0.9521, MSE 0.0006, NCE 10.1176, Accuracy: 26534/40000 (66%)

Performing 57th epoch
Time for epoch pass 248.53568863868713
Train set: Average loss: CE 0.3649, MSE 0.0006, NCE 1.3374, Accuracy: 276268/320000 (86%)

Time for epoch pass 15.718398094177246
Valid set: Average loss: CE 0.9649, MSE 0.0006, NCE 10.1234, Accuracy: 26758/40000 (67%)

Model Saved
Performing 58th epoch
Time for epoch pass 245.32661032676697
Train set: Average loss: CE 0.3492, MSE 0.0006, NCE 1.3339, Accuracy: 278272/320000 (87%)

Time for epoch pass 14.815999269485474
Valid set: Average loss: CE 1.1084, MSE 0.0006, NCE 10.1298, Accuracy: 25578/40000 (64%)

Performing 59th e

Time for epoch pass 15.066775560379028
Valid set: Average loss: CE 1.0291, MSE 0.0006, NCE 10.1320, Accuracy: 26354/40000 (66%)

Performing 63th epoch
Time for epoch pass 247.38429522514343
Train set: Average loss: CE 0.2821, MSE 0.0006, NCE 1.3213, Accuracy: 286693/320000 (90%)

Time for epoch pass 14.79113221168518
Valid set: Average loss: CE 1.0521, MSE 0.0006, NCE 10.1287, Accuracy: 26782/40000 (67%)

Model Saved
Performing 64th epoch
Time for epoch pass 247.85897493362427
Train set: Average loss: CE 0.2715, MSE 0.0006, NCE 1.3166, Accuracy: 288057/320000 (90%)

Time for epoch pass 15.034172058105469
Valid set: Average loss: CE 1.0429, MSE 0.0006, NCE 10.1287, Accuracy: 26506/40000 (66%)

Performing 65th epoch
Time for epoch pass 273.7488558292389
Train set: Average loss: CE 0.2611, MSE 0.0006, NCE 1.3157, Accuracy: 289565/320000 (90%)

Time for epoch pass 15.330137252807617
Valid set: Average loss: CE 1.1141, MSE 0.0006, NCE 10.1284, Accuracy: 25833/40000 (65%)

Performing 66th ep

Time for epoch pass 14.642661333084106
Valid set: Average loss: CE 1.1408, MSE 0.0006, NCE 10.1270, Accuracy: 25917/40000 (65%)

Performing 70th epoch
Time for epoch pass 249.841450214386
Train set: Average loss: CE 0.2154, MSE 0.0006, NCE 1.3085, Accuracy: 295301/320000 (92%)

Time for epoch pass 15.776302814483643
Valid set: Average loss: CE 1.1191, MSE 0.0006, NCE 10.1280, Accuracy: 26585/40000 (66%)

Performing 71th epoch
Time for epoch pass 249.5428273677826
Train set: Average loss: CE 0.2075, MSE 0.0005, NCE 1.3034, Accuracy: 296060/320000 (93%)

Time for epoch pass 15.692427158355713
Valid set: Average loss: CE 1.1005, MSE 0.0005, NCE 10.1347, Accuracy: 26653/40000 (67%)

Performing 72th epoch
Time for epoch pass 247.89784789085388
Train set: Average loss: CE 0.1985, MSE 0.0005, NCE 1.3015, Accuracy: 297384/320000 (93%)

Time for epoch pass 15.098030805587769
Valid set: Average loss: CE 1.1425, MSE 0.0006, NCE 10.1281, Accuracy: 26471/40000 (66%)

Performing 73th epoch
Time for 

Time for epoch pass 15.27805495262146
Valid set: Average loss: CE 1.1407, MSE 0.0005, NCE 10.1355, Accuracy: 26564/40000 (66%)

Performing 77th epoch
Time for epoch pass 249.74462962150574
Train set: Average loss: CE 0.1724, MSE 0.0005, NCE 1.2940, Accuracy: 300464/320000 (94%)

Time for epoch pass 14.504590511322021
Valid set: Average loss: CE 1.1422, MSE 0.0005, NCE 10.1326, Accuracy: 26559/40000 (66%)

Performing 78th epoch
Time for epoch pass 249.70467710494995
Train set: Average loss: CE 0.1691, MSE 0.0005, NCE 1.2969, Accuracy: 300940/320000 (94%)

Time for epoch pass 14.739376068115234
Valid set: Average loss: CE 1.1642, MSE 0.0005, NCE 10.1328, Accuracy: 26593/40000 (66%)

Performing 79th epoch
Time for epoch pass 249.62563753128052
Train set: Average loss: CE 0.1647, MSE 0.0005, NCE 1.2910, Accuracy: 301307/320000 (94%)

Time for epoch pass 14.535629987716675
Valid set: Average loss: CE 1.1567, MSE 0.0005, NCE 10.1423, Accuracy: 26698/40000 (67%)

Performing 80th epoch
Time fo

Time for epoch pass 15.067648649215698
Valid set: Average loss: CE 1.2021, MSE 0.0005, NCE 10.1423, Accuracy: 26714/40000 (67%)

Performing 84th epoch
Time for epoch pass 249.29510641098022
Train set: Average loss: CE 0.1482, MSE 0.0005, NCE 1.2932, Accuracy: 303387/320000 (95%)

Time for epoch pass 15.69714879989624
Valid set: Average loss: CE 1.2173, MSE 0.0005, NCE 10.1316, Accuracy: 26680/40000 (67%)

Performing 85th epoch
Time for epoch pass 249.5522620677948
Train set: Average loss: CE 0.1412, MSE 0.0005, NCE 1.2856, Accuracy: 304302/320000 (95%)

Time for epoch pass 15.116466283798218
Valid set: Average loss: CE 1.1673, MSE 0.0005, NCE 10.1376, Accuracy: 26613/40000 (67%)

Performing 86th epoch
Time for epoch pass 250.09863114356995
Train set: Average loss: CE 0.1417, MSE 0.0005, NCE 1.2918, Accuracy: 304200/320000 (95%)

Time for epoch pass 15.040379524230957
Valid set: Average loss: CE 1.2059, MSE 0.0005, NCE 10.1352, Accuracy: 26639/40000 (67%)

Performing 87th epoch
Time for

Time for epoch pass 14.901733160018921
Valid set: Average loss: CE 1.1116, MSE 0.0004, NCE 10.1375, Accuracy: 27544/40000 (69%)

Model Saved
Performing 91th epoch
Time for epoch pass 248.70060801506042
Train set: Average loss: CE 0.0586, MSE 0.0004, NCE 1.2971, Accuracy: 315002/320000 (98%)

Time for epoch pass 14.969276189804077
Valid set: Average loss: CE 1.1173, MSE 0.0004, NCE 10.1400, Accuracy: 27682/40000 (69%)

Model Saved
Performing 92th epoch
Time for epoch pass 249.17767763137817
Train set: Average loss: CE 0.0508, MSE 0.0004, NCE 1.2930, Accuracy: 315962/320000 (99%)

Time for epoch pass 14.539621114730835
Valid set: Average loss: CE 1.1238, MSE 0.0004, NCE 10.1364, Accuracy: 27702/40000 (69%)

Model Saved
Performing 93th epoch
Time for epoch pass 249.32107877731323
Train set: Average loss: CE 0.0451, MSE 0.0004, NCE 1.2879, Accuracy: 316488/320000 (99%)

Time for epoch pass 14.428200244903564
Valid set: Average loss: CE 1.1324, MSE 0.0004, NCE 10.1313, Accuracy: 27729/40000

Time for epoch pass 14.639148235321045
Valid set: Average loss: CE 1.1550, MSE 0.0004, NCE 10.1372, Accuracy: 27805/40000 (70%)

Performing 98th epoch
Time for epoch pass 247.57343196868896
Train set: Average loss: CE 0.0308, MSE 0.0004, NCE 1.2829, Accuracy: 317961/320000 (99%)

Time for epoch pass 14.936846494674683
Valid set: Average loss: CE 1.1583, MSE 0.0004, NCE 10.1335, Accuracy: 27779/40000 (69%)

Performing 99th epoch
Time for epoch pass 248.2142333984375
Train set: Average loss: CE 0.0297, MSE 0.0004, NCE 1.2805, Accuracy: 318049/320000 (99%)

Time for epoch pass 14.983694314956665
Valid set: Average loss: CE 1.1623, MSE 0.0004, NCE 10.1378, Accuracy: 27783/40000 (69%)

Performing 100th epoch
Time for epoch pass 248.99423551559448
Train set: Average loss: CE 0.0283, MSE 0.0004, NCE 1.2836, Accuracy: 318152/320000 (99%)

Time for epoch pass 15.120772123336792
Valid set: Average loss: CE 1.1682, MSE 0.0004, NCE 10.1371, Accuracy: 27809/40000 (70%)

Performing 101th epoch
Time 

Time for epoch pass 14.564223051071167
Valid set: Average loss: CE 1.1858, MSE 0.0004, NCE 10.1317, Accuracy: 27837/40000 (70%)

Performing 105th epoch
Time for epoch pass 249.1065468788147
Train set: Average loss: CE 0.0233, MSE 0.0004, NCE 1.2847, Accuracy: 318647/320000 (100%)

Time for epoch pass 15.581726312637329
Valid set: Average loss: CE 1.1901, MSE 0.0004, NCE 10.1370, Accuracy: 27867/40000 (70%)

Model Saved
Performing 106th epoch
Time for epoch pass 247.5387351512909
Train set: Average loss: CE 0.0222, MSE 0.0004, NCE 1.2740, Accuracy: 318675/320000 (100%)

Time for epoch pass 15.14485478401184
Valid set: Average loss: CE 1.1903, MSE 0.0004, NCE 10.1338, Accuracy: 27865/40000 (70%)

Performing 107th epoch
Time for epoch pass 247.99124884605408
Train set: Average loss: CE 0.0209, MSE 0.0004, NCE 1.2873, Accuracy: 318832/320000 (100%)

Time for epoch pass 14.772210597991943
Valid set: Average loss: CE 1.1990, MSE 0.0004, NCE 10.1351, Accuracy: 27915/40000 (70%)

Model Saved
P

Time for epoch pass 14.899242639541626
Valid set: Average loss: CE 1.2159, MSE 0.0004, NCE 10.1338, Accuracy: 27851/40000 (70%)

Performing 112th epoch
Time for epoch pass 249.84858751296997
Train set: Average loss: CE 0.0186, MSE 0.0004, NCE 1.2809, Accuracy: 318971/320000 (100%)

Time for epoch pass 14.948431491851807
Valid set: Average loss: CE 1.2132, MSE 0.0004, NCE 10.1335, Accuracy: 27867/40000 (70%)

Performing 113th epoch
Time for epoch pass 249.87631797790527
Train set: Average loss: CE 0.0173, MSE 0.0004, NCE 1.2731, Accuracy: 319115/320000 (100%)

Time for epoch pass 14.84443211555481
Valid set: Average loss: CE 1.2183, MSE 0.0004, NCE 10.1324, Accuracy: 27869/40000 (70%)

Performing 114th epoch
Time for epoch pass 249.4764347076416
Train set: Average loss: CE 0.0174, MSE 0.0004, NCE 1.2729, Accuracy: 319087/320000 (100%)

Time for epoch pass 15.594929456710815
Valid set: Average loss: CE 1.2183, MSE 0.0004, NCE 10.1359, Accuracy: 27913/40000 (70%)

Performing 115th epoch
T

Time for epoch pass 15.379642248153687
Valid set: Average loss: CE 1.2270, MSE 0.0004, NCE 10.1351, Accuracy: 27857/40000 (70%)

Performing 119th epoch
Time for epoch pass 249.5175757408142
Train set: Average loss: CE 0.0155, MSE 0.0004, NCE 1.2738, Accuracy: 319229/320000 (100%)

Time for epoch pass 14.905559778213501
Valid set: Average loss: CE 1.2278, MSE 0.0004, NCE 10.1339, Accuracy: 27874/40000 (70%)

Performing 120th epoch
Time for epoch pass 249.5138921737671
Train set: Average loss: CE 0.0152, MSE 0.0004, NCE 1.2709, Accuracy: 319242/320000 (100%)

Time for epoch pass 14.742491245269775
Valid set: Average loss: CE 1.2351, MSE 0.0004, NCE 10.1367, Accuracy: 27876/40000 (70%)

Performing 121th epoch
Time for epoch pass 250.08604907989502
Train set: Average loss: CE 0.0148, MSE 0.0004, NCE 1.2660, Accuracy: 319303/320000 (100%)

Time for epoch pass 14.557133197784424
Valid set: Average loss: CE 1.2349, MSE 0.0004, NCE 10.1355, Accuracy: 27867/40000 (70%)

Performing 122th epoch
T

Time for epoch pass 15.15276050567627
Valid set: Average loss: CE 1.2429, MSE 0.0004, NCE 10.1334, Accuracy: 27887/40000 (70%)

Performing 126th epoch
Time for epoch pass 248.78944611549377
Train set: Average loss: CE 0.0138, MSE 0.0004, NCE 1.2722, Accuracy: 319349/320000 (100%)

Time for epoch pass 14.659951210021973
Valid set: Average loss: CE 1.2467, MSE 0.0004, NCE 10.1341, Accuracy: 27907/40000 (70%)

Performing 127th epoch
Time for epoch pass 249.32333302497864
Train set: Average loss: CE 0.0133, MSE 0.0004, NCE 1.2738, Accuracy: 319397/320000 (100%)

Time for epoch pass 14.79986834526062
Valid set: Average loss: CE 1.2459, MSE 0.0004, NCE 10.1334, Accuracy: 27898/40000 (70%)

Performing 128th epoch
Time for epoch pass 249.88666605949402
Train set: Average loss: CE 0.0130, MSE 0.0004, NCE 1.2644, Accuracy: 319437/320000 (100%)

Time for epoch pass 15.181530475616455
Valid set: Average loss: CE 1.2468, MSE 0.0004, NCE 10.1384, Accuracy: 27966/40000 (70%)

Model Saved
Performing 1

Time for epoch pass 14.78565788269043
Valid set: Average loss: CE 1.2473, MSE 0.0004, NCE 10.1320, Accuracy: 27952/40000 (70%)

Performing 133th epoch
Time for epoch pass 248.27972054481506
Train set: Average loss: CE 0.0116, MSE 0.0004, NCE 1.2950, Accuracy: 319561/320000 (100%)

Time for epoch pass 14.541817426681519
Valid set: Average loss: CE 1.2459, MSE 0.0004, NCE 10.1363, Accuracy: 27959/40000 (70%)

Performing 134th epoch
Time for epoch pass 249.07976412773132
Train set: Average loss: CE 0.0115, MSE 0.0004, NCE 1.2833, Accuracy: 319553/320000 (100%)

Time for epoch pass 14.963905334472656
Valid set: Average loss: CE 1.2457, MSE 0.0004, NCE 10.1337, Accuracy: 27962/40000 (70%)

Performing 135th epoch
Time for epoch pass 249.491797208786
Train set: Average loss: CE 0.0118, MSE 0.0004, NCE 1.2804, Accuracy: 319497/320000 (100%)

Time for epoch pass 14.678524017333984
Valid set: Average loss: CE 1.2493, MSE 0.0004, NCE 10.1358, Accuracy: 27983/40000 (70%)

Model Saved
Performing 13

Time for epoch pass 14.940871953964233
Valid set: Average loss: CE 1.2483, MSE 0.0004, NCE 10.1338, Accuracy: 27921/40000 (70%)

Performing 140th epoch
Time for epoch pass 249.10780429840088
Train set: Average loss: CE 0.0112, MSE 0.0004, NCE 1.2907, Accuracy: 319521/320000 (100%)

Time for epoch pass 14.725669622421265
Valid set: Average loss: CE 1.2493, MSE 0.0004, NCE 10.1385, Accuracy: 27945/40000 (70%)

Performing 141th epoch
Time for epoch pass 247.5449299812317
Train set: Average loss: CE 0.0113, MSE 0.0004, NCE 1.2762, Accuracy: 319549/320000 (100%)

Time for epoch pass 15.190996408462524
Valid set: Average loss: CE 1.2526, MSE 0.0004, NCE 10.1342, Accuracy: 27948/40000 (70%)

Performing 142th epoch
Time for epoch pass 247.3822112083435
Train set: Average loss: CE 0.0114, MSE 0.0004, NCE 1.2726, Accuracy: 319523/320000 (100%)

Time for epoch pass 14.599154233932495
Valid set: Average loss: CE 1.2519, MSE 0.0004, NCE 10.1375, Accuracy: 27935/40000 (70%)

Performing 143th epoch
T

Time for epoch pass 15.731042385101318
Valid set: Average loss: CE 1.2498, MSE 0.0004, NCE 10.1352, Accuracy: 27989/40000 (70%)

Performing 147th epoch
Time for epoch pass 249.85174131393433
Train set: Average loss: CE 0.0111, MSE 0.0004, NCE 1.2737, Accuracy: 319576/320000 (100%)

Time for epoch pass 14.62201738357544
Valid set: Average loss: CE 1.2501, MSE 0.0004, NCE 10.1390, Accuracy: 27944/40000 (70%)

Performing 148th epoch
Time for epoch pass 249.343670129776
Train set: Average loss: CE 0.0112, MSE 0.0004, NCE 1.2800, Accuracy: 319563/320000 (100%)

Time for epoch pass 15.20651650428772
Valid set: Average loss: CE 1.2514, MSE 0.0004, NCE 10.1326, Accuracy: 27994/40000 (70%)

Model Saved
Performing 149th epoch
Time for epoch pass 247.69167137145996
Train set: Average loss: CE 0.0108, MSE 0.0004, NCE 1.2754, Accuracy: 319587/320000 (100%)

Time for epoch pass 15.51503849029541
Valid set: Average loss: CE 1.2482, MSE 0.0004, NCE 10.1379, Accuracy: 27961/40000 (70%)

Performing 150t

Time for epoch pass 15.179227352142334
Valid set: Average loss: CE 1.2512, MSE 0.0004, NCE 10.1353, Accuracy: 27985/40000 (70%)

Performing 154th epoch
Time for epoch pass 249.50252223014832
Train set: Average loss: CE 0.0109, MSE 0.0004, NCE 1.2719, Accuracy: 319587/320000 (100%)

Time for epoch pass 14.783946990966797
Valid set: Average loss: CE 1.2553, MSE 0.0004, NCE 10.1387, Accuracy: 27967/40000 (70%)

Performing 155th epoch
Time for epoch pass 249.47020435333252
Train set: Average loss: CE 0.0108, MSE 0.0004, NCE 1.2711, Accuracy: 319584/320000 (100%)

Time for epoch pass 14.738401889801025
Valid set: Average loss: CE 1.2532, MSE 0.0004, NCE 10.1341, Accuracy: 27981/40000 (70%)

Performing 156th epoch
Time for epoch pass 245.405042886734
Train set: Average loss: CE 0.0108, MSE 0.0004, NCE 1.2726, Accuracy: 319624/320000 (100%)

Time for epoch pass 14.789666891098022
Valid set: Average loss: CE 1.2548, MSE 0.0004, NCE 10.1385, Accuracy: 27963/40000 (70%)

Performing 157th epoch
T

Time for epoch pass 15.086047887802124
Valid set: Average loss: CE 1.2553, MSE 0.0004, NCE 10.1344, Accuracy: 27966/40000 (70%)

Performing 161th epoch
Time for epoch pass 249.3591628074646
Train set: Average loss: CE 0.0108, MSE 0.0004, NCE 1.2686, Accuracy: 319562/320000 (100%)

Time for epoch pass 15.299856424331665
Valid set: Average loss: CE 1.2547, MSE 0.0004, NCE 10.1381, Accuracy: 27994/40000 (70%)

Performing 162th epoch
Time for epoch pass 249.6866021156311
Train set: Average loss: CE 0.0104, MSE 0.0004, NCE 1.2734, Accuracy: 319642/320000 (100%)

Time for epoch pass 15.154332160949707
Valid set: Average loss: CE 1.2556, MSE 0.0004, NCE 10.1336, Accuracy: 27943/40000 (70%)

Performing 163th epoch
Time for epoch pass 249.64747381210327
Train set: Average loss: CE 0.0108, MSE 0.0004, NCE 1.2765, Accuracy: 319568/320000 (100%)

Time for epoch pass 14.694899559020996
Valid set: Average loss: CE 1.2517, MSE 0.0004, NCE 10.1394, Accuracy: 27985/40000 (70%)

Performing 164th epoch
T

Time for epoch pass 14.508657932281494
Valid set: Average loss: CE 1.2584, MSE 0.0004, NCE 10.1351, Accuracy: 27965/40000 (70%)

Performing 168th epoch
Time for epoch pass 248.6991786956787
Train set: Average loss: CE 0.0106, MSE 0.0004, NCE 1.2720, Accuracy: 319598/320000 (100%)

Time for epoch pass 15.282838106155396
Valid set: Average loss: CE 1.2592, MSE 0.0004, NCE 10.1384, Accuracy: 27956/40000 (70%)

Performing 169th epoch
Time for epoch pass 249.48671007156372
Train set: Average loss: CE 0.0105, MSE 0.0004, NCE 1.2709, Accuracy: 319619/320000 (100%)

Time for epoch pass 14.606862545013428
Valid set: Average loss: CE 1.2601, MSE 0.0004, NCE 10.1333, Accuracy: 28005/40000 (70%)

Model Saved
Performing 170th epoch
Time for epoch pass 244.78806614875793
Train set: Average loss: CE 0.0103, MSE 0.0004, NCE 1.2700, Accuracy: 319636/320000 (100%)

Time for epoch pass 15.197943925857544
Valid set: Average loss: CE 1.2563, MSE 0.0004, NCE 10.1382, Accuracy: 27999/40000 (70%)

Performing 

Time for epoch pass 14.606724739074707
Valid set: Average loss: CE 1.2614, MSE 0.0004, NCE 10.1360, Accuracy: 27970/40000 (70%)

Performing 175th epoch
Time for epoch pass 247.67649269104004
Train set: Average loss: CE 0.0103, MSE 0.0004, NCE 1.2735, Accuracy: 319642/320000 (100%)

Time for epoch pass 14.560242176055908
Valid set: Average loss: CE 1.2601, MSE 0.0004, NCE 10.1369, Accuracy: 27958/40000 (70%)

Performing 176th epoch
Time for epoch pass 248.83107590675354
Train set: Average loss: CE 0.0102, MSE 0.0004, NCE 1.2704, Accuracy: 319643/320000 (100%)

Time for epoch pass 15.650834321975708
Valid set: Average loss: CE 1.2627, MSE 0.0004, NCE 10.1344, Accuracy: 27962/40000 (70%)

Performing 177th epoch
Time for epoch pass 249.30939030647278
Train set: Average loss: CE 0.0104, MSE 0.0004, NCE 1.2674, Accuracy: 319630/320000 (100%)

Time for epoch pass 15.077457189559937
Valid set: Average loss: CE 1.2592, MSE 0.0004, NCE 10.1390, Accuracy: 28000/40000 (70%)

Performing 178th epoch

Time for epoch pass 14.560480833053589
Valid set: Average loss: CE 1.2604, MSE 0.0004, NCE 10.1352, Accuracy: 27982/40000 (70%)

Performing 182th epoch
Time for epoch pass 248.10415148735046
Train set: Average loss: CE 0.0101, MSE 0.0004, NCE 1.2685, Accuracy: 319662/320000 (100%)

Time for epoch pass 14.879882335662842
Valid set: Average loss: CE 1.2560, MSE 0.0004, NCE 10.1378, Accuracy: 27977/40000 (70%)

Performing 183th epoch
Time for epoch pass 246.21627736091614
Train set: Average loss: CE 0.0101, MSE 0.0004, NCE 1.2677, Accuracy: 319658/320000 (100%)

Time for epoch pass 14.910043716430664
Valid set: Average loss: CE 1.2588, MSE 0.0004, NCE 10.1330, Accuracy: 28003/40000 (70%)

Performing 184th epoch
Time for epoch pass 245.2946183681488
Train set: Average loss: CE 0.0101, MSE 0.0004, NCE 1.2676, Accuracy: 319639/320000 (100%)

Time for epoch pass 15.076895236968994
Valid set: Average loss: CE 1.2595, MSE 0.0004, NCE 10.1390, Accuracy: 27968/40000 (70%)

Performing 185th epoch


Time for epoch pass 14.730249643325806
Valid set: Average loss: CE 1.2591, MSE 0.0004, NCE 10.1361, Accuracy: 27981/40000 (70%)

Performing 189th epoch
Time for epoch pass 249.6375744342804
Train set: Average loss: CE 0.0099, MSE 0.0004, NCE 1.2638, Accuracy: 319691/320000 (100%)

Time for epoch pass 14.450082778930664
Valid set: Average loss: CE 1.2609, MSE 0.0004, NCE 10.1379, Accuracy: 27997/40000 (70%)

Performing 190th epoch
Time for epoch pass 248.56309175491333
Train set: Average loss: CE 0.0098, MSE 0.0004, NCE 1.3722, Accuracy: 319679/320000 (100%)

Time for epoch pass 14.802314519882202
Valid set: Average loss: CE 1.2588, MSE 0.0004, NCE 10.1389, Accuracy: 28008/40000 (70%)

Performing 191th epoch
Time for epoch pass 247.75056147575378
Train set: Average loss: CE 0.0098, MSE 0.0004, NCE 1.3702, Accuracy: 319680/320000 (100%)

Time for epoch pass 15.40983772277832
Valid set: Average loss: CE 1.2607, MSE 0.0004, NCE 10.1377, Accuracy: 27991/40000 (70%)

Performing 192th epoch
T

Time for epoch pass 14.843295097351074
Valid set: Average loss: CE 1.2609, MSE 0.0004, NCE 10.1390, Accuracy: 27959/40000 (70%)

Performing 196th epoch
Time for epoch pass 249.3142650127411
Train set: Average loss: CE 0.0101, MSE 0.0004, NCE 1.3230, Accuracy: 319634/320000 (100%)

Time for epoch pass 15.368137121200562
Valid set: Average loss: CE 1.2585, MSE 0.0004, NCE 10.1365, Accuracy: 27993/40000 (70%)

Performing 197th epoch
Time for epoch pass 245.09973287582397
Train set: Average loss: CE 0.0100, MSE 0.0004, NCE 1.2849, Accuracy: 319659/320000 (100%)

Time for epoch pass 14.681443452835083
Valid set: Average loss: CE 1.2620, MSE 0.0004, NCE 10.1371, Accuracy: 27988/40000 (70%)

Performing 198th epoch
Time for epoch pass 248.62728905677795
Train set: Average loss: CE 0.0099, MSE 0.0004, NCE 1.2959, Accuracy: 319685/320000 (100%)

Time for epoch pass 15.503095149993896
Valid set: Average loss: CE 1.2621, MSE 0.0004, NCE 10.1363, Accuracy: 27982/40000 (70%)

Performing 199th epoch


Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send

KeyboardInterrupt: 

In [9]:
#save model again with different name
save_path = "weights/AlexNet_Decoupling_200.pth"
states = {
                'epoch': 200,
                'feature_net':Networks['feature'].state_dict(),
                'classifier_net':Networks['classifier'].state_dict(),
                'transformation_net':Networks['transformation'].state_dict(),
                'feature_optimizer': Optimizers['feature'].state_dict(),
                'classifier_optimizer': Optimizers['classifier'].state_dict(),
                'transformation_optimizer': Optimizers['transformation'].state_dict()
            }
torch.save(states, save_path)