In [52]:
from PIL import Image
from PIL import ImageOps
import numpy as np
from pathlib import Path
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim

root = Path(".")

In [66]:
class ConvNet(nn.Module):
    def __init__(self, num_classes=256):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2))
        self.layer3 = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2))
        self.layer4 = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2))
        self.layer5 = nn.Sequential(
            nn.Conv3d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2))
        self.layer6 = nn.Sequential(
            nn.MaxPool3d(kernel_size=3, stride=3))
        self.fc = nn.Linear(256, 256)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
#         out = self.layer6(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

In [42]:
class dataset:
    def __init__(self, directory):
        self.directory = directory
        self.name = str(directory)[-3:]
        self.num_images = len([file for file in os.listdir(directory) if ".tif" in file])
        self.resolution = Image.open(self.directory / os.listdir(self.directory)[0]).convert('RGB').size
        self.filepaths = [directory / file for file in os.listdir(directory) if ".tif" in file]
        self.ground_truth = np.loadtxt(directory / ".." / ".." / "output" / "002" / "indices.txt", dtype=int)
    def sample(self, coord):
        x, y = coord
        x = x*16 + 8# range(4, 256, 8) 
        y = y*16 + 8
        assert 0 <= x <= 248, "x out of bounds: %d"%x
        assert 0 <= y <= 248, "y out of bounds: %d"%y
        process_img = lambda img_t: ImageOps.autocontrast(Image.open(img_t).convert('L').resize((256,256))).crop(box=(y-8, x-8, y+8, x+8))
        process_label = lambda img_t: Image.fromarray(np.uint8(img_t*255)).crop(box=(y-8, x-8, y+8, x+8))
        return (process_label(self.ground_truth), [process_img(img_t) for img_t in self.filepaths])
    
def image_buffer_to_tensor(image_buffer, section_index):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(.5, .5)])
    tag_map_image, frame_list = image_buffer[section_index]
    totensor = transforms.ToTensor()
    tag_tensor = totensor(tag_map_image)
    video_tensor = torch.cat([transform(frame) for frame in frame_list], dim=0)
    video_tensor = video_tensor[None, :, :, :]
    return tag_tensor, video_tensor

def batch_generator(image_buffer, batch_size):
    all_data = image_buffer
    assert len(all_data)%batch_size == 0
    n_datapoints_used = 0
    def get_batch():
        nonlocal all_data
        nonlocal n_datapoints_used
        nonlocal batch_size
        if len(all_data)-n_datapoints_used==0:
            return []
        else:
            video_tensor_batch = torch.stack([image_buffer_to_tensor(all_data, n)[1] for n in range(n_datapoints_used, n_datapoints_used+batch_size)])
            tag_tensor_batch = torch.stack([image_buffer_to_tensor(all_data, n)[0] for n in range(n_datapoints_used, n_datapoints_used+batch_size)])
            n_datapoints_used += batch_size
            return tag_tensor_batch, video_tensor_batch
    return get_batch

In [3]:
# get_batch returns tuple of (tag_tensor_batch, video_tensor_batch)
# tag_tensor_batch = (N, C, W, H)
# video_tensor_batch = (N, C, D, W, H)

data = dataset(root / "data" / "002")
image_buffer = []
for n in range(16):
    for m in range(16):
        print("%d %d "%(n,m), end='\r')
        image_buffer += [data.sample((n, m))]

15 15 

In [67]:
batch_size=4
training_set = image_buffer # for now
net = ConvNet()
net.cuda()
criterion = nn.MSELoss(reduction='none')
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [68]:
for epoch in range(200):
    get_batch = batch_generator(training_set, batch_size)
    trainloader = get_batch()
    i=0
    loss_ax = []
    while len(trainloader) > 0: 
        optimizer.zero_grad()

        # get the inputs; data is a list of [inputs, labels]
        labels, inputs = trainloader
        labels, inputs = labels.cuda(), inputs.cuda()

        # forward + backward + optimize
#         a = list(net.parameters())[0].clone() 
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
#         b = list(net.parameters())[0].clone()
#         print("Model updating? ", not(torch.equal(a,b)))

        # print statistics
        loss_ax += [loss.item()]
        print(epoch, i, loss.item())
        trainloader = get_batch()
        i += 1
    print('Finished Training')

RuntimeError: Given input size: (256x37x1x1). Calculated output size: (256x18x0x0). Output size is too small