In [2]:
%load_ext autoreload
%autoreload 2
import argparse
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
# import torch.utils.data.distributed
# import torchvision.transforms as transforms
# import torchvision.datasets as datasets
# import torchvision.models as models

from torch.utils import data
import random
import numpy as np
from itertools import product

import torch.nn.functional as F

from args import args
from train_f import *
from Dataset import Dataset



In [3]:
torch.__version__

'0.4.1'

In [4]:
def convolution_loss_prep(pred,Y):
    fixed_ones_conv = nn.Conv3d(1,1,kernel_size = 3, padding = 1)
    nn.init.constant_(fixed_ones_conv.weight,1)
    fixed_ones_conv.weight.requires_grad = False
    conv_pred = fixed_ones_conv(pred)
    conv_Y = fixed_ones_conv(Y)
    return conv_pred, conv_Y

# Convolution loss for MSE, use together with weight
def convolution_loss(pred,Y, weight_ratio):
    weight_mse_loss = weighted_nn_square_loss(weight_ratio)
    conv_pred, conv_Y = convolution_loss_prep(pred,Y)
    return weight_mse_loss(con_pred,con_Y)

In [5]:
def weighted_nn_loss(weight_ratio):
    def weighted(X,Y):
        base_loss = F.mse_loss(X,Y,reduction = 'sum')
        index = Y > 0
        plus_loss = F.mse_loss(X[index],Y[index], reduction = 'sum') if index.any() > 0 else 0
        total_loss = base_loss + (weight_ratio -1) * plus_loss
        return total_loss/X.shape[0]
    return weighted

In [12]:
a = torch.Tensor([3,2,1, 0, 0])
b = torch.Tensor([1,2,100, 0, 2])

In [13]:
criterion1 = weighted_nn_loss(1)
criterion2 = nn.MSELoss()
print(criterion1(a,b))
print(criterion2(a,b))


tensor(1961.8000)
tensor(1961.8000)


In [2]:
class Baseline(nn.Module):

    def __init__(self, in_ch, out_ch):
        
        super(Baseline, self).__init__()
        '''
        torch.nn.Conv3d: input(N,C,D,H,W)
                            output(N,C,Dout,Hout,Wout) 
        torch.nn.AvgPool3d: input(N,C,D,H,W)
                            output(N,C,Dout,Hout,Wout)     
        '''
        '''
        nn.Conv3d(in_channels, out_channels, kernel_size)
        nn.AvgPool3d()
        '''
        self.draft_model = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
#             nn.AvgPool3d(3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_ch, out_ch, 3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
#             nn.AvgPool3d(3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, cube):
        cube = self.draft_model(cube)
        return cube
    
class SimpleUnet(nn.Module):

    def __init__(self, in_channels):
        
        super(SimpleUnet, self).__init__()
        self.in_channels = in_channels
        self.Simple_Unet = nn.Sequential(
            self.conv_layer(self.in_channels, 16),
            self.conv_layer(16, 16),
            nn.AvgPool3d(2),
            self.conv_layer(16,32),
            nn.AvgPool3d(2),
            self.conv_layer(32,64),
            self.up_conv_layer(64, 64, 3),
            self.conv_layer(64, 32),
            self.up_conv_layer(32, 32, 3),
            self.conv_layer(32, 16),
            self.up_conv_layer(16, 16, 2),
            self.conv_layer(16, 8),
            self.up_conv_layer(8, 8, 2),
            self.conv_layer(8, 4),
            self.conv_layer(4, 1)
        )
    
    def conv_layer(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True):
        layers = nn.Sequential(
        nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
        nn.BatchNorm3d(out_channels),
        nn.LeakyReLU())
        return layers
    
    def up_conv_layer(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, bias=True):
        layers = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            # should be feat_in*2 or feat_in
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.LeakyReLU())
        return layers
    
    
    def forward(self, cube):
        cube = self.Simple_Unet(cube)
        return cube

In [3]:
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    data_time = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)



        # add a dimension, from (1, 32, 32, 32) to (1,1,32,32,32)
        input = input.unsqueeze(dim = 1).to(device).float()
        target = target.unsqueeze(dim = 1).to(device).float()
        # compute output
        output = model(input)
        
        loss = criterion(output, target)

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))


        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))


In [4]:
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()


    # switch to evaluate mode
    model.train()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            input = input.unsqueeze(dim = 1).to(device).float()
            target = target.unsqueeze(dim = 1).to(device).float()

            # compute output
            output = model(input)
            
            loss = criterion(output, target)

            # measure accuracy and record loss
            losses.update(loss.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            
        print('Test: Time {batch_time.val:.3f} \t\t\t\t\t\t\t\t\t\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(batch_time=batch_time, loss=losses))

    return 

In [5]:
#index for the cube, each tuple corresponds to a cude
train_data = [(800, 640, 224)]
val_data = [(800, 640, 224)]
test_data = [(800, 640, 224)]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [6]:
# pos=list(np.arange(0,1024,32))
# ranges=list(product(pos,repeat=3))
# # random.shuffle(ranges)
# train_data = ranges[:int(np.round(len(ranges)*0.6))]
# val_data=ranges[int(np.round(len(ranges)*0.6)):int(np.round(len(ranges)*0.8))]
# test_data = ranges[int(np.round(len(ranges)*0.8)):]

In [7]:
params = {'batch_size': 50,
          'shuffle': False,
          #'shuffle': True,
          'num_workers':20}
max_epochs = 100

training_set, validation_set = Dataset(train_data), Dataset(val_data)
testing_set= Dataset(test_data)
training_generator = data.DataLoader(training_set, **params)
validation_generator = data.DataLoader(validation_set, **params)
testing_generator = data.DataLoader(testing_set, **params)

for i, (dark,full) in enumerate(testing_generator):
    dark=dark
    full=full
        

In [8]:
dim = 1
model = SimpleUnet(dim).to(device)
criterion = nn.MSELoss().to(device) #yueqiu
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

In [9]:
for epoch in range(args.start_epoch, args.epochs):
    adjust_learning_rate(optimizer, epoch)

    # train for one epoch
    train(training_generator, model, criterion, optimizer, epoch)

    # evaluate on validation set
    validate(validation_generator, model, criterion)






Epoch: [0][0/1]	Time 0.340 (0.340)	Loss 0.0651 (0.0651)	
Test: Time 0.156 										Loss 21.5622 (21.5622)	
Epoch: [1][0/1]	Time 0.281 (0.281)	Loss 21.5622 (21.5622)	
Test: Time 0.164 										Loss 1303.5524 (1303.5524)	
Epoch: [2][0/1]	Time 0.293 (0.293)	Loss 1303.5524 (1303.5524)	
Test: Time 0.161 										Loss 58937.8438 (58937.8438)	
Epoch: [3][0/1]	Time 0.285 (0.285)	Loss 58937.8438 (58937.8438)	
Test: Time 0.167 										Loss 16216511.0000 (16216511.0000)	
Epoch: [4][0/1]	Time 0.293 (0.293)	Loss 16216511.0000 (16216511.0000)	
Test: Time 0.179 										Loss 3803902208.0000 (3803902208.0000)	
Epoch: [5][0/1]	Time 0.284 (0.284)	Loss 3803902208.0000 (3803902208.0000)	
Test: Time 0.185 										Loss 658882363392.0000 (658882363392.0000)	
Epoch: [6][0/1]	Time 0.300 (0.300)	Loss 658882363392.0000 (658882363392.0000)	
Test: Time 0.189 										Loss 237700653252608.0000 (237700653252608.0000)	
Epoch: [7][0/1]	Time 0.327 (0.327)	Loss 237700653252608.0000 (237700653252608.0000)	
Test:

In [11]:
train_data

[(800, 640, 224)]