In [None]:
# https://dacon.io/competitions/official/235951/codeshare/6628?page=1&dtype=recent

In [None]:
!unzip -q /content/drive/MyDrive/이미지_3D.zip

In [None]:
import numpy as np
from tqdm import tqdm

import random
import h5py
import os
import csv

In [None]:
class H5DataLoader:
    def __init__(self, data_path, mode='train', img_size=256):
        self.it = 0
        self.mode = mode
        self.label = None
        if mode == 'train':
            self.off = 0
            self.label = {r['ID']: r['label'] for r in csv.DictReader(open(os.path.join(data_path, 'train.csv')))}
            self.data = h5py.File(os.path.join(data_path, "train.h5"), 'r')
        elif mode == 'test':
            self.off = 50000
            self.data = h5py.File(os.path.join(data_path, "test.h5"), 'r')
        self.img_size = img_size

    def __iter__(self):
        return self

    def __next__(self):
        index = self.it
        if index >= len(self.data):
            raise StopIteration
        data = self.data[str(index + self.off)]
        label = index + self.off
        if self.mode == 'train':
            label = self.label[str(index)]
            label = np.array(label).astype(int)
        self.it += 1
        return str(index + self.off), data, label

    def __len__(self):
        return len(self.data)

In [None]:
def h5_to_np(data_path, save_path):
    [os.makedirs("%s/train/%d" % (save_path, i), exist_ok=True) for i in range(10)]
    os.makedirs("%s/test" % save_path, exist_ok=True)

    train_dl = H5DataLoader(data_path, mode='train')
    bar = tqdm(train_dl, total=len(train_dl))
    for index, dd, gt in bar:
        np.save("%s/train/%d/%05d" % (save_path, gt, int(index)), dd)
    print('train data save complete')
    test_dl = H5DataLoader(data_path, mode='test')
    bar = tqdm(test_dl, total=len(test_dl))
    for index, dd, gt in bar:
        np.save("%s/test/%05d" % (save_path, int(index)), dd)

In [None]:
h5_to_np('/content', "/content/np")

100%|██████████| 50000/50000 [02:56<00:00, 283.31it/s]


train data save complete


100%|██████████| 40000/40000 [03:00<00:00, 221.27it/s]


In [None]:
import torch
from torch import Tensor
import torch.nn as nn


class DeepHoughModel(nn.Module):
    def __init__(self, input_size=10000):
        super(DeepHoughModel, self).__init__()
        self.input_size = input_size
        self.l1 = nn.Conv1d(3, 16, (100, ), (5, ))
        self.l2 = nn.Conv1d(16, 64, (100, ), (5, ))
        self.l3 = nn.Conv1d(64, 128, (10, ), (5, ))
        self.l4 = nn.Conv1d(128, 256, (10, ), (5, ))
        self.l5 = nn.Conv1d(256, 512, (10, ), (5, ))
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, 3)
        self.active = nn.ReLU()
        self.fc_active = nn.Tanh()

    def forward(self, x: Tensor) -> Tensor:
        c = x.permute(0, 2, 1)
        c = self.l1(c)
        c = self.active(c)
        c = self.l2(c)
        c = self.active(c)
        c = self.l3(c)
        c = self.active(c)
        c = self.l4(c)
        c = self.active(c)
        c = self.l5(c)
        c = self.active(c)
        feature = torch.mean(c, dim=2)
        c = self.fc1(feature)
        c = self.fc_active(c)
        pred = self.fc2(c)

        return pred

In [None]:
def derotation(a, b, c, dots):
    def _rotation(a, b, c, dots):
        mx = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]])
        my = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]])
        mz = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]])
        m = np.dot(np.dot(mx, my), mz)
        dots = np.dot(dots, m.T)
        return dots
    dot = _rotation(0, 0, c, dots)
    dot = _rotation(0, b, 0, dot)
    dot = _rotation(a, 0, 0, dot)
    return dot


def viz_result(data, pred, gt):
    b_size = data.shape[0]
    in_list = []
    pred_list = []
    gt_list = []
    for i in range(b_size):
        dd = data[i].cpu().detach().numpy()
        in_list.append(data2img(dd.copy()))
        pp = pred[i].cpu().detach().numpy()
        gg = gt[i]
        x, y, z = pp
        # x *= np.pi*2
        # y *= np.pi*2
        # z *= np.pi*2
        result_data = derotation(-x, -y, -z, dd.copy())
        result_img = data2img(result_data)
        pred_list.append(result_img)
        x, y, z = gg
        # x *= np.pi*2
        # y *= np.pi*2
        # z *= np.pi*2
        gt_data = derotation(-x, -y, -z, dd.copy())
        gt_img = data2img(gt_data)
        gt_list.append(gt_img)

    in_list = np.concatenate(in_list, axis=1)
    pred_list = np.concatenate(pred_list, axis=1)
    gt_list = np.concatenate(gt_list, axis=1)
    return np.concatenate([in_list, pred_list, gt_list])


def data2img(data, img_size=224):
    w = 1
    index = np.array(data)
    index = np.array((index + 1) * (img_size // 2), dtype=int)
    img = np.zeros((img_size, img_size, img_size), dtype=float)
    for i in index:
        x, y, z = i
        # img[x, y, z] += 1
        img[x - w:x + w, y - w:y + w, z - w:z + w] += 1
    # print(np.max(img))
    img = img[img_size // 2, :, :]
    # img = cv2.blur(img, (5, 5))
    # img = cv2.blur(img, (11, 11))
    img[img > 1] = 1
    img *= 255
    return img.astype(np.uint8)

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms

class RotationLoader(Dataset):
    def __init__(self, data_path, mode='train', img_size=128):
        # data_path="F:\\Data\\mnist_3d", mode='train'
        self.mode = mode
        self.data, self.label = self.get_img_list(data_path)
        self.img_size = img_size
        self.off = None
        self.transform_3d = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.5]*img_size,
                                                                  std=[0.225]*img_size)
                                             ])
        self.transform = transforms.Compose([transforms.ToTensor()])

    def get_img_list(self, data_path):
        if self.mode == 'test':
            img_list = glob(os.path.join(data_path, 'test', '*.npy'))
            label_list = [int(img_l.split(os.sep)[-1].split('.')[0]) for img_l in img_list]
            return img_list, label_list
        else:
            img_list = []
            label_list = []
            for i in range(10):
                sub_path = os.path.join(data_path, 'train', '%d' % i)
                sub_img_list = [os.path.join(sub_path, s) for s in os.listdir(sub_path) if 'npy' in s]
                sub_img_num = len(sub_img_list)
                sub_img_list.sort()
                if self.mode == 'train':
                    sub_img_list = sub_img_list[:int(sub_img_num*0.9)]
                elif self.mode == 'val':
                    sub_img_list = sub_img_list[int(sub_img_num*0.9):]
                else:
                    raise AssertionError
                sub_label = [i for _ in range(len(sub_img_list))]
                img_list += sub_img_list
                label_list += sub_label

            return img_list, label_list

    @staticmethod
    def rotation(a, b, c, dots):
        def _rotation(a, b, c, dots):
            mx = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]])
            my = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]])
            mz = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]])
            m = np.dot(np.dot(mx, my), mz)
            dots = np.dot(dots, m.T)
            return dots
        dot = _rotation(a, 0, 0, dots)
        dot = _rotation(0, b, 0, dot)
        dot = _rotation(0, 0, c, dot)
        return dot

    @staticmethod
    def derotation(a, b, c, dots):
        def _derotation(a, b, c, dots):
            mx = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]])
            my = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]])
            mz = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]])
            m = np.dot(np.dot(mx, my), mz)
            dots = np.dot(dots, m.T)
            return dots

        dot = _derotation(0, 0, c, dots)
        dot = _derotation(0, b, 0, dot)
        dot = _derotation(a, 0, 0, dot)
        return dot


    def random_rotation(self, data):
        # 45 -np.pi/4
        # off = np.pi*2
        if self.off is None:
            off = 0.5*np.pi
        else:
            off = self.off
        x, y, z = 2*np.random.rand(3) - 1
        x *= off
        y *= off
        z *= off
        return self.rotation(x, y, z, data), [x, y, z]

    def cvt_data(self, data):
        data_np = np.array(data[:, 14, :], dtype=np.float32)
        data_np -= np.min(data_np)
        data_np /= np.max(data_np)
        img = np.array(data_np*255, dtype=np.uint8)
        return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    def data2img(self, data):
        w = 1
        index = np.array(data)
        index = np.array((index + 1) * (self.img_size//2), dtype=int)
        img = np.zeros((self.img_size, self.img_size, self.img_size), dtype=float)
        for i in index:
            x, y, z = i
            img[x-w:x+w, y-w:y+w, z-w:z+w] += 1
        # print(np.max(img))
        img /= np.max(img)
        img *= 255
        return img.astype(np.uint8)

    def sampling(self, data):
        data_len = data.shape[0]
        try:
            idx = np.random.choice(np.array(range(0, data_len)), 10000, False)
        except:
            idx = np.random.choice(np.array(range(0, data_len)), 10000)
        idx.sort()
        feature = data[idx]
        # sample_idx = torch.tensor(sample(range(0, data.shape[0]), 10000))
        # feature = torch.index_select(data, 1, sample_idx)
        return feature

    def __getitem__(self, index):
        # index = 0
        data = self.data[index]
        label = self.label[index]
        try:
            data = np.load(data)
        except Exception as e:
            print('='*100)
            print(index, e)
            print('='*100)
            index = 0
            data = self.data[index]
            data = np.load(data)
            label = self.label[index]

        data, rota = self.random_rotation(data)
        data = self.sampling(data.astype(np.float32))

        # data = self.transform(data)[0]
        # print(data.shape)
        return data, label, np.array(rota, dtype=float)

    def __len__(self):
        return len(self.data)


class TestRotationLoader(RotationLoader):
    def __init__(self, data_path, mode='train', img_size=128):
        # data_path="F:\\Data\\mnist_3d", mode='train'
        super().__init__(data_path, mode, img_size)
        self.mode = mode
        self.data, self.label = self.get_img_list(data_path)
        self.img_size = img_size
        self.transform_3d = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.5]*img_size,
                                                                  std=[0.225]*img_size)
                                             ])
        self.transform = transforms.Compose([transforms.ToTensor()])

    def random_rotation(self, data, off=None):
        # 45 -np.pi/4
        # off = np.pi*2
        if off is None:
            off = np.pi * 0.5
        x, y, z = np.random.rand(3)
        # x, y, z = 1.0, 1.0, 1.0
        x *= off
        y *= off
        z *= off
        return self.rotation(x, y, z, data), [x, y, z]

    def __getitem__(self, index):
        # index = 0
        data = self.data[index]
        label = self.label[index]
        try:
            data = np.load(data)
        except Exception as e:
            print('='*100)
            print(index, e)
            print('='*100)
            index = 0
            data = self.data[index]
            data = np.load(data)
            label = self.label[index]

        data, rota = self.random_rotation(data, np.pi*0.2)
        data = self.sampling(data.astype(np.float32))

        return data, label

    def __len__(self):
        return len(self.data)


class TrainDataLoader(RotationLoader):
    def __init__(self, data_path):
        super().__init__(data_path)
        self.data, self.label = self.get_img_list(data_path)
        self.idx = 0

    def get_img_list(self, data_path):
        img_list = []
        label_list = []
        for i in range(10):
            sub_path = os.path.join(data_path, 'train', '%d' % i)
            sub_img_list = [os.path.join(sub_path, s) for s in os.listdir(sub_path) if 'npy' in s]
            sub_img_list.sort()
            sub_label = [i for _ in range(len(sub_img_list))]
            img_list += sub_img_list
            label_list += sub_label

        return img_list, label_list

    def __getitem__(self, index):
        data_name = self.data[index]
        label = self.label[index]
        data = np.load(data_name)
        data = self.sampling(data.astype(np.float32))
        name = os.path.split(data_name)[-1].split('.')[0]
        return data, label, name

In [None]:
import os
import cv2
import numpy as np
from glob import glob
from tqdm import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

In [None]:
def train_rotation(data_path, save_path):
    os.makedirs(save_path, exist_ok=True)

    model = DeepHoughModel()
    model = model.cuda()

    train_dl = RotationLoader(data_path, mode='train', img_size=64)
    train_loader = DataLoader(dataset=train_dl, batch_size=8, num_workers=4, shuffle=False)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, amsgrad=True, weight_decay=1e-4)  # default as 0.0001
    val_dl = RotationLoader(data_path, mode='val', img_size=64)
    val_loader = DataLoader(dataset=val_dl, batch_size=8, num_workers=4, shuffle=True)

    for epoch in range(50):
        print('train', epoch)
        model.train()
        pbar = tqdm(train_loader, total=len(train_loader))
        total_loss = 0
        for i, (data, label, gt) in enumerate(pbar):
            optimizer.zero_grad()
            pred = model(data.cuda())
            distance = gt.cuda() - pred
            l2_loss = torch.sum(torch.multiply(distance, distance))
            l1_loss = torch.sum(torch.abs(distance))
            loss = l1_loss + l2_loss
            total_loss += float(loss)
            loss.backward()
            optimizer.step()
            pbar.set_postfix({'loss': '%.4f (%.4f, %.4f)' % (loss, l1_loss, l2_loss)})
            if i % 3000 == 0:
                result_img = viz_result(data, pred, gt)
                cv2.imwrite(os.path.join(save_path, "train_%d_%d.png" % (epoch, i//3000)), result_img)
        total_loss /= len(train_loader)

        val_loss = 0
        with torch.no_grad():
            for i, (data, label, gt) in enumerate(val_loader):
                pred = model(data.cuda())
                distance = gt.cuda() - pred
                # distance = torch.cos(gt.cuda()) - torch.cos(pred)
                val_loss += torch.sum(torch.multiply(distance, distance))
                if i % 300 == 0:
                    result_img = viz_result(data, pred, gt)
                    cv2.imwrite(os.path.join(save_path, "val_%d_%d.png" % (epoch, i//300)), result_img)

        model_save_path = os.path.join(save_path, "%d_rotation1d_%.4f_%.4f.pth" % (epoch, total_loss, val_loss/len(val_loader)))
        torch.save({'weight': model.state_dict()}, model_save_path)
        print("#"*100)
        print("[%d] validation loss : %.4f" % (epoch, val_loss/len(val_loader)), end=' ')
        print("#"*100)

In [None]:
np_data_path = "/content/np"
save_model_path = "/content/model/rotatio"
result_data_path = "/content/img"

In [None]:
train_rotation(np_data_path, save_model_path)

train 0


100%|██████████| 5625/5625 [03:28<00:00, 26.93it/s, loss=0.8765 (0.7771, 0.0995)]


####################################################################################################
[0] validation loss : 1.3769 ####################################################################################################
train 1


100%|██████████| 5625/5625 [03:41<00:00, 25.37it/s, loss=0.2672 (0.2555, 0.0117)]


####################################################################################################
[1] validation loss : 0.6398 ####################################################################################################
train 2


100%|██████████| 5625/5625 [03:39<00:00, 25.61it/s, loss=0.3938 (0.3707, 0.0231)]


####################################################################################################
[2] validation loss : 0.4525 ####################################################################################################
train 3


100%|██████████| 5625/5625 [03:38<00:00, 25.74it/s, loss=1.4239 (1.0476, 0.3764)]


####################################################################################################
[3] validation loss : 0.3840 ####################################################################################################
train 4


100%|██████████| 5625/5625 [03:39<00:00, 25.60it/s, loss=0.5169 (0.4608, 0.0560)]


####################################################################################################
[4] validation loss : 0.3307 ####################################################################################################
train 5


100%|██████████| 5625/5625 [03:39<00:00, 25.61it/s, loss=0.5497 (0.4754, 0.0743)]


####################################################################################################
[5] validation loss : 0.3073 ####################################################################################################
train 6


100%|██████████| 5625/5625 [03:42<00:00, 25.25it/s, loss=0.3099 (0.2925, 0.0174)]


####################################################################################################
[6] validation loss : 0.2489 ####################################################################################################
train 7


100%|██████████| 5625/5625 [03:02<00:00, 30.75it/s, loss=0.1157 (0.1131, 0.0026)]


####################################################################################################
[7] validation loss : 0.2572 ####################################################################################################
train 8


100%|██████████| 5625/5625 [03:04<00:00, 30.50it/s, loss=0.2071 (0.2005, 0.0066)]


####################################################################################################
[8] validation loss : 0.2030 ####################################################################################################
train 9


100%|██████████| 5625/5625 [02:58<00:00, 31.49it/s, loss=0.1926 (0.1860, 0.0067)]


####################################################################################################
[9] validation loss : 0.2300 ####################################################################################################
train 10


100%|██████████| 5625/5625 [02:58<00:00, 31.43it/s, loss=0.4925 (0.4383, 0.0541)]


####################################################################################################
[10] validation loss : 0.2341 ####################################################################################################
train 11


100%|██████████| 5625/5625 [02:55<00:00, 32.00it/s, loss=0.3524 (0.3351, 0.0173)]


####################################################################################################
[11] validation loss : 0.1555 ####################################################################################################
train 12


100%|██████████| 5625/5625 [02:55<00:00, 32.11it/s, loss=0.2244 (0.2173, 0.0071)]


####################################################################################################
[12] validation loss : 0.2529 ####################################################################################################
train 13


100%|██████████| 5625/5625 [02:56<00:00, 31.82it/s, loss=0.2887 (0.2791, 0.0097)]


####################################################################################################
[13] validation loss : 0.1875 ####################################################################################################
train 14


100%|██████████| 5625/5625 [02:57<00:00, 31.75it/s, loss=2.7126 (1.5958, 1.1168)]


####################################################################################################
[14] validation loss : 0.1873 ####################################################################################################
train 15


100%|██████████| 5625/5625 [02:56<00:00, 31.92it/s, loss=0.2468 (0.2388, 0.0080)]


####################################################################################################
[15] validation loss : 0.1614 ####################################################################################################
train 16


100%|██████████| 5625/5625 [02:57<00:00, 31.75it/s, loss=0.2903 (0.2792, 0.0111)]


####################################################################################################
[16] validation loss : 0.1713 ####################################################################################################
train 17


100%|██████████| 5625/5625 [02:57<00:00, 31.72it/s, loss=0.2157 (0.2100, 0.0057)]


####################################################################################################
[17] validation loss : 0.2150 ####################################################################################################
train 18


100%|██████████| 5625/5625 [03:01<00:00, 31.06it/s, loss=0.2126 (0.2057, 0.0070)]


####################################################################################################
[18] validation loss : 0.1894 ####################################################################################################
train 19


100%|██████████| 5625/5625 [03:03<00:00, 30.66it/s, loss=0.2055 (0.1986, 0.0069)]


####################################################################################################
[19] validation loss : 0.2038 ####################################################################################################
train 20


100%|██████████| 5625/5625 [02:58<00:00, 31.45it/s, loss=0.2577 (0.2414, 0.0164)]


####################################################################################################
[20] validation loss : 0.2047 ####################################################################################################
train 21


100%|██████████| 5625/5625 [02:55<00:00, 32.04it/s, loss=0.1949 (0.1887, 0.0062)]


####################################################################################################
[21] validation loss : 0.1240 ####################################################################################################
train 22


100%|██████████| 5625/5625 [02:58<00:00, 31.58it/s, loss=0.4515 (0.4197, 0.0318)]


####################################################################################################
[22] validation loss : 0.1290 ####################################################################################################
train 23


100%|██████████| 5625/5625 [02:58<00:00, 31.43it/s, loss=0.1784 (0.1723, 0.0060)]


####################################################################################################
[23] validation loss : 0.1560 ####################################################################################################
train 24


100%|██████████| 5625/5625 [02:57<00:00, 31.66it/s, loss=0.1998 (0.1928, 0.0070)]


####################################################################################################
[24] validation loss : 0.1715 ####################################################################################################
train 25


100%|██████████| 5625/5625 [02:58<00:00, 31.43it/s, loss=0.1989 (0.1931, 0.0058)]


####################################################################################################
[25] validation loss : 0.1231 ####################################################################################################
train 26


100%|██████████| 5625/5625 [02:59<00:00, 31.36it/s, loss=0.2512 (0.2402, 0.0110)]


####################################################################################################
[26] validation loss : 0.1500 ####################################################################################################
train 27


100%|██████████| 5625/5625 [02:56<00:00, 31.78it/s, loss=0.1674 (0.1625, 0.0049)]


####################################################################################################
[27] validation loss : 0.1561 ####################################################################################################
train 28


100%|██████████| 5625/5625 [02:57<00:00, 31.68it/s, loss=0.1294 (0.1257, 0.0038)]


####################################################################################################
[28] validation loss : 0.1414 ####################################################################################################
train 29


100%|██████████| 5625/5625 [02:55<00:00, 32.03it/s, loss=0.2262 (0.2173, 0.0089)]


####################################################################################################
[29] validation loss : 0.1989 ####################################################################################################
train 30


100%|██████████| 5625/5625 [02:58<00:00, 31.46it/s, loss=0.1957 (0.1912, 0.0046)]


####################################################################################################
[30] validation loss : 0.1463 ####################################################################################################
train 31


100%|██████████| 5625/5625 [02:57<00:00, 31.60it/s, loss=0.2275 (0.2200, 0.0075)]


####################################################################################################
[31] validation loss : 0.1659 ####################################################################################################
train 32


100%|██████████| 5625/5625 [02:57<00:00, 31.63it/s, loss=0.1089 (0.1067, 0.0022)]


####################################################################################################
[32] validation loss : 0.1344 ####################################################################################################
train 33


100%|██████████| 5625/5625 [02:59<00:00, 31.35it/s, loss=0.3165 (0.3020, 0.0144)]


####################################################################################################
[33] validation loss : 0.2404 ####################################################################################################
train 34


100%|██████████| 5625/5625 [02:57<00:00, 31.66it/s, loss=0.1356 (0.1326, 0.0030)]


####################################################################################################
[34] validation loss : 0.1239 ####################################################################################################
train 35


100%|██████████| 5625/5625 [02:59<00:00, 31.37it/s, loss=0.2308 (0.2226, 0.0082)]


####################################################################################################
[35] validation loss : 0.1377 ####################################################################################################
train 36


100%|██████████| 5625/5625 [02:58<00:00, 31.47it/s, loss=0.2330 (0.2249, 0.0081)]


####################################################################################################
[36] validation loss : 0.0866 ####################################################################################################
train 37


100%|██████████| 5625/5625 [02:57<00:00, 31.64it/s, loss=0.1542 (0.1507, 0.0035)]


####################################################################################################
[37] validation loss : 0.1184 ####################################################################################################
train 38


100%|██████████| 5625/5625 [02:58<00:00, 31.54it/s, loss=0.1638 (0.1599, 0.0040)]


####################################################################################################
[38] validation loss : 0.1367 ####################################################################################################
train 39


100%|██████████| 5625/5625 [02:58<00:00, 31.54it/s, loss=0.2530 (0.2382, 0.0149)]


####################################################################################################
[39] validation loss : 0.1615 ####################################################################################################
train 40


100%|██████████| 5625/5625 [02:59<00:00, 31.40it/s, loss=0.2003 (0.1941, 0.0062)]


####################################################################################################
[40] validation loss : 0.1316 ####################################################################################################
train 41


100%|██████████| 5625/5625 [02:59<00:00, 31.37it/s, loss=0.4530 (0.4174, 0.0356)]


####################################################################################################
[41] validation loss : 0.1498 ####################################################################################################
train 42


100%|██████████| 5625/5625 [02:56<00:00, 31.84it/s, loss=0.3045 (0.2888, 0.0157)]


####################################################################################################
[42] validation loss : 0.1556 ####################################################################################################
train 43


100%|██████████| 5625/5625 [02:58<00:00, 31.55it/s, loss=0.2484 (0.2408, 0.0075)]


####################################################################################################
[43] validation loss : 0.1565 ####################################################################################################
train 44


100%|██████████| 5625/5625 [02:57<00:00, 31.61it/s, loss=0.1641 (0.1599, 0.0043)]


####################################################################################################
[44] validation loss : 0.1278 ####################################################################################################
train 45


100%|██████████| 5625/5625 [02:57<00:00, 31.67it/s, loss=0.1634 (0.1596, 0.0038)]


####################################################################################################
[45] validation loss : 0.1418 ####################################################################################################
train 46


100%|██████████| 5625/5625 [02:58<00:00, 31.52it/s, loss=0.1964 (0.1917, 0.0047)]


####################################################################################################
[46] validation loss : 0.1473 ####################################################################################################
train 47


100%|██████████| 5625/5625 [02:58<00:00, 31.45it/s, loss=0.3950 (0.3746, 0.0204)]


####################################################################################################
[47] validation loss : 0.1625 ####################################################################################################
train 48


100%|██████████| 5625/5625 [02:57<00:00, 31.67it/s, loss=0.0783 (0.0773, 0.0010)]


####################################################################################################
[48] validation loss : 0.1651 ####################################################################################################
train 49


100%|██████████| 5625/5625 [03:00<00:00, 31.22it/s, loss=0.1354 (0.1322, 0.0032)]


####################################################################################################
[49] validation loss : 0.1147 ####################################################################################################


In [None]:
def cvt_train_data_2d(data_path, save_path):
  os.makedirs(save_path, exist_ok=True)
  [os.makedirs("%s/%d" % (save_path, i), exist_ok=True) for i in range(10)]
  train_dl = TrainDataLoader(data_path)
  train_loader = DataLoader(dataset=train_dl, batch_size=1, num_workers=4, shuffle=False)
  pbar = tqdm(train_loader, total=len(train_loader))
  for i, (data, label, name) in enumerate(pbar):
      # print(data.shape, label, name)
      img = data2img(data[0], 64)
      cv2.imwrite(os.path.join(save_path, '%d' % label[0], name[0] + '.png'), img)

In [None]:
cvt_train_data_2d(np_data_path, result_data_path + '/train')


100%|██████████| 50000/50000 [1:04:10<00:00, 12.98it/s]


In [None]:
def inference_cascade_result(weight_path, data_path, save_path):
    pre_weight = torch.load(weight_path)

    os.makedirs(save_path, exist_ok=True)
    [os.remove(p) for p in glob(save_path + '/*.png')]

    model = DeepHoughModel()
    model = model.cuda()
    init_weight = model.state_dict()
    init_weight.update(pre_weight['weight'])
    model.load_state_dict(init_weight)
    model = model.cuda()
    model.eval()

    test_dl = TestRotationLoader(data_path, mode='test', img_size=64)
    test_loader = DataLoader(dataset=test_dl, batch_size=1, num_workers=32, shuffle=False)

    with torch.no_grad():
        pbar = tqdm(test_loader, total=len(test_loader))
        for i, (data, label) in enumerate(pbar):
            dd = data.cpu().detach().numpy()
            for j in range(9):
                pred = model(torch.from_numpy(dd.astype(np.float32)).cuda())
                pp = pred[0].cpu().detach().numpy()
                x, y, z = pp
                dd = test_dl.derotation(-x, -y, -z, dd)
            result_img = data2img(dd[0], 64)
            cv2.imwrite(os.path.join(save_path, '%d.png' % label), result_img)

In [None]:
for i in range(10):
    ww = glob("%s/%d_*.pth" % (save_model_path, 3*i + 22))[0]
    ss = "%s/test_%d" % (result_data_path, i)
    print(ww, ss)
    inference_cascade_result(ww, np_data_path, ss)

/content/model/rotatio/22_rotation1d_1.0856_0.1290.pth /content/img/test_0


100%|██████████| 40000/40000 [1:17:48<00:00,  8.57it/s]


/content/model/rotatio/25_rotation1d_1.0701_0.1231.pth /content/img/test_1


100%|██████████| 40000/40000 [1:17:57<00:00,  8.55it/s]


/content/model/rotatio/28_rotation1d_1.0161_0.1414.pth /content/img/test_2


100%|██████████| 40000/40000 [1:17:15<00:00,  8.63it/s]


/content/model/rotatio/31_rotation1d_0.9952_0.1659.pth /content/img/test_3


100%|██████████| 40000/40000 [1:17:05<00:00,  8.65it/s]


/content/model/rotatio/34_rotation1d_0.9571_0.1239.pth /content/img/test_4


100%|██████████| 40000/40000 [1:17:58<00:00,  8.55it/s]


/content/model/rotatio/37_rotation1d_0.9189_0.1184.pth /content/img/test_5


100%|██████████| 40000/40000 [1:17:24<00:00,  8.61it/s]


/content/model/rotatio/40_rotation1d_0.9027_0.1316.pth /content/img/test_6


100%|██████████| 40000/40000 [1:17:29<00:00,  8.60it/s]


/content/model/rotatio/43_rotation1d_0.9055_0.1565.pth /content/img/test_7


100%|██████████| 40000/40000 [1:15:42<00:00,  8.80it/s]


/content/model/rotatio/46_rotation1d_0.9080_0.1473.pth /content/img/test_8


100%|██████████| 40000/40000 [1:15:36<00:00,  8.82it/s]


/content/model/rotatio/49_rotation1d_0.8979_0.1147.pth /content/img/test_9


100%|██████████| 40000/40000 [1:14:20<00:00,  8.97it/s]


In [None]:
img_data_path = "/content/img"
save_model_path = "/content/model/cnn"
save_result_path = "/content/result"

In [None]:
class RandomRotation(object):
    def __init__(self, degrees, seed=1):
        self.degrees = (-degrees, degrees)
        random.seed(seed)

    @staticmethod
    def get_params(degrees):
        angle = random.uniform(degrees[0], degrees[1])
        return angle

    def __call__(self, img):
        angle = self.get_params(self.degrees)
        return vf.rotate(img, angle, False, False, None, None)

In [None]:
class MnistRotaDatasetLabel(torch.utils.data.Dataset):
    def __init__(self, img_path, mode=None, transform=None):
        self.mode = mode
        name_list, img_list, label_list = [], [], []
        if mode == 'test':
            for n in range(50000, 90000):
                name_list.append(n)
                img_list.append(os.path.join(img_path, '%s.png' % n))
                label_list.append(0)
        else:
            name_list, img_list, label_list = self.get_img_list(img_path)
        self.name_data = name_list
        self.x_data = img_list
        self.y_data = label_list
        self.transform = transform

    def get_img_list(self, data_path):
        img_list = []
        label_list = []
        name_list = []
        for i in range(10):
            sub_path = os.path.join(data_path, 'train/%d' % i)
            sub_img_list = [os.path.join(sub_path, s) for s in os.listdir(sub_path) if 'png' in s]
            sub_img_num = len(sub_img_list)
            sub_img_list.sort()
            if self.mode == 'train':
                sub_img_list = sub_img_list[:int(sub_img_num*0.9)]
            elif self.mode == 'val':
                sub_img_list = sub_img_list[int(sub_img_num*0.9):]
            else:
                raise AssertionError
            sub_name = [p.split(os.sep)[-1] for p in sub_img_list]
            name_list += sub_name
            sub_label = [i for _ in range(len(sub_img_list))]
            img_list += sub_img_list
            label_list += sub_label

        return name_list, img_list, label_list

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        x = cv2.imread(self.x_data[idx], 0)
        x[x > 1] = 255
        x = cv2.resize(x, (28, 28))
        x = np.reshape(x, (28, 28, 1)).astype(np.float32)
        y = self.y_data[idx]
        n = self.name_data[idx]
        x = transforms.ToPILImage()(x)
        if self.transform:
            x = self.transform(x)
        x = transforms.ToTensor()(np.array(x)/255)
        return x, y, n

In [None]:
class ModelM3(nn.Module):
    def __init__(self):
        super(ModelM3, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, bias=False)       # output becomes 26x26
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 48, 3, bias=False)      # output becomes 24x24
        self.conv2_bn = nn.BatchNorm2d(48)
        self.conv3 = nn.Conv2d(48, 64, 3, bias=False)      # output becomes 22x22
        self.conv3_bn = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 80, 3, bias=False)      # output becomes 20x20
        self.conv4_bn = nn.BatchNorm2d(80)
        self.conv5 = nn.Conv2d(80, 96, 3, bias=False)      # output becomes 18x18
        self.conv5_bn = nn.BatchNorm2d(96)
        self.conv6 = nn.Conv2d(96, 112, 3, bias=False)     # output becomes 16x16
        self.conv6_bn = nn.BatchNorm2d(112)
        self.conv7 = nn.Conv2d(112, 128, 3, bias=False)    # output becomes 14x14
        self.conv7_bn = nn.BatchNorm2d(128)
        self.conv8 = nn.Conv2d(128, 144, 3, bias=False)    # output becomes 12x12
        self.conv8_bn = nn.BatchNorm2d(144)
        self.conv9 = nn.Conv2d(144, 160, 3, bias=False)    # output becomes 10x10
        self.conv9_bn = nn.BatchNorm2d(160)
        self.conv10 = nn.Conv2d(160, 176, 3, bias=False)   # output becomes 8x8
        self.conv10_bn = nn.BatchNorm2d(176)
        self.fc1 = nn.Linear(11264, 10, bias=False)
        self.fc1_bn = nn.BatchNorm1d(10)

    def get_logits(self, x):
        x = (x - 0.5) * 2.0
        conv1 = F.relu(self.conv1_bn(self.conv1(x)))
        conv2 = F.relu(self.conv2_bn(self.conv2(conv1)))
        conv3 = F.relu(self.conv3_bn(self.conv3(conv2)))
        conv4 = F.relu(self.conv4_bn(self.conv4(conv3)))
        conv5 = F.relu(self.conv5_bn(self.conv5(conv4)))
        conv6 = F.relu(self.conv6_bn(self.conv6(conv5)))
        conv7 = F.relu(self.conv7_bn(self.conv7(conv6)))
        conv8 = F.relu(self.conv8_bn(self.conv8(conv7)))
        conv9 = F.relu(self.conv9_bn(self.conv9(conv8)))
        conv10 = F.relu(self.conv10_bn(self.conv10(conv9)))
        flat1 = torch.flatten(conv10.permute(0, 2, 3, 1), 1)
        logits = self.fc1_bn(self.fc1(flat1))
        return logits

    def forward(self, x):
        logits = self.get_logits(x)
        return F.log_softmax(logits, dim=1)


In [None]:

class EMA:
    def __init__(self, model, decay):
        self.decay = decay
        self.shadow = {}
        self.original = {}

        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def __call__(self, model, num_updates):
        decay = min(self.decay, (1.0 + num_updates) / (10.0 + num_updates))
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - decay) * param.data + decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def assign(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.original[name] = param.data.clone()
                param.data = self.shadow[name]

    def resume(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                param.data = self.original[name]


In [None]:

import warnings
warnings.filterwarnings("ignore")
def train(data_path, save_path, p_seed=0, p_epochs=150):
    # random number generator seed ------------------------------------------------#
    SEED = p_seed
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    np.random.seed(SEED)

    # number of epochs ------------------------------------------------------------#
    NUM_EPOCHS = p_epochs

    # file names ------------------------------------------------------------------#
    os.makedirs("%s" % save_path, exist_ok=True)
    MODEL_FILE = str("%s/model%03d.pth" % (save_path, SEED))

    # enable GPU usage ------------------------------------------------------------#
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda == False:
        print("WARNING: CPU will be used for training.")
        exit(0)

    # data augmentation methods ---------------------------------------------------#
    transform = transforms.Compose([
        RandomRotation(20, seed=SEED),
        transforms.RandomAffine(0, translate=(0.2, 0.2)),
        ])

    # data loader -----------------------------------------------------------------#
    train_dataset = MnistRotaDatasetLabel(data_path, mode='train', transform=transform)
    val_dataset = MnistRotaDatasetLabel(data_path, mode='val', transform=None)
    train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=120, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, num_workers=4, batch_size=120, shuffle=True)

    # model selection -------------------------------------------------------------#
    model = ModelM3().to(device)

    # hyperparameter selection ----------------------------------------------------#
    ema = EMA(model, decay=0.999)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

    # global variables ------------------------------------------------------------#
    g_step = 0
    max_correct = 0
    min_loss = 10

    # training and evaluation loop ------------------------------------------------#

    for epoch in tqdm(range(NUM_EPOCHS)):
        #--------------------------------------------------------------------------#
        # train process                                                            #
        #--------------------------------------------------------------------------#
        model.train()
        train_loss = 0
        train_corr = 0
        for batch_idx, (data, target, name) in enumerate(train_loader):
            data, target = data.to(device), target.to(device, dtype=torch.int64)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            train_pred = output.argmax(dim=1, keepdim=True)
            train_corr += train_pred.eq(target.view_as(train_pred)).sum().item()
            train_loss += F.nll_loss(output, target, reduction='sum').item()
            loss.backward()
            optimizer.step()
            g_step += 1
            ema(model, g_step)
#             if batch_idx % 100 == 0:
#                 print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                     epoch, batch_idx * len(data), len(train_loader.dataset),
#                     100. * batch_idx / len(train_loader), loss.item()))
        train_loss /= len(train_loader.dataset)

        if (train_loss < min_loss):
            torch.save(model.state_dict(), MODEL_FILE)
            min_loss = train_loss
#             print("Save Model Best train loss %.4f" % train_loss)
        #--------------------------------------------------------------------------#
        # test process                                                             #
        #--------------------------------------------------------------------------#
        model.eval()
        ema.assign(model)
        val_loss = 0
        val_correct = 0
        val_pred = np.zeros(0)
        val_target = np.zeros(0)
        with torch.no_grad():
            for data, target, name in val_loader:
                data, target = data.to(device), target.to(device,  dtype=torch.int64)
                output = model(data)
                val_loss += F.nll_loss(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                val_pred = np.append(val_pred, pred.cpu().numpy())
                val_target = np.append(val_target, target.cpu().numpy())
                val_correct += pred.eq(target.view_as(pred)).sum().item()
        val_loss /= len(val_loader.dataset)
        val_accuracy = 100 * val_correct / len(val_loader.dataset)
#         print('Val set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
#             val_loss, val_correct, len(val_loader.dataset), val_accuracy))
        ema.resume(model)

        lr_scheduler.step()

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as vf

In [None]:
for i in range(10):
    print('train', i)
    train(img_data_path,
          save_model_path,
          p_seed=i,
          p_epochs=100)

train 0


100%|██████████| 100/100 [24:04<00:00, 14.45s/it]


train 1


100%|██████████| 100/100 [24:01<00:00, 14.42s/it]


train 2


100%|██████████| 100/100 [24:03<00:00, 14.43s/it]


train 3


100%|██████████| 100/100 [24:05<00:00, 14.45s/it]


train 4


100%|██████████| 100/100 [24:13<00:00, 14.53s/it]


train 5


100%|██████████| 100/100 [24:05<00:00, 14.46s/it]


train 6


100%|██████████| 100/100 [24:08<00:00, 14.48s/it]


train 7


100%|██████████| 100/100 [24:10<00:00, 14.51s/it]


train 8


100%|██████████| 100/100 [24:14<00:00, 14.54s/it]


train 9


100%|██████████| 100/100 [24:12<00:00, 14.52s/it]
