In [288]:
# import os
import time

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision import transforms

import sys
import importlib

import captcha_setting
import my_dataset
import one_hot_encoding


importlib.reload(my_dataset)

RANDOM_SEED = 1
LEARNING_RATE = 0.0002
BATCH_SIZE = 64
NUM_EPOCH = 30

# Architecture
NUM_FEATURES = 160 * 60
NUM_CLASSES = captcha_setting.CLASS_NUM

# Other
DEVICE = "cuda:0"
GRAYSCALE = True


144


In [289]:
def compute_accuracy(model, data_loader, device):
    correct_pred, num_examples = 0, 0
    for i, (features, targets) in enumerate(data_loader):
            
        features = features.to(device)
        targets = targets.to(device)

        
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        optimizer.zero_grad()
        preds = model(features.cuda())
        new_preds = preds.reshape(preds.size(0),4,int(preds.size(1)/4))
        _,predicted_labels = torch.max(new_preds,2)
        new_targets = targets.reshape(targets.size(0),4,int(targets.size(1)/4))
        _,true_labels = torch.max(new_targets,2)        
            
        num_examples += targets.size(0)
        correct_pred += predicted_labels.eq(true_labels).all(dim=1).sum().item()
        
    return correct_pred/num_examples * 100
        
       

In [270]:
def train_model():
    torch.manual_seed(RANDOM_SEED)

    model = models.resnet34(pretrained=True)
    model.fc = nn.Sequential(nn.Linear(model.fc.in_features, NUM_CLASSES))
    
    model = model.to(DEVICE)
    model.cuda()
            

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 
    

    start_time = time.time()
    train_loader = my_dataset.get_train_data_loader() 
    for epoch in range(NUM_EPOCH):
        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            features = features.to(DEVICE)
            targets = targets.to(DEVICE)
            targets = targets.cuda()

            ### FORWARD AND BACK PROP
            optimizer.zero_grad()
            opt = model(features.cuda())
            criterion = nn.MultiLabelSoftMarginLoss()
            cost = criterion(opt, targets)

            cost.backward()
            ### UPDATE MODEL PARAMETERS
            optimizer.step()

            ### LOGGING
            if not batch_idx % 50:
                print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.6f' 
                       %(epoch+1, NUM_EPOCH, batch_idx, len(train_loader), cost))




    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

    torch.save(model.state_dict(),'./capReg.pth')
    

In [None]:
train_model()

In [292]:
def test_model():
    torch.manual_seed(RANDOM_SEED)

    model = models.resnet34(pretrained=True)
    model.fc = nn.Sequential(nn.Linear(model.fc.in_features, NUM_CLASSES))
    
    model = model.to(DEVICE)
    model.cuda()
        
    model.load_state_dict(torch.load('capReg.pth'),strict=False)
    model.eval()

    print("load resnet.")

    test_dataloader = my_dataset.get_train_data_loader()
    
    with torch.set_grad_enabled(False): # save memory during inference
        print('test accurary: %.3f%%' % compute_accuracy(model, test_dataloader, device=DEVICE))

    
test_model()

load resnet.
0 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
1 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
2 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
3 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
4 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
5 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
6 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
7 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
8 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
9 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
10 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
11 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
12 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
13 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
14 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
15 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
16 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
17 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
18 64 torch.Size([64, 4, 36]) torch.Size([64, 4])
19 64 torch.Size([64, 4, 36]) torch.Size([64, 4