In [3]:
import h5py # .h5 파일을 읽기 위한 패키지
import random
import pandas as pd
import numpy as np
import os
import glob
import math
import timm
from itertools import permutations

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from tqdm.auto import tqdm

from sklearn.metrics import accuracy_score

import warnings

from cnn_3d import effi, resnet, resneXt

warnings.filterwarnings(action='ignore') 

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


In [5]:
class CustomDataset(Dataset):
    def __init__(self, id_list, label_list, point_list):
        self.id_list = id_list
        self.label_list = label_list
        self.point_list = point_list
        
    def __getitem__(self, index):
        image_id = self.id_list[index]
        
        # h5파일을 바로 접근하여 사용하면 학습 속도가 병목 현상으로 많이 느릴 수 있습니다.
        points = self.point_list[str(image_id)][:]
        
        if self.label_list is not None:
            rand_degree = np.random.choice([-np.pi/12, -np.pi/8, -np.pi/6, -np.pi/4, -np.pi/3], 3)
            rotated_points = self.rotate(rand_degree[0], rand_degree[1], rand_degree[2], points)
            image = self.get_vector(rotated_points, x_y_z=CFG['voxel_grid'])
            label = self.label_list[index]
            return torch.Tensor(image).unsqueeze(0), label
        else:
            image = self.get_vector(points, x_y_z=CFG['voxel_grid'])
            return torch.Tensor(image).unsqueeze(0)
    
    def rotate(self, a, b, c, dots):
        mx = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]])
        my = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]])
        mz = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]])
        m = np.dot(np.dot(mx,my),mz)
        dots = np.dot(dots, m.T)
        return dots

    
    def get_vector(self, points, x_y_z=[16, 16, 16]):
        # 3D Points -> [16,16,16]
        xyzmin = np.min(points, axis=0) - 0.001
        xyzmax = np.max(points, axis=0) + 0.001

        diff = max(xyzmax-xyzmin) - (xyzmax-xyzmin)
        xyzmin = xyzmin - diff / 2
        xyzmax = xyzmax + diff / 2

        segments = []
        shape = []

        for i in range(3):
            # note the +1 in num 
            if type(x_y_z[i]) is not int:
                raise TypeError("x_y_z[{}] must be int".format(i))
            s, step = np.linspace(xyzmin[i], xyzmax[i], num=(x_y_z[i] + 1), retstep=True)
            segments.append(s)
            shape.append(step)

        n_voxels = x_y_z[0] * x_y_z[1] * x_y_z[2]
        n_x = x_y_z[0]
        n_y = x_y_z[1]
        n_z = x_y_z[2]

        structure = np.zeros((len(points), 4), dtype=int)
        structure[:,0] = np.searchsorted(segments[0], points[:,0]) - 1
        structure[:,1] = np.searchsorted(segments[1], points[:,1]) - 1
        structure[:,2] = np.searchsorted(segments[2], points[:,2]) - 1

        # i = ((y * n_x) + x) + (z * (n_x * n_y))
        structure[:,3] = ((structure[:,1] * n_x) + structure[:,0]) + (structure[:,2] * (n_x * n_y)) 

        vector = np.zeros(n_voxels)
        count = np.bincount(structure[:,3])
        vector[:len(count)] = count

        vector = vector.reshape(n_z, n_y, n_x)
        return vector

    def __len__(self):
        return len(self.id_list)

In [6]:
class _CustomDataset(Dataset):
    def __init__(self, id_list, label_list, point_list, dim='2d', shape=224):
        self.id_list = id_list
        self.label_list = label_list
        self.point_list = point_list
        self.dim = dim
        self._shape = shape
        self.per = list(permutations([0,1,2], 2))

    def __getitem__(self, index):
        image_id = self.id_list[index]
        
        # h5파일을 바로 접근하여 사용하면 학습 속도가 병목 현상으로 많이 느릴 수 있습니다.
        points = self.point_list[str(image_id)][:]
        
        # training
        if self.label_list is not None:
            rand_degree = np.random.choice([-np.pi/12, -np.pi/8, -np.pi/6, -np.pi/4, -np.pi/3], 3)
            rotated_points = self.rotate(rand_degree[0], rand_degree[1], rand_degree[2], points)
            
            label = self.label_list[index]
            
            # image
            if self.dim == '2d' :
                image = self.sliced_section_6ch(rotated_points)  
                return torch.Tensor(image), label
            
            # vector
            else :
                image = self.get_vector(rotated_points, x_y_z=CFG['voxel_grid'])
                return torch.Tensor(image).unsqueeze(0), label
        
        # test
        else:
            # image
            if self.dim == '2d' :
                image = self.sliced_section_6ch(points)  
                return torch.Tensor(image)
            
            # vector
            else :
                image = self.get_vector(points, x_y_z=CFG['voxel_grid'])
                return torch.Tensor(image).unsqueeze(0)
    
    def rotate(self, a, b, c, dots):
        mx = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]])
        my = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]])
        mz = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]])
        m = np.dot(np.dot(mx,my),mz)
        dots = np.dot(dots, m.T)
        return dots

    
    def get_vector(self, points, x_y_z=[16, 16, 16]):
        # 3D Points -> [16,16,16]
        xyzmin = np.min(points, axis=0) - 0.001
        xyzmax = np.max(points, axis=0) + 0.001

        diff = max(xyzmax-xyzmin) - (xyzmax-xyzmin)
        xyzmin = xyzmin - diff / 2
        xyzmax = xyzmax + diff / 2

        segments = []
        shape = []

        for i in range(3):
            # note the +1 in num 
            if type(x_y_z[i]) is not int:
                raise TypeError("x_y_z[{}] must be int".format(i))
            s, step = np.linspace(xyzmin[i], xyzmax[i], num=(x_y_z[i] + 1), retstep=True)
            segments.append(s)
            shape.append(step)

        n_voxels = x_y_z[0] * x_y_z[1] * x_y_z[2]
        n_x = x_y_z[0]
        n_y = x_y_z[1]
        n_z = x_y_z[2]

        structure = np.zeros((len(points), 4), dtype=int)
        structure[:,0] = np.searchsorted(segments[0], points[:,0]) - 1
        structure[:,1] = np.searchsorted(segments[1], points[:,1]) - 1
        structure[:,2] = np.searchsorted(segments[2], points[:,2]) - 1

        # i = ((y * n_x) + x) + (z * (n_x * n_y))
        structure[:,3] = ((structure[:,1] * n_x) + structure[:,0]) + (structure[:,2] * (n_x * n_y)) 

        vector = np.zeros(n_voxels)
        count = np.bincount(structure[:,3])
        vector[:len(count)] = count

        vector = vector.reshape(n_z, n_y, n_x)
        return vector
    
    def sliced_section_6ch(self, points) :
                # 3D Points -> [16,16,16]
        xyzmin = np.min(points, axis=0) - 0.001
        xyzmax = np.max(points, axis=0) + 0.001

        diff = max(xyzmax-xyzmin) - (xyzmax-xyzmin)
        xyzmin = xyzmin - diff / 2
        xyzmax = xyzmax + diff / 2

        segments = []
        shape = []

        for i in range(3):
            # note the +1 in num 
#             print(type(self._shape))
            if type(self._shape) is not int:
                raise TypeError("shape must be int")
            s, step = np.linspace(xyzmin[i], xyzmax[i], num=(self._shape + 1), retstep=True)
            segments.append(s)
            shape.append(step)

        structure = np.zeros((len(points), 3), dtype=int)
        structure[:, 0] = np.searchsorted(segments[0], points[:,0]) - 1        
        structure[:, 1] = np.searchsorted(segments[1], points[:,1]) - 1
        structure[:, 2] = np.searchsorted(segments[2], points[:,2]) - 1
        
        # shape = C, H, W
        imgs = np.zeros((6, self._shape, self._shape))
        for idx, (i,j) in enumerate(self.per) :
            imgs[idx, -structure[:, i], structure[:, j]] = 1
        
        return imgs
    
    def __len__(self):
        return len(self.id_list)

In [7]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    best_score = 0
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for data, label in tqdm(iter(train_loader)):
            data, label = data.float().to(device), label.long().to(device)
            optimizer.zero_grad()
            
            output = model(data)
            loss = criterion(output, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
        
        if scheduler is not None:
            scheduler.step()
            
        val_loss, val_acc = validation(model, criterion, val_loader, device)
        print(f'Epoch : [{epoch}] Train Loss : [{np.mean(train_loss)}] Val Loss : [{val_loss}] Val ACC : [{val_acc}]')
        
        if best_score < val_acc:
            best_score = val_acc
            torch.save(model.state_dict(), './model/'+str(epoch)+'E-val'+str(best_score)+'-'+CFG['output'])

In [8]:
def validation(model, criterion, val_loader, device):
    model.eval()
    true_labels = []
    model_preds = []
    val_loss = []
    with torch.no_grad():
        for data, label in tqdm(iter(val_loader)):
            data, label = data.float().to(device), label.long().to(device)
            
            model_pred = model(data)
            loss = criterion(model_pred, label)
            
            val_loss.append(loss.item())
            
            model_preds += model_pred.argmax(1).detach().cpu().numpy().tolist()
            true_labels += label.detach().cpu().numpy().tolist()
    
    return np.mean(val_loss), accuracy_score(true_labels, model_preds)

In [9]:
def test(model, test_loader, device):
    model.eval()
    true_labels = []
    model_preds = []
    with torch.no_grad():
        for data, label in tqdm(iter(test_loader)):
            data, label = data.float().to(device), label.long().to(device)
            
            model_pred = model(data)
            
            model_preds += model_pred.argmax(1).detach().cpu().numpy().tolist()
            true_labels += label.detach().cpu().numpy().tolist()
    
    return accuracy_score(true_labels, model_preds)

In [10]:
class dim_change(nn.Module) :
    def __init__(self, shape_2d=224, in_channel=1) :
        super().__init__()
        self.conv3d = nn.Conv3d(in_channel, 1, 3)
        self.pool = nn.AdaptiveAvgPool3d((shape_2d))
        self.conv2d = nn.Conv2d(shape_2d, 3, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        x = self.conv3d(x)
        x = x.squeeze(1)
        x = self.pool(x)
        x = self.conv2d(x)
        return x

In [11]:
class CNN(nn.Module):
    def __init__(self, model_name, shape_2d=224, in_channel=1, dim_changer=False) :
        super().__init__()
        
        if dim_changer :
            self.dim_changer = dim_change(shape_2d, in_channel)
            self.model = timm.create_model(model_name=model_name, num_classes=10, pretrained=True)
        
        else :
            self.dim_changer = dim_changer

            if model_name.split('-')[-1] == 'efficientnet' :
                self.model = effi.EfficientNet3D.from_name(model_name, 
                                                      override_params={'num_classes': 10}, 
                                                      in_channels=1)
            elif model_name == 'resneXt' :
                self.model = resneXt.resnet101(
                                num_classes=10,
                                shortcut_type="B",
                                cardinality=CFG['voxel_grid'][0] * 2,
                                spatial_size=CFG['voxel_grid'][0],
                                sample_duration=1)
                
            elif model_name =='resnet' :
                self.model = resnet.ResNet(num_layers=CFG['num_layers'],
                               in_channels=CFG['in_channels'],
                               stride=CFG['stride'],
                               num_classes=10).to(device)
    
    def forward(self, x) :
        if self.dim_changer :
            x = self.model(self.dim_changer(x))
        else :
            x = self.model(x)
            
        return x

In [12]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

CFG = {
    'EPOCHS':5,
    'LEARNING_RATE':1e-3,
    'BATCH_SIZE':64,
    'SEED':41, 
    'output' : '6ch_img-effi_b0.pth',
    'model_name' : 'efficientnet_b0',
    'num_layers' : [3,4,6,3],
    'in_channels' : [8, 32 ,64, 128],
    'stride' : [1,1,1,1],
    'voxel_grid' : [128, 128, 128]
}

seed_everything(CFG['SEED']) # Seed 고정

In [13]:
all_df = pd.read_csv('./data/train.csv')
all_points = h5py.File('./data/train.h5', 'r')
# all_points = np.load('./data/train.npy', allow_pickle=True)

train_df = all_df.iloc[:int(len(all_df)*0.8)]
val_df = all_df.iloc[int(len(all_df)*0.8) : int(len(all_df)*0.9)]
test_df = all_df.iloc[int(len(all_df)*0.9) :]

# train_dataset = CustomDataset(train_df['ID'].values, train_df['label'].values, all_points)
# train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=0)

# val_dataset = CustomDataset(val_df['ID'].values, val_df['label'].values, all_points)
# val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False, num_workers=0)

# test_dataset = CustomDataset(test_df['ID'].values, test_df['label'].values, all_points)
# test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True, num_workers=0)
train_dataset = _CustomDataset(train_df['ID'].values, train_df['label'].values, all_points, dim='2d', shape=224)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=0)

val_dataset = _CustomDataset(val_df['ID'].values, val_df['label'].values, all_points, dim='2d', shape=224)
val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False, num_workers=0)

test_dataset = _CustomDataset(test_df['ID'].values, test_df['label'].values, all_points, dim='2d', shape=224)
test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True, num_workers=0)

In [14]:
# model = CNN(CFG['model_name'], shape_2d=288, in_channel=1, dim_changer=True).to(device)
model = timm.create_model(model_name=CFG['model_name'], num_classes=10, pretrained=True, in_chans=6)

In [11]:
# model = effi.EfficientNet3D.from_name(
#             'efficientnet-b3', 
#             override_params={'num_classes': 10}, 
#             in_channels=1)
# summary(model.to("cuda"), input_size=(1,128,128,128))

In [12]:
# model = resneXt.resnet101(
#                 num_classes=10,
#                 shortcut_type="B",
#                 cardinality=64,
#                 spatial_size=32,
#                 sample_duration=1)
# summary(model.to("cuda"), input_size=(1,32,32,32))

In [13]:
# model = resnet.ResNet(num_layers=CFG['num_layers'],
#                in_channels=CFG['in_channels'],
#                stride=CFG['stride'],
#                num_classes=10).to(device)

In [15]:
# model = ResNet(num_classes=10).to(device)
model.eval()
optimizer = torch.optim.AdamW(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                      T_max=CFG['EPOCHS'], 
                                                      eta_min=1e-4)

train(model, optimizer, train_loader, val_loader, scheduler, device)

  0%|          | 0/625 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : [1] Train Loss : [0.5088406831264496] Val Loss : [0.33130504167344] Val ACC : [0.8976]


  0%|          | 0/625 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : [2] Train Loss : [0.1901017424225807] Val Loss : [0.19719855463768862] Val ACC : [0.9358]


  0%|          | 0/625 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : [3] Train Loss : [0.1354466692507267] Val Loss : [0.09708018498377333] Val ACC : [0.9682]


  0%|          | 0/625 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : [4] Train Loss : [0.09176970025151968] Val Loss : [0.07865223127617678] Val ACC : [0.9754]


  0%|          | 0/625 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : [5] Train Loss : [0.06536214575134218] Val Loss : [0.060010814331943475] Val ACC : [0.9794]


In [16]:
test_acc = test(model, test_loader, device)
print(f"Test Accuracy : {round(test_acc, 4)}")

  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy : 0.9744


## submission

In [23]:
test_df = pd.read_csv('./data/sample_submission.csv')
test_points = h5py.File('./data/test.h5', 'r')

test_dataset = _CustomDataset(test_df['ID'].values, None, test_points, dim='2d', shape=224)
# test_dataset = CustomDataset(test_df['ID'].values, None, test_points)
test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False, num_workers=0)

checkpoint = torch.load('./model/5E-val0.9794-6ch_img-effi_b0.pth')
# model = CNN(CFG['model_name'], shape_2d=288, in_channel=1, dim_changer=True).to(device)
model = timm.create_model(model_name=CFG['model_name'], num_classes=10, pretrained=True, in_chans=6)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [24]:
def predict(model, test_loader, device, dim_changer=None):
    model.to(device)
    model.eval()
    model_preds = []
    with torch.no_grad():
        for data in tqdm(iter(test_loader)):
            data = data.float().to(device)
            if dim_changer :
                batch_pred = model(dim_changer(data))
            else :
                batch_pred = model(data)
            
            model_preds += batch_pred.argmax(1).detach().cpu().numpy().tolist()
    
    return model_preds

In [25]:
preds = predict(model, test_loader, device)

  0%|          | 0/625 [00:00<?, ?it/s]

In [26]:
test_df['label'] = preds

test_df.to_csv('./submission/5E-val0.9794-6ch_img-effi_b0.csv', index=False)