In [1]:
import os
from GetDataLoaders import get_dataloaders, get_short_dataloaders
from architectures.AlexNetFeature import AlexNetFeature
from architectures.AlexNetClassifierModified import AlexNetClassifier
from architectures.LinearTransformationNormModified import LinearTransformationNorm, Normalize
from architectures.ContrastiveLoss import ContrastiveLoss
#from architectures.NCEAverage import NCEAverage
#from architectures.NCELoss import NCELoss
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch import nn
import time
from torch import optim

In [2]:
'''
# we skip the probs for now
gama = 2.0
with open(os.path.join("./PUprobs", 'prob.dat'), 'r') as file_input:
    train_prob_str = file_input.readlines()
    train_prob = [float(i_prob_str.rstrip('\n')) for i_prob_str in train_prob_str]
    print(len(train_prob)/4.0)
    train_weight = [1.0 if 0==i%4 else 1-train_prob[i]**gama for i in range(len(train_prob))]
'''

'\n# we skip the probs for now\ngama = 2.0\nwith open(os.path.join("./PUprobs", \'prob.dat\'), \'r\') as file_input:\n    train_prob_str = file_input.readlines()\n    train_prob = [float(i_prob_str.rstrip(\'\n\')) for i_prob_str in train_prob_str]\n    print(len(train_prob)/4.0)\n    train_weight = [1.0 if 0==i%4 else 1-train_prob[i]**gama for i in range(len(train_prob))]\n'

In [3]:
use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"
batch_size = 192
lr = 0.1
LUT_lr = [(90,0.01), (130,0.001), (190,0.0001), (210,0.00001), (230,0.0001), (245,0.00001)]
num_epochs = 245
momentum = 0.9
weight_decay = 5e-4
nesterov = True
Lambdas = {'CE':1.0, 'MSE':1.0, 'NCE':1.0}

loaders = get_dataloaders('imagenet', batch_size=batch_size, num_workers=2)
ndata_train = len(loaders['train_loader'].dataset)
ndata_valid = len(loaders['valid_loader'].dataset)
t = 0.07
m = 4096
gamma = 2

In [4]:
ndata_train, ndata_valid

(80000, 10000)

In [5]:
#from torch.optim.lr_scheduler import ExponentialLR

feature_net = AlexNetFeature().to(device)
classifier_net = AlexNetClassifier().to(device)
transformation_net = LinearTransformationNorm().to(device)

feature_optimizer = optim.SGD(feature_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
classifier_optimizer = optim.SGD(classifier_net.parameters() ,lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
transformation_optimizer = optim.SGD(transformation_net.parameters() ,lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)

#Schedulers = {'feature':ExponentialLR(feature_optimizer, gamma=1e-6), 'classifier':ExponentialLR(classifier_optimizer, gamma=1e-6), 'transformation':ExponentialLR(transformation_optimizer, gamma=1e-6)}

Networks = {'feature':feature_net, 'classifier':classifier_net, 'transformation':transformation_net}
Optimizers = {'feature':feature_optimizer, 'classifier':classifier_optimizer, 'transformation':transformation_optimizer}

ContrastiveCriterion = ContrastiveLoss(tau=0.1)
#NCE = {'NCEAverage_train': NCEAverage(outputSize=ndata_train, K=m, T=t).cuda(), 'NCEAverage_valid': NCEAverage(outputSize=ndata_valid, K=m, T=t).cuda() , 'NCECriterion_train': NCELoss(ndata_train), 'NCECriterion_valid': NCELoss(ndata_valid)}
Criterions = {'CE': nn.CrossEntropyLoss(reduction='none'), 'MSE':nn.MSELoss() }

In [6]:
#for sample in loaders['train_loader']:
    #print(sample[0].size(), sample[1].size(), sample[2].size())

In [7]:
def adjust_learning_rates(epoch):
    # filter out the networks that are not trainable and that do
    # not have a learning rate Look Up Table (LUT_lr) in their optim_params
    lr = next((lr for (max_epoch, lr) in LUT_lr if max_epoch>epoch), LUT_lr[-1][1])
    for key in Optimizers:
        for g in Optimizers[key].param_groups:
            g['lr'] = lr

def AlexNetDecoupling(batch_idx, batch, PUWeights=None, train=True, accumulation_steps=4):
    data, targets, indices = batch
    
    if train is True:
        Optimizers['feature'].zero_grad()
        Optimizers['classifier'].zero_grad()
        Optimizers['transformation'].zero_grad()
    
    #to cuda
    
    data = data.to(device)
    targets = targets.to(device)
    indices = indices.to(device)

    
    #perform rotations & adjust data shape
    data_90 = torch.flip(torch.transpose(data,2,3),[2])
    data_180 = torch.flip(torch.flip(data,[2]),[3])
    data_270 = torch.transpose(torch.flip(data,[2]),2,3)
    data = torch.stack([data, data_90, data_180, data_270], dim=1)
    batch_size, rotations, channels, height, width = data.size()
    data = data.view([batch_size*rotations, channels, height, width])
    
    #debug for backward
    data.requires_grad = False 
    targets.requires_grad = False
    indices.requires_grad = False
    
    #collect features
    features = Networks['feature'](data)
    features_rot, features_invariance = torch.split(features, 2048, dim=1)
    
    #collect rotation prediction
    pred = Networks['classifier'](features_rot)
    
    
    #average features across 4 rotations
    features_invariance_instance = features_invariance[0::4,:] + features_invariance[1::4,:] + features_invariance[2::4,:] + features_invariance[3::4,:]
    features_invariance_instance = torch.mul(features_invariance_instance, 0.25) #fbar 192x2048
    
    #downsample to 128 & perform normalization of vector
    
    features_128_norm = Networks['transformation'](features_invariance_instance)#192x128
    features_128_norm_0 = Networks['transformation'](features_invariance[0::4,:])
    features_128_norm_90 = Networks['transformation'](features_invariance[1::4,:])
    features_128_norm_180 = Networks['transformation'](features_invariance[2::4,:])
    features_128_norm_270 = Networks['transformation'](features_invariance[3::4,:])
    
    features_128_list = [features_128_norm_0, features_128_norm_90, features_128_norm_180, features_128_norm_270]
    
    with torch.no_grad():
        #stack 192x2048 4 times to be  = 192x4x2048 = 768x2048
        features_invariance_instance_mean = torch.unsqueeze(features_invariance_instance,1).expand(-1,4,-1).clone()
        features_invariance_instance_mean = features_invariance_instance_mean.view(4*len(features_invariance_instance), 2048) #2048
    
    #calculate rotation loss ignore PU for now
    loss_cls_each = Criterions['CE'](pred, targets)
    loss_cls = torch.sum(loss_cls_each)/loss_cls_each.shape[0]
    
    #calculate rotation invariance by MSE
    with torch.no_grad():
        loss_mse = Criterions['MSE'](features_invariance, features_invariance_instance_mean)
    
    #calculate instance loss using NT-xent
    loss_nce = 0.0
    loss_nce = ContrastiveCriterion(features_128_norm_0, features_128_norm) + ContrastiveCriterion(features_128_norm_90, features_128_norm) + ContrastiveCriterion(features_128_norm_180, features_128_norm) + ContrastiveCriterion(features_128_norm_270, features_128_norm)
    
    loss_total = Lambdas['CE']*loss_cls + Lambdas['NCE']*loss_nce
    
    if train is True:
        loss_total.backward()
        Optimizers['feature'].step()
        Optimizers['classifier'].step()
        Optimizers['transformation'].step()
            
    
    #calculate classification accuracy
    pred = pred.argmax(dim=1, keepdim=True)
    correct = pred.eq(targets.view_as(pred)).sum().item()
    
    losses = {'ce':loss_cls.item(), 'mse':loss_mse.item(), 'nce':loss_nce.item(), 'correct':correct}
    
    return losses
    

In [8]:
def train(data_loader, epoch, log_interval=50):
    
    Networks['feature'].train()
    Networks['classifier'].train()
    Networks['transformation'].train()
    
    losses = {'ce':[], 'mse':[], 'nce':[]}
    correct = 0
    train_loss = np.Inf
    train_acc = 0.0
    
    start_time = time.time()
    for batch_idx, sample in enumerate(data_loader):
        
        lossesdict = AlexNetDecoupling(batch_idx, sample, train=True)
        
        losses['ce'].append(lossesdict['ce'])
        losses['mse'].append(lossesdict['mse'])
        losses['nce'].append(lossesdict['nce'])
        correct += lossesdict['correct']
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: CE {:.6f}, MSE {:.6f}, NT-Xent {:.6f}'.format(epoch, batch_idx*len(sample[0]), len(data_loader.dataset), 100. * batch_idx / len(data_loader), lossesdict['ce'], lossesdict['mse'], lossesdict['nce']))
    adjust_learning_rates(epoch)
    end_time = time.time()
    print("Time for epoch pass {}".format(end_time-start_time))
    train_loss = {'ce': float(np.mean(losses['ce'])), 'mse': float(np.mean(losses['mse'])), 'nce':float(np.mean(losses['nce']))}
    train_acc = correct / float(len(data_loader.dataset)*4)
    print('Train set: Average loss: CE {:.4f}, MSE {:.4f}, NT-Xent {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(train_loss['ce'], train_loss['mse'], train_loss['nce'], correct, len(data_loader.dataset)*4, 100.0*train_acc))
    return train_loss, train_acc

def validate(data_loader, epoch, log_interval=50):
    
    Networks['feature'].eval()
    Networks['classifier'].eval()
    Networks['transformation'].eval()
    
    losses = {'ce':[], 'mse':[], 'nce':[]}
    correct = 0
    valid_loss = np.Inf
    valid_acc = 0.0
    start_time = time.time()
    for batch_idx, sample in enumerate(data_loader): 
        with torch.no_grad():
            lossesdict = AlexNetDecoupling(batch_idx, sample, train=False)
        
        losses['ce'].append(lossesdict['ce'])
        losses['mse'].append(lossesdict['mse'])
        losses['nce'].append(lossesdict['nce'])
        correct += lossesdict['correct']
        if batch_idx % log_interval == 0:
            print('Valid Epoch: {} [{}/{} ({:.0f}%)]\tLoss: CE {:.6f}, MSE {:.6f}, NT-Xent {:.6f}'.format(epoch, batch_idx*len(sample[0]), len(data_loader.dataset), 100. * batch_idx / len(data_loader), lossesdict['ce'], lossesdict['mse'], lossesdict['nce']))
    
    end_time = time.time()
    print("Time for epoch pass {}".format(end_time-start_time))
    valid_loss = {'ce': float(np.mean(losses['ce'])), 'mse': float(np.mean(losses['mse'])), 'nce':float(np.mean(losses['nce']))}
    valid_acc = correct / float(len(data_loader.dataset)*4)
    print('Valid set: Average loss: CE {:.4f}, MSE {:.4f}, NT-Xent {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(valid_loss['ce'], valid_loss['mse'], valid_loss['nce'], correct, len(data_loader.dataset)*4, 100.0*valid_acc))
    return valid_loss, valid_acc

def run_main_loop(loaders, num_epochs):
    writer = SummaryWriter('./logs/AlexNet_Unsupervised_Decoupling_Contrastive')
    save_path = "weights/AlexNet_Decoupling_Contrastive.pth"
    best_acc = 0.0
    for epoch in range(num_epochs):
        print("Performing {}th epoch".format(epoch))
        train_loss, train_acc = train(loaders['train_loader'], epoch)
        val_loss, val_acc = validate(loaders['valid_loader'], epoch)
        
        writer.add_scalar('CELoss/train', train_loss['ce'], epoch)
        writer.add_scalar('MSELoss/train', train_loss['mse'], epoch)
        writer.add_scalar('NT-XENTLoss/train', train_loss['nce'], epoch)
        
        writer.add_scalar('CELoss/Valid', val_loss['ce'], epoch)
        writer.add_scalar('MSELoss/Valid', val_loss['mse'], epoch)
        writer.add_scalar('NT-XENTLoss/Valid', val_loss['nce'], epoch)
        
        writer.add_scalar('Accuracy/train', train_acc, epoch)
        writer.add_scalar('Accuracy/Valid', val_acc, epoch)
    
        writer.add_scalar('LR', Optimizers['feature'].param_groups[0]['lr'], epoch)
        
        if val_acc > best_acc :
            best_acc = val_acc
            #save model
            states = {
                'epoch': epoch + 1,
                'feature_net':Networks['feature'].state_dict(),
                'classifier_net':Networks['classifier'].state_dict(),
                'transformation_net':Networks['transformation'].state_dict(),
                'feature_optimizer': Optimizers['feature'].state_dict(),
                'classifier_optimizer': Optimizers['classifier'].state_dict(),
                'transformation_optimizer': Optimizers['transformation'].state_dict(),
                'best_accuracy': best_acc
            }
            torch.save(states, save_path)
            print('Model Saved')

In [9]:
run_main_loop(loaders, num_epochs)

Performing 0th epoch
Time for epoch pass 245.0652515888214
Train set: Average loss: CE 1.2866, MSE 0.0033, NT-Xent 15.5817, Accuracy: 129314/320000 (40.4106%)

Time for epoch pass 16.432915449142456
Valid set: Average loss: CE 1.2650, MSE 0.0003, NT-Xent 8.6648, Accuracy: 16791/40000 (41.9775%)

Model Saved
Performing 1th epoch
Time for epoch pass 246.45639777183533
Train set: Average loss: CE 1.2477, MSE 0.0004, NT-Xent 7.1947, Accuracy: 136797/320000 (42.7491%)

Time for epoch pass 13.805333614349365
Valid set: Average loss: CE 1.2358, MSE 0.0003, NT-Xent 7.5638, Accuracy: 17502/40000 (43.7550%)

Model Saved
Performing 2th epoch
Time for epoch pass 249.0085892677307
Train set: Average loss: CE 1.2348, MSE 0.0004, NT-Xent 7.0921, Accuracy: 139233/320000 (43.5103%)

Time for epoch pass 13.0483238697052
Valid set: Average loss: CE 1.2288, MSE 0.0004, NT-Xent 7.5557, Accuracy: 17620/40000 (44.0500%)

Model Saved
Performing 3th epoch
Time for epoch pass 249.5250265598297
Train set: Averag

Time for epoch pass 12.899427652359009
Valid set: Average loss: CE 1.2080, MSE 0.0007, NT-Xent 3.5117, Accuracy: 18204/40000 (45.5100%)

Model Saved
Performing 7th epoch
Time for epoch pass 250.29778814315796
Train set: Average loss: CE 1.2047, MSE 0.0006, NT-Xent 2.6987, Accuracy: 145716/320000 (45.5363%)

Time for epoch pass 14.285152435302734
Valid set: Average loss: CE 1.1944, MSE 0.0006, NT-Xent 3.1178, Accuracy: 18472/40000 (46.1800%)

Model Saved
Performing 8th epoch
Time for epoch pass 249.97857284545898
Train set: Average loss: CE 1.1953, MSE 0.0006, NT-Xent 2.4434, Accuracy: 147420/320000 (46.0688%)

Time for epoch pass 15.346858263015747
Valid set: Average loss: CE 1.1888, MSE 0.0006, NT-Xent 3.0428, Accuracy: 18696/40000 (46.7400%)

Model Saved
Performing 9th epoch
Time for epoch pass 277.01947236061096
Train set: Average loss: CE 1.1879, MSE 0.0006, NT-Xent 2.1785, Accuracy: 148850/320000 (46.5156%)

Time for epoch pass 13.606786012649536
Valid set: Average loss: CE 1.1886

Time for epoch pass 249.0567741394043
Train set: Average loss: CE 1.1531, MSE 0.0004, NT-Xent 1.4463, Accuracy: 156368/320000 (48.8650%)

Time for epoch pass 13.312241554260254
Valid set: Average loss: CE 1.1557, MSE 0.0004, NT-Xent 1.7938, Accuracy: 19519/40000 (48.7975%)

Model Saved
Performing 14th epoch
Time for epoch pass 245.6960391998291
Train set: Average loss: CE 1.1430, MSE 0.0004, NT-Xent 1.3170, Accuracy: 157818/320000 (49.3181%)

Time for epoch pass 14.918301105499268
Valid set: Average loss: CE 1.1512, MSE 0.0004, NT-Xent 1.7317, Accuracy: 19656/40000 (49.1400%)

Model Saved
Performing 15th epoch
Time for epoch pass 249.8580276966095
Train set: Average loss: CE 1.1330, MSE 0.0004, NT-Xent 1.2007, Accuracy: 160326/320000 (50.1019%)

Time for epoch pass 14.858738660812378
Valid set: Average loss: CE 1.1446, MSE 0.0004, NT-Xent 1.6516, Accuracy: 19847/40000 (49.6175%)

Model Saved
Performing 16th epoch
Time for epoch pass 250.17377829551697
Train set: Average loss: CE 1.1227

Time for epoch pass 248.54260110855103
Train set: Average loss: CE 1.0743, MSE 0.0004, NT-Xent 0.8800, Accuracy: 171527/320000 (53.6022%)

Time for epoch pass 15.178144216537476
Valid set: Average loss: CE 1.0880, MSE 0.0004, NT-Xent 1.2492, Accuracy: 21197/40000 (52.9925%)

Model Saved
Performing 21th epoch
Time for epoch pass 249.04976797103882
Train set: Average loss: CE 1.0624, MSE 0.0004, NT-Xent 0.8356, Accuracy: 173371/320000 (54.1784%)

Time for epoch pass 13.94554877281189
Valid set: Average loss: CE 1.0732, MSE 0.0004, NT-Xent 1.1788, Accuracy: 21329/40000 (53.3225%)

Model Saved
Performing 22th epoch
Time for epoch pass 249.40667462348938
Train set: Average loss: CE 1.0486, MSE 0.0004, NT-Xent 0.8126, Accuracy: 176018/320000 (55.0056%)

Time for epoch pass 13.950123310089111
Valid set: Average loss: CE 1.0626, MSE 0.0004, NT-Xent 1.1535, Accuracy: 21843/40000 (54.6075%)

Model Saved
Performing 23th epoch
Time for epoch pass 249.71903276443481
Train set: Average loss: CE 1.03

Time for epoch pass 250.9131190776825
Train set: Average loss: CE 0.9902, MSE 0.0004, NT-Xent 0.7100, Accuracy: 185921/320000 (58.1003%)

Time for epoch pass 13.803361892700195
Valid set: Average loss: CE 1.0136, MSE 0.0004, NT-Xent 0.9969, Accuracy: 22999/40000 (57.4975%)

Model Saved
Performing 27th epoch
Time for epoch pass 245.38855814933777
Train set: Average loss: CE 0.9740, MSE 0.0004, NT-Xent 0.6796, Accuracy: 188781/320000 (58.9941%)

Time for epoch pass 13.544970750808716
Valid set: Average loss: CE 1.0172, MSE 0.0004, NT-Xent 0.9646, Accuracy: 22857/40000 (57.1425%)

Performing 28th epoch
Time for epoch pass 250.32291984558105
Train set: Average loss: CE 0.9585, MSE 0.0004, NT-Xent 0.6593, Accuracy: 191194/320000 (59.7481%)

Time for epoch pass 14.03407096862793
Valid set: Average loss: CE 0.9909, MSE 0.0004, NT-Xent 0.9631, Accuracy: 23405/40000 (58.5125%)

Model Saved
Performing 29th epoch
Time for epoch pass 249.427237033844
Train set: Average loss: CE 0.9419, MSE 0.0004,

Time for epoch pass 245.39498829841614
Train set: Average loss: CE 0.8868, MSE 0.0004, NT-Xent 0.5585, Accuracy: 202770/320000 (63.3656%)

Time for epoch pass 13.369072437286377
Valid set: Average loss: CE 0.9464, MSE 0.0004, NT-Xent 0.8603, Accuracy: 24393/40000 (60.9825%)

Model Saved
Performing 34th epoch
Time for epoch pass 245.44848608970642
Train set: Average loss: CE 0.8715, MSE 0.0004, NT-Xent 0.5349, Accuracy: 205283/320000 (64.1509%)

Time for epoch pass 14.700320482254028
Valid set: Average loss: CE 0.9346, MSE 0.0004, NT-Xent 0.8070, Accuracy: 24475/40000 (61.1875%)

Model Saved
Performing 35th epoch
Time for epoch pass 245.06179428100586
Train set: Average loss: CE 0.8597, MSE 0.0004, NT-Xent 0.5244, Accuracy: 207606/320000 (64.8769%)

Time for epoch pass 14.638298749923706
Valid set: Average loss: CE 0.9173, MSE 0.0004, NT-Xent 0.8253, Accuracy: 24755/40000 (61.8875%)

Model Saved
Performing 36th epoch
Time for epoch pass 250.2929322719574
Train set: Average loss: CE 0.84

Model Saved
Performing 40th epoch
Time for epoch pass 248.7535412311554
Train set: Average loss: CE 0.7913, MSE 0.0004, NT-Xent 0.4702, Accuracy: 217208/320000 (67.8775%)

Time for epoch pass 15.08474087715149
Valid set: Average loss: CE 0.8796, MSE 0.0003, NT-Xent 0.7889, Accuracy: 25650/40000 (64.1250%)

Model Saved
Performing 41th epoch
Time for epoch pass 245.86799907684326
Train set: Average loss: CE 0.7808, MSE 0.0004, NT-Xent 0.4606, Accuracy: 219077/320000 (68.4616%)

Time for epoch pass 14.467799425125122
Valid set: Average loss: CE 0.8986, MSE 0.0004, NT-Xent 0.7934, Accuracy: 25296/40000 (63.2400%)

Performing 42th epoch
Time for epoch pass 250.52836060523987
Train set: Average loss: CE 0.7649, MSE 0.0004, NT-Xent 0.4530, Accuracy: 221357/320000 (69.1741%)

Time for epoch pass 13.13620924949646
Valid set: Average loss: CE 0.8908, MSE 0.0004, NT-Xent 0.8078, Accuracy: 25468/40000 (63.6700%)

Performing 43th epoch
Time for epoch pass 250.82507967948914
Train set: Average loss:

Time for epoch pass 249.7010633945465
Train set: Average loss: CE 0.7118, MSE 0.0003, NT-Xent 0.4162, Accuracy: 228771/320000 (71.4909%)

Time for epoch pass 14.76507830619812
Valid set: Average loss: CE 0.8690, MSE 0.0003, NT-Xent 0.6682, Accuracy: 26027/40000 (65.0675%)

Model Saved
Performing 47th epoch
Time for epoch pass 248.63002061843872
Train set: Average loss: CE 0.6977, MSE 0.0003, NT-Xent 0.4102, Accuracy: 231198/320000 (72.2494%)

Time for epoch pass 14.154652118682861
Valid set: Average loss: CE 0.8763, MSE 0.0003, NT-Xent 0.7000, Accuracy: 26045/40000 (65.1125%)

Model Saved
Performing 48th epoch
Time for epoch pass 249.03373861312866
Train set: Average loss: CE 0.6840, MSE 0.0003, NT-Xent 0.4023, Accuracy: 232770/320000 (72.7406%)

Time for epoch pass 14.814438819885254
Valid set: Average loss: CE 0.9183, MSE 0.0003, NT-Xent 0.7581, Accuracy: 25222/40000 (63.0550%)

Performing 49th epoch
Time for epoch pass 249.75106716156006
Train set: Average loss: CE 0.6678, MSE 0.000

Time for epoch pass 250.5673999786377
Train set: Average loss: CE 0.6133, MSE 0.0003, NT-Xent 0.3619, Accuracy: 242706/320000 (75.8456%)

Time for epoch pass 13.625788927078247
Valid set: Average loss: CE 0.8680, MSE 0.0003, NT-Xent 0.6023, Accuracy: 26333/40000 (65.8325%)

Performing 54th epoch
Time for epoch pass 245.26786470413208
Train set: Average loss: CE 0.5993, MSE 0.0003, NT-Xent 0.3919, Accuracy: 244717/320000 (76.4741%)

Time for epoch pass 13.629765272140503
Valid set: Average loss: CE 0.8797, MSE 0.0003, NT-Xent 0.6159, Accuracy: 26364/40000 (65.9100%)

Performing 55th epoch
Time for epoch pass 248.2932403087616
Train set: Average loss: CE 0.5839, MSE 0.0003, NT-Xent 0.3439, Accuracy: 246757/320000 (77.1116%)

Time for epoch pass 13.679483890533447
Valid set: Average loss: CE 0.9025, MSE 0.0003, NT-Xent 0.6373, Accuracy: 26190/40000 (65.4750%)

Performing 56th epoch
Time for epoch pass 248.27594113349915
Train set: Average loss: CE 0.5711, MSE 0.0003, NT-Xent 0.3393, Accur

Time for epoch pass 250.16544389724731
Train set: Average loss: CE 0.5119, MSE 0.0003, NT-Xent 0.3119, Accuracy: 256350/320000 (80.1094%)

Time for epoch pass 14.975419998168945
Valid set: Average loss: CE 0.8962, MSE 0.0003, NT-Xent 0.5345, Accuracy: 26377/40000 (65.9425%)

Performing 61th epoch
Time for epoch pass 250.25404596328735
Train set: Average loss: CE 0.4937, MSE 0.0003, NT-Xent 0.3097, Accuracy: 259219/320000 (81.0059%)

Time for epoch pass 13.537293195724487
Valid set: Average loss: CE 0.9132, MSE 0.0003, NT-Xent 0.5478, Accuracy: 26472/40000 (66.1800%)

Performing 62th epoch
Time for epoch pass 250.3390028476715
Train set: Average loss: CE 0.4811, MSE 0.0003, NT-Xent 0.3069, Accuracy: 260590/320000 (81.4344%)

Time for epoch pass 13.978749990463257
Valid set: Average loss: CE 0.9131, MSE 0.0003, NT-Xent 0.5330, Accuracy: 26443/40000 (66.1075%)

Performing 63th epoch
Time for epoch pass 248.88970017433167
Train set: Average loss: CE 0.4656, MSE 0.0003, NT-Xent 0.3048, Accu

Time for epoch pass 14.063148975372314
Valid set: Average loss: CE 0.9441, MSE 0.0002, NT-Xent 0.4939, Accuracy: 26524/40000 (66.3100%)

Performing 67th epoch
Time for epoch pass 249.55094242095947
Train set: Average loss: CE 0.4062, MSE 0.0003, NT-Xent 0.2921, Accuracy: 270321/320000 (84.4753%)

Time for epoch pass 14.124279975891113
Valid set: Average loss: CE 0.9993, MSE 0.0003, NT-Xent 0.5081, Accuracy: 26212/40000 (65.5300%)

Performing 68th epoch
Time for epoch pass 244.9351167678833
Train set: Average loss: CE 0.3922, MSE 0.0003, NT-Xent 0.2893, Accuracy: 272275/320000 (85.0859%)

Time for epoch pass 13.044202327728271
Valid set: Average loss: CE 0.9998, MSE 0.0002, NT-Xent 0.5284, Accuracy: 26256/40000 (65.6400%)

Performing 69th epoch
Time for epoch pass 255.12275099754333
Train set: Average loss: CE 0.3803, MSE 0.0002, NT-Xent 0.2868, Accuracy: 273819/320000 (85.5684%)

Time for epoch pass 14.280521154403687
Valid set: Average loss: CE 0.9879, MSE 0.0002, NT-Xent 0.4768, Accu

Time for epoch pass 248.2335913181305
Train set: Average loss: CE 0.3233, MSE 0.0002, NT-Xent 0.2760, Accuracy: 281314/320000 (87.9106%)

Time for epoch pass 13.618389129638672
Valid set: Average loss: CE 1.0447, MSE 0.0002, NT-Xent 0.5084, Accuracy: 26539/40000 (66.3475%)

Performing 74th epoch
Time for epoch pass 249.42920112609863
Train set: Average loss: CE 0.3132, MSE 0.0002, NT-Xent 0.2729, Accuracy: 282345/320000 (88.2328%)

Time for epoch pass 13.061443567276001
Valid set: Average loss: CE 1.0520, MSE 0.0002, NT-Xent 0.4716, Accuracy: 26378/40000 (65.9450%)

Performing 75th epoch
Time for epoch pass 249.46642637252808
Train set: Average loss: CE 0.3038, MSE 0.0002, NT-Xent 0.2700, Accuracy: 283742/320000 (88.6694%)

Time for epoch pass 14.456207990646362
Valid set: Average loss: CE 1.0813, MSE 0.0002, NT-Xent 0.4678, Accuracy: 26335/40000 (65.8375%)

Performing 76th epoch
Time for epoch pass 249.41836786270142
Train set: Average loss: CE 0.2945, MSE 0.0002, NT-Xent 0.2653, Accu

Time for epoch pass 249.1072292327881
Train set: Average loss: CE 0.2541, MSE 0.0002, NT-Xent 0.2489, Accuracy: 289782/320000 (90.5569%)

Time for epoch pass 13.813164472579956
Valid set: Average loss: CE 1.1178, MSE 0.0002, NT-Xent 0.5032, Accuracy: 25934/40000 (64.8350%)

Performing 81th epoch
Time for epoch pass 245.81565141677856
Train set: Average loss: CE 0.2464, MSE 0.0002, NT-Xent 0.2468, Accuracy: 290757/320000 (90.8616%)

Time for epoch pass 13.626309633255005
Valid set: Average loss: CE 1.1311, MSE 0.0002, NT-Xent 0.4462, Accuracy: 26267/40000 (65.6675%)

Performing 82th epoch
Time for epoch pass 246.1366035938263
Train set: Average loss: CE 0.2371, MSE 0.0002, NT-Xent 0.2441, Accuracy: 292081/320000 (91.2753%)

Time for epoch pass 14.99554991722107
Valid set: Average loss: CE 1.1564, MSE 0.0002, NT-Xent 0.4510, Accuracy: 26319/40000 (65.7975%)

Performing 83th epoch
Time for epoch pass 249.18159246444702
Train set: Average loss: CE 0.2285, MSE 0.0002, NT-Xent 0.2414, Accura

Time for epoch pass 13.779422283172607
Valid set: Average loss: CE 1.1663, MSE 0.0002, NT-Xent 0.4338, Accuracy: 26554/40000 (66.3850%)

Performing 87th epoch
Time for epoch pass 244.06331491470337
Train set: Average loss: CE 0.2021, MSE 0.0002, NT-Xent 0.2287, Accuracy: 296600/320000 (92.6875%)

Time for epoch pass 14.843094825744629
Valid set: Average loss: CE 1.1879, MSE 0.0002, NT-Xent 0.4264, Accuracy: 26263/40000 (65.6575%)

Performing 88th epoch
Time for epoch pass 247.76409673690796
Train set: Average loss: CE 0.1941, MSE 0.0002, NT-Xent 0.2274, Accuracy: 297329/320000 (92.9153%)

Time for epoch pass 13.679208517074585
Valid set: Average loss: CE 1.2239, MSE 0.0002, NT-Xent 0.4060, Accuracy: 26264/40000 (65.6600%)

Performing 89th epoch
Time for epoch pass 248.49028515815735
Train set: Average loss: CE 0.1916, MSE 0.0002, NT-Xent 0.2254, Accuracy: 297762/320000 (93.0506%)

Time for epoch pass 13.302693843841553
Valid set: Average loss: CE 1.2048, MSE 0.0001, NT-Xent 0.3873, Acc

Time for epoch pass 249.92821288108826
Train set: Average loss: CE 0.0679, MSE 0.0001, NT-Xent 0.2105, Accuracy: 313966/320000 (98.1144%)

Time for epoch pass 14.850669622421265
Valid set: Average loss: CE 1.0989, MSE 0.0001, NT-Xent 0.3394, Accuracy: 27622/40000 (69.0550%)

Model Saved
Performing 94th epoch
Time for epoch pass 249.96076035499573
Train set: Average loss: CE 0.0606, MSE 0.0001, NT-Xent 0.2096, Accuracy: 314816/320000 (98.3800%)

Time for epoch pass 13.068204164505005
Valid set: Average loss: CE 1.1131, MSE 0.0001, NT-Xent 0.3386, Accuracy: 27649/40000 (69.1225%)

Model Saved
Performing 95th epoch
Time for epoch pass 244.3839259147644
Train set: Average loss: CE 0.0566, MSE 0.0001, NT-Xent 0.2088, Accuracy: 315256/320000 (98.5175%)

Time for epoch pass 13.748661994934082
Valid set: Average loss: CE 1.1228, MSE 0.0001, NT-Xent 0.3382, Accuracy: 27708/40000 (69.2700%)

Model Saved
Performing 96th epoch
Time for epoch pass 248.9021441936493
Train set: Average loss: CE 0.051

Time for epoch pass 249.6324167251587
Train set: Average loss: CE 0.0394, MSE 0.0001, NT-Xent 0.2058, Accuracy: 316887/320000 (99.0272%)

Time for epoch pass 13.740588426589966
Valid set: Average loss: CE 1.1754, MSE 0.0001, NT-Xent 0.3311, Accuracy: 27683/40000 (69.2075%)

Performing 101th epoch
Time for epoch pass 249.82156038284302
Train set: Average loss: CE 0.0381, MSE 0.0001, NT-Xent 0.2055, Accuracy: 317042/320000 (99.0756%)

Time for epoch pass 13.411169528961182
Valid set: Average loss: CE 1.1831, MSE 0.0001, NT-Xent 0.3304, Accuracy: 27672/40000 (69.1800%)

Performing 102th epoch
Time for epoch pass 249.7460823059082
Train set: Average loss: CE 0.0360, MSE 0.0001, NT-Xent 0.2048, Accuracy: 317292/320000 (99.1538%)

Time for epoch pass 15.36743950843811
Valid set: Average loss: CE 1.1858, MSE 0.0001, NT-Xent 0.3290, Accuracy: 27660/40000 (69.1500%)

Performing 103th epoch
Time for epoch pass 249.97894287109375
Train set: Average loss: CE 0.0341, MSE 0.0001, NT-Xent 0.2039, Acc

Time for epoch pass 14.716797590255737
Valid set: Average loss: CE 1.2241, MSE 0.0001, NT-Xent 0.3267, Accuracy: 27680/40000 (69.2000%)

Performing 107th epoch
Time for epoch pass 245.3345024585724
Train set: Average loss: CE 0.0289, MSE 0.0001, NT-Xent 0.2024, Accuracy: 317999/320000 (99.3747%)

Time for epoch pass 13.018553018569946
Valid set: Average loss: CE 1.2275, MSE 0.0001, NT-Xent 0.3251, Accuracy: 27682/40000 (69.2050%)

Performing 108th epoch
Time for epoch pass 249.14860844612122
Train set: Average loss: CE 0.0286, MSE 0.0001, NT-Xent 0.2019, Accuracy: 317913/320000 (99.3478%)

Time for epoch pass 14.087097406387329
Valid set: Average loss: CE 1.2270, MSE 0.0001, NT-Xent 0.3250, Accuracy: 27661/40000 (69.1525%)

Performing 109th epoch
Time for epoch pass 245.18302130699158
Train set: Average loss: CE 0.0271, MSE 0.0001, NT-Xent 0.2018, Accuracy: 318130/320000 (99.4156%)

Time for epoch pass 13.392793655395508
Valid set: Average loss: CE 1.2397, MSE 0.0001, NT-Xent 0.3246, A

Time for epoch pass 249.4979808330536
Train set: Average loss: CE 0.0240, MSE 0.0001, NT-Xent 0.2005, Accuracy: 318404/320000 (99.5012%)

Time for epoch pass 13.653282165527344
Valid set: Average loss: CE 1.2655, MSE 0.0001, NT-Xent 0.3222, Accuracy: 27723/40000 (69.3075%)

Model Saved
Performing 114th epoch
Time for epoch pass 247.87795162200928
Train set: Average loss: CE 0.0234, MSE 0.0001, NT-Xent 0.2006, Accuracy: 318510/320000 (99.5344%)

Time for epoch pass 13.402904987335205
Valid set: Average loss: CE 1.2676, MSE 0.0001, NT-Xent 0.3210, Accuracy: 27664/40000 (69.1600%)

Performing 115th epoch
Time for epoch pass 248.63235807418823
Train set: Average loss: CE 0.0230, MSE 0.0001, NT-Xent 0.1999, Accuracy: 318479/320000 (99.5247%)

Time for epoch pass 14.05826449394226
Valid set: Average loss: CE 1.2697, MSE 0.0001, NT-Xent 0.3209, Accuracy: 27671/40000 (69.1775%)

Performing 116th epoch
Time for epoch pass 249.42172145843506
Train set: Average loss: CE 0.0215, MSE 0.0001, NT-Xen

Time for epoch pass 250.36848187446594
Train set: Average loss: CE 0.0193, MSE 0.0001, NT-Xent 0.1987, Accuracy: 318794/320000 (99.6231%)

Time for epoch pass 13.252129793167114
Valid set: Average loss: CE 1.2888, MSE 0.0001, NT-Xent 0.3177, Accuracy: 27708/40000 (69.2700%)

Performing 121th epoch
Time for epoch pass 250.14043641090393
Train set: Average loss: CE 0.0190, MSE 0.0001, NT-Xent 0.1981, Accuracy: 318871/320000 (99.6472%)

Time for epoch pass 13.799894094467163
Valid set: Average loss: CE 1.2909, MSE 0.0001, NT-Xent 0.3165, Accuracy: 27642/40000 (69.1050%)

Performing 122th epoch
Time for epoch pass 244.94880270957947
Train set: Average loss: CE 0.0191, MSE 0.0001, NT-Xent 0.1975, Accuracy: 318793/320000 (99.6228%)

Time for epoch pass 12.950609922409058
Valid set: Average loss: CE 1.3014, MSE 0.0001, NT-Xent 0.3148, Accuracy: 27656/40000 (69.1400%)

Performing 123th epoch
Time for epoch pass 246.6922881603241
Train set: Average loss: CE 0.0181, MSE 0.0001, NT-Xent 0.1974, A

Time for epoch pass 249.8655252456665
Train set: Average loss: CE 0.0170, MSE 0.0001, NT-Xent 0.1960, Accuracy: 318981/320000 (99.6816%)

Time for epoch pass 13.959014415740967
Valid set: Average loss: CE 1.3134, MSE 0.0001, NT-Xent 0.3114, Accuracy: 27628/40000 (69.0700%)

Performing 127th epoch
Time for epoch pass 249.98911142349243
Train set: Average loss: CE 0.0165, MSE 0.0001, NT-Xent 0.1954, Accuracy: 319083/320000 (99.7134%)

Time for epoch pass 15.021501064300537
Valid set: Average loss: CE 1.3132, MSE 0.0001, NT-Xent 0.3114, Accuracy: 27731/40000 (69.3275%)

Model Saved
Performing 128th epoch
Time for epoch pass 249.9748499393463
Train set: Average loss: CE 0.0160, MSE 0.0001, NT-Xent 0.1955, Accuracy: 319111/320000 (99.7222%)

Time for epoch pass 13.153163194656372
Valid set: Average loss: CE 1.3201, MSE 0.0001, NT-Xent 0.3101, Accuracy: 27669/40000 (69.1725%)

Performing 129th epoch
Time for epoch pass 250.36520504951477
Train set: Average loss: CE 0.0160, MSE 0.0001, NT-Xen

Time for epoch pass 249.42291021347046
Train set: Average loss: CE 0.0144, MSE 0.0001, NT-Xent 0.1935, Accuracy: 319216/320000 (99.7550%)

Time for epoch pass 13.54158616065979
Valid set: Average loss: CE 1.3209, MSE 0.0001, NT-Xent 0.3086, Accuracy: 27660/40000 (69.1500%)

Performing 134th epoch
Time for epoch pass 249.6006634235382
Train set: Average loss: CE 0.0142, MSE 0.0001, NT-Xent 0.1935, Accuracy: 319245/320000 (99.7641%)

Time for epoch pass 13.51943039894104
Valid set: Average loss: CE 1.3169, MSE 0.0001, NT-Xent 0.3075, Accuracy: 27680/40000 (69.2000%)

Performing 135th epoch
Time for epoch pass 249.5199432373047
Train set: Average loss: CE 0.0142, MSE 0.0001, NT-Xent 0.1936, Accuracy: 319247/320000 (99.7647%)

Time for epoch pass 13.680839538574219
Valid set: Average loss: CE 1.3150, MSE 0.0001, NT-Xent 0.3091, Accuracy: 27660/40000 (69.1500%)

Performing 136th epoch
Time for epoch pass 244.5707414150238
Train set: Average loss: CE 0.0145, MSE 0.0001, NT-Xent 0.1935, Accur

Time for epoch pass 14.588085889816284
Valid set: Average loss: CE 1.3198, MSE 0.0001, NT-Xent 0.3072, Accuracy: 27692/40000 (69.2300%)

Performing 140th epoch
Time for epoch pass 248.70902037620544
Train set: Average loss: CE 0.0141, MSE 0.0001, NT-Xent 0.1934, Accuracy: 319280/320000 (99.7750%)

Time for epoch pass 13.909303426742554
Valid set: Average loss: CE 1.3223, MSE 0.0001, NT-Xent 0.3077, Accuracy: 27644/40000 (69.1100%)

Performing 141th epoch
Time for epoch pass 249.04858827590942
Train set: Average loss: CE 0.0135, MSE 0.0001, NT-Xent 0.1933, Accuracy: 319372/320000 (99.8038%)

Time for epoch pass 13.42787480354309
Valid set: Average loss: CE 1.3157, MSE 0.0001, NT-Xent 0.3065, Accuracy: 27676/40000 (69.1900%)

Performing 142th epoch
Time for epoch pass 249.69112372398376
Train set: Average loss: CE 0.0139, MSE 0.0001, NT-Xent 0.1931, Accuracy: 319306/320000 (99.7831%)

Time for epoch pass 15.124470472335815
Valid set: Average loss: CE 1.3166, MSE 0.0001, NT-Xent 0.3066, A

Time for epoch pass 249.58897519111633
Train set: Average loss: CE 0.0136, MSE 0.0001, NT-Xent 0.1928, Accuracy: 319328/320000 (99.7900%)

Time for epoch pass 14.038239002227783
Valid set: Average loss: CE 1.3262, MSE 0.0001, NT-Xent 0.3072, Accuracy: 27704/40000 (69.2600%)

Performing 147th epoch
Time for epoch pass 248.0299835205078
Train set: Average loss: CE 0.0135, MSE 0.0001, NT-Xent 0.1929, Accuracy: 319324/320000 (99.7888%)

Time for epoch pass 14.279763221740723
Valid set: Average loss: CE 1.3211, MSE 0.0001, NT-Xent 0.3067, Accuracy: 27693/40000 (69.2325%)

Performing 148th epoch
Time for epoch pass 247.89854645729065
Train set: Average loss: CE 0.0137, MSE 0.0001, NT-Xent 0.1925, Accuracy: 319299/320000 (99.7809%)

Time for epoch pass 13.656662464141846
Valid set: Average loss: CE 1.3242, MSE 0.0001, NT-Xent 0.3067, Accuracy: 27696/40000 (69.2400%)

Performing 149th epoch
Time for epoch pass 247.97035670280457
Train set: Average loss: CE 0.0138, MSE 0.0001, NT-Xent 0.1925, A

Time for epoch pass 250.23069286346436
Train set: Average loss: CE 0.0132, MSE 0.0001, NT-Xent 0.1925, Accuracy: 319351/320000 (99.7972%)

Time for epoch pass 13.195565700531006
Valid set: Average loss: CE 1.3240, MSE 0.0001, NT-Xent 0.3065, Accuracy: 27695/40000 (69.2375%)

Performing 154th epoch
Time for epoch pass 249.43707513809204
Train set: Average loss: CE 0.0134, MSE 0.0001, NT-Xent 0.1928, Accuracy: 319331/320000 (99.7909%)

Time for epoch pass 14.76328706741333
Valid set: Average loss: CE 1.3232, MSE 0.0001, NT-Xent 0.3059, Accuracy: 27696/40000 (69.2400%)

Performing 155th epoch
Time for epoch pass 247.93387579917908
Train set: Average loss: CE 0.0130, MSE 0.0001, NT-Xent 0.1924, Accuracy: 319441/320000 (99.8253%)

Time for epoch pass 13.716233253479004
Valid set: Average loss: CE 1.3251, MSE 0.0001, NT-Xent 0.3065, Accuracy: 27687/40000 (69.2175%)

Performing 156th epoch
Time for epoch pass 248.4121232032776
Train set: Average loss: CE 0.0133, MSE 0.0001, NT-Xent 0.1924, Ac

Time for epoch pass 13.562434434890747
Valid set: Average loss: CE 1.3220, MSE 0.0001, NT-Xent 0.3054, Accuracy: 27694/40000 (69.2350%)

Performing 160th epoch
Time for epoch pass 250.26774430274963
Train set: Average loss: CE 0.0134, MSE 0.0001, NT-Xent 0.1920, Accuracy: 319345/320000 (99.7953%)

Time for epoch pass 13.544274806976318
Valid set: Average loss: CE 1.3293, MSE 0.0001, NT-Xent 0.3054, Accuracy: 27716/40000 (69.2900%)

Performing 161th epoch
Time for epoch pass 250.48159909248352
Train set: Average loss: CE 0.0129, MSE 0.0001, NT-Xent 0.1921, Accuracy: 319434/320000 (99.8231%)

Time for epoch pass 13.286922931671143
Valid set: Average loss: CE 1.3243, MSE 0.0001, NT-Xent 0.3049, Accuracy: 27679/40000 (69.1975%)

Performing 162th epoch
Time for epoch pass 249.63604855537415
Train set: Average loss: CE 0.0130, MSE 0.0001, NT-Xent 0.1919, Accuracy: 319406/320000 (99.8144%)

Time for epoch pass 13.85911512374878
Valid set: Average loss: CE 1.3240, MSE 0.0001, NT-Xent 0.3058, A

Time for epoch pass 249.44499564170837
Train set: Average loss: CE 0.0124, MSE 0.0001, NT-Xent 0.1918, Accuracy: 319468/320000 (99.8337%)

Time for epoch pass 13.800668478012085
Valid set: Average loss: CE 1.3278, MSE 0.0001, NT-Xent 0.3060, Accuracy: 27712/40000 (69.2800%)

Performing 167th epoch
Time for epoch pass 250.05238318443298
Train set: Average loss: CE 0.0129, MSE 0.0001, NT-Xent 0.1918, Accuracy: 319367/320000 (99.8022%)

Time for epoch pass 13.763346195220947
Valid set: Average loss: CE 1.3281, MSE 0.0001, NT-Xent 0.3054, Accuracy: 27708/40000 (69.2700%)

Performing 168th epoch
Time for epoch pass 250.18766570091248
Train set: Average loss: CE 0.0128, MSE 0.0001, NT-Xent 0.1914, Accuracy: 319391/320000 (99.8097%)

Time for epoch pass 13.729069232940674
Valid set: Average loss: CE 1.3322, MSE 0.0001, NT-Xent 0.3041, Accuracy: 27687/40000 (69.2175%)

Performing 169th epoch
Time for epoch pass 250.09694600105286
Train set: Average loss: CE 0.0127, MSE 0.0001, NT-Xent 0.1914, 

Time for epoch pass 249.06423234939575
Train set: Average loss: CE 0.0127, MSE 0.0001, NT-Xent 0.1914, Accuracy: 319413/320000 (99.8166%)

Time for epoch pass 13.780777215957642
Valid set: Average loss: CE 1.3332, MSE 0.0001, NT-Xent 0.3043, Accuracy: 27712/40000 (69.2800%)

Performing 174th epoch
Time for epoch pass 248.8522505760193
Train set: Average loss: CE 0.0125, MSE 0.0001, NT-Xent 0.1913, Accuracy: 319444/320000 (99.8263%)

Time for epoch pass 14.91546630859375
Valid set: Average loss: CE 1.3324, MSE 0.0001, NT-Xent 0.3045, Accuracy: 27678/40000 (69.1950%)

Performing 175th epoch
Time for epoch pass 244.5576844215393
Train set: Average loss: CE 0.0125, MSE 0.0001, NT-Xent 0.1915, Accuracy: 319405/320000 (99.8141%)

Time for epoch pass 12.99723219871521
Valid set: Average loss: CE 1.3323, MSE 0.0001, NT-Xent 0.3035, Accuracy: 27645/40000 (69.1125%)

Performing 176th epoch
Time for epoch pass 244.5243787765503
Train set: Average loss: CE 0.0125, MSE 0.0001, NT-Xent 0.1914, Accur

Time for epoch pass 247.782888174057
Train set: Average loss: CE 0.0125, MSE 0.0001, NT-Xent 0.1914, Accuracy: 319405/320000 (99.8141%)

Time for epoch pass 13.714428424835205
Valid set: Average loss: CE 1.3288, MSE 0.0001, NT-Xent 0.3040, Accuracy: 27698/40000 (69.2450%)

Performing 180th epoch
Time for epoch pass 248.77330231666565
Train set: Average loss: CE 0.0126, MSE 0.0001, NT-Xent 0.1909, Accuracy: 319425/320000 (99.8203%)

Time for epoch pass 14.557491064071655
Valid set: Average loss: CE 1.3300, MSE 0.0001, NT-Xent 0.3037, Accuracy: 27639/40000 (69.0975%)

Performing 181th epoch
Time for epoch pass 249.21757221221924
Train set: Average loss: CE 0.0124, MSE 0.0001, NT-Xent 0.1913, Accuracy: 319428/320000 (99.8212%)

Time for epoch pass 14.910550832748413
Valid set: Average loss: CE 1.3279, MSE 0.0001, NT-Xent 0.3046, Accuracy: 27688/40000 (69.2200%)

Performing 182th epoch
Time for epoch pass 249.52028250694275
Train set: Average loss: CE 0.0122, MSE 0.0001, NT-Xent 0.1909, Ac

Time for epoch pass 249.74819350242615
Train set: Average loss: CE 0.0124, MSE 0.0001, NT-Xent 0.1908, Accuracy: 319399/320000 (99.8122%)

Time for epoch pass 13.892106056213379
Valid set: Average loss: CE 1.3272, MSE 0.0001, NT-Xent 0.3045, Accuracy: 27682/40000 (69.2050%)

Performing 187th epoch
Time for epoch pass 248.0968782901764
Train set: Average loss: CE 0.0123, MSE 0.0001, NT-Xent 0.1910, Accuracy: 319439/320000 (99.8247%)

Time for epoch pass 13.454894542694092
Valid set: Average loss: CE 1.3285, MSE 0.0001, NT-Xent 0.3030, Accuracy: 27664/40000 (69.1600%)

Performing 188th epoch
Time for epoch pass 247.97842693328857
Train set: Average loss: CE 0.0123, MSE 0.0001, NT-Xent 0.1909, Accuracy: 319396/320000 (99.8113%)

Time for epoch pass 14.337158441543579
Valid set: Average loss: CE 1.3296, MSE 0.0001, NT-Xent 0.3025, Accuracy: 27681/40000 (69.2025%)

Performing 189th epoch
Time for epoch pass 248.55598902702332
Train set: Average loss: CE 0.0120, MSE 0.0001, NT-Xent 0.1907, A

Time for epoch pass 14.530308723449707
Valid set: Average loss: CE 1.3295, MSE 0.0001, NT-Xent 0.3027, Accuracy: 27702/40000 (69.2550%)

Performing 193th epoch
Time for epoch pass 250.24615693092346
Train set: Average loss: CE 0.0121, MSE 0.0001, NT-Xent 0.1908, Accuracy: 319474/320000 (99.8356%)

Time for epoch pass 13.232393264770508
Valid set: Average loss: CE 1.3336, MSE 0.0001, NT-Xent 0.3031, Accuracy: 27698/40000 (69.2450%)

Performing 194th epoch
Time for epoch pass 249.91026782989502
Train set: Average loss: CE 0.0119, MSE 0.0001, NT-Xent 0.1905, Accuracy: 319458/320000 (99.8306%)

Time for epoch pass 13.56683349609375
Valid set: Average loss: CE 1.3337, MSE 0.0001, NT-Xent 0.3031, Accuracy: 27661/40000 (69.1525%)

Performing 195th epoch
Time for epoch pass 248.61250853538513
Train set: Average loss: CE 0.0120, MSE 0.0001, NT-Xent 0.1906, Accuracy: 319487/320000 (99.8397%)

Time for epoch pass 13.652667045593262
Valid set: Average loss: CE 1.3336, MSE 0.0001, NT-Xent 0.3024, A

Time for epoch pass 249.95415115356445
Train set: Average loss: CE 0.0123, MSE 0.0001, NT-Xent 0.1906, Accuracy: 319447/320000 (99.8272%)

Time for epoch pass 14.400623559951782
Valid set: Average loss: CE 1.3252, MSE 0.0001, NT-Xent 0.3039, Accuracy: 27677/40000 (69.1925%)

Performing 200th epoch
Time for epoch pass 249.9066662788391
Train set: Average loss: CE 0.0123, MSE 0.0001, NT-Xent 0.1907, Accuracy: 319417/320000 (99.8178%)

Time for epoch pass 12.946860790252686
Valid set: Average loss: CE 1.3282, MSE 0.0001, NT-Xent 0.3035, Accuracy: 27694/40000 (69.2350%)

Performing 201th epoch


KeyboardInterrupt: 

In [10]:
save_path = "weights/AlexNet_Decoupling_Contrastive200.pth"
states = {
                'feature_net':Networks['feature'].state_dict(),
                'classifier_net':Networks['classifier'].state_dict(),
                'transformation_net':Networks['transformation'].state_dict(),
                'feature_optimizer': Optimizers['feature'].state_dict(),
                'classifier_optimizer': Optimizers['classifier'].state_dict(),
                'transformation_optimizer': Optimizers['transformation'].state_dict(), 
            }
torch.save(states, save_path)