In [1]:
from __future__ import print_function

from epinet_fun.func_generate_traindata import generate_traindata_for_train
from epinet_fun.func_generate_traindata import data_augmentation_for_train
from epinet_fun.func_generate_traindata import generate_traindata512
from epinet_fun.func_makeinput import make_multiinput
from epinet_fun.func_pfm import read_pfm
from epinet_fun.func_savedata import display_current_output
from epinet_fun.util import load_LFdata

from network.model_student import *
from network.model_teacher import *

import numpy as np
import matplotlib.pyplot as plt

import h5py
import os
import time
import datetime
import threading
import configparser
import json
from PIL import Image, ImageEnhance, ImageOps
#from epinet_fun.func_middle_output import middle_layer_output
import imageio

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import torch.nn.functional as F

import matplotlib.pyplot as plt
import time
from PIL import Image

In [2]:
inifile = configparser.ConfigParser()
inifile.read('./config.ini', 'UTF-8')

train_dataset_list = json.loads(inifile.get('dataset_list','train_dataset_list'))
test_dataset_list = json.loads(inifile.get('dataset_list','test_dataset_list'))

dataset_path = inifile.get('PATH','dataset_path')
boolmask_img4_path = inifile.get('PATH','boolmask_img4')
boolmask_img6_path = inifile.get('PATH','boolmask_img6')
boolmask_img15_path = inifile.get('PATH','boolmask_img15')

image_width = int(inifile.get('model_1371','image_width'))
image_height = int(inifile.get('model_1371','image_height'))

batch_size_training = int(inifile.get('training_general','batch_size_training'))
batch_size_validation = int(inifile.get('training_general','batch_size_validation'))
batch_num_in_1epoch_for_training = int(inifile.get('training_general','batch_num_in_1epoch_for_training'))
training_img_size = int(inifile.get('training_general','training_img_size'))
validation_img_size = int(inifile.get('training_general','validation_img_size'))

learning_rate = float(inifile.get('training_general','learning_rate'))
LR_scheduler_change_point_iteration = int(inifile.get('training_general','LR_scheduler_change_point_iteration'))
validation_frequency = int(inifile.get('training_general','validation_frequency'))
save_model_frequency = int(inifile.get('training_general','save_model_frequency'))

input_ch = int(inifile.get('training_general','input_ch'))
filter_num = int(inifile.get('training_general','filter_num'))
stream_num = int(inifile.get('training_general','stream_num'))
input_size_train = int(inifile.get('training_general','input_size_train'))
input_size_valid = int(inifile.get('training_general','input_size_valid'))
label_size_train = int(inifile.get('training_general','label_size_train'))
label_size_valid = int(inifile.get('training_general','label_size_valid'))

lambda_mae = float(inifile.get('knowledgedistillation','lambda_mae'))
lambda_pairwise = float(inifile.get('knowledgedistillation','lambda_pairwise'))

seed = int(inifile.get('training_general','seed'))

In [3]:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
def save_validation_tensor_as_png(tensor,save_path):

    directory_path = os.path.dirname(save_path)
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)

    tensor = tensor.detach().cpu().numpy()
    normalized_image = (tensor - tensor.min()) / (tensor.max() - tensor.min())
    image_uint8 = np.uint8(normalized_image * 255)

    concatenated_images = np.hstack(image_uint8)
    imageio.imsave(save_path, np.squeeze(concatenated_images))

def save_tensor_as_png(tensor,save_path):

    tensor = tensor.detach().cpu().numpy()
    imageio.imsave(save_path, np.squeeze(tensor))

In [5]:
def save_model_and_optimizer(model, optimizer, save_path):

    directory_path = os.path.dirname(save_path)
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)

    state = {
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    }

    torch.save(state, save_path)

In [6]:
def display_current_output(train_output, traindata_label, save_path,save_img_flag):
    '''
        display current results from EPINET
        and save results in /current_output
    '''

    directory_path = os.path.dirname(save_path)
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)

    sz=len(traindata_label)
    train_output=np.squeeze(train_output)
    if(len(traindata_label.shape)>3 and traindata_label.shape[-1]==9): # traindata
        pad1_half=int(0.5*(np.size(traindata_label,1)-np.size(train_output,1)))
        train_label482=traindata_label[:,15:-15,15:-15,4,4]
    else: # valdata
        pad1_half=int(0.5*(np.size(traindata_label,1)-np.size(train_output,1)))
        train_label482=traindata_label[:,15:-15,15:-15]

    train_output482=train_output[:,15-pad1_half:482+15-pad1_half,15-pad1_half:482+15-pad1_half]

    train_diff=np.abs(train_output482-train_label482)
    train_bp=(train_diff>=0.07)

    if save_img_flag:
        condition = train_bp > 0

        bp_img=np.zeros_like(train_bp)
        bp_img[condition] = 1
        bp_img[~condition] = 0
        

        train_output482_all=np.zeros((3*482,sz*482),np.uint8)
        train_output482_all[0:482,:]=np.uint8(25*np.reshape(np.transpose(train_label482,(1,0,2)),(482,sz*482))+100)
        train_output482_all[482:2*482,:]=np.uint8(25*np.reshape(np.transpose(train_output482,(1,0,2)),(482,sz*482))+100)
        train_output482_all[2*482:3*482,:]=np.uint8(25*np.reshape(np.transpose(bp_img,(1,0,2)),(482,sz*482))+100)

        imageio.imsave(save_path, np.squeeze(train_output482_all))

    return train_diff, train_bp

In [7]:
class CustomDataset:
    def __init__(self,traindata_all, traindata_label, input_size,label_size,batch_size,Setting02_AngualrViews,
                                                boolmask_img4,boolmask_img6,boolmask_img15, batch_num_in_1epoch, mode):
        self.traindata_all = traindata_all
        self.traindata_label = traindata_label
        self.input_size = input_size
        self.label_size = label_size
        self.batch_size = batch_size
        self.Setting02_AngualrViews = Setting02_AngualrViews
        self.boolmask_img4 = boolmask_img4
        self.boolmask_img6 = boolmask_img6
        self.boolmask_img15 = boolmask_img15
        self.batch_num_in_1epoch = batch_num_in_1epoch
        self.mode = mode

    def __getitem__(self, index):
        (traindata_batch_90d, traindata_batch_0d,traindata_batch_45d, traindata_batch_m45d,
        traindata_label_batchNxN)= generate_traindata_for_train(self.traindata_all,self.traindata_label,
                                                                self.input_size,self.label_size,1,
                                                                self.Setting02_AngualrViews,
                                                                self.boolmask_img4,self.boolmask_img6,self.boolmask_img15,self.mode)
        """
         traindata_batch_0d : (1, 25, 25, 9) nd.array
         traindata_label_batchNxN : (1, 3, 3) nd.array
        """

        (traindata_batch_90d_aug, traindata_batch_0d_aug,traindata_batch_45d_aug,traindata_batch_m45d_aug,
        traindata_label_batchNxN_aug) =  data_augmentation_for_train(traindata_batch_90d,traindata_batch_0d,
                                                                traindata_batch_45d,traindata_batch_m45d,
                                                                traindata_label_batchNxN, 1)

        traindata_batch_90d = torch.from_numpy(traindata_batch_90d_aug).squeeze(0).to(torch.float32).permute(2, 0, 1)
        traindata_batch_0d = torch.from_numpy(traindata_batch_0d_aug).squeeze(0).to(torch.float32).permute(2, 0, 1)
        traindata_batch_45d = torch.from_numpy(traindata_batch_45d_aug).squeeze(0).to(torch.float32).permute(2, 0, 1)
        traindata_batch_m45d = torch.from_numpy(traindata_batch_m45d_aug).squeeze(0).to(torch.float32).permute(2, 0, 1)
        traindata_label_batchNxN = torch.from_numpy(traindata_label_batchNxN_aug).squeeze(0).to(torch.float32)

        """
         traindata_batch_0d : (9, 25, 25) tensor
         traindata_label_batchNxN : (3, 3) tensor
        """

        return traindata_batch_90d, traindata_batch_0d, traindata_batch_45d, traindata_batch_m45d, traindata_label_batchNxN

    def __len__(self):
        return self.batch_size * self.batch_num_in_1epoch


In [8]:
def PairwiseLoss(hint_student, hint_teacher):
    batchsize, channelsize, h, w = hint_student.size()

    feature_student = hint_student.view(batchsize, channelsize, -1)
    feature_teacher = hint_teacher.view(batchsize, channelsize, -1)

    affinity_student = torch.einsum('bci,bcj->bij', feature_student, feature_student)
    affinity_teacher = torch.einsum('bci,bcj->bij', feature_teacher, feature_teacher)

    norm_student = torch.linalg.norm(feature_student, dim=1).unsqueeze(2)
    norm_teacher = torch.linalg.norm(feature_teacher, dim=1).unsqueeze(2)

    affinity_student /= (norm_student * norm_student.transpose(1, 2))
    affinity_teacher /= (norm_teacher * norm_teacher.transpose(1, 2))
    
    loss = F.mse_loss(affinity_student, affinity_teacher, reduction='sum') / batchsize

    return loss /w/h

load train data

In [9]:
traindata_all,traindata_label=load_LFdata(train_dataset_list)
testdata_all,testdata_label=load_LFdata(test_dataset_list)
"""
 traindata_all  :  (16, 512, 512, 9, 9, 3) nd.array
 traindata_label : (16, 512, 512) nd.array
 testdata_all : (8, 512, 512, 9, 9, 3)  nd.array
 testdata_label : (8, 512, 512)   nd.array
"""

additional/antinous
additional/boardgames
additional/dishes
additional/greek
additional/kitchen
additional/medieval2
additional/museum
additional/pens
additional/pillows
additional/platonic
additional/rosemary
additional/table
additional/tomb
additional/tower
additional/town
additional/vinyl
stratified/backgammon
stratified/dots
stratified/pyramids
stratified/stripes
training/boxes
training/cotton
training/dino
training/sideboard


'\n traindata_all  :  (16, 512, 512, 9, 9, 3) nd.array\n traindata_label : (16, 512, 512) nd.array\n testdata_all : (8, 512, 512, 9, 9, 3)  nd.array\n testdata_label : (8, 512, 512)   nd.array\n'

setup size information

In [10]:
Setting02_AngualrViews = np.array([0,1,2,3,4,5,6,7,8])  # number of views ( 0~8 for 9x9 )

make training tensor for visualization

In [11]:
training_full_90d = torch.zeros((batch_size_training, training_img_size, training_img_size, 9))
training_full_0d = torch.zeros((batch_size_training, training_img_size, training_img_size, 9))
training_full_45d = torch.zeros((batch_size_training, training_img_size, training_img_size, 9))
training_full_M45d = torch.zeros((batch_size_training, training_img_size, training_img_size, 9))

for batch, image_path in enumerate(train_dataset_list):
    
    image_path = os.path.join(dataset_path, image_path)
    (train_90d_np, train_0d_np, train_45d_np, train_M45d_np) = make_multiinput(image_path,
                                                                              training_img_size,
                                                                              training_img_size,
                                                                              Setting02_AngualrViews)
    train_90d = torch.from_numpy(np.squeeze(train_90d_np))
    train_0d = torch.from_numpy(np.squeeze(train_0d_np))
    train_45d = torch.from_numpy(np.squeeze(train_45d_np))
    train_M45d = torch.from_numpy(np.squeeze(train_M45d_np))

    training_full_90d[batch, :, :, :] = train_90d
    training_full_0d[batch, :, :, :] = train_0d
    training_full_45d[batch, :, :, :] = train_45d
    training_full_M45d[batch, :, :, :] = train_M45d

training_full_90d = training_full_90d.permute(0, 3, 1, 2)
training_full_0d = training_full_0d.permute(0, 3, 1, 2)
training_full_45d = training_full_45d.permute(0, 3, 1, 2)
training_full_M45d = training_full_M45d.permute(0, 3, 1, 2)

print(f"training dataset tensor size : {training_full_90d.size()}")

training dataset tensor size : torch.Size([16, 9, 512, 512])


make validation image

In [12]:
validation_full_90d = torch.zeros((batch_size_validation,validation_img_size,validation_img_size,9))
validation_full_0d = torch.zeros((batch_size_validation,validation_img_size,validation_img_size,9))
validation_full_45d = torch.zeros((batch_size_validation,validation_img_size,validation_img_size,9))
validation_full_M45d = torch.zeros((batch_size_validation,validation_img_size,validation_img_size,9))

for batch, image_path in enumerate(test_dataset_list):

    image_path = os.path.join(dataset_path,image_path)
    (val_90d_np , val_0d_np, val_45d_np, val_M45d_np)=make_multiinput(image_path,
                                                            validation_img_size,
                                                            validation_img_size,
                                                            Setting02_AngualrViews)
    val_90d = torch.from_numpy(np.squeeze(val_90d_np))
    val_0d = torch.from_numpy(np.squeeze(val_0d_np))
    val_45d = torch.from_numpy(np.squeeze(val_45d_np))
    val_M45d = torch.from_numpy(np.squeeze(val_M45d_np))

    validation_full_90d[batch, :, :, :] = val_90d
    validation_full_0d[batch, :, :, :] = val_0d
    validation_full_45d[batch, :, :, :] = val_45d
    validation_full_M45d[batch, :, :, :] = val_M45d

validation_full_90d = validation_full_90d.permute(0, 3, 1, 2)
validation_full_0d = validation_full_0d.permute(0, 3, 1, 2)
validation_full_45d = validation_full_45d.permute(0, 3, 1, 2)
validation_full_M45d = validation_full_M45d.permute(0, 3, 1, 2)

print(f"validation dataset tensor size : {validation_full_90d.size()}")

validation dataset tensor size : torch.Size([8, 9, 512, 512])


setup boolmask

In [13]:
print(f"boolmask_img4_path : {boolmask_img4_path}")
boolmask_img4 = np.array(Image.open(boolmask_img4_path))
boolmask_img6 = np.array(Image.open(boolmask_img6_path))
boolmask_img15 = np.array(Image.open(boolmask_img15_path))

boolmask_img4  = 1.0*boolmask_img4[:,:,3]>0
boolmask_img6  = 1.0*boolmask_img6[:,:,3]>0
boolmask_img15 = 1.0*boolmask_img15[:,:,3]>0

boolmask_img4_path : ../../hci_dataset/additional_invalid_area/kitchen/input_Cam040_invalid_ver2.png


initialize loss txt file

In [14]:
if not os.path.exists("loss"):
    os.makedirs("loss")

with open("./loss/loss_training.txt", "w") as file:
    file.write(f"==================\n")
with open("./loss/loss_validation.txt", "w") as file:
    file.write(f"==================\n")

In [15]:
train_dataset = CustomDataset(traindata_all, traindata_label, input_size_train,label_size_train,batch_size_training,
                              Setting02_AngualrViews,boolmask_img4,boolmask_img6,boolmask_img15,batch_num_in_1epoch_for_training,mode="training")
test_dataset = CustomDataset(testdata_all, testdata_label, input_size_valid,label_size_valid,batch_size_validation,
                              Setting02_AngualrViews,boolmask_img4,boolmask_img6,boolmask_img15,1,mode="validation")

train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size_training,\
                                shuffle=True, num_workers=0, pin_memory=False, drop_last=True)
test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size_validation,\
                                shuffle=True, num_workers=0, pin_memory=False, drop_last=True)

In [16]:
net_student = EPINET_student(input_ch = input_ch, filter_num = filter_num, stream_num =stream_num).to("cuda")
net_teacher = EPINET_teacher(input_ch = input_ch, filter_num = filter_num, stream_num =stream_num).to("cuda")
def mae_loss(output, target):
    return F.l1_loss(output, target, reduction='mean')
optimizer = optim.RMSprop(net_student.parameters(), lr=learning_rate)

In [17]:
checkpoint = torch.load('./checkpoint_load/BESTmodel_epoch15638_MSE_1p8188_BP4p7553.pth')
net_teacher.load_state_dict(checkpoint['model_state'])

<All keys matched successfully>

In [18]:
best_BadPixel = 100
best_MSE = 100
best_validation_loss = 100

In [19]:
print("Training start")
net_teacher.eval()

for epoch in range(0, 120000):

    net_student.train()
    torch.set_grad_enabled(True)

    train_loss_stack = 0
    train_MAEloss_stack = 0
    train_PAIRWISEloss_stack = 0

    test_loss = 0
    loss_division_train = 0
    loss_division_test = 0

    #===============================#
    #           Training            #
    #===============================#
    print(f"epoch : {epoch}")

    for nBatch, (x_90d, x_0d, x_45d, x_m45d, labels) in enumerate(train_dataloader):

        loss_division_train += 1
        train_loss_Batch = 0.0

        optimizer.zero_grad()

        x_0d_T = x_0d.clone().to("cuda")
        x_90d_T = x_90d.clone().to("cuda")
        x_45d_T = x_45d.clone().to("cuda")
        x_m45d_T = x_m45d.clone().to("cuda")
        x_0d_S = x_0d[:, :, 1:24, 1:24].clone().to("cuda")
        x_90d_S = x_90d[:, :, 1:24, 1:24].clone().to("cuda")
        x_45d_S = x_45d[:, :, 1:24, 1:24].clone().to("cuda")
        x_m45d_S = x_m45d[:, :, 1:24, 1:24].clone().to("cuda")
        labels = labels.clone().to("cuda")

        outputs, hint_student = net_student(x_0d_S,x_90d_S,x_45d_S,x_m45d_S)

        with torch.no_grad():
            _, hint_teacher = net_teacher(x_0d_T, x_90d_T, x_45d_T, x_m45d_T)

        loss_mae = mae_loss(labels,outputs.squeeze(1))
        loss_pairwise = PairwiseLoss(hint_student, hint_teacher)

        loss = lambda_mae * loss_mae + lambda_pairwise * loss_pairwise

        train_loss_stack += loss.item()
        train_MAEloss_stack += loss_mae.item()
        train_PAIRWISEloss_stack += loss_pairwise.item()

        loss.backward()

        optimizer.step()

        #===============================#
        #          Validation           #
        #===============================#
    
    for nBatch, (x_90d, x_0d, x_45d, x_m45d, labels) in enumerate(test_dataloader):
        net_student.eval()
        with torch.no_grad():
            loss_division_test += 1
            train_loss_Batch = 0.0

            optimizer.zero_grad()

            x_0d = x_0d.clone().to("cuda")
            x_90d = x_90d.clone().to("cuda")
            x_45d = x_45d.clone().to("cuda")
            x_m45d = x_m45d.clone().to("cuda")
            labels = labels.clone().to("cuda")

            outputs,_ = net_student(x_0d, x_90d, x_45d, x_m45d)
            loss = mae_loss(labels,outputs.squeeze(1))
            test_loss += loss.item()

    # validation lossが改善されたときのみ実行
    
    validatioin_loss = test_loss / loss_division_test

    if best_validation_loss > validatioin_loss:
        best_validation_loss = validatioin_loss

    if epoch % 10 == 0: #10回に1回BPとMSEを計測
        print('Validating')
        net_student.eval()
        with torch.no_grad():

            x_0d = validation_full_0d.clone().to("cuda")
            x_90d = validation_full_90d.clone().to("cuda")
            x_45d = validation_full_45d.clone().to("cuda")
            x_m45d = validation_full_M45d.clone().to("cuda")

            outputs, _= net_student(x_0d, x_90d, x_45d, x_m45d)

            train_error, train_bp=display_current_output(outputs.detach().cpu().numpy(), \
                                    testdata_label, f"./validation_output2/val_{epoch}.png",False)

            training_mean_squared_error_x100=100*np.average(np.square(train_error))
            training_bad_pixel_ratio=100*np.average(train_bp)

            #save_validation_tensor_as_png(outputs.squeeze(1),f"./images/va_e_{epoch}.png")
    
        with open("./loss/loss_validation.txt", "a") as file:
            file.write(f"Epoch {epoch}/, MSE : {training_mean_squared_error_x100} BP : {training_bad_pixel_ratio}\n")

        net_student.train()

    #scheduler.step()
    with open("./loss/loss_training.txt", "a") as file:
        file.write(f"Epoch {epoch}/, total Loss: {train_loss_stack/loss_division_train}, mae Loss: {train_MAEloss_stack/loss_division_train}, \
                   pairwise Loss: {train_PAIRWISEloss_stack/loss_division_train}\n")


    if best_BadPixel > training_bad_pixel_ratio:
        MSE_record = f"{training_mean_squared_error_x100:.4f}".replace(".","p")
        BP_record = f"{training_bad_pixel_ratio:.4f}".replace(".","p")

        filename = f"epoch_{epoch}_MSE_{MSE_record}_BP_{BP_record}.pth"
        save_model_and_optimizer(net_student, optimizer, f"./model_checkpoint/{filename}")
        best_BadPixel = training_bad_pixel_ratio


    if best_BadPixel > training_bad_pixel_ratio:
        MSE_record = f"{training_mean_squared_error_x100:.4f}".replace(".","p")
        BP_record = f"{training_bad_pixel_ratio:.4f}".replace(".","p")

        filename = f"epoch_{epoch}_MSE_{MSE_record}_BP_{BP_record}.pth"
        save_model_and_optimizer(net_student, optimizer, f"./model_checkpoint/{filename}")
        best_MSE = training_mean_squared_error_x100

Training start
epoch : 0
Validating
epoch : 1
epoch : 2
epoch : 3
epoch : 4
epoch : 5
epoch : 6
epoch : 7
epoch : 8
epoch : 9
epoch : 10
Validating
epoch : 11
epoch : 12
epoch : 13
epoch : 14
epoch : 15
epoch : 16
epoch : 17
epoch : 18
epoch : 19
epoch : 20
Validating
epoch : 21
epoch : 22
epoch : 23
epoch : 24
epoch : 25
epoch : 26
epoch : 27
epoch : 28
epoch : 29
epoch : 30
Validating
epoch : 31
epoch : 32
epoch : 33
epoch : 34
epoch : 35
epoch : 36
epoch : 37
epoch : 38
epoch : 39
epoch : 40
Validating
epoch : 41
epoch : 42
epoch : 43
epoch : 44
epoch : 45
epoch : 46
epoch : 47
epoch : 48
epoch : 49
epoch : 50
Validating
epoch : 51
epoch : 52
epoch : 53
epoch : 54
epoch : 55
epoch : 56
epoch : 57
epoch : 58
epoch : 59
epoch : 60
Validating
epoch : 61
epoch : 62
epoch : 63
epoch : 64
epoch : 65
epoch : 66
epoch : 67
epoch : 68
epoch : 69
epoch : 70
Validating
epoch : 71
epoch : 72
epoch : 73
epoch : 74
epoch : 75
epoch : 76
epoch : 77
epoch : 78
epoch : 79
epoch : 80
Validating
epoch

In [None]:
save_model_and_optimizer(net_student, optimizer, f"./model_checkpoint/Stopped_model_20240109_epoch_108479.pth")