In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from glob import glob
from imageio import imread
import torch.nn as nn

In [2]:
class TableDataset(Dataset):
    """The training table dataset.
    """
    def __init__(self, x_path=None, y_path=None):
        if x_path is None or y_path is None:
            raise ValueError("No data source specified.")
        
        x_filenames = glob(x_path + '*.png')
        y_filenames = glob(y_path + '*.png')
        
        self.x_data = [torch.from_numpy(imread(filename).transpose(2, 0, 1)) for filename in x_filenames]
        self.y_data = [torch.from_numpy(imread(filename).reshape(1, *imread(filename).shape)) for filename in y_filenames]
        self.len = len(self.x_data)
        
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
        
    def __len__(self):
        return self.len

In [3]:
dataset = TableDataset('/Users/calvinku/Dropbox/ThoroughAI/Side Projects/YangZhe/data/cell01/',
                       '/Users/calvinku/Dropbox/ThoroughAI/Side Projects/YangZhe/data/xu_label_cell01/')

train_loader = DataLoader(dataset=dataset,
                         batch_size=1,
                         shuffle=True,
                         num_workers=2)

In [4]:
img = imread('/Users/calvinku/Dropbox/ThoroughAI/Side Projects/YangZhe/data/cell01/1.png')
label = imread('/Users/calvinku/Dropbox/ThoroughAI/Side Projects/YangZhe/data/xu_label_cell01/1.png')

print(img.shape, label.shape)
print(img.transpose(2, 0, 1).shape, label.reshape(1, *label.shape).shape)

(332, 1413, 3) (332, 1413)
(3, 332, 1413) (1, 332, 1413)


In [16]:
class FCN(nn.Module):

    def __init__(self):
        super().__init__()        
        self.conv11 = nn.Conv2d(3, 64, 3, stride=1, padding=0)
        self.conv12 = nn.Conv2d(64, 64, 3, stride=1, padding=0)
        self.pool1 = nn.MaxPool2d(2, stride=2)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv21 = nn.Conv2d(64, 128, 3, stride=1, padding=0)
        self.conv22 = nn.Conv2d(128, 128, 3, stride=1, padding=0)
        self.pool2 = nn.MaxPool2d(2, stride=2)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv31 = nn.Conv2d(128, 512, 3, stride=1, padding=0)
        self.conv32 = nn.Conv2d(512, 512, 3, stride=1, padding=0)
        self.pool3 = nn.MaxPool2d(2, stride=2)
        self.relu3 = nn.ReLU(inplace=True)
                
        self.deconv11 = nn.ConvTranspose2d(512, 128, kernel_size=3, stride=1, padding=0)
        self.deconv12 = nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=0)
        self.unpool1 = nn.MaxUnpool2d(2, stride=2)
        self.relu4 = nn.ReLU(inplace=True)
        
        self.deconv21 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=0)
        self.deconv22 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=0)
        self.unpool2 = nn.MaxUnpool2d(2, stride=2)
        self.relu5 = nn.ReLU(inplace=True)
        
        self.deconv31 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=0)
        self.deconv32 = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=1, padding=0)
        self.unpool3 = nn.MaxUnpool2d(2, stride=2)
    
    def forward(self, x):
        x = self.relu1(self.pool1(self.conv12(self.conv11(x))))
        print(x.shape)
#         x = self.relu2(self.pool2(self.conv22(self.conv21(x))))
#         x = self.relu3(self.pool3(self.conv32(self.conv31(x))))
        
#         x = self.relu4(self.unpool1(self.deconv12(self.deconv11(x))))
#         x = self.relu5(self.unpool2(self.deconv22(self.deconv21(x))))
        x = self.unpool3(self.deconv32(self.deconv31(x)))
        print(x.shape)

        return x

In [6]:
class FCN32s(nn.Module):

    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu    = nn.ReLU(inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn1     = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn2     = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn3     = nn.BatchNorm2d(128)
        self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn4     = nn.BatchNorm2d(64)
        self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn5     = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, n_class, kernel_size=1)

    def forward(self, x):
        output = self.pretrained_net(x)
        x5 = output['x5']  # size=(N, 512, x.H/32, x.W/32)
        print(x5.shape)
        x = self.bn1(self.relu(self.deconv1(x5)))     # size=(N, 512, x.H/16, x.W/16)
        print(x.shape)
        x = self.bn2(self.relu(self.deconv2(x)))  # size=(N, 256, x.H/8, x.W/8)
        print(x.shape)
        x = self.bn3(self.relu(self.deconv3(x)))  # size=(N, 128, x.H/4, x.W/4)
        print(x.shape)
        x = self.bn4(self.relu(self.deconv4(x)))  # size=(N, 64, x.H/2, x.W/2)
        print(x.shape)
        x = self.bn5(self.relu(self.deconv5(x)))  # size=(N, 32, x.H, x.W)
        print(x.shape)
        x = self.classifier(x)                   # size=(N, n_class, x.H/1, x.W/1)
        print(x.shape)

        return x  # size=(N, n_class, x.H/1, x.W/1)

In [7]:
from torchvision.models.vgg import VGG

class VGGNet(VGG):
    def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
        super().__init__(make_layers(cfg[model]))
        self.ranges = ranges[model]

        if pretrained:
            exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)

        if not requires_grad:
            for param in super().parameters():
                param.requires_grad = False

        if remove_fc:  # delete redundant fully-connected layer params, can save memory
            del self.classifier

        if show_params:
            for name, param in self.named_parameters():
                print(name, param.size())

    def forward(self, x):
        output = {}

        # get the output of each maxpooling layer (5 maxpool in VGG net)
        for idx in range(len(self.ranges)):
            for layer in range(self.ranges[idx][0], self.ranges[idx][1]):
                x = self.features[layer](x)
            output["x%d"%(idx+1)] = x

        return output

In [11]:
ranges = {
    'vgg11': ((0, 3), (3, 6),  (6, 11),  (11, 16), (16, 21)),
    'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
    'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
    'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
}

# cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
cfg = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

batch_size, n_class, h, w = 10, 20, 160, 160

In [17]:
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable

if __name__ == "__main__":
    num_epochs = 3
    
#     # test output size
#     vgg_model = VGGNet(requires_grad=True)
#     input = torch.autograd.Variable(torch.randn(batch_size, 3, 224, 224))
#     output = vgg_model(input)
#     assert output['x5'].size() == torch.Size([batch_size, 512, 7, 7])

#     fcn_model = FCN32s(pretrained_net=vgg_model, n_class=n_class)
#     input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
#     output = fcn_model(input)
#     assert output.size() == torch.Size([batch_size, n_class, h, w])

#     print("Pass size check")

#     # test a random batch, loss should decrease

    pretrained_net = VGGNet()
#     fcn_model = FCN32s(pretrained_net=pretrained_net, n_class=1)
    fcn_model = FCN()
    criterion = nn.BCELoss()
    optimizer = optim.SGD(fcn_model.parameters(), lr=1e-3, momentum=0.9)

    
    for epoch in range(num_epochs):
        for i, data in enumerate(train_loader):
            x, y = data
            x, y = Variable(x.type(torch.FloatTensor), requires_grad=False), Variable(y.type(torch.FloatTensor), requires_grad=False)
            
            print("x shape: {}, y shape: {}".format(x.shape, y.shape))
            optimizer.zero_grad()
            output = fcn_model(x)
            output = nn.functional.sigmoid(output)
            loss = criterion(output, y)
            loss.backward()
            print("iter{}, loss {}".format(iter, loss.data[0]))
            optimizer.step()

x shape: torch.Size([1, 3, 522, 1385]), y shape: torch.Size([1, 1, 522, 1385])
torch.Size([1, 64, 259, 690])


TypeError: forward() missing 1 required positional argument: 'indices'