# Main Code

In [31]:
import numpy as np
from sklearn.model_selection import train_test_split
from heapq import nlargest
from sklearn.metrics import accuracy_score

from tqdm import tqdm

import time

import torch
import torch.nn as nn
import torch.nn.functional as F
# !pip install torch-optimizer
import torch_optimizer as optim   
from torch.utils.data import Dataset, DataLoader
from torch.nn import MSELoss

In [28]:
import import_ipynb
import MLP
from MLP import MLP_Network

## Get Data

In [3]:
print('loading...', flush = True)
X_train = np.load('./data/X_train.npy', allow_pickle= True)  # X_train and X_test is pickled
Y_train = np.load('./data/Y_train.npy')
X_test = np.load('./data/X_test.npy', allow_pickle= True)
Y_test = np.load('./data/Y_test.npy')
print('done')

loading...
done


## Preprocessing - converting conditions into 'possibility maps'

In [4]:
X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

((60000, 2, 8), (60000, 8, 8), (10000, 2, 8), (10000, 8, 8))

In [5]:
def searchcombinationsUtil(k, n):
    """
    k: number of elements (>= 1)
    n: total sum of elements
    return all possible combinations of k numbers that add up to n, regarding its order
    """
    # Recursive function
    
    if k == 1:
        return [[n]]
    else:
        output = []
        for i in range(0, n+1):
            output += [[i]+items for items in searchcombinationsUtil(k-1, n-i)]
        return output        

In [6]:
def pixel_val_calculator(constraint, N):
    """
    constraint: the condition(constraint) of a single row/column in a nonogram
    N: the length of the width/height of the nonogram
    returns a vector with length N, and each value of the vector is the possibility of the corresponding pixel to be colored
    """
    total_colored = np.sum(constraint)

    if(len(constraint) == 0):
        return [0 for _ in range(N)]
    else:
        combinations = searchcombinationsUtil(k=int(len(constraint)+1), n= int(N-total_colored-len(constraint)+1))
        output = []

        for each_combination in combinations:
            pixel_val = []
            for idx, elements in enumerate(each_combination):
                if (idx == 0) or (idx == len(constraint)):
                    pixel_val += [0 for _ in range(elements)]
                else:
                    pixel_val += [0 for _ in range(elements+1)]
                if idx != (len(constraint)):
                    pixel_val += [1 for _ in range(constraint[idx])]
            output.append(pixel_val)
        
        output = np.array(output, dtype = np.float64)
        output = np.sum(output, axis = 0)
        output/= len(combinations)

        return output

In [7]:
pixel_val_calculator([7, 1], 10)

array([0.66666667, 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 0.33333333, 0.33333333, 0.66666667])

In [44]:
def possibility_map_generator(X):
    """
    X: (batch_size, 2, N) or (batch_size, 2, N, t) only if all constraints are composed of same number of blocks (=t)
    (2, N): each nonogram puzzle
    2 stands for each condition (row condition, column condition)
    N stands for the number of pixels (Number of total constraints)
    returns possibility_map: (batch_size, 2, N, N)
    """
    assert X.ndim in [3, 4]
    
    if X.ndim == 3:
        N = X.shape[-1]
    if X.ndim == 4:
        N = X.shape[-2]

    possibility_map = []
    for puzzles in tqdm(X):
        row_condition = puzzles[0]
        row_map = []
        for constraints in row_condition:
            row_map.append(pixel_val_calculator(constraint=constraints, N=N))
        row_map = np.array(row_map)

        column_condition = puzzles[1]
        column_map = []
        for constraints in column_condition:
            column_map.append(pixel_val_calculator(constraint=constraints, N=N))
        column_map = np.array(column_map)
        column_map = column_map.T

        possibility_map.append(np.array([row_map, column_map]))
        
    possibility_map = np.asarray(possibility_map)
    return possibility_map

In [9]:
a = np.array([[[2], [1]], [[1], [2]]])
print(a.shape)
possibility_map_generator(np.expand_dims(a,0))

100%|██████████| 1/1 [00:00<00:00, 964.65it/s]

(2, 2, 1)





array([[[[1. , 1. ],
         [0.5, 0.5]],

        [[0.5, 1. ],
         [0.5, 1. ]]]])

## Neural Network

In [55]:
class NN_Dataset(Dataset):

    def __init__(self, X, X_mapped=None, y=None):

        assert X.ndim in [3, 4]

        if X.ndim == 3:
            N = X.shape[-1]
            self.X = np.array([[np.array([k + [0]*(N - len(k))for k in j]) for j in i] for i in X])
        if X.ndim == 4:
            N = X.shape[-2]


        # pre-calculated possibility map if possible
        if X_mapped is not None:
            self.X_mapped = X_mapped
        else:
            self.X_mapped = possibility_map_generator(X)

        self.y = y
        
    def __getitem__(self, idx):
        inputs = np.array(self.X[idx])
        mapped_inputs = np.array(self.X_mapped[idx], dtype = np.float64)
        

        if self.y is not None:    # Train
            labels = self.y[idx]
            labels = np.array(labels, dtype=np.float64)
            return inputs, mapped_inputs, labels

        else:                     # Test
            return inputs, mapped_inputs
    
    def __len__(self):
        return len(self.X)


In [56]:
def get_loader(batch_size, shuffle, num_workers, X, X_mapped=None, y=None):
    dataset = NN_Dataset(X, X_mapped, y)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                            num_workers=num_workers)

    print(f'length : {len(dataset)}')
    return data_loader

## Training

In [57]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [90]:
def train(epochs, model, optimizer, train_loader, valid_loader, device):
    start = time.time()

    best_acc = 0
    for epoch in range(1, epochs+1):

        # Train
        model.train()    
        train_accuracies = []
        train_losses = []
        for i, (inputs, mapped_inputs, labels) in enumerate(tqdm(train_loader)):
            
            optimizer.zero_grad()
            mapped_inputs = torch.tensor(mapped_inputs, device=device, dtype=torch.float32)
            labels = torch.tensor(labels, device=device, dtype=torch.float32)
            labels = torch.flatten(labels, start_dim = 1)
            
            preds = model(mapped_inputs)

            loss = MSELoss()
            l = loss(preds, labels)
            
            l.backward()
            optimizer.step()

            preds = preds.detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()

            for idx, predictions in enumerate(preds):
                total_colored = np.sum(np.array(inputs[idx][0]))
                threshold = np.amin(nlargest(total_colored, predictions))
                predictions = [1 if a >= threshold else 0 for a in predictions]
                acc = accuracy_score(predictions, labels[idx])
                train_accuracies.append(acc)
            
            
            train_losses.append(l.item())
        

        # Validation
        model.eval()
        val_accuracies = []
        val_losses = []
        with torch.no_grad():
            for inputs, mapped_inputs, labels in valid_loader:
                
                mapped_inputs = torch.tensor(mapped_inputs, device=device, dtype=torch.float32)
                labels = torch.tensor(labels, device=device, dtype=torch.float32)
                labels = torch.flatten(labels, start_dim = 1)

                preds = model(mapped_inputs)
                
                val_losses.append(loss.item())

                for idx, predictions in enumerate(preds):
                    total_colored = np.sum(np.array(inputs[idx][0]))
                    threshold = np.amin(nlargest(total_colored, predictions))
                    predictions = [1 if a >= threshold else 0 for a in predictions]
                    acc = accuracy_score(predictions, labels[idx])
                    val_accuracies.append(acc)
            
            val_losses.append(loss.item())

            
        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        train_acc = np.mean(train_accuracies)
        val_acc = np.mean(val_accuracies)

        if best_acc < valid_acc:
            best_acc = valid_acc
            best_epoch = epoch
            best_auc = valid_auc
            print('saving model...', flush = True)
            torch.save(model.state_dict(), './weights/MLP.pth')
            print('done')

        print(f'Epoch:{epoch}  Train Loss:{train_loss:.3f} | Valid Loss:{val_loss:.3f}')
        print(f'Train  Acc:{train_acc:.3f}')
        print(f'Valid  Acc:{valid_acc:.3f}')

    end = time.time()
    print(f'\nEpoch Process Time: {(end-start)/60:.2f}Minute')
    print(f'Best Epoch:{best_epoch}, Best Acc:{best_acc:.3f}')

In [76]:
X_train_mapped = possibility_map_generator(X_train)
X_test_mapped = possibility_map_generator(X_test)

X_train, X_val, X_train_mapped, X_val_mapped, Y_train, Y_val = train_test_split(X_train, X_train_mapped, Y_train, test_size = 0.2, random_state = 42) 

100%|██████████| 48000/48000 [03:09<00:00, 253.04it/s]
100%|██████████| 10000/10000 [00:47<00:00, 211.23it/s]


In [91]:
print(X_train.shape, X_val.shape, X_train_mapped.shape, X_val_mapped.shape, Y_train.shape, Y_val.shape)

(38400, 2, 8) (9600, 2, 8) (38400, 2, 8, 8) (9600, 2, 8, 8) (38400, 8, 8) (9600, 8, 8)


In [92]:
train_loader = get_loader(batch_size=32, shuffle=True, num_workers=0, X = X_train, X_mapped=X_train_mapped, y = Y_train) 
valid_loader = get_loader(batch_size=32, shuffle=False, num_workers=0, X = X_val, X_mapped = X_val_mapped, y = Y_val) 
epochs = 100
lr = 0.001

num_pixels = X_train.shape[-2] if X_train.ndim is 4 else X_train.shape[-1]

model = MLP_Network(num_pixels = num_pixels).to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)

length : 38400
length : 9600


In [93]:
best_epoch, best_auprc, best_auc = train(epochs = epochs, model = model, optimizer = optimizer, train_loader = train_loader, \
                                         valid_loader = valid_loader, device = device)

  
  from ipykernel import kernelapp as app
  0%|          | 2/1200 [00:00<01:02, 19.25it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  0%|          | 5/1200 [00:00<01:06, 18.06it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  1%|          | 9/1200 [00:00<01:09, 17.07it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  1%|          | 14/1200 [00:00<01:03, 18.54it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

  1%|▏         | 17/1200 [00:00<01:00, 19.59it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  2%|▏         | 21/1200 [00:01<01:02, 18.83it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  2%|▏         | 25/1200 [00:01<01:05, 18.04it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  2%|▏         | 29/1200 [00:01<01:03, 18.38it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  3%|▎         | 33/1200 [00:01<01:10, 16.64it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  3%|▎         | 37/1200 [00:02<01:11, 16.16it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

  3%|▎         | 41/1200 [00:02<01:06, 17.42it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  4%|▎         | 44/1200 [00:02<01:02, 18.50it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

  4%|▍         | 47/1200 [00:02<01:03, 18.17it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  4%|▍         | 53/1200 [00:02<00:56, 20.47it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  5%|▍         | 56/1200 [00:03<00:54, 21.13it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  5%|▌         | 62/1200 [00:03<00:58, 19.42it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  5%|▌         | 65/1200 [00:03<00:56, 19.98it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  6%|▌         | 71/1200 [00:03<00:55, 20.29it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

  6%|▌         | 74/1200 [00:03<00:52, 21.57it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

  7%|▋         | 80/1200 [00:04<00:49, 22.57it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

  7%|▋         | 83/1200 [00:04<00:48, 22.94it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  7%|▋         | 89/1200 [00:04<00:49, 22.29it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

  8%|▊         | 92/1200 [00:04<00:51, 21.69it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  8%|▊         | 95/1200 [00:04<00:54, 20.27it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  8%|▊         | 100/1200 [00:05<00:58, 18.69it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  9%|▉         | 106/1200 [00:05<00:58, 18.84it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  9%|▉         | 108/1200 [00:05<01:01, 17.67it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


  9%|▉         | 112/1200 [00:05<01:01, 17.82it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 10%|▉         | 116/1200 [00:06<01:00, 18.04it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 10%|█         | 121/1200 [00:06<00:56, 19.25it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 10%|█         | 124/1200 [00:06<00:51, 20.92it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 11%|█         | 127/1200 [00:06<00:53, 19.96it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 11%|█         | 133/1200 [00:06<00:53, 19.84it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 12%|█▏        | 139/1200 [00:07<00:49, 21.47it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 12%|█▏        | 142/1200 [00:07<00:51, 20.51it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 12%|█▏        | 148/1200 [00:07<00:48, 21.65it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 13%|█▎        | 151/1200 [00:07<00:52, 20.08it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 13%|█▎        | 154/1200 [00:07<00:53, 19.70it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 13%|█▎        | 157/1200 [00:08<00:50, 20.45it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 14%|█▎        | 163/1200 [00:08<00:49, 20.79it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 14%|█▍        | 166/1200 [00:08<00:50, 20.60it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 14%|█▍        | 169/1200 [00:08<00:47, 21.56it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 15%|█▍        | 175/1200 [00:08<00:50, 20.20it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 15%|█▌        | 181/1200 [00:09<00:43, 23.30it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 16%|█▌        | 187/1200 [00:09<00:40, 24.74it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 16%|█▌        | 193/1200 [00:09<00:40, 24.74it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 16%|█▋        | 196/1200 [00:09<00:41, 24.31it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 17%|█▋        | 199/1200 [00:10<00:56, 17.82it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 17%|█▋        | 205/1200 [00:10<00:54, 18.17it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64])

 17%|█▋        | 207/1200 [00:10<00:55, 17.75it/s]

 torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 18%|█▊        | 212/1200 [00:10<00:50, 19.70it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 18%|█▊        | 215/1200 [00:10<00:48, 20.45it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 18%|█▊        | 218/1200 [00:10<00:50, 19.52it/s]

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([32, 64]) torch.Size([32, 64])


 18%|█▊        | 221/1200 [00:11<02:05,  7.83it/s]

torch.Size([32, 64]) torch.Size([32, 64])


 18%|█▊        | 221/1200 [00:12<00:54, 17.98it/s]


KeyboardInterrupt: 

## Inference

In [None]:
def inference(model, test_loader, device):
    pred_list = []
    label_list = []

    model.eval()
    with torch.no_grad():
        for inputs in test_loader:
            
            inputs = torch.unsqueeze(inputs, 1)
            inputs = torch.tensor(inputs, device=device, dtype=torch.float32)
            
            preds = model(inputs)
            preds = preds.view(preds.nelement())
            pred_list += (list(preds.detach().cpu().numpy()))

        pred_list = np.array(pred_list)
        
    return pred_list

In [None]:
model = Network(init_channel = 1)
model.load_state_dict(torch.load('./non_cardiac_weights/CNN_single_model.pth'))
model = model.cuda()   # if cuda is available

In [None]:
test_loader = get_loader(batch_size=256, shuffle=False, num_workers=6, X = x_test, y = None) 
CNN_single_prediction = inference(model = model, test_loader = test_loader, device = device)