In [3]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from glob import glob
from imageio import imread
import torch.nn as nn

In [4]:
class TableDataset(Dataset):
    """The training table dataset.
    """
    def __init__(self, x_path=None, y_path=None):
        if x_path is None or y_path is None:
            raise ValueError("No data source specified.")
        
        x_filenames = glob(x_path + '*.png')
        y_filenames = glob(y_path + '*.png')
        
        self.x_data = [torch.from_numpy(imread(filename).transpose(2, 0, 1)) for filename in x_filenames]
        self.y_data = [torch.from_numpy(imread(filename).reshape(1, *imread(filename).shape)) for filename in y_filenames]
        self.len = len(self.x_data)
        
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
        
    def __len__(self):
        return self.len

In [5]:
dataset = TableDataset('/Users/calvinku/Dropbox/ThoroughAI/Side Projects/YangZhe/data/cell01/',
                       '/Users/calvinku/Dropbox/ThoroughAI/Side Projects/YangZhe/data/xu_label_cell01/')

train_loader = DataLoader(dataset=dataset,
                         batch_size=1,
                         shuffle=True,
                         num_workers=2)

In [6]:
img = imread('/Users/calvinku/Dropbox/ThoroughAI/Side Projects/YangZhe/data/cell01/1.png')
label = imread('/Users/calvinku/Dropbox/ThoroughAI/Side Projects/YangZhe/data/xu_label_cell01/1.png')

print(img.shape, label.shape)
print(img.transpose(2, 0, 1).shape, label.reshape(1, *label.shape).shape)

(332, 1413, 3) (332, 1413)
(3, 332, 1413) (1, 332, 1413)


In [29]:
class FCN(nn.Module):

    def __init__(self):
        super().__init__()        
        self.conv11 = nn.Conv2d(3, 64, 3, stride=1, padding=0)
        self.conv12 = nn.Conv2d(64, 64, 3, stride=1, padding=0)
        self.pool1 = nn.MaxPool2d(2, stride=1)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv21 = nn.Conv2d(64, 128, 3, stride=1, padding=0)
        self.conv22 = nn.Conv2d(128, 128, 3, stride=1, padding=0)
        self.pool2 = nn.MaxPool2d(2, stride=1)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv31 = nn.Conv2d(128, 512, 3, stride=1, padding=0)
        self.conv32 = nn.Conv2d(512, 512, 3, stride=1, padding=0)
        self.pool3 = nn.MaxPool2d(2, stride=1)
        self.relu3 = nn.ReLU(inplace=True)
                
        self.deconv11 = nn.ConvTranspose2d(512, 128, kernel_size=3, stride=1, padding=0)
        self.deconv12 = nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=0)
        self.deconv13 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=1, padding=0)
        self.relu4 = nn.ReLU(inplace=True)
        
        self.deconv21 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=0)
        self.deconv22 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=0)
        self.deconv23 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=1, padding=0)
        self.relu5 = nn.ReLU(inplace=True)
        
        self.deconv31 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=0)
        self.deconv32 = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=1, padding=0)
        self.deconv33 = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=1, padding=0)
    
    def forward(self, x):
        x = self.relu1(self.pool1(self.conv12(self.conv11(x))))
        print(x.shape)
#         x = self.relu2(self.pool2(self.conv22(self.conv21(x))))
#         x = self.relu3(self.pool3(self.conv32(self.conv31(x))))
        
#         x = self.relu4(self.unpool1(self.deconv12(self.deconv11(x))))
#         x = self.relu5(self.unpool2(self.deconv22(self.deconv21(x))))
        x = self.deconv33(self.deconv32(self.deconv31(x)))
        print(x.shape)

        return x

In [19]:
class FCN32s(nn.Module):

    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu    = nn.ReLU(inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn1     = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn2     = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn3     = nn.BatchNorm2d(128)
        self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn4     = nn.BatchNorm2d(64)
        self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn5     = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, n_class, kernel_size=1)

    def forward(self, x):
        output = self.pretrained_net(x)
        x5 = output['x5']  # size=(N, 512, x.H/32, x.W/32)
        print(x5.shape)
        x = self.bn1(self.relu(self.deconv1(x5)))     # size=(N, 512, x.H/16, x.W/16)
        print(x.shape)
        x = self.bn2(self.relu(self.deconv2(x)))  # size=(N, 256, x.H/8, x.W/8)
        print(x.shape)
        x = self.bn3(self.relu(self.deconv3(x)))  # size=(N, 128, x.H/4, x.W/4)
        print(x.shape)
        x = self.bn4(self.relu(self.deconv4(x)))  # size=(N, 64, x.H/2, x.W/2)
        print(x.shape)
        x = self.bn5(self.relu(self.deconv5(x)))  # size=(N, 32, x.H, x.W)
        print(x.shape)
        x = self.classifier(x)                   # size=(N, n_class, x.H/1, x.W/1)
        print(x.shape)

        return x  # size=(N, n_class, x.H/1, x.W/1)

In [30]:
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable

if __name__ == "__main__":
    num_epochs = 3
    
#     # test output size
#     vgg_model = VGGNet(requires_grad=True)
#     input = torch.autograd.Variable(torch.randn(batch_size, 3, 224, 224))
#     output = vgg_model(input)
#     assert output['x5'].size() == torch.Size([batch_size, 512, 7, 7])

#     fcn_model = FCN32s(pretrained_net=vgg_model, n_class=n_class)
#     input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
#     output = fcn_model(input)
#     assert output.size() == torch.Size([batch_size, n_class, h, w])

#     print("Pass size check")

#     # test a random batch, loss should decrease

#     pretrained_net = VGGNet()
#     fcn_model = FCN32s(pretrained_net=pretrained_net, n_class=1)
    fcn_model = FCN()
    criterion = nn.BCELoss()
    optimizer = optim.SGD(fcn_model.parameters(), lr=1e-3, momentum=0.9)

    
    for epoch in range(num_epochs):
        for i, data in enumerate(train_loader):
            x, y = data
            x, y = Variable(x.type(torch.FloatTensor), requires_grad=False), Variable(y.type(torch.FloatTensor), requires_grad=False)
            
            print("x shape: {}, y shape: {}".format(x.shape, y.shape))
            optimizer.zero_grad()
            output = fcn_model(x)
            output = nn.functional.sigmoid(output)
            loss = criterion(output, y)
            loss.backward()
            print("iter{}, loss {}".format(iter, loss.data[0]))
            optimizer.step()

x shape: torch.Size([1, 3, 556, 1135]), y shape: torch.Size([1, 1, 556, 1135])
torch.Size([1, 64, 551, 1130])
torch.Size([1, 1, 556, 1135])




iter<built-in function iter>, loss -75.0948715209961
x shape: torch.Size([1, 3, 1277, 1604]), y shape: torch.Size([1, 1, 1277, 1604])
torch.Size([1, 64, 1272, 1599])
torch.Size([1, 1, 1277, 1604])
iter<built-in function iter>, loss -5660.197265625
x shape: torch.Size([1, 3, 246, 1403]), y shape: torch.Size([1, 1, 246, 1403])
torch.Size([1, 64, 241, 1398])
torch.Size([1, 1, 246, 1403])
iter<built-in function iter>, loss -5374.37548828125
x shape: torch.Size([1, 3, 891, 1536]), y shape: torch.Size([1, 1, 891, 1536])
torch.Size([1, 64, 886, 1531])
torch.Size([1, 1, 891, 1536])
iter<built-in function iter>, loss -5545.23095703125
x shape: torch.Size([1, 3, 406, 1101]), y shape: torch.Size([1, 1, 406, 1101])
torch.Size([1, 64, 401, 1096])
torch.Size([1, 1, 406, 1101])
iter<built-in function iter>, loss -4652.00439453125
x shape: torch.Size([1, 3, 495, 1570]), y shape: torch.Size([1, 1, 495, 1570])
torch.Size([1, 64, 490, 1565])
torch.Size([1, 1, 495, 1570])
iter<built-in function iter>, los

Process Process-16:
Process Process-15:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/calvinku/anaconda/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/calvinku/anaconda/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/calvinku/anaconda/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/calvinku/anaconda/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/calvinku/anaconda/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/Users/calvinku/anaconda/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r

KeyboardInterrupt: 