# Normalized Voxels: Align Planes, Adjust Contrast, and Crop

As shown in several notebooks, MRI plane type (Axial, Coronal, and Sagittal) is not consistent among patients or MRI scan types (FLAIR, T1w, T1wCE, T2w).
While augmentations might alleviate this inconsistency, it is better to train models using MRI voxels that are consistent in terms of plane type.
This notebook shows we can obtain normalized voxels by appropriately rotating MRI voxels.
I found that simply rotating MRI voxels is not enough because the order of planes is also inconsistent in some cases.
For example, even with the same Sagittal type, some of scans were in left-to-right order, while others were the other way around.
As Instance Number does not help, Image Position (Patient) is used in this notebook to reorder stacked images.
After normalized voxels with respect to planes, contrast is adjusted and then voxels are cropped.
Finally, the voxel is resized to arbitrary fixed size.

## Normalized Voxel Datasets

The normalized voxels created the above procedure were stored as a dataset. Please refer to the second half of this notebook.

# Because of the different scanning distances, i think 3dconv is not suitable.so use conv2d for voxel;
for mcr2 loss,label so little.
mcr2: https://github.com/Ma-Lab-Berkeley/ReduNet
but when submit commit error: Notebook Threw Exception, seven time;

In [None]:
from pathlib import Path
import numpy as np
import cv2
import pydicom
import matplotlib.pyplot as plt

DATASET = 'train'
scan_types = ['FLAIR','T1w','T1wCE','T2w']
data_root = Path("../input/rsna-miccai-brain-tumor-radiogenomic-classification")

In [None]:
def get_image_plane(data):
    cords = [round(j) for j in data.ImageOrientationPatient]
    return cords

def get_voxel(study_id, scan_type, split="train"):
    imgs = []
    dcm_dir = data_root.joinpath( split , study_id, scan_type)
    dcm_paths = sorted(dcm_dir.glob("*.dcm"), key=lambda x: int(x.stem.split("-")[-1]))
    positions = []

    for dcm_path in dcm_paths:
        img = pydicom.dcmread(str(dcm_path))
        imgs.append(img.pixel_array)
        positions.append(img.ImagePositionPatient)

    plane = get_image_plane(img)        
    voxel = np.stack(imgs)

    rotDir = []
    rotDir.append(positions[-1][0]-positions[0][0])
    rotDir.append(positions[-1][1]-positions[0][1])
    rotDir.append(positions[-1][2]-positions[0][2])

    rotDir = np.array(rotDir)
    rotDir = rotDir / np.max(np.absolute(rotDir))
    rotDir = np.around(rotDir)

    rotVec = []

    rotVec.append(np.arctan2(rotDir[0], rotDir[1]))
    rotVec.append(np.arctan2(rotDir[1], rotDir[2]))
    rotVec.append(np.arctan2(rotDir[2], rotDir[0]))

    rotVec = np.array(rotVec)
    rotVec = rotVec / np.max(np.absolute(rotVec))
    rotVec = np.around(rotVec)

    voxel = np.rot90(voxel, plane[1], (2, 0))
    voxel = np.rot90(voxel, plane[2], (0, 1))
    voxel = np.rot90(voxel, plane[3], (1, 2))
    voxel = np.rot90(voxel, plane[5], (0, 1))

    if plane[0] == 0:
        voxel = np.flip(voxel, 1)
    if plane[4] == 0:
        voxel = np.flip(voxel, 0)
    if rotDir[1] == 1:
        voxel = np.flip(voxel, 1)
    if rotDir[2] == -1:
        voxel = np.flip(voxel, 0)

    return voxel, plane

In [None]:
def normalize_contrast(voxel):
    if voxel.sum() == 0:
        return voxel
    voxel = voxel - np.min(voxel)
    voxel = voxel / np.max(voxel)
    voxel = (voxel * 255).astype(np.uint8)
    return voxel

In [None]:
def crop_voxel(voxel):
    if voxel.sum() == 0:
        return voxel
    keep = (voxel.mean(axis=(0, 1)) > 0)
    voxel = voxel[:, :, keep]
    keep = (voxel.mean(axis=(0, 2)) > 0)
    voxel = voxel[:, keep]
    keep = (voxel.mean(axis=(1, 2)) > 0)
    voxel = voxel[keep]
    return voxel

Sample planes along the longest axis and resize the sampled planes.
By sampling along the longest axis, the degradation due to sampling is minimized.
The best way is to resize twice (e.g. (x, y) axis then (y, z) axis) but it is computationally expensive.

In [None]:
def resize_orlay_voxel_forconv2d(voxel, sz=256, NUM_IMAGES=16):
#     my_log.info("voxel.shape  no  0  resize is {}:".format(voxel.shape))
    NUM_IMAGES = sz
    if np.argmin(voxel.shape) == 0:
        if voxel.shape[0] < NUM_IMAGES:
            output = np.zeros((NUM_IMAGES, sz, sz), dtype=np.uint8)
            for i,s  in enumerate(np.linspace(0, voxel.shape[0] - 1, voxel.shape[0])):
                b = int(abs(int(voxel.shape[0] / 2) - NUM_IMAGES / 2) + i)
                output[b] = cv2.resize(voxel[i], (sz, sz))
#             my_log.info("voxel.shape  do 11  resize is {}:".format(output.shape))
        else:
            output = np.zeros((voxel.shape[0], sz, sz), dtype=np.uint8)
            for i, s in enumerate(np.linspace(0, voxel.shape[0] - 1, voxel.shape[0])):
                output[i] = cv2.resize(voxel[i], (sz, sz))
#             my_log.info("voxel.shape  do 12  resize is {}:".format(output.shape))
            voxel = output
            output = np.zeros((sz, sz, sz), dtype=np.uint8)
            for i, s in enumerate(np.linspace(0, sz - 1, sz)):
                output[:, i] = cv2.resize(voxel[:, i], (sz, sz))
#             my_log.info("voxel.shape  do 13  resize is {}:".format(output.shape))
        output = torch.tensor(output).permute(0, 1, 2)
    elif np.argmin(voxel.shape) == 1:
        if voxel.shape[1] < NUM_IMAGES:
            output = np.zeros((sz, NUM_IMAGES, sz), dtype=np.uint8)
            for i,s  in enumerate(np.linspace(0, voxel.shape[1] - 1, voxel.shape[1])):
                b = int(abs(int(voxel.shape[1] / 2) - NUM_IMAGES / 2) + i)
                output[:, b] = cv2.resize(voxel[:, i], (sz, sz))
#             my_log.info("voxel.shape  do 21  resize is {}:".format(output.shape))
        else:
            output = np.zeros((sz, voxel.shape[1], sz), dtype=np.uint8)
            for i, s in enumerate(np.linspace(0, voxel.shape[1] - 1, voxel.shape[1])):
                output[:, i] = cv2.resize(voxel[:, i], (sz, sz))
#             my_log.info("voxel.shape  do 22  resize is {}:".format(output.shape))
            voxel = output
            output = np.zeros((sz, sz, sz), dtype=np.uint8)
            for i, s in enumerate(np.linspace(0, sz - 1, sz)):
                output[i] = cv2.resize(voxel[i], (sz, sz))
#             my_log.info("voxel.shape  do 23  resize is {}:".format(output.shape))
        output = torch.tensor(output).permute(1, 0, 2)
    elif np.argmin(voxel.shape) == 2:
        if voxel.shape[2] < NUM_IMAGES:
            output = np.zeros((sz, sz, NUM_IMAGES), dtype=np.uint8)
            for i,s  in enumerate(np.linspace(0, voxel.shape[2] - 1, voxel.shape[2])):
                b = int(abs(int(voxel.shape[2] / 2) - NUM_IMAGES / 2) + i)
                output[:, :, b] = cv2.resize(voxel[:, :, i], (sz, sz))
#             my_log.info("voxel.shape  do 31  resize is {}:".format(output.shape))
        else:
            output = np.zeros((sz, sz, voxel.shape[2]), dtype=np.uint8)
            for i, s in enumerate(np.linspace(0, voxel.shape[2] - 1, voxel.shape[2])):
                output[:, :, i] = cv2.resize(voxel[:, :, i], (sz, sz))
#             my_log.info("voxel.shape  do 32  resize is {}:".format(output.shape))
            voxel = output
            output = np.zeros((sz, sz, sz), dtype=np.uint8)
            for i, s in enumerate(np.linspace(0, sz - 1, sz)):
                output[:, i] = cv2.resize(voxel[:, i], (sz, sz))
#             my_log.info("voxel.shape  do 33  resize is {}:".format(output.shape))
        output = torch.tensor(output).permute(2, 0, 1)
    return output


In [None]:
def resize_orlay_voxel_forconv2d000(voxel, sz=256, NUM_IMAGES=16):
    output = np.zeros((sz, sz, sz), dtype=np.uint8)

    if np.argmax(voxel.shape) == 0:
        for i, s in enumerate(np.linspace(0, voxel.shape[0] - 1, sz)):
            output[i] = cv2.resize(voxel[int(s)], (sz, sz))
    elif np.argmax(voxel.shape) == 1:
        for i, s in enumerate(np.linspace(0, voxel.shape[1] - 1, sz)):
            output[:, i] = cv2.resize(voxel[:, int(s)], (sz, sz))
    elif np.argmax(voxel.shape) == 2:
        for i, s in enumerate(np.linspace(0, voxel.shape[2] - 1, sz)):
            output[:, :, i] = cv2.resize(voxel[:, :, int(s)], (sz, sz))

    return output

In [None]:
import glob
from torch.utils import data as torch_data
from torch.utils.data import Dataset
import os

#scan_types = ['FLAIR'] #,'T1w','T1wCE','T2w']

#DataRetriever7voxelconv
class DataRetriever(torch_data.Dataset):
    def __init__(self, paths, targets,  split='train', vsz = 96):
        self.paths = paths
        self.targets = targets
        self.vsz = vsz
        self.split = split
        self.NUM_IMAGES = vsz
        if not os.path.exists(os.path.join('./', self.split)):
            os.mkdir(os.path.join('./', self.split))
        os.system("ls -la ./train/")  
        #os.mkdir(os.path.join('./', test))

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        _id = self.paths[index]
        # _id = 388
        #print(_id)
        case = str(_id).zfill(5)
        save_path = os.path.join('../input/newvox96', self.split, case + ".size." + str( self.vsz) + ".voxel4type.npy")
        #save_path = os.path.join('./', self.split, case + ".size." + str( self.vsz) + ".voxel4type.npy")
        if Path(save_path).exists():
            voxels =  np.load(save_path)
        else:
            for i, scan_type in enumerate(scan_types):
                voxel, plane = get_voxel(case, scan_type, self.split)
                voxel = normalize_contrast(voxel)
#                 print("voxel.shape  no  crop is :", voxel.shape)
                voxel = crop_voxel(voxel)
#                 print("voxel.shape  do  crop is :", voxel.shape)
                voxel = resize_orlay_voxel_forconv2d(voxel, self.vsz, self.NUM_IMAGES)
                if i > 0:
                    voxel = np.concatenate((voxels, voxel), axis = 0)
#                     print(voxel.shape)
                voxels = voxel
            np.save(save_path, voxels)
        y = torch.tensor(self.targets[index], dtype=torch.int8)
        #time.sleep(0.1)
        return torch.tensor(voxels).float(), y
    
class DataRetrievertest(torch_data.Dataset):
    def __init__(self, paths, targets,  split='test', vsz = 96):
        self.paths = paths
        self.targets = targets
        self.vsz = vsz
        self.split = split
        self.NUM_IMAGES = vsz
        if not os.path.exists(os.path.join('./', self.split)):
            os.mkdir(os.path.join('./', self.split))
        #os.system("ls -la ./test/")  checkpoints
        


    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        _id = self.paths[index]
        # _id = 388
        #print(_id)
        case = str(_id).zfill(5)
        save_path = os.path.join('./', self.split, case + ".size." + str( self.vsz) + ".voxel4type.npy")
        #save_path = os.path.join('../input/voxel96', self.split, case + ".size." + str( self.vsz) + ".voxel4type.npy")
        if Path(save_path).exists():
            voxels =  np.load(save_path)
        else:
            for i, scan_type in enumerate(scan_types):
                voxel, plane = get_voxel(case, scan_type, self.split)
                voxel = normalize_contrast(voxel)
#                 print("voxel.shape  no  crop is :", voxel.shape)
                voxel = crop_voxel(voxel)
#                 print("voxel.shape  do  crop is :", voxel.shape)
                voxel = resize_orlay_voxel_forconv2d(voxel, self.vsz, self.NUM_IMAGES)
                if i > 0:
                    voxel = np.concatenate((voxels, voxel), axis = 0)
#                     print(voxel.shape)
                voxels = voxel
            np.save(save_path, voxels)
#         y = torch.tensor(self.targets[index], dtype=torch.int8)
        #time.sleep(0.1)
        return torch.tensor(voxels).float(), _id

# Normalized Voxel Datasets
The normalized voxels created the above procedure were stored as a dataset:

- [64x64x64 voxel](https://www.kaggle.com/ren4yu/rsna-miccai-voxel-64-dataset)
- [128x128x128 voxel](https://www.kaggle.com/ren4yu/rsna-miccai-voxel-128-dataset)
- [256x256x256 voxel](https://www.kaggle.com/ren4yu/rsna-miccai-voxel-256-dataset)

The directory structure is as follows:

```
voxel
├── train
│   ├── 00000
│   │   ├── FLAIR.npy
│   │   ├── T1w.npy
│   │   ├── T1wCE.npy
│   │   └── T2w.npy
│   ├── 00002
│   │   ├── FLAIR.npy
...
├── test
│   ├── 00001
│   │   ├── FLAIR.npy
│   │   ├── T1w.npy
│   │   ├── T1wCE.npy
│   │   └── T2w.npy
│   ├── 00013
│   │   ├── FLAIR.npy
```

Some voxels do not exist because all images belonging to these scans are completely black:

- ('train', '00109', 'FLAIR.npy')
- ('train', '00123', 'T1w.npy')
- ('train', '00123', 'T2w.npy')
- ('train', '00709', 'FLAIR.npy')

Let's get one voxel and visualize it.

In [None]:
! ls ../input

In [None]:
def load_voxel(study_id, scan_type="FLAIR", split="train", sz=256):
    assert sz in (64, 128, 256)
    data_root = Path(f"../input/rsna-miccai-voxel-{sz}-dataset")
    npy_path = Path(data_root).joinpath("voxel", split, study_id, f"{scan_type}.npy")
    voxel = np.load(str(npy_path))
    return voxel

In [None]:
import os
import logging
import json
import numpy as np
import torch




def save_params(model_dir, params):
    """Save params to a .json file. Params is a dictionary of parameters."""
    path = os.path.join(model_dir, 'params.json')
    with open(path, 'w') as f:
        json.dump(params, f, indent=2, sort_keys=True)

def update_params(model_dir, pretrain_dir):
    """Updates architecture and feature dimension from pretrain directory 
    to new directoy. """
    params = load_params(model_dir)
    old_params = load_params(pretrain_dir)
    params['arch'] = old_params["arch"]
    params['fd'] = old_params['fd']
    save_params(model_dir, params)

def load_params(model_dir):
    """Load params.json file in model directory and return dictionary."""
    _path = os.path.join(model_dir, "params.json")
    with open(_path, 'r') as f:
        _dict = json.load(f)
    return _dict

def save_state(model_dir, *entries, filename='losses.csv'):
    """Save entries to csv. Entries is list of numbers. """
    csv_path = os.path.join(model_dir, filename)
    assert os.path.exists(csv_path), 'CSV file is missing in project directory.'
    with open(csv_path, 'a') as f:
        f.write('\n'+','.join(map(str, entries)))

def save_ckpt(model_dir, net, epoch):
    """Save PyTorch checkpoint to ./checkpoints/ directory in model directory. """
    torch.save(net.state_dict(), os.path.join(model_dir, 'checkpoints', 
        'model-epoch{}.pt'.format(epoch)))

def save_labels(model_dir, labels, epoch):
    """Save labels of a certain epoch to directory. """
    path = os.path.join(model_dir, 'plabels', f'epoch{epoch}.npy')
    np.save(path, labels)

def compute_accuracy0(y_pred, y_true):
    """Compute accuracy by counting correct classification. """
    assert y_pred.shape == y_true.shape
    return 1 - np.count_nonzero(y_pred - y_true) / y_true.size

def compute_accuracy(y_pred, y_true):
    """Compute accuracy by counting correct classification. """
    assert y_pred.shape == y_true.shape
    if type(y_pred) == torch.Tensor:
        n_wrong = torch.count_nonzero(y_pred - y_true).item()
    elif type(y_pred) == np.ndarray:
        n_wrong = np.count_nonzero(y_pred - y_true)
    else:
        raise TypeError("Not Tensor nor Array type.")
    n_samples = len(y_pred)
    return 1 - n_wrong / n_samples


def clustering_accuracy(labels_true, labels_pred):
    """Compute clustering accuracy."""
    from sklearn.metrics.cluster import supervised
    from scipy.optimize import linear_sum_assignment
    labels_true, labels_pred = supervised.check_clusterings(labels_true, labels_pred)
    value = supervised.contingency_matrix(labels_true, labels_pred)
    [r, c] = linear_sum_assignment(-value)
    return value[r, c].sum() / len(labels_true)

In [None]:
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC,SVC
import time
import os
from sklearn.calibration import CalibratedClassifierCV as CalibratedClassifierCV
#from simple_log import get_log, add_file

# my_log = get_log('acc_log')
# timelog= time.time()
# print("timelog is : ",timelog)
# #add_file('acc_log', os.path.join(time.time() +  'acc_log.log'))
# add_file('acc_log', os.path.join(str(time.time()) +  'acc_log.log'))
# # my_log.info("acc is {},{}:".format(,))

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB


def bys( train_features, train_labels, test_features, test_labels,maxtest,subtest,test='train'):
    gnb = GaussianNB()
    gnb.fit(train_features, train_labels)
    acc_train = gnb.score(train_features, train_labels)
    acc_test = gnb.score(test_features, test_labels)
    pred = []
    testp = test
    print(testp)
    if testp == 'test':
        pred = gnb.predict_proba(subtest)
    else:
        if acc_test > maxtest :
            maxtest = acc_test
        if acc_test > 0.51:
            pred = gnb.predict_proba(subtest)
    print("gnb acc is {},{},{}:".format(maxtest,acc_train,acc_test))
    return acc_train, acc_test ,maxtest, pred





def svm( train_features, train_labels, test_features, test_labels,maxtest,subtest,test='train'):
    svm = SVC(verbose=0, random_state=10, max_iter=2100, probability=True)
    svm.fit(train_features, train_labels)
    acc_train = svm.score(train_features, train_labels)
    acc_test = svm.score(test_features, test_labels)
    pred = []
    testp = test
    print(testp)
    if testp == 'test':
        pred = svm.predict_proba(subtest)
    else:
        if acc_test > maxtest :
            maxtest = acc_test
        if acc_test > 0.51:
            pred = svm.predict_proba(subtest)
    print("svm acc is {},{},{}:".format(maxtest,acc_train,acc_test))
    return acc_train, acc_test ,maxtest, pred

def sgd( train_features, train_labels, test_features, test_labels,maxtest,subtest,test='train'):
    clf = CalibratedClassifierCV(base_estimator=SGDClassifier(max_iter=5000, tol=1e-3, early_stopping=True, validation_fraction=0.8, n_iter_no_change=120 ))
    clf.fit(train_features, train_labels)
    acc_train = clf.score(train_features, train_labels)
    acc_test = clf.score(test_features, test_labels)
    pred = []
    testp = test
    print(testp)
    if testp == 'test':
        pred = clf.predict_proba(subtest)
    else:
        if acc_test > maxtest :
            maxtest = acc_test
        if acc_test > 0.51:
            pred = clf.predict_proba(subtest)
    print("sgd acc is {},{},{}:".format(maxtest,acc_train,acc_test))
    return acc_train, acc_test ,maxtest, pred

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out



class ResNetControl(nn.Module):
    def __init__(self, block, num_blocks, feature_dim=512):
        super(ResNetControl, self).__init__()
        self.in_planes = 256
        self.feature_dim = feature_dim
        self.conv1 = nn.Conv2d(384, 256, kernel_size=2, stride=1,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(256)
        self.layer1 = self._make_layer(block, 256, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 256, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, feature_dim, num_blocks[3], stride=2)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        return F.normalize(out)


def ResNet18(feature_dim=512):
    return ResNet(BasicBlock, [2, 2, 2, 2], feature_dim)

def ResNet18Control(feature_dim=512):
    return ResNetControl(BasicBlock, [2, 2, 2, 2], feature_dim)


In [None]:
import os
from tqdm import tqdm

import cv2
import numpy as np
import torch
import torch.nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# from cluster import ElasticNetSubspaceClustering, clustering_accuracy
# import utils





def load_checkpoint(model_dir, epoch=None, eval_=False):
    """Load checkpoint from model directory. Checkpoints should be stored in 
    `model_dir/checkpoints/model-epochX.ckpt`, where `X` is the epoch number.
    
    Parameters:
        model_dir (str): path to model directory
        epoch (int): epoch number; set to None for last available epoch
        eval_ (bool): PyTorch evaluation mode. set to True for testing
        
    Returns:
        net (torch.nn.Module): PyTorch checkpoint at `epoch`
        epoch (int): epoch number
    
    """
    if epoch is None: # get last epoch
        ckpt_dir = os.path.join(model_dir, 'checkpoints')
        epochs = [int(e[11:-3]) for e in os.listdir(ckpt_dir) if e[-3:] == ".pt"]
        epoch = np.sort(epochs)[-1]
    ckpt_path = os.path.join(model_dir, 'checkpoints', 'model-epoch{}.pt'.format(epoch))
    params = load_params(model_dir)
    print('Loading checkpoint: {}'.format(ckpt_path))
    state_dict = torch.load(ckpt_path)
    net = load_architectures(params['arch'], params['fd'])
    net.load_state_dict(state_dict)
    del state_dict
    if eval_:
        net.eval()
    return net, epoch

    
def get_features(net, trainloader, verbose=True):
    '''Extract all features out into one single batch. 
    
    Parameters:
        net (torch.nn.Module): get features using this model
        trainloader (torchvision.dataloader): dataloader for loading data
        verbose (bool): shows loading staus bar

    Returns:
        features (torch.tensor): with dimension (num_samples, feature_dimension)
        labels (torch.tensor): with dimension (num_samples, )
    '''
    features = []
    labels = []
    if verbose:
        train_bar = tqdm(trainloader, desc="extracting all features from dataset")
    else:
        train_bar = trainloader
    for step, (batch_imgs, batch_lbls) in enumerate(train_bar):
        batch_features = net(batch_imgs.cuda())
        features.append(batch_features.cpu().detach())
        labels.append(batch_lbls)
    return torch.cat(features), torch.cat(labels)
    

def corrupt_labels(mode="default"):
    """Returns higher corder function"""
    if mode == "default":
        from corrupt import default_corrupt
        return default_corrupt
    elif mode == "asymmetric_noise":
        from corrupt import asymmetric_noise
        return asymmetric_noise
    elif mode == "noisify_pairflip":
        from corrupt import noisify_pairflip
        return noisify_pairflip
    elif mode == "noisify_multiclass_symmetric":
        from corrupt import noisify_multiclass_symmetric
        return noisify_multiclass_symmetric



def label_to_membership(targets, num_classes=None):
    """Generate a true membership matrix, and assign value to current Pi.

    Parameters:
        targets (np.ndarray): matrix with one hot labels

    Return:
        Pi: membership matirx, shape (num_classes, num_samples, num_samples)

    """
    targets = one_hot(targets, num_classes)
    num_samples, num_classes = targets.shape
    Pi = np.zeros(shape=(num_classes, num_samples, num_samples))
    for j in range(len(targets)):
        k = np.argmax(targets[j])
        Pi[k, j, j] = 1.
    return Pi


def membership_to_label(membership):
    """Turn a membership matrix into a list of labels."""
    _, num_classes, num_samples, _ = membership.shape
    labels = np.zeros(num_samples)
    for i in range(num_samples):
        labels[i] = np.argmax(membership[:, i, i])
    return labels

def one_hot(labels_int, n_classes):
    """Turn labels into one hot vector of K classes. """
    labels_onehot = torch.zeros(size=(len(labels_int), n_classes)).float()
    for i, y in enumerate(labels_int):
        labels_onehot[i, y] = 1.
    return labels_onehot


## Additional Augmentations
class GaussianBlur():
    # Implements Gaussian blur as described in the SimCLR paper
    def __init__(self, kernel_size, min=0.1, max=2.0):
        self.min = min
        self.max = max
        # kernel size is set to be 10% of the image height/width
        self.kernel_size = kernel_size

    def __call__(self, sample):
        sample = np.array(sample)

        # blur the image with a 50% chance
        prob = np.random.random_sample()

        if prob < 0.5:
            sigma = (self.max - self.min) * np.random.random_sample() + self.min
            sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

        return sample

def sparse2coarse(targets):
    """CIFAR100 Coarse Labels. """
    coarse_targets = [ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,  3, 14,  9, 18,  7, 11,  3,
                       9,  7, 11,  6, 11,  5, 10,  7,  6, 13, 15,  3, 15,  0, 11,  1, 10,
                      12, 14, 16,  9, 11,  5,  5, 19,  8,  8, 15, 13, 14, 17, 18, 10, 16,
                       4, 17,  4,  2,  0, 17,  4, 18, 17, 10,  3,  2, 12, 12, 16, 12,  1,
                       9, 19,  2, 10,  0,  1, 16, 12,  9, 13, 15, 13, 16, 19,  2,  4,  6,
                      19,  5,  5,  8, 19, 18,  1,  2, 15,  6,  0, 17,  8, 14, 13]
    return np.array(coarse_targets)[targets]


In [None]:
import numpy as np
import torch
from itertools import combinations


class MaximalCodingRateReduction(torch.nn.Module):
    def __init__(self, gam1=1.0, gam2=1.0, eps=0.01):
        super(MaximalCodingRateReduction, self).__init__()
        self.gam1 = gam1
        self.gam2 = gam2
        self.eps = eps

    def compute_discrimn_loss_empirical(self, W):
        """Empirical Discriminative Loss."""
        p, m = W.shape
        I = torch.eye(p).cuda()
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + self.gam1 * scalar * W.matmul(W.T))
        return logdet / 2.

    def compute_compress_loss_empirical(self, W, Pi):
        """Empirical Compressive Loss."""
        p, m = W.shape
        k, _, _ = Pi.shape
        I = torch.eye(p).cuda()
        compress_loss = 0.
        for j in range(k):
            trPi = torch.trace(Pi[j]) + 1e-8
            scalar = p / (trPi * self.eps)
            log_det = torch.logdet(I + scalar * W.matmul(Pi[j]).matmul(W.T))
            compress_loss += log_det * trPi / m
        return compress_loss / 2.

    def compute_discrimn_loss_theoretical(self, W):
        """Theoretical Discriminative Loss."""
        p, m = W.shape
        I = torch.eye(p).cuda()
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.

    def compute_compress_loss_theoretical(self, W, Pi):
        """Theoretical Compressive Loss."""
        p, m = W.shape
        k, _, _ = Pi.shape
        I = torch.eye(p).cuda()
        compress_loss = 0.
        for j in range(k):
            trPi = torch.trace(Pi[j]) + 1e-8
            scalar = p / (trPi * self.eps)
            log_det = torch.logdet(I + scalar * W.matmul(Pi[j]).matmul(W.T))
            compress_loss += trPi / (2 * m) * log_det
        return compress_loss

    def forward(self, X, Y, num_classes=None):
        if num_classes is None:
            num_classes = Y.max() + 1
        W = X.T
        Pi = label_to_membership(Y.numpy(), num_classes)
        Pi = torch.tensor(Pi, dtype=torch.float32).cuda()

        discrimn_loss_empi = self.compute_discrimn_loss_empirical(W)
        compress_loss_empi = self.compute_compress_loss_empirical(W, Pi)
        discrimn_loss_theo = self.compute_discrimn_loss_theoretical(W)
        compress_loss_theo = self.compute_compress_loss_theoretical(W, Pi)
 
        total_loss_empi = self.gam2 * -discrimn_loss_empi + compress_loss_empi
        return (total_loss_empi,
                [discrimn_loss_empi.item(), compress_loss_empi.item()],
                [discrimn_loss_theo.item(), compress_loss_theo.item()])

In [None]:
import torch.optim.lr_scheduler as lr_scheduler
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler



net = ResNet18Control(512)


criterion = MaximalCodingRateReduction(gam1=1, gam2=1, eps=0.001)
#optimizer = optim.SGD(net.parameters(), lr=0.00001, momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(net.parameters(),lr=0.000001)
scheduler = lr_scheduler.MultiStepLR(optimizer, [200, 400, 600], gamma=0.1)
# utils.save_params(model_dir, vars(args))

from sklearn import model_selection as sk_model_selection
from torch.utils import data as torch_data
import torch
import pandas as pd

# import DataRetriever7voxelconv as DataRetriever
df = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")

to_exclude = [109, 123, 123, 709] # [514, 11, 658, 537, 544, 688, 692, 694, 703, 709, 710, 711, 713, 719, 720, 722, 631, 109, 750, 621, 501,
             #         118, 503, 630, 121, 123, 125, 126, 127, 551, 605, 539, 505]
df = df[~df['BraTS21ID'].isin(to_exclude)]

df_train, df_valid = sk_model_selection.train_test_split(df,  test_size=0.15 ,     random_state=42,     stratify=df["MGMT_value"],  )
# df_train = df_train.head(10)
# df_valid = df_valid.head(10)

train_data_retriever = DataRetriever(    df_train["BraTS21ID"].values,       df_train["MGMT_value"].values,  )
valid_data_retriever = DataRetriever(     df_valid["BraTS21ID"].values,     df_valid["MGMT_value"].values, )
train_loader = torch_data.DataLoader(
    train_data_retriever,
    batch_size=16, #args.samples,  #100
    shuffle=True,
    num_workers=2,
)
train_loader2 = torch_data.DataLoader(
    train_data_retriever,
    batch_size=16, #200, #args.samples,  #100
    shuffle=False,
    num_workers=1,
)
valid_loader = torch_data.DataLoader(
    valid_data_retriever,
    batch_size=8,
    shuffle=False,
    num_workers=1,
)

submission = pd.read_csv(f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv")
sub_data_retriever = DataRetrievertest(    submission["BraTS21ID"].values,       submission["MGMT_value"].values, split='test',  )
sub_loader = torch_data.DataLoader(
    sub_data_retriever,
    batch_size=8, #args.bs, #200, #args.samples,  #100
    shuffle=False,
    num_workers=2,
)

df_train.head()

In [None]:
os.system('ls -la ./train/')
torch.cuda.empty_cache() 
# os.system('killall -9 python')
# os.system('kill -9 32512')
 

import os
for dirname, _, filenames in os.walk('./checkpoints/'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:

# X_train, y_train = next(iter(train_loader))
# print("X_train.shape, y_train.shape is :", X_train.shape, y_train.shape)
# X_train, y_train = X_train.to(device), y_train.to(device)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
      
if os.path.exists(os.path.join('../', 'input/checkpt')):
    maxtest=0
    maxtestsave=0
    ## Training
    allsubpre=[]
    
    ckpt_dir = os.path.join('../', 'input/checkpt')
    epochs = [int(e[11:-3]) for e in os.listdir(ckpt_dir) if e[-3:] == ".pt"]
    epoch = np.sort(epochs)[-1]
    ckpt_path = os.path.join('../', 'input/checkpt', 'model-epoch{}.pt'.format(epoch))
    print('Loading checkpoint: {}'.format(ckpt_path))
    state_dict = torch.load(ckpt_path)
    net = ResNet18Control(512)
    net.to(device)
    net.load_state_dict(state_dict)
    del state_dict
    net.eval()
    
    
    train_features, train_labels = get_features(net, train_loader2)
    test_features, test_labels = get_features(net, valid_loader)
    sub_features, _sub_labels = get_features(net, sub_loader)
    _,_,maxtest,pred = svm( train_features, train_labels, test_features, test_labels,maxtest,sub_features,'test')
    if len(pred) >0  :
        print()
        allsubpre.append(pred)
    _,_,maxtest,pred = sgd( train_features, train_labels, test_features, test_labels,maxtest,sub_features,'test')
    if len(pred) >0  :
        print()
        allsubpre.append(pred)
    _,_,maxtest,pred = bys( train_features, train_labels, test_features, test_labels,maxtest,sub_features,'test')
    if len(pred) >0  :
        print()
        allsubpre.append(pred)
        #print(allsubpre)
    print(np.mean(allsubpre, axis=0))

    print(submission.head())
    submission = pd.DataFrame({"BraTS21ID": _sub_labels, "MGMT_value": np.mean(allsubpre, axis=0)[:,0]})
    print(submission.head())
    submission.to_csv("submission.csv", index=False)
    
    
else:
    if not os.path.exists(os.path.join('./', 'checkpoints')):
        os.mkdir(os.path.join('./', 'checkpoints'))
    #net = ResNet18Control(512)
    net.to(device)
    maxtest=0
    maxtestsave=0
    ## Training
    allsubpre=[]
    for epoch in range(25):
        for step, (batch_imgs, batch_lbls) in enumerate(train_loader):
            #print(batch_imgs.shape)
            # features = net(batch_imgs)
            features = net(batch_imgs.cuda())
            loss, loss_empi, loss_theo = criterion(features, batch_lbls, num_classes=2) #trainset.num_classes)
            if ( 'nan' in str(loss_empi[0]) or 'nan' in str(loss_empi[1]) )  or  ( 'nan' in str(loss_theo[0]) or 'nan' in str(loss_theo[1]) ) :
                print(epoch, step, loss.item(), *loss_empi, *loss_theo)
                print( "because nan ,so  exit")
                exit()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step % 6 == 0:
                print(epoch, step, loss.item(), *loss_empi, *loss_theo)
        scheduler.step()

        #print('train_features, train_labels start ')
        if epoch % 1 == 0:

            train_features, train_labels = get_features(net, train_loader2)
            #print('train to cpu end ,start test feature  start ')
            test_features, test_labels = get_features(net, valid_loader)
    #        knn(args, train_features, train_labels, test_features, test_labels)
            sub_features, _sub_labels = get_features(net, sub_loader)
            _,_,maxtest,pred = svm( train_features, train_labels, test_features, test_labels,maxtest,sub_features,'train')
            if len(pred) >0  :
                print()
                allsubpre.append(pred)
                #utils.save_ckpt(model_dir, net, epoch)
                #print(allsubpre)
            _,_,maxtest,pred = sgd( train_features, train_labels, test_features, test_labels,maxtest,sub_features,'train')
            if len(pred) >0  :
                print()
                allsubpre.append(pred)
                #utils.save_ckpt(model_dir, net, epoch)
                #print(allsubpre)
            _,_,maxtest,pred = bys( train_features, train_labels, test_features, test_labels,maxtest,sub_features,'train')
            if len(pred) >0  :
                print()
                allsubpre.append(pred)
                #print(allsubpre)
            if maxtest > maxtestsave:  # maxtestsave:
                maxtestsave = maxtest
                print('save ck ....')
                save_ckpt('./', net, epoch)
    print(np.mean(allsubpre, axis=0))

    print(submission.head())
    submission = pd.DataFrame({"BraTS21ID": _sub_labels, "MGMT_value": np.mean(allsubpre, axis=0)[:,0]})
    print(submission.head())
    submission.to_csv("submission.csv", index=False)


