In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data
from torch.utils.data import *
import numpy as np
import time
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import easydict
import sys
import pickle
import re
import six
import math
import torchvision.transforms as transforms

import torch.distributed as dist
from albumentations import GaussNoise, IAAAdditiveGaussianNoise, Compose, OneOf
from albumentations.pytorch import ToTensor
import albumentations
import cv2
from efficientnet_pytorch import EfficientNet
# import PositionEnhancement

device = torch.device('cuda')

In [2]:
with open('../dataset_syllable_180', 'rb') as file:
    data = pickle.load(file)

In [3]:
class CustomDataset_clf(Dataset):
    
    def __init__(self, dataset, resize_shape = (64, 256), input_channel = 3):
        self.dataset = dataset
        self.resize_H = resize_shape[0]
        self.resize_W = resize_shape[1]
        self.transform = albumentations.Compose([
            albumentations.RandomBrightnessContrast(p=0.5),
            albumentations.RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.8, p=0.5 ),
            albumentations.Resize(self.resize_H, self.resize_W),
            albumentations.Normalize( mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
            albumentations.pytorch.transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path, label = self.dataset[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        trans = self.transform(image=image)
        image = trans['image']
        
        return image, len(label)

def get_accuracy(pred, label):
    pred_max = torch.argmax(torch.softmax(pred, -1), -1)
    match = 0
    batch_size = pred_max.shape[0]
    for i in range(batch_size):
        if pred_max[i] == label[i]:
            match +=1
    return round(match / batch_size, 3)

def get_latest_model(name):
    relate_models = []
    for model_file in os.listdir('./models'):
        if re.compile(name).match(model_file):
            relate_models.append(int(model_file.split('_')[-1].replace('.pth', '')))
    return max(relate_models)  

In [4]:
albu_loader = CustomDataset_clf(data)
dataloader = DataLoader(albu_loader, batch_size = 768, shuffle=True, pin_memory=True, num_workers=5, drop_last=True, )

In [5]:
name = 'efficientnet-b0'
model = EfficientNet.from_name(name, include_top=True)
model._fc = torch.nn.Linear(in_features = 1280, out_features = 23, bias=True)

previous_iter = get_latest_model(name)
load_path = f'./models/{name}_{previous_iter}.pth'

if load_path :
    model.load_state_dict(torch.load(load_path))
model = torch.nn.DataParallel(model, device_ids = [0,1]).to(device)
# model = model.to(device)
_ = model.train()

In [6]:
# filter that only require gradient descent
filtered_parameters = []
params_num = []
for p in filter(lambda p : p.requires_grad, model.parameters()):
    filtered_parameters.append(p)
    params_num.append(np.prod(p.size()))
print('Tranable params : ', sum(params_num))

optimizer = optim.Adadelta(filtered_parameters)
criterion = torch.nn.CrossEntropyLoss().to(device) #ignore [GO] token = ignore index 0

Tranable params :  4037011


In [None]:
for i, (img, label) in enumerate(dataloader):
    img, label = img.to(device) , label.to(device)
    pred = model(img)
    loss = criterion(pred, label)
    model.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
    optimizer.step()
    
    prev_acc=0
    if (i % 500 ==0) :
#         & (i!=0)
        acc = get_accuracy(pred, label)
        print(f'{i}th batch, loss : {round(loss.item(), 4)}, last minibatch accuracy : {acc}')
        with open('train_log.txt', 'a+') as f:
            line = '-'*100 + '\n'
            log = f'{i}th minibatch, last minibatch accuracy : {acc}'
            f.write(line + log+'\n')
        if prev_acc < acc:
            torch.save(model.module.state_dict(), f'./models/{name}_{previous_iter + i}.pth')
            prev_acc = acc
             

In [None]:
# torch.save(model.module.state_dict(), f'./Nchar_clf_{i}.pth')