In [1]:
%load_ext autoreload
%autoreload 2
import argparse
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
# import torch.utils.data.distributed
# import torchvision.transforms as transforms
# import torchvision.datasets as datasets
# import torchvision.models as models

from torch.utils import data
import random
import numpy as np
from itertools import product
from torch.autograd import Variable


from args import args
from train_f import *
from Dataset import Dataset
#from dataset import Dataset

In [2]:
class Baseline(nn.Module):

    def __init__(self, in_ch, out_ch):
        
        super(Baseline, self).__init__()
        '''
        torch.nn.Conv3d: input(N,C,D,H,W)
                            output(N,C,Dout,Hout,Wout) 
        torch.nn.AvgPool3d: input(N,C,D,H,W)
                            output(N,C,Dout,Hout,Wout)     
        '''
        '''
        nn.Conv3d(in_channels, out_channels, kernel_size)
        nn.AvgPool3d()
        '''
        self.draft_model = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
#             nn.AvgPool3d(3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_ch, out_ch, 3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
#             nn.AvgPool3d(3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, cube):
        cube = self.draft_model(cube)
        return cube
    
class SimpleUnet(nn.Module):

    def __init__(self, in_channels):
        
        super(SimpleUnet, self).__init__()
        self.in_channels = in_channels
        self.Simple_Unet = nn.Sequential(
            self.conv_layer(self.in_channels, 16),
            self.conv_layer(16, 16),
            nn.AvgPool3d(2),
            self.conv_layer(16,32),
            nn.AvgPool3d(2),
            self.conv_layer(32,64),
            self.up_conv_layer(64, 64, 3),
            self.conv_layer(64, 32),
            self.up_conv_layer(32, 32, 3),
            self.conv_layer(32, 16),
            self.up_conv_layer(16, 16, 2),
            self.conv_layer(16, 8),
            self.up_conv_layer(8, 8, 2),
            self.conv_layer(8, 4),
            self.conv_layer(4, 1)
        )
    
    def conv_layer(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True):
        layers = nn.Sequential(
        nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
        nn.BatchNorm3d(out_channels),
        nn.LeakyReLU())
        return layers
    
    def up_conv_layer(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, bias=True):
        layers = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            # should be feat_in*2 or feat_in
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.LeakyReLU())
        return layers
    
    
    def forward(self, cube):
        cube = self.Simple_Unet(cube)
        return cube

In [12]:
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()


    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)



        # add a dimension, from (1, 32, 32, 32) to (1,1,32,32,32)
        input = input.unsqueeze(dim = 1).to(device).float()
        target = target.unsqueeze(dim = 1).to(device).float()
        # compute output
        output = model(input)
        
        loss = criterion(output, target)

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))


        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))


In [13]:
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()


    # switch to evaluate mode
    model.train()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            input = input.unsqueeze(dim = 1).to(device).float()
            target = target.unsqueeze(dim = 1).to(device).float()

            # compute output
            output = model(input)
            
            loss = criterion(output, target)

            # measure accuracy and record loss
            losses.update(loss.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            
        print('Test: Time {batch_time.val:.3f} \t\t\t\t\t\t\t\t\t\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(batch_time=batch_time, loss=losses))

    return 

In [14]:
#index for the cube, each tuple corresponds to a cude
train_data = [(800, 640, 224)]
val_data = [(800, 640, 224)]
test_data = [(800, 640, 224)]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [15]:
params = {'batch_size': 50,
          'shuffle': False,
          #'shuffle': True,
          'num_workers':20}
max_epochs = 100

training_set, validation_set = Dataset(train_data), Dataset(val_data)
testing_set= Dataset(test_data)
training_generator = data.DataLoader(training_set, **params)
validation_generator = data.DataLoader(validation_set, **params)
testing_generator = data.DataLoader(testing_set, **params)

for i, (dark,full) in enumerate(testing_generator):
    dark=dark
    full=full
        

In [16]:
dim = 1
model = SimpleUnet(dim).to(device)
criterion = nn.MSELoss().to(device) #yueqiu
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

In [17]:
for epoch in range(args.start_epoch, args.epochs):
#     if args.distributed:
#         train_sampler.set_epoch(epoch)
    adjust_learning_rate(optimizer, epoch)

    # train for one epoch
    #train(training_generator, model, criterion, optimizer, epoch)

    # evaluate on validation set
    validate(validation_generator, model, criterion)




Test: Time 0.379 										Loss 0.5122 (0.5122)	




Test: Time 0.427 										Loss 0.5122 (0.5122)	
Test: Time 0.400 										Loss 0.5122 (0.5122)	
Test: Time 0.396 										Loss 0.5122 (0.5122)	
Test: Time 0.429 										Loss 0.5122 (0.5122)	
Test: Time 0.412 										Loss 0.5122 (0.5122)	
Test: Time 0.407 										Loss 0.5122 (0.5122)	
Test: Time 0.372 										Loss 0.5122 (0.5122)	
Test: Time 0.427 										Loss 0.5122 (0.5122)	
Test: Time 0.402 										Loss 0.5122 (0.5122)	
