In [1]:
import torch
from torch import nn
import pickle as pkl
import numpy as np

In [2]:
filehandler = open("data.pkl","rb")
dataset = pkl.load(filehandler)
filehandler.close()

In [3]:
X = np.vstack([x[None,None,:,:,:] for (x,y) in dataset])
X = X.astype(np.float32)
X[X == 0.0] = -1.0
print(X.shape)

Y = np.vstack([y[None,:,:,:,None] for (x,y) in dataset])
Y = Y.astype(np.float32)
print(Y.shape)

(1000, 1, 27, 27, 26)
(1000, 27, 27, 26, 1)


In [58]:
class Conv3DFeatureDetector(nn.Module):
    def __init__(self, inplanes, outplanes, kernel_size):
        super().__init__()

        self.conv = nn.Conv3d(
            inplanes, outplanes, 
            kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2,padding_mode='zeros'
        )
        self.linear1 = nn.Linear(outplanes, 1)
#         self.relu = nn.Tanh()
#         self.linear2 = nn.Linear(100, 1)

    def forward(self, x):
        out = self.conv(x)
        out = out.permute(0,2,3,4,1)
        out = self.linear1(out)
        out = torch.sigmoid(out)
        return out

In [59]:
dev = torch.device("cpu")

In [61]:
model = Conv3DFeatureDetector(1, 50, 5)
model.to(dev)

error = nn.BCELoss()


# learning_rate = 0.1
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.95))

In [62]:
X_torch = torch.from_numpy(X[1:100]).to(dev)
Y_torch = torch.from_numpy(Y[1:100]).to(dev)

In [92]:
loss_list = []
iteration_list = []
accuracy_list = []

batch_size = 100

for epoch in range(1000):
#     permutation = torch.randperm(X_torch.size()[0])
    permutation = torch.arange(X_torch.size()[0])
    
    print("epoch ", epoch)
    total_loss = 0.0
    for i in range(0,X_torch.size()[0], batch_size):
        print(i)
        optimizer.zero_grad()

        indices = permutation[i:i+batch_size]
        batch_x, batch_y = X_torch[indices], Y_torch[indices]

        # in case you wanted a semi-full example
        outputs = model.forward(batch_x)
        batch_y.cuda()
        loss = error(outputs,batch_y)

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    print("total loss: ", total_loss)

epoch  0
0
total loss:  0.09145272523164749
epoch  1
0
total loss:  0.09015856683254242
epoch  2
0
total loss:  0.08888415992259979
epoch  3
0
total loss:  0.08762956410646439
epoch  4
0
total loss:  0.08639457076787949
epoch  5
0
total loss:  0.08517926186323166
epoch  6
0
total loss:  0.08398348838090897
epoch  7
0
total loss:  0.08280691504478455
epoch  8
0
total loss:  0.08164922147989273
epoch  9
0
total loss:  0.08051007986068726
epoch  10
0
total loss:  0.07938934862613678
epoch  11
0
total loss:  0.0782870352268219
epoch  12
0
total loss:  0.0772029459476471
epoch  13
0
total loss:  0.07613690197467804
epoch  14
0
total loss:  0.07508842647075653
epoch  15
0
total loss:  0.07405740767717361
epoch  16
0
total loss:  0.07304371893405914
epoch  17
0
total loss:  0.07204721868038177
epoch  18
0
total loss:  0.07106771320104599
epoch  19
0
total loss:  0.07010496407747269
epoch  20
0
total loss:  0.06915866583585739
epoch  21
0
total loss:  0.06822860985994339
epoch  22
0
total loss

KeyboardInterrupt: 

In [93]:
outputs = model.forward(X_torch)

In [94]:
torch.where(Y_torch == 1.0)

(tensor([ 0,  0,  0,  ..., 98, 98, 98]),
 tensor([ 1,  1,  1,  ..., 33, 33, 33]),
 tensor([23, 23, 23,  ..., 26, 26, 26]),
 tensor([12, 13, 14,  ..., 20, 21, 22]),
 tensor([0, 0, 0,  ..., 0, 0, 0]))

In [95]:
Y_torch[0,1,23,12,0]

tensor(1.)

In [96]:
outputs[0,1,23,12,0]

tensor(0.0294, grad_fn=<SelectBackward0>)

In [69]:
in_tensor = torch.ones(20, 1, 35, 35, 32)
out_tensor = model(in_tensor)
out_tensor.shape

torch.Size([20, 35, 35, 32, 1])