In [3]:
def train_model(get_dataloader, model, criterion, optimizer, device, num_epochs, lr_decay,warmup=5):
    model.to(device)
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    train_acc = []
    test_acc = []
    unfreeze_num = 13
    ratio = 0.35
    for epoch in range(num_epochs):
        dataloaders, classes, dataset_size = get_dataloader(debug=Config['debug'], batch_size=Config['batch_size'], num_workers=Config['num_workers'])
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'test']:
            if phase=='train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for input1, input2, labels in tqdm(dataloaders[phase]):
                input1 = input1.to(device)
                input2 = input2.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase=='train'):
                    outputs = model(input1,input2)
                    _, pred = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    print(loss)

                    if (epoch+1) in [7,8,9,17,18,19,22,23]:
                        num_inst = outputs.size(0)
                        num_hns = int(ratio * num_inst)
                        

                    if phase=='train':
                        loss.backward()
                        optimizer.step()


                running_loss += loss.item() * input1.size(0)
                running_corrects += torch.sum(pred==labels.data)

            epoch_loss = running_loss / dataset_size[phase]
            epoch_acc = running_corrects.double() / dataset_size[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            if phase=='train':
                train_acc.append(epoch_acc.item())
            elif phase=='test':
                test_acc.append(epoch_acc.item())

            if phase=='test' and epoch_acc > best_acc:
                best_epoch = epoch
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, osp.join(Config['root_path'], Config['checkpoint_path'], 'best_model.pth'))
                print('Best model saved at: {}'.format(osp.join(Config['root_path'], Config['checkpoint_path'], 'best_model.pth')))
                print('Save best check point at epoch %d'%(epoch+1))
            elif phase=='test':
                print('best model is save on epoch %d and best accuracy is %f'%(best_epoch, best_acc))
        torch.save(model.state_dict(),osp.join(Config['root_path'], Config['checkpoint_path'], 'model%d.pth'%(epoch+1)))
        print('Model saved at: {}'.format(osp.join(Config['root_path'], Config['checkpoint_path'], 'model%d.pth'%(epoch+1))))

        # warmup: unfrozen model parameter gradually. 
        if (epoch+1)%warmup==0:
            model.model = unfreeze(model.model,unfreeze_num)
            unfreeze_num = unfreeze_num - 3

        # Learning rate decay
        if epoch < num_epochs-1: 
            for param_group in optimizer.param_groups:
                print('lr: {:.6f} -> {:.6f}'.format(param_group['lr'], param_group['lr'] * lr_decay))
                param_group['lr'] *= lr_decay

    time_elapsed = time.time() - since
    print('Time taken to complete training: {:0f}m {:0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best acc: {:.4f}'.format(best_acc))
    return train_acc, test_acc

def unfreeze(model,num):
    child_counter = 0
    for child in model.children():
        if child_counter < num:
            print("child ",child_counter," was frozen")
            child_counter += 1
            for param in child.parameters():
                param.requires_grad = False
        else:
            print("child ",child_counter," was not frozen")
            child_counter += 1
    return model

In [4]:
    classes = 2
    model = fashion_model(classes)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(model.parameters(), lr=Config['learning_rate'])
    device = torch.device('cuda:0' if torch.cuda.is_available() and Config['use_cuda'] else 'cpu')

    train_acc,test_acc = train_model(get_dataloader, model, criterion, optimizer, device, num_epochs=Config['num_epochs'],  lr_decay=Config['lr_decay'])
    
    epochs = np.arange(Config['num_epochs'])

    plt.figure()
    plt.plot(epochs, train_acc, label='loss')
    plt.plot(epochs, test_acc, label='val_loss')
    plt.xlabel('epochs')
    plt.ylabel('Acc')
    plt.legend()
    plt.show()
    plt.savefig('learning_acc.png', dpi=256)

child  0  was frozen
child  1  was frozen
child  2  was frozen
child  3  was frozen
child  4  was frozen
child  5  was frozen
child  6  was frozen
child  7  was frozen
child  8  was frozen
child  9  was frozen
child  10  was frozen
child  11  was frozen
child  12  was frozen
child  13  was frozen
child  14  was frozen
child  15  was frozen
child  16  was not frozen
child  17  was not frozen
child  18  was not frozen


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1/25
----------


  2%|▏         | 1/50 [00:03<02:33,  3.12s/it]

tensor(0.7596, grad_fn=<NllLossBackward>)


  4%|▍         | 2/50 [00:03<01:50,  2.30s/it]

tensor(1.3133, grad_fn=<NllLossBackward>)


  6%|▌         | 3/50 [00:03<01:20,  1.71s/it]

tensor(0.8133, grad_fn=<NllLossBackward>)


  8%|▊         | 4/50 [00:04<00:59,  1.30s/it]

tensor(0.3133, grad_fn=<NllLossBackward>)


 10%|█         | 5/50 [00:04<00:45,  1.00s/it]

tensor(0.3133, grad_fn=<NllLossBackward>)


 12%|█▏        | 6/50 [00:04<00:34,  1.26it/s]

tensor(0.8133, grad_fn=<NllLossBackward>)


 14%|█▍        | 7/50 [00:05<00:28,  1.53it/s]

tensor(0.8133, grad_fn=<NllLossBackward>)


 16%|█▌        | 8/50 [00:05<00:23,  1.82it/s]

tensor(0.8133, grad_fn=<NllLossBackward>)


 18%|█▊        | 9/50 [00:05<00:19,  2.10it/s]

tensor(0.3133, grad_fn=<NllLossBackward>)


 20%|██        | 10/50 [00:06<00:16,  2.37it/s]

tensor(1.3133, grad_fn=<NllLossBackward>)


 22%|██▏       | 11/50 [00:06<00:15,  2.57it/s]

tensor(0.8133, grad_fn=<NllLossBackward>)


 24%|██▍       | 12/50 [00:06<00:13,  2.77it/s]

tensor(0.8133, grad_fn=<NllLossBackward>)


 26%|██▌       | 13/50 [00:06<00:12,  2.95it/s]

tensor(0.8133, grad_fn=<NllLossBackward>)


 28%|██▊       | 14/50 [00:07<00:11,  3.08it/s]

tensor(1.3133, grad_fn=<NllLossBackward>)


 30%|███       | 15/50 [00:07<00:17,  1.97it/s]

tensor(0.8133, grad_fn=<NllLossBackward>)





KeyboardInterrupt: 

In [5]:
loss = nn.CrossEntropyLoss()

In [8]:
loss.reduction = 'none'

In [18]:
loss(torch.Tensor([[232,9],[8594,9]]),torch.Tensor([1,0]).type(torch.int64))

tensor([223.,  -0.])

In [12]:
test_acc.dtype

torch.int64

In [13]:
criterion(train_acc,test_acc)

tensor(0.3382, grad_fn=<NllLossBackward>)

In [2]:
import torch
import torch.nn as nn

In [5]:
q = nn.Sequential(nn.Linear(1280*2,100,True),nn.ReLU(),nn.Linear(100,2,True),nn.Softmax(dim=1))

In [9]:
q[2].weight

Parameter containing:
tensor([[ 0.0068, -0.0805, -0.0967, -0.0443, -0.0442, -0.0239, -0.0065, -0.0651,
         -0.0688,  0.0131,  0.0707, -0.0403, -0.0351,  0.0251, -0.0819, -0.0521,
          0.0066, -0.0737, -0.0169,  0.0479,  0.0771, -0.0376, -0.0959, -0.0220,
         -0.0233,  0.0490,  0.0869,  0.0338, -0.0646, -0.0131, -0.0684,  0.0886,
         -0.0482,  0.0973,  0.0149,  0.0089, -0.0951,  0.0243, -0.0483, -0.0090,
          0.0794,  0.0041, -0.0873,  0.0147, -0.0106, -0.0180, -0.0942, -0.0019,
         -0.0560, -0.0349,  0.0048, -0.0618, -0.0389,  0.0385, -0.0199,  0.0974,
          0.0992, -0.0668,  0.0776,  0.0446, -0.0113,  0.0721, -0.0349,  0.0111,
         -0.0374,  0.0046,  0.0744, -0.0022, -0.0621, -0.0233, -0.0213,  0.0788,
         -0.0662, -0.0414,  0.0559,  0.0960, -0.0206, -0.0470, -0.0587,  0.0456,
         -0.0466, -0.0308, -0.0333, -0.0800,  0.0798,  0.0183, -0.0541, -0.0655,
         -0.0944,  0.0978, -0.0590, -0.0472,  0.0355, -0.0996,  0.0081,  0.0308,
      