In [1]:
import os
os.environ['OMP_NUM_THREADS'] = '8'

In [2]:
from IPython.display import clear_output, display
from PIL import Image
from sklearn import metrics

import numpy as np

import time
import torch
import torchvision as tv

import gc

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    torch.set_num_threads(8)
print(device)

cpu


In [None]:
arc = np.load('box_sphere_00.npz')

In [None]:
ds_raw = torch.from_numpy(arc['ds']).float()[:, :, :, :1].transpose(1, 3)
del arc

In [None]:
ds_norm = ds_raw / 255.0 - 0.5

In [None]:
ind = np.arange(ds_norm.shape[0])
np.random.shuffle(ind)

In [None]:
train_part = 0.8
ind_thr = int(ind.shape[0]*train_part)
ind_train, ind_test = ind[:ind_thr], ind[ind_thr:]

In [None]:
label = torch.from_numpy(np.arange(ds_norm.shape[0]) % 2)

In [None]:
pt_train_ds, pt_train_lbl = ds_norm[ind_train], label[ind_train]
pt_test_ds, pt_test_lbl = ds_norm[ind_test], label[ind_test]

In [None]:
pt_train_ds, pt_train_lbl = pt_train_ds.to(device), pt_train_lbl.to(device)
pt_test_ds, pt_test_lbl = pt_test_ds.to(device), pt_test_lbl.to(device)

In [None]:
class Descriminate(torch.nn.Module):
    def __init__(self):
        super(Descriminate, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 9, 5)
        self.mp1 = torch.nn.MaxPool2d(2)
        self.conv2 = torch.nn.Conv2d(9, 18, 5)
        self.mp2 = torch.nn.MaxPool2d(2)
        self.conv3 = torch.nn.Conv2d(18, 36, 5)
        self.mp3 = torch.nn.MaxPool2d(2)
        self.lin1 = torch.nn.Linear(36 * 9 * 9, 100)
        self.lin2 = torch.nn.Linear(100, 20)
        self.lin3 = torch.nn.Linear(20, 2)

    def forward(self, x):
        x = torch.nn.functional.relu(self.conv1(x))
        x = self.mp1(x)
        x = torch.nn.functional.relu(self.conv2(x))
        x = self.mp2(x)
        x = torch.nn.functional.relu(self.conv3(x))
        x = self.mp3(x)
        x = x.view(-1, 36 * 9 * 9)
        x = torch.nn.functional.relu(self.lin1(x))
        x = torch.nn.functional.relu(self.lin2(x))
        x = torch.nn.functional.relu(self.lin3(x))
        return x

In [None]:
# m(pt_train_ds[:1]).shape

In [None]:
m = Descriminate().to(device)
crit = torch.nn.CrossEntropyLoss()

In [None]:
optim = torch.optim.SGD(m.parameters(), lr=1e-5, momentum=0.5)

In [None]:
pt_train_ds.shape, pt_train_lbl.shape

In [None]:
def mpp(msg, a):
    return msg + ' '.join('{:.04f}'.format(t) for t in a)

In [None]:
try:
    meter = 1
    while(True):
        minibatch_percent = 0.1
        random_sample_ind = torch.randperm(pt_train_ds.shape[0])[:int(pt_train_ds.shape[0]*minibatch_percent)]
        sampled_ds = pt_train_ds[random_sample_ind]
        sampled_lbl = pt_train_lbl[random_sample_ind]
        for _ in range(meter):
            optim.zero_grad()        
            estim = m(sampled_ds)
            loss = crit(estim, sampled_lbl)
            loss.backward()
            optim.step()

        if meter < 100:
            meter += 5

        model_cube_train_score = m(pt_train_ds).cpu().detach().numpy()[:, 0]
        model_cube_test_score = m(pt_test_ds).cpu().detach().numpy()[:, 0]
        model_sphere_train_score = m(pt_train_ds).cpu().detach().numpy()[:, 1]
        model_sphere_test_score = m(pt_test_ds).cpu().detach().numpy()[:, 1]
        to_show = (float(loss), 
                   metrics.roc_auc_score(pt_train_lbl, model_cube_train_score),
                   metrics.roc_auc_score(pt_test_lbl, model_cube_test_score),
                   metrics.roc_auc_score(pt_train_lbl, model_sphere_train_score),
                   metrics.roc_auc_score(pt_test_lbl, model_sphere_test_score),
                   mpp('cube_train_scores   :', model_cube_train_score[:15]),
                   mpp('sphere_train_scores :', model_sphere_train_score[:15]),
                   mpp('labels              :', pt_train_lbl[:15]),
                  )
        # clear_output()
        display(to_show)
except KeyboardInterrupt as ki:
    print("iterations stopped!")

In [None]:
with torch.no_grad():
    for t in range(20):
        pic = (pt_test_ds[t].transpose(0, 2).cpu().numpy() + 0.5)*255
        int_pic = pic.astype(np.uint8).repeat(3, axis=2)
        display(Image.fromarray(int_pic), m(pt_test_ds[t:t+1]).cpu().detach().numpy(), '', '')  # 