In [1]:
%load_ext autoreload
%autoreload 2
import argparse
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
# import torch.utils.data.distributed
# import torchvision.transforms as transforms
# import torchvision.datasets as datasets
# import torchvision.models as models

from torch.utils import data
import random
import numpy as np
from itertools import product
from torch.autograd import Variable


from args import args
from train_f import *
from Dataset import Dataset
#from dataset import Dataset

In [2]:
class Baseline(nn.Module):

    def __init__(self, in_ch, out_ch):
        
        super(Baseline, self).__init__()
        '''
        torch.nn.Conv3d: input(N,C,D,H,W)
                            output(N,C,Dout,Hout,Wout) 
        torch.nn.AvgPool3d: input(N,C,D,H,W)
                            output(N,C,Dout,Hout,Wout)     
        '''
        '''
        nn.Conv3d(in_channels, out_channels, kernel_size)
        nn.AvgPool3d()
        '''
        self.draft_model = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
#             nn.AvgPool3d(3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_ch, out_ch, 3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
#             nn.AvgPool3d(3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, cube):
        cube = self.draft_model(cube)
        return cube
    
class SimpleUnet(nn.Module):

    def __init__(self, in_channels):
        
        super(SimpleUnet, self).__init__()
        self.in_channels = in_channels
        self.Simple_Unet = nn.Sequential(
            self.conv_layer(self.in_channels, 16),
            self.conv_layer(16, 16),
            nn.AvgPool3d(2),
            self.conv_layer(16,32),
            nn.AvgPool3d(2),
            self.conv_layer(32,64),
            self.up_conv_layer(64, 64, 3),
            self.conv_layer(64, 32),
            self.up_conv_layer(32, 32, 3),
            self.conv_layer(32, 16),
            self.up_conv_layer(16, 16, 2),
            self.conv_layer(16, 8),
            self.up_conv_layer(8, 8, 2),
            self.conv_layer(8, 4),
            self.conv_layer(4, 1)
        )
    
    def conv_layer(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True):
        layers = nn.Sequential(
        nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
        nn.BatchNorm3d(out_channels),
        nn.LeakyReLU())
        return layers
    
    def up_conv_layer(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, bias=True):
        layers = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            # should be feat_in*2 or feat_in
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.LeakyReLU())
        return layers
    
    
    def forward(self, cube):
        cube = self.Simple_Unet(cube)
        return cube

In [3]:
#index for the cube, each tuple corresponds to a cude
train_data = [(800, 640, 224)]
val_data = [(800, 640, 224)]
test_data = [(800, 640, 224)]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [4]:
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            input = input.cuda(args.gpu, non_blocking=True)
        #target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        input = input.unsqueeze(dim = 1)
        target = target.unsqueeze(dim = 1)
        output = model(input.float())
        
        loss = criterion(output, target.float())

        # measure accuracy and record loss
        #prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        #top1.update(prec1[0], input.size(0))
        #top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))


In [5]:
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            if args.gpu is not None:
                input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5))

        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg

In [6]:
params = {'batch_size': 3,
          'shuffle': False,
          #'shuffle': True,
          'num_workers':20}
max_epochs = 100

training_set, validation_set = Dataset(train_data), Dataset(val_data)
testing_set= Dataset(test_data)
training_generator = data.DataLoader(training_set, **params)
validation_generator = data.DataLoader(validation_set, **params)
testing_generator = data.DataLoader(testing_set, **params)

for i, (dark,full) in enumerate(testing_generator):
    dark=dark
    full=full
        

In [7]:
best_prec1 = 0
dim = 1
model = SimpleUnet(dim)
criterion = nn.MSELoss().cuda() #yueqiu
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

In [8]:
for epoch in range(args.start_epoch, args.epochs):
#     if args.distributed:
#         train_sampler.set_epoch(epoch)
    adjust_learning_rate(optimizer, epoch)

    # train for one epoch
    train(training_generator, model, criterion, optimizer, epoch)

    # evaluate on validation set
#     prec1 = validate(validation_generator, model, criterion)

#     # remember best prec@1 and save checkpoint
#     is_best = prec1 > best_prec1
#     best_prec1 = max(prec1, best_prec1)




Epoch: [0][0/1]	Time 0.400 (0.400)	Data 0.076 (0.076)	Loss 0.1626 (0.1626)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [1][0/1]	Time 0.388 (0.388)	Data 0.082 (0.082)	Loss 0.0669 (0.0669)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [2][0/1]	Time 0.391 (0.391)	Data 0.087 (0.087)	Loss 0.0090 (0.0090)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [3][0/1]	Time 0.408 (0.408)	Data 0.089 (0.089)	Loss 0.0040 (0.0040)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [4][0/1]	Time 0.407 (0.407)	Data 0.095 (0.095)	Loss 0.0040 (0.0040)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [5][0/1]	Time 0.396 (0.396)	Data 0.089 (0.089)	Loss 0.0040 (0.0040)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [6][0/1]	Time 0.405 (0.405)	Data 0.096 (0.096)	Loss 0.0040 (0.0040)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [7][0/1]	Time 0.403 (0.403)	Data 0.091 (0.091)	Loss 0.0040 (0.0040)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [8][0/1]	Time 0.392 (0.392)	Data 0.089 (0.089)	Loss 0.004