In [1]:
import os
from GetDataLoaders import get_dataloaders, get_short_dataloaders
from architectures.AlexNetFeatureModified import AlexNetFeature
from architectures.NonLinearClassifier import Classifier
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch import nn
import time
from torch import optim
from torch.nn import functional as F
from tqdm import tqdm

In [2]:
'''
# we skip the probs for now
gama = 2.0
with open(os.path.join("./PUprobs", 'prob.dat'), 'r') as file_input:
    train_prob_str = file_input.readlines()
    train_prob = [float(i_prob_str.rstrip('\n')) for i_prob_str in train_prob_str]
    print(len(train_prob)/4.0)
    train_weight = [1.0 if 0==i%4 else 1-train_prob[i]**gama for i in range(len(train_prob))]
'''

'\n# we skip the probs for now\ngama = 2.0\nwith open(os.path.join("./PUprobs", \'prob.dat\'), \'r\') as file_input:\n    train_prob_str = file_input.readlines()\n    train_prob = [float(i_prob_str.rstrip(\'\n\')) for i_prob_str in train_prob_str]\n    print(len(train_prob)/4.0)\n    train_weight = [1.0 if 0==i%4 else 1-train_prob[i]**gama for i in range(len(train_prob))]\n'

In [3]:
use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"
batch_size = 192
lr = 1e-3
LUT_lr = [(5, 0.1),(25, 0.02),(45, 0.0004),(65,0.00008)]
num_epochs = 200
momentum = 0.9
weight_decay = 1e-6
nesterov = True
num_classes = 200
loaders = get_dataloaders('imagenet', batch_size=batch_size, num_workers=2, unsupervised=False, simclr=False)

In [4]:
feature_net = AlexNetFeature().to(device)
#load pretrained weights in feature_net
state_dict = torch.load("weights/AlexNet_Decoupling_Contrastive_SimCLR_Features.pth")
feature_net.load_state_dict(state_dict['featurenet'], strict=False)

feature_net.eval()
for param in feature_net.parameters():
    param.requires_grad = False


classifier_net = Classifier().to(device)
classifier_optimizer = optim.Adam(classifier_net.parameters(), lr=lr, weight_decay=weight_decay)
#feature_optimizer = optim.Adam(feature_net.parameters(), lr=lr, weight_decay=weight_decay)
Networks =   {'classifier':classifier_net, 'feature':feature_net}
Optimizers = {'classifier':classifier_optimizer} #'feature':feature_optimizer}

Criterions = {'CE': nn.CrossEntropyLoss()}

In [5]:
classifier_net

Classifier(
  (classifier): Sequential(
    (Pool5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (Flatten): Flatten()
    (Linear1): Linear(in_features=9216, out_features=4096, bias=False)
    (BatchNorm1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ReLU1): ReLU(inplace=True)
    (Liniear2): Linear(in_features=4096, out_features=4096, bias=False)
    (BatchNorm2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ReLU2): ReLU(inplace=True)
    (LinearF): Linear(in_features=4096, out_features=200, bias=True)
  )
)

In [6]:
def train_validate(data_loader, epoch, train=True):
    
    mode = "Train" if train else "Valid"
    if train is True:
        #for key in Networks:
        Networks['classifier'].train()
    else:
        #for key in Networks:
        Networks['classifier'].eval()
    
    
    losses = []
    correct = 0
    
    overallloss = None
    
    
    start_time = time.time()
    tqdm_bar = tqdm(data_loader)
    batch_sizes = 0
    for batch_idx, batch in enumerate(tqdm_bar):
        data, targets = batch
        
        data, targets = data.to(device), targets.to(device)
        
        with torch.no_grad():
            features = Networks['feature'](data, ['conv5'])
        
        if train is False:
            with torch.no_grad():
                output =  Networks['classifier'](features)
        else:
            #features = Networks['feature'](data, ['conv5'])
            output = Networks['classifier'](features)
            
    
        loss_ce = Criterions['CE'](output, targets)
        

        if train is True:
            loss_ce.backward()
            Optimizers['classifier'].zero_grad()
            #Optimizers['feature'].zero_grad()
            Optimizers['classifier'].step()
            #Optimizers['feature'].step()
               
        losses.append(loss_ce.item())
        output = F.softmax(output, dim=1)
        pred = output.argmax(dim=1, keepdim=True).squeeze_(dim=1)
        #print(pred.size(), targets.size())
        correct_iter = pred.eq(targets.view_as(pred)).sum().item()
        correct += correct_iter
        batch_sizes += data.size(0)
        tqdm_bar.set_description('{} Epoch: [{}] Loss: CE {:.4f}, Correct: {}/{}'.format(mode, epoch, loss_ce.item(), correct, batch_sizes))
        
    
    end_time = time.time()
    print("Time for epoch pass {}".format(end_time-start_time))
    acc = float(correct/len(data_loader.dataset))
    averageloss = float(np.mean(losses))
    overallloss = {'ce':averageloss, 'acc':acc*100.0}
    print('{} set: Average loss: CE {:.4f}, Accuracy {}/{} {:.4f}%\n'.format(mode, overallloss['ce'], correct, len(data_loader.dataset), overallloss['acc']))
    return overallloss


In [7]:
def run_main_loop(loaders, num_epochs):
    writer = SummaryWriter('./logs/AlexNet_SimCLR_NonLinearClassifier')
    save_path = "weights/AlexNet_Decoupling_Contrastive_SimCLR_NonLinearClassifier.pth"
    best_acc = 0
    for epoch in range(num_epochs):
        #print("Performing {}th epoch".format(epoch))
        train_loss = train_validate(loaders['train_loader'], epoch, train=True)
        val_loss = train_validate(loaders['valid_loader'], epoch, train=False)
        
        
        writer.add_scalar('CELoss/train', train_loss['ce'], epoch)
        writer.add_scalar('Accuracy/train', train_loss['acc'], epoch)
        writer.add_scalar('CELoss/Valid', val_loss['ce'], epoch)
        writer.add_scalar('Accuracy/Valid', val_loss['acc'], epoch)
        
        if val_loss['acc'] > best_acc :
            best_acc = val_loss['acc']
            #save model
            states = {
                'epoch': epoch + 1,
                'best_accuracy': best_acc
            }
            for key in Networks:
                states[key+"net"] = Networks[key].state_dict()
            for key in Optimizers:
                states[key+"optimizer"] = Optimizers[key].state_dict()
            torch.save(states, save_path)
            print('Model Saved')

In [None]:
run_main_loop(loaders, num_epochs)

Train Epoch: [0] Loss: CE 5.2983, Correct: 369/80000: 100%|██████████| 417/417 [01:43<00:00,  4.04it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 103.09648180007935
Train set: Average loss: CE 6.4199, Accuracy 369/80000 0.4613%



Valid Epoch: [0] Loss: CE 5.2983, Correct: 47/10000: 100%|██████████| 53/53 [00:11<00:00,  4.65it/s]


Time for epoch pass 11.410181999206543
Valid set: Average loss: CE 5.2983, Accuracy 47/10000 0.4700%



  0%|          | 0/417 [00:00<?, ?it/s]

Model Saved


Train Epoch: [1] Loss: CE 5.2983, Correct: 402/80000: 100%|██████████| 417/417 [01:27<00:00,  4.77it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.50876784324646
Train set: Average loss: CE 5.2983, Accuracy 402/80000 0.5025%



Valid Epoch: [1] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.41it/s]


Time for epoch pass 9.802339315414429
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



  0%|          | 0/417 [00:00<?, ?it/s]

Model Saved


Train Epoch: [2] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.76it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.58167171478271
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [2] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.42it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.787522315979004
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [3] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.79it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.06472086906433
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [3] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.52it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.601924896240234
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [4] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.76it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.62943577766418
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [4] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.45it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.733651638031006
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [5] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.78it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.15405058860779
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [5] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.53it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.578474283218384
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [6] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:28<00:00,  4.73it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 88.1937198638916
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [6] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.45it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.732505083084106
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [7] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:26<00:00,  4.80it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 86.83552193641663
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [7] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.51it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.623085975646973
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [8] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.74it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.93145227432251
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [8] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:10<00:00,  5.28it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 10.043013095855713
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [9] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:28<00:00,  4.72it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 88.4179208278656
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [9] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.44it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.751720666885376
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [10] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:28<00:00,  4.73it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 88.24071645736694
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [10] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.45it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.726225137710571
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [11] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:28<00:00,  4.72it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 88.42006516456604
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [11] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.37it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.86524486541748
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [12] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.77it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.38693904876709
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [12] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.43it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.759761810302734
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [13] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:28<00:00,  4.73it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 88.13458323478699
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [13] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.43it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.753756761550903
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [14] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.76it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.53376531600952
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [14] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.48it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.667107820510864
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [15] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.74it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.96247053146362
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [15] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.34it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.934108018875122
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [16] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.76it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.66706585884094
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [16] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.33it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.937958240509033
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [17] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:28<00:00,  4.73it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 88.13276958465576
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [17] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:10<00:00,  5.23it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 10.140543699264526
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [18] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.76it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.61696600914001
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [18] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.54it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.5669264793396
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [19] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:28<00:00,  4.70it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 88.63193702697754
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [19] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.51it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.616174936294556
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [20] Loss: CE 5.2983, Correct: 400/80000: 100%|██████████| 417/417 [01:27<00:00,  4.77it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 87.38092064857483
Train set: Average loss: CE 5.2983, Accuracy 400/80000 0.5000%



Valid Epoch: [20] Loss: CE 5.2983, Correct: 50/10000: 100%|██████████| 53/53 [00:09<00:00,  5.32it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 9.968754053115845
Valid set: Average loss: CE 5.2983, Accuracy 50/10000 0.5000%



Train Epoch: [21] Loss: CE 5.2983, Correct: 328/64512:  81%|████████  | 336/417 [01:10<00:15,  5.26it/s]

In [None]:
test_loss = train_validate(loaders['test_loader'], 1, train=False)
print("Test Average Loss is {:.4f}, and Accuracy is {:.4f}".format(test_loss['ce'], test_loss['acc']*100.0))