In [1]:
import os
from GetDataLoaders import get_dataloaders, get_short_dataloaders
from ImageNet_RotNet_AlexNet.AlexNet import AlexNet as AlexNetFeature
from architectures.TransferLearningNet import Flatten
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch import nn
import time
from torch import optim
from torch.nn import functional as F
from tqdm import tqdm

In [2]:
'''
# we skip the probs for now
gama = 2.0
with open(os.path.join("./PUprobs", 'prob.dat'), 'r') as file_input:
    train_prob_str = file_input.readlines()
    train_prob = [float(i_prob_str.rstrip('\n')) for i_prob_str in train_prob_str]
    print(len(train_prob)/4.0)
    train_weight = [1.0 if 0==i%4 else 1-train_prob[i]**gama for i in range(len(train_prob))]
'''

'\n# we skip the probs for now\ngama = 2.0\nwith open(os.path.join("./PUprobs", \'prob.dat\'), \'r\') as file_input:\n    train_prob_str = file_input.readlines()\n    train_prob = [float(i_prob_str.rstrip(\'\n\')) for i_prob_str in train_prob_str]\n    print(len(train_prob)/4.0)\n    train_weight = [1.0 if 0==i%4 else 1-train_prob[i]**gama for i in range(len(train_prob))]\n'

In [3]:
class Classifier(nn.Module):
    def __init__(self, nChannels=256, num_classes=4, pool_size=6, pool_type='max'):
        super(Classifier, self).__init__()
        nChannelsAll = nChannels * pool_size * pool_size

        layers = []
        if pool_type == 'max':
            layers.append(nn.AdaptiveMaxPool2d((pool_size, pool_size)))
            #layers.append(nn.MaxPool2d(kernel_size=3, stride=2))
        elif pool_type == 'avg':
            layer.append(nn.AdaptiveAvgPool2d((pool_size, pool_size)))
        layers.append(nn.BatchNorm2d(nChannels, affine=False))
        layers.append(Flatten())
        layers.append(nn.Linear(nChannelsAll, num_classes))
        self.classifier = nn.Sequential(*layers)
        self.initilize()
    
    def forward(self, feat):
        return self.classifier(feat)
    def initilize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                fin = m.in_features
                fout = m.out_features
                std_val = np.sqrt(2.0/fout)
                m.weight.data.normal_(0.0, std_val)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

In [4]:
use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"
batch_size = 192
lr = 0.1
LUT_lr = [(5, 0.01),(25, 0.002),(45, 0.0004),(65,0.00008)]
num_epochs = 65
momentum = 0.9
weight_decay = 5e-4
nesterov = True
num_classes = 200
loaders = get_dataloaders('imagenet', batch_size=batch_size, num_workers=2, unsupervised=False, simclr=False)

In [5]:
feature_net = AlexNetFeature().to(device)
#load pretrained weights in feature_net
state_dict = torch.load("ImageNet_RotNet_AlexNet/model_net_epoch50")
feature_net.load_state_dict(state_dict['network'], strict=False)

feature_net.eval()
for param in feature_net.parameters():
    param.requires_grad = False

classifier_net = Classifier().to(device)
classifier_optimizer = optim.SGD(classifier_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
Networks =   {'classifier':classifier_net}
Optimizers = {'classifier':classifier_optimizer}

Criterions = {'CE': nn.CrossEntropyLoss()}

In [6]:
def adjust_learning_rates(epoch):
    # filter out the networks that are not trainable and that do
    # not have a learning rate Look Up Table (LUT_lr) in their optim_params
    lr = next((lr for (max_epoch, lr) in LUT_lr if max_epoch>epoch), LUT_lr[-1][1])
    for key in Optimizers:
        for g in Optimizers[key].param_groups:
            prev_lr = g['lr']
            if prev_lr != lr:
                print("Learning Rate Updated from {} to {}".format(prev_lr, lr))
                g['lr'] = lr

def train_step(batch, train=True):
    data, targets = batch
    batch_size = targets.size(0)
    
    targets = torch.zeros(batch_size*4, dtype=torch.long, requires_grad=False)
    targets[0 : batch_size] = 0
    targets[batch_size*1 : batch_size*2] = 1
    targets[batch_size*2 : batch_size*3] = 2
    targets[batch_size*3 : batch_size*4] = 3
    
    #print(targets)
    
    if train is True:
        for key in Optimizers:
            Optimizers[key].zero_grad()
    
    #to cuda
    
    data, targets = data.to(device), targets.to(device)
    
    data_90 = torch.flip(torch.transpose(data,2,3),[2])
    data_180 = torch.flip(torch.flip(data,[2]),[3])
    data_270 = torch.transpose(torch.flip(data,[2]),2,3)
    data = torch.stack([data, data_90, data_180, data_270], dim=1)
    batch_size, rotations, channels, height, width = data.size()
    data = data.view([batch_size*rotations, channels, height, width])
   
    data.requires_grad = False
    
    
    #collect features
    with torch.no_grad():
        features = feature_net(data, ['conv5'])
    
    if train is False:
        with torch.no_grad():
            pred = Networks['classifier'](features)
            #calculate loss
            loss_cls =  Criterions['CE'](pred, targets)
            
    else:
        pred = Networks['classifier'](features)
        #calculate loss
        loss_cls =  Criterions['CE'](pred, targets)
    
    if train is True:
        loss_cls.backward()
        for key in Optimizers:
            Optimizers[key].step()
    
    #calculate classification accuracy
    pred = F.softmax(pred, dim=1)
    pred = pred.argmax(dim=1, keepdim=True)
    correct = pred.eq(targets.view_as(pred)).sum().item()
   
    
    return loss_cls.item(), correct
    

In [7]:
def train_validate(data_loader, epoch, train=True):
    mode = "Train" if train else "Valid"
    if train is True:
        for key in Networks:
            Networks[key].train()
    else:
        for key in Networks:
            Networks[key].eval()
    
    losses = []
    correct = 0
    
    if train:
        adjust_learning_rates(epoch)
    
    start_time = time.time()
    
    tqdm_bar = tqdm(data_loader)
    total_number = 0
    for batch_idx, sample in enumerate(tqdm_bar):
        
        loss, correct_step = train_step(sample, train=train)
        losses.append(loss)
        correct += correct_step
        total_number += sample[0].size(0)*4
        tqdm_bar.set_description('{} Epoch: {} Loss: {:.6f}, Accuracy: {}/{} [{:.4f}%]'.format(mode, epoch, loss, correct, total_number, 100.0*(correct/total_number)))
    
    end_time = time.time()
    print("Time for epoch pass {}".format(end_time-start_time))
    
    loss = float(np.mean(losses))
    acc =  float(correct / len(data_loader.dataset)*4)
    print('{} set: Average loss: {:.4f}, Accuracy:{}/{} ({:.4f}%)\n'.format(mode, loss, correct, len(data_loader.dataset)*4, 100.0*acc))
    return loss, acc


def run_main_loop(loaders, num_epochs):
    writer = SummaryWriter('logs/PUprobsfortinyimagenet')
    save_path = "weights/puclassifierfortinyimagenet.pth"
    best_acc = 0
    for epoch in range(num_epochs):
        #print("Performing {}th epoch".format(epoch))
        
        train_loss, train_acc = train_validate(loaders['train_loader'], epoch, train=True)
        val_loss, val_acc = train_validate(loaders['valid_loader'], epoch, train=False)
    
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/Valid', val_loss, epoch)
        writer.add_scalar('Accuracy/train', train_acc, epoch)
        writer.add_scalar('Accuracy/Valid', val_acc, epoch)
    
        
        if val_acc > best_acc  :
            best_acc = val_acc
            #save model
            states = {
                'epoch': epoch + 1,
                'best_accuracy': best_acc
            }
            for key in Networks:
                states[key+"model"] = Networks[key].state_dict()
            for key in Optimizers:
                states[key+"optimizer"] = Optimizers[key].state_dict()
            torch.save(states, save_path)
            print('Model Saved')

In [8]:
run_main_loop(loaders, num_epochs)

  0%|          | 0/417 [00:00<?, ?it/s]

Learning Rate Updated from 0.1 to 0.01


Train Epoch: 0 Loss: 35.087143, Accuracy: 80143/320000 [25.0447%]: 100%|██████████| 417/417 [01:43<00:00,  4.03it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 103.55574607849121
Train set: Average loss: 45.2970, Accuracy:80143/320000 (400.7150%)



Valid Epoch: 0 Loss: 25.340919, Accuracy: 9959/40000 [24.8975%]: 100%|██████████| 53/53 [00:12<00:00,  4.17it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.701408863067627
Valid set: Average loss: 36.1554, Accuracy:9959/40000 (398.3600%)

Model Saved


Train Epoch: 1 Loss: 27.519411, Accuracy: 79992/320000 [24.9975%]: 100%|██████████| 417/417 [01:45<00:00,  3.97it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 105.14005279541016
Train set: Average loss: 30.0772, Accuracy:79992/320000 (399.9600%)



Valid Epoch: 1 Loss: 21.002329, Accuracy: 10027/40000 [25.0675%]: 100%|██████████| 53/53 [00:12<00:00,  4.13it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.826058626174927
Valid set: Average loss: 26.1861, Accuracy:10027/40000 (401.0800%)

Model Saved


Train Epoch: 2 Loss: 18.932745, Accuracy: 80452/320000 [25.1412%]: 100%|██████████| 417/417 [01:45<00:00,  3.97it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 105.04572200775146
Train set: Average loss: 22.1658, Accuracy:80452/320000 (402.2600%)



Valid Epoch: 2 Loss: 15.775415, Accuracy: 10053/40000 [25.1325%]: 100%|██████████| 53/53 [00:13<00:00,  4.03it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 13.164636850357056
Valid set: Average loss: 19.6141, Accuracy:10053/40000 (402.1200%)

Model Saved


Train Epoch: 3 Loss: 14.218469, Accuracy: 79961/320000 [24.9878%]: 100%|██████████| 417/417 [01:45<00:00,  3.96it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 105.17873477935791
Train set: Average loss: 16.7865, Accuracy:79961/320000 (399.8050%)



Valid Epoch: 3 Loss: 12.126490, Accuracy: 10027/40000 [25.0675%]: 100%|██████████| 53/53 [00:12<00:00,  4.13it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.820189476013184
Valid set: Average loss: 14.8106, Accuracy:10027/40000 (401.0800%)



Train Epoch: 4 Loss: 11.115479, Accuracy: 79872/320000 [24.9600%]: 100%|██████████| 417/417 [01:43<00:00,  4.02it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 103.80945563316345
Train set: Average loss: 12.6930, Accuracy:79872/320000 (399.3600%)



Valid Epoch: 4 Loss: 9.012389, Accuracy: 10034/40000 [25.0850%]: 100%|██████████| 53/53 [00:13<00:00,  4.01it/s] 
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 13.232436180114746
Valid set: Average loss: 11.2040, Accuracy:10034/40000 (401.3600%)

Learning Rate Updated from 0.01 to 0.002


Train Epoch: 5 Loss: 10.050145, Accuracy: 80250/320000 [25.0781%]: 100%|██████████| 417/417 [01:44<00:00,  4.00it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 104.29106378555298
Train set: Average loss: 10.4885, Accuracy:80250/320000 (401.2500%)



Valid Epoch: 5 Loss: 7.869442, Accuracy: 10021/40000 [25.0525%]: 100%|██████████| 53/53 [00:12<00:00,  4.13it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.841731071472168
Valid set: Average loss: 10.2905, Accuracy:10021/40000 (400.8400%)



Train Epoch: 6 Loss: 9.368963, Accuracy: 80201/320000 [25.0628%]: 100%|██████████| 417/417 [01:44<00:00,  4.00it/s] 
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 104.3473379611969
Train set: Average loss: 9.8288, Accuracy:80201/320000 (401.0050%)



Valid Epoch: 6 Loss: 7.550991, Accuracy: 10051/40000 [25.1275%]: 100%|██████████| 53/53 [00:12<00:00,  4.09it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.94724154472351
Valid set: Average loss: 9.6639, Accuracy:10051/40000 (402.0400%)



Train Epoch: 7 Loss: 9.435275, Accuracy: 80080/320000 [25.0250%]: 100%|██████████| 417/417 [01:44<00:00,  3.97it/s] 
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 104.99618744850159
Train set: Average loss: 9.2533, Accuracy:80080/320000 (400.4000%)



Valid Epoch: 7 Loss: 6.939579, Accuracy: 10074/40000 [25.1850%]: 100%|██████████| 53/53 [00:12<00:00,  4.14it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.819048404693604
Valid set: Average loss: 9.1088, Accuracy:10074/40000 (402.9600%)

Model Saved


Train Epoch: 8 Loss: 8.871571, Accuracy: 79701/320000 [24.9066%]: 100%|██████████| 417/417 [01:43<00:00,  4.03it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 103.5072557926178
Train set: Average loss: 8.6989, Accuracy:79701/320000 (398.5050%)



Valid Epoch: 8 Loss: 6.249327, Accuracy: 10022/40000 [25.0550%]: 100%|██████████| 53/53 [00:12<00:00,  4.16it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.757431268692017
Valid set: Average loss: 8.5217, Accuracy:10022/40000 (400.8800%)



Train Epoch: 9 Loss: 8.306579, Accuracy: 80182/320000 [25.0569%]: 100%|██████████| 417/417 [01:44<00:00,  4.00it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 104.31588077545166
Train set: Average loss: 8.1760, Accuracy:80182/320000 (400.9100%)



Valid Epoch: 9 Loss: 6.087802, Accuracy: 9980/40000 [24.9500%]: 100%|██████████| 53/53 [00:12<00:00,  4.13it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.837442398071289
Valid set: Average loss: 8.0630, Accuracy:9980/40000 (399.2000%)



Train Epoch: 10 Loss: 7.460752, Accuracy: 80032/320000 [25.0100%]: 100%|██████████| 417/417 [01:44<00:00,  4.01it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 104.10119199752808
Train set: Average loss: 7.6656, Accuracy:80032/320000 (400.1600%)



Valid Epoch: 10 Loss: 5.410100, Accuracy: 10087/40000 [25.2175%]: 100%|██████████| 53/53 [00:12<00:00,  4.12it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.872437477111816
Valid set: Average loss: 7.5353, Accuracy:10087/40000 (403.4800%)

Model Saved


Train Epoch: 11 Loss: 6.983691, Accuracy: 80127/320000 [25.0397%]: 100%|██████████| 417/417 [01:45<00:00,  3.95it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 105.55942964553833
Train set: Average loss: 7.1810, Accuracy:80127/320000 (400.6350%)



Valid Epoch: 11 Loss: 5.538237, Accuracy: 10066/40000 [25.1650%]: 100%|██████████| 53/53 [00:12<00:00,  4.12it/s]
  0%|          | 0/417 [00:00<?, ?it/s]

Time for epoch pass 12.872727394104004
Valid set: Average loss: 7.0986, Accuracy:10066/40000 (402.6400%)



Train Epoch: 12 Loss: 6.841916, Accuracy: 80160/320000 [25.0500%]: 100%|██████████| 417/417 [01:45<00:00,  3.96it/s]
  0%|          | 0/53 [00:00<?, ?it/s]

Time for epoch pass 105.17408347129822
Train set: Average loss: 6.7404, Accuracy:80160/320000 (400.8000%)



Valid Epoch: 12 Loss: 6.626887, Accuracy: 5559/22272 [24.9596%]:  55%|█████▍    | 29/53 [00:07<00:05,  4.32it/s]

KeyboardInterrupt: 

In [None]:
loss, acc = train_validate(loaders['test_loader'], 1, train=False)