In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import argparse
import time
import copy
from tqdm import tqdm
import os.path as osp
from matplotlib import pyplot as plt
from utils import Config
from model import fashion_model
from data import get_dataloader
#import os
#os.environ['CUDA_VISIBLE_DEVICES']="0"


def train_model(get_dataloader, model, criterion, optimizer, device, num_epochs, lr_decay):
    model.to(device)
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    train_acc = []
    test_acc = []
    for epoch in range(num_epochs):
        dataloaders, classes, dataset_size = get_dataloader(debug=Config['debug'], batch_size=Config['batch_size'], num_workers=Config['num_workers'])
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'test']:
            if phase=='train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for input1, input2, labels in tqdm(dataloaders[phase]):
                input1 = input1.to(device)
                input2 = input2.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase=='train'):
                    outputs = model(input1,input2)
                    _, pred = torch.max(outputs, 1)
                    #return outputs, labels
                    loss = criterion(outputs, labels)

                    if phase=='train':
                        loss.backward()
                        optimizer.step()


                running_loss += loss.item() * input1.size(0)
                running_corrects += torch.sum(pred==labels.data)

            epoch_loss = running_loss / dataset_size[phase]
            epoch_acc = running_corrects.double() / dataset_size[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            if phase=='train':
                train_acc.append(epoch_acc.item())
            elif phase=='test':
                test_acc.append(epoch_acc.item())

            if phase=='test' and epoch_acc > best_acc:
                best_epoch = epoch
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, osp.join(Config['root_path'], Config['checkpoint_path'], 'best_model.pth'))
                print('Best model saved at: {}'.format(osp.join(Config['root_path'], Config['checkpoint_path'], 'best_model.pth')))
                print('Save best check point at epoch %d'%(epoch+1))
            elif phase=='test':
                print('best model is save on epoch %d and best accuracy is %f'%(best_epoch, best_acc))
        torch.save(model.state_dict(),osp.join(Config['root_path'], Config['checkpoint_path'], 'model%d.pth'%(epoch+1)))
        print('Model saved at: {}'.format(osp.join(Config['root_path'], Config['checkpoint_path'], 'model%d.pth'%(epoch+1))))
        # Learning rate decay
        if epoch < num_epochs-1: 
            for param_group in optimizer.param_groups:
                print('lr: {:.6f} -> {:.6f}'.format(param_group['lr'], param_group['lr'] * lr_decay))
                param_group['lr'] *= lr_decay

    time_elapsed = time.time() - since
    print('Time taken to complete training: {:0f}m {:0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best acc: {:.4f}'.format(best_acc))
    return train_acc, test_acc


In [20]:
    classes = 2
    model = fashion_model(classes)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(model.parameters(), lr=Config['learning_rate'])
    device = torch.device('cuda:0' if torch.cuda.is_available() and Config['use_cuda'] else 'cpu')

    train_acc,test_acc = train_model(get_dataloader, model, criterion, optimizer, device, num_epochs=Config['num_epochs'],  lr_decay=Config['lr_decay'])
    
    epochs = np.arange(Config['num_epochs'])

    plt.figure()
    plt.plot(epochs, train_acc, label='loss')
    plt.plot(epochs, test_acc, label='val_loss')
    plt.xlabel('epochs')
    plt.ylabel('Acc')
    plt.legend()
    plt.show()
    plt.savefig('learning_acc.png', dpi=256)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/25
----------
tensor([1])
torch.Size([1, 2])
torch.Size([1])


  1%|          | 1/100 [00:01<03:03,  1.86s/it]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


  2%|▏         | 2/100 [00:02<02:28,  1.52s/it]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


  3%|▎         | 3/100 [00:03<02:04,  1.28s/it]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


  4%|▍         | 4/100 [00:04<01:51,  1.17s/it]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


  5%|▌         | 5/100 [00:04<01:38,  1.03s/it]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


  6%|▌         | 6/100 [00:05<01:24,  1.12it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


  7%|▋         | 7/100 [00:06<01:12,  1.28it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


  8%|▊         | 8/100 [00:06<01:04,  1.43it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


  9%|▉         | 9/100 [00:07<00:58,  1.56it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 10%|█         | 10/100 [00:07<00:53,  1.67it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 11%|█         | 11/100 [00:08<00:50,  1.77it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 12%|█▏        | 12/100 [00:08<00:47,  1.86it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 13%|█▎        | 13/100 [00:08<00:44,  1.94it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 14%|█▍        | 14/100 [00:09<00:43,  1.98it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 15%|█▌        | 15/100 [00:09<00:42,  1.98it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 16%|█▌        | 16/100 [00:10<00:45,  1.84it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 17%|█▋        | 17/100 [00:11<00:48,  1.73it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 18%|█▊        | 18/100 [00:12<00:52,  1.58it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 19%|█▉        | 19/100 [00:12<00:56,  1.44it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 20%|██        | 20/100 [00:13<00:53,  1.50it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 21%|██        | 21/100 [00:13<00:48,  1.63it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 22%|██▏       | 22/100 [00:14<00:45,  1.73it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 23%|██▎       | 23/100 [00:14<00:42,  1.80it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 24%|██▍       | 24/100 [00:15<00:41,  1.82it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 25%|██▌       | 25/100 [00:15<00:39,  1.88it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 26%|██▌       | 26/100 [00:16<00:38,  1.92it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 27%|██▋       | 27/100 [00:16<00:37,  1.95it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 28%|██▊       | 28/100 [00:17<00:36,  1.96it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 29%|██▉       | 29/100 [00:17<00:36,  1.96it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


 30%|███       | 30/100 [00:18<00:36,  1.90it/s]

tensor([1])
torch.Size([1, 2])
torch.Size([1])


KeyboardInterrupt: 

In [11]:
train_acc,test_acc


(tensor([[0.9551, 0.0449]], grad_fn=<SoftmaxBackward>), tensor([0]))

In [12]:
test_acc.dtype

torch.int64

In [13]:
criterion(train_acc,test_acc)

tensor(0.3382, grad_fn=<NllLossBackward>)