In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torch.optim as optim
import cv2
import numpy as np
import math
import time
import random
from matplotlib import pyplot as plt

In [2]:
resnet18_pre = models.resnet18(pretrained=True)

In [3]:
googlenet = models.googlenet(pretrained=True)

In [4]:
class Residual(nn.Module):  
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)

In [5]:
def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    if first_block:
        assert in_channels == out_channels 
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)

In [6]:
class FlattenLayer(torch.nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(-1, x.shape[1])

In [7]:
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()
    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])

In [8]:
def new_resnet(resnet_block_num, resnet_block, FlattenLayer, GlobalAvgPool2d):
    
    net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64), 
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

    resnet_block_num = 1

    if resnet_block_num == 1:
        net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
        net.add_module("global_avg_pool", GlobalAvgPool2d()) 
        net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(64, 1000))) 
    elif resnet_block_num == 2:
        net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
        net.add_module("resnet_block2", resnet_block(64, 128, 2))
        net.add_module("global_avg_pool", GlobalAvgPool2d()) 
        net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(128, 1000))) 
    elif resnet_block_num == 3:
        net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
        net.add_module("resnet_block2", resnet_block(64, 128, 2))
        net.add_module("resnet_block3", resnet_block(128, 256, 2))
        net.add_module("global_avg_pool", GlobalAvgPool2d()) 
        net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(256, 1000))) 
    else:
        net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
        net.add_module("resnet_block2", resnet_block(64, 128, 2))
        net.add_module("resnet_block3", resnet_block(128, 256, 2))
        net.add_module("resnet_block4", resnet_block(256, 512, 2))   
    return net

In [9]:
resnet = models.resnet18()

In [10]:
pretrained_dict = resnet18_pre.state_dict()
nonpretrained_dict = resnet.state_dict()

In [11]:
for name, param in pretrained_dict.items():
    if name == 'layer3.0.conv1.weight':
        break
    else:
        nonpretrained_dict[name] = param

In [12]:
resnet.load_state_dict(nonpretrained_dict)  

<All keys matched successfully>

In [13]:
class Graph_RNN1(nn.Module):
    def __init__(self, extra): #extra = NUM_PIECES
        super(Graph_RNN1, self).__init__()
        self.linear1 = nn.Linear(1000 + extra, 64)
        self.relu = nn.ReLU()
        self.lstm = nn.LSTM(64, 128, 4)
        self.linear2 = nn.Linear(128, 32)
        self.linear3 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()
        self.state = None
    def forward(self, x, state):
        linear1_output = self.relu(self.linear1(x))
        lstm_output, self.state = self.lstm(linear1_output.view(1, 1, 64), state)
        linear2_output = self.relu(self.linear2(lstm_output.view(-1, lstm_output.shape[-1])))
        linear3_output = self.sigmoid(self.linear3(linear2_output))
        return linear3_output, self.state

In [14]:
class Graph_RNN2(nn.Module):
    def __init__(self):
        super(Graph_RNN2, self).__init__()
        self.linear1 = nn.Linear(1, 8)
        self.relu = nn.ReLU()
        self.lstm = nn.LSTM(8, 16, 4)
        self.linear2 = nn.Linear(16, 8)
        self.linear3 = nn.Linear(8, 1)
        self.sigmoid = nn.Sigmoid()
        self.state = None
    def forward(self, x, state):
        linear1_output = self.relu(self.linear1(x))
        lstm_output, self.state = self.lstm(linear1_output.view(1, 1, 8), state)
        linear2_output = self.relu(self.linear2(lstm_output.view(-1, lstm_output.shape[-1])))
        linear3_output = self.sigmoid(self.linear3(linear2_output))
        return linear3_output, self.state

In [15]:
extra = 128
RNN1 = Graph_RNN1(extra) 
RNN2 = Graph_RNN2()

In [16]:
initialize = True

if initialize == True:
    for name, param in RNN1.named_parameters(): 
        if 'bias' in name:
            nn.init.constant_(param, 0.25)
        else:
            if 'lstm' in name:
                nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain('sigmoid'))
            else:
                nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain('relu'))
    for name, param in RNN2.named_parameters(): 
        if 'bias' in name:
            nn.init.constant_(param, 0.25)
        else:
            if 'lstm' in name:
                nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain('sigmoid'))
            else:
                nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain('relu'))

In [17]:
pre_train = False
epoch_num = 10
NUM_PIECES_load = 8

if pre_train == True:
    resnet.load_state_dict(torch.load(r'/home/heu/wmh/model/CNN_' + str(NUM_PIECES_load) + r'_' 
                                        + str(epoch_num) + r'.pt'))
    RNN1.load_state_dict(torch.load(r'/home/heu/wmh/model/RNN1_' + str(NUM_PIECES_load) + r'_' 
                                        + str(epoch_num) + r'.pt'))
    RNN2.load_state_dict(torch.load(r'/home/heu/wmh/model/RNN2_' + str(NUM_PIECES_load) + r'_' 
                                        + str(epoch_num) + r'.pt'))

In [18]:
def erode(image1, image2): 
    kernel = np.ones((10, 10))
    erosion = cv2.erode(image2, kernel, iterations = 3)
    target = ((image2 - erosion) / 255) * image1
    target.resize(256, 256, 3)
    transform_GY = transforms.ToTensor()
    batch_size = 1
    ans = transform_GY(target).view(1, 3, 256, 256)
    return ans

In [19]:
def ajm(NUM_PIECES, ajmatrix): 
    ans = torch.zeros(NUM_PIECES, NUM_PIECES)
    cur_tar = 0
    for i in range(NUM_PIECES):
        for j in range(NUM_PIECES):
            if ajmatrix[cur_tar] == '1':
                ans[j][i] = 1
            cur_tar += 1
    return ans

In [20]:
def shuffle_ajm(temp_index, adjacency_matrix, NUM_PIECES):
    ans = torch.zeros(NUM_PIECES, NUM_PIECES)
    cur_tar = 0
    j = 0
    for i in range(NUM_PIECES):
        if i == 0:
            continue
        for j in range(i):
            temp_list = [temp_index[i]]
            temp_list.append(temp_index[j])
            temp_list.sort()
            ans[j][i] = adjacency_matrix[temp_list[0]][temp_list[1]]
    return ans

In [21]:
def dataset_pieces(path, NUM_PIECES, train):
    if train == True: 
        start = 1
        end = 45
    else:
        start = 45
        end = 51
    
    shuffle_list = list(range(start, end))
    random.shuffle(shuffle_list)
    shuffle = False
    if train == True and shuffle == True:
        iteration = 4 
        shuffle_dict = []
        for i in range(iteration):
            list_a = list(range(NUM_PIECES))
            list_b = list(range(NUM_PIECES))
            list_c = list(range(NUM_PIECES))
            list_d = list(range(NUM_PIECES))
            random.shuffle(list_b)
            random.shuffle(list_c)
            random.shuffle(list_d)
            shuffle_dict.append(list_a)
            shuffle_dict.append(list_b)
            shuffle_dict.append(list_c)
            shuffle_dict.append(list_d)
        
    for SHEET_INDEX in shuffle_list:
        if SHEET_INDEX == 7 or SHEET_INDEX == 27 or SHEET_INDEX == 48:  
            continue
        SIDE = ['back', 'front']
        for side in SIDE:
            dataset_pieces_data = []
            for image_index in range(1, NUM_PIECES + 1):
                if image_index >= 10:
                    index = '00' + str(image_index)
                else:
                    index = '000' + str(image_index)
                # path = r'E:\data\data'
                final_path = path + r'/sheet' + str(SHEET_INDEX) + '//' + str(NUM_PIECES) + r'pieces'\
                + '//' + side + r'/final'
                image = cv2.imread(final_path + r'/IMG_' + index + r'_erode.png')
                image.resize(256, 256, 3)
                transform_GY = transforms.ToTensor()
                batch_size = 1
                ans = transform_GY(image).view(1, 3, 256, 256)
                dataset_pieces_data.append(ans)
            fid = open(final_path + '/groundtruth.txt', 'r')
            ajmatrix = fid.read()
            fid.close()
            adjacency_matrix = ajm(NUM_PIECES, ajmatrix)
            yield dataset_pieces_data, adjacency_matrix

In [22]:
path = r'/home/heu/wmh/data/data'
dataset8_train = dataset_pieces(path, 8, True)
dataset8_test = dataset_pieces(path, 8, False)
dataset16_train = dataset_pieces(path, 16, True)
dataset16_test = dataset_pieces(path, 16, False)
dataset_train = {'8':dataset8_train, '16':dataset16_train}
dataset_test = {'8':dataset8_test, '16':dataset16_test}

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
loss = nn.CrossEntropyLoss()  
lr_1 = 0.0001    #0.0001
lr_2 = 0.00001  #0.000001
lr_3 = 0.0001    #0.0001
resnetoptim = torch.optim.AdamW(resnet.parameters(), lr_1)  
RNN1optim = torch.optim.AdamW(RNN1.parameters(), lr_2)
RNN2optim = torch.optim.AdamW(RNN2.parameters(), lr_3)
optim = [resnetoptim, RNN1optim, RNN2optim]
resnet_lr_scheduler = torch.optim.lr_scheduler.CyclicLR(resnetoptim, lr_1, lr_1 * 10, cycle_momentum = False)
RNN1_lr_scheduler = torch.optim.lr_scheduler.CyclicLR(RNN1optim, lr_2, lr_2 * 10, cycle_momentum = False)
RNN2_lr_scheduler = torch.optim.lr_scheduler.CyclicLR(RNN2optim, lr_3, lr_3 * 10, cycle_momentum = False)
lr_scheduler = [resnet_lr_scheduler, RNN1_lr_scheduler, RNN2_lr_scheduler]

In [24]:
def grad_clipping(params, theta, device): 
    norm = torch.tensor([0.0], device=device)
    for param in params:
        norm += (param.grad.data ** 2).sum()
    norm = norm.sqrt().item()
    if norm > theta:
        for param in params:
            param.grad.data *= (theta / norm)

In [25]:
def copy_tensor(rnn2_output):
    ans = torch.tensor(rnn2_output.shape, dtype = torch.float, device = 
                       torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    ans = ans.copy(rnn2_output)
    return ans

In [26]:
def evaluate_model(dataset_pieces, cnn, rnn1, rnn2, NUM_PIECES, device, extra, path):
    cnn = cnn.to(device)
    rnn1 = rnn1.to(device)
    rnn2 = rnn2.to(device)
    acc_sum = 0
    data_sum = 0
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    with torch.no_grad():
        train = False
        for x, y in dataset_pieces:
            n = 0
            for i in range(1, NUM_PIECES):
                n += i
            data_sum += n
            pre_ajm = torch.zeros(NUM_PIECES, NUM_PIECES, device = device)
            y = y.to(device)
            rnn1_state = None 
            rnn2_state = None
            for i in range(NUM_PIECES):
                cnn1_output = cnn(x[i].type(torch.float).to(device))
                if i == 0:
                    cnn_ajm = torch.zeros(1, extra, dtype = torch.float, device = device)
                    cnn1_output = torch.cat((cnn1_output, cnn_ajm), 1)
                else:
                    cnn_ajm = torch.zeros(1, extra, dtype = torch.float, device = device)
                    for j in range(NUM_PIECES):  
                        cnn_ajm[0][j] = pre_ajm[j][i - 1]
                    cnn1_output = torch.cat((cnn1_output, cnn_ajm), 1)
                rnn1_output, rnn1_state = rnn1(cnn1_output, rnn1_state)
                rnn2_output, rnn2_state = rnn2(rnn1_output, rnn2_state)
                if i != 0:
                    first = True
                    count = i
                    while count > 0:
                        count -= 1
                        if first:
                            first = False
                            pre_ajm[count][i] = 1 if rnn2_output.sum().item() >= 0.5 else 0
                        else:
                            rnn2_output, rnn2_state = rnn2(rnn2_output, rnn2_state)
                            pre_ajm[count][i] = 1 if rnn2_output.sum().item() >= 0.5 else 0
            for i in range(NUM_PIECES - 1):
                for j in range(i + 1, NUM_PIECES):
                    if pre_ajm[i][j] == y[i][j]:
                        acc_sum += 1
                    if pre_ajm[i][j] == 1:
                        if y[i][j] == 1:
                            TP += 1  #TP
                        else:
                            FP += 1  #FN
                    else:
                        if y[i][j] == 1:
                            FN += 1  #FP
                        else:
                            TN += 1  #TN
    return acc_sum / data_sum, TP, FP, TN, FN

In [27]:
def train_model(dataset_pieces, cnn, rnn1, rnn2, num_epochs, device, opitm, lr_scheduler, extra, theta, path, evaluate_model):
    cnn = cnn.to(device)
    rnn1 = rnn1.to(device)
    rnn2 = rnn2.to(device)
    loss = nn.BCELoss()
    for epoch in range(1, num_epochs + 1):
        #if epoch % 10 == 0:
            #torch.save(cnn.state_dict(), r'/home/heu/wmh/model/CNN_' + str(NUM_PIECES) + r'_' + str(epoch) + r'.pt')
            #torch.save(rnn1.state_dict(), r'/home/heu/wmh/model/RNN1_' + str(NUM_PIECES) + r'_' + str(epoch) + r'.pt')
            #torch.save(rnn2.state_dict(), r'/home/heu/wmh/model/RNN2_' + str(NUM_PIECES) + r'_' + str(epoch) + r'.pt')

        for NUM_PIECES in [8, 16]:
            start = time.time()
            sum_loss = 0
            n = 0
            for i in range(1, NUM_PIECES):
                n += i
            if NUM_PIECES == 8:
                dataset_cur = dataset_pieces(path, 8, True)
            else:
                dataset_cur = dataset_pieces(path, 16, True)
            for x, y in dataset_cur:
                y = y.to(device)
                rnn1_state = None 
                rnn2_state = None
                for i in range(NUM_PIECES):
                    cnn_input = x[i].type(torch.float).to(device)
                    cnn_output = cnn(cnn_input)
                    if i == 0:
                        cnn_ajm = torch.zeros(1, extra, dtype = torch.float, device = device)
                        cnn_output2 = torch.cat((cnn_output, cnn_ajm), 1)
                    else:
                        cnn_ajm = torch.zeros(1, extra, dtype = torch.float, device = device)
                        for j in range(NUM_PIECES):
                            cnn_ajm[0][j] = y[j][i - 1]
                        cnn_output2 = torch.cat((cnn_output, cnn_ajm), 1)
                    if rnn1_state != None:
                        rnn1_state = (rnn1_state[0].detach(), rnn1_state[1].detach())
                    if rnn2_state != None:
                        rnn2_state = (rnn2_state[0].detach(), rnn2_state[1].detach())
                    rnn1_output, rnn1_state = rnn1(cnn_output2, rnn1_state)
                    rnn2_output, rnn2_state = rnn2(rnn1_output, rnn2_state)
                    if i != 0:
                        first = True
                        count = i
                        while count > 0:
                            count -= 1
                            if first:
                                first = False
                                cur_y = torch.tensor([y[count][i].item()], dtype = torch.float, device = device)
                                l = loss(rnn2_output.view(-1), cur_y)
                                sum_loss += l.item()
                                for cur_optim in optim: 
                                    cur_optim.zero_grad()
                                l.backward(retain_graph=True)
                                grad_clipping(RNN1.parameters(), theta, device) 
                                grad_clipping(RNN2.parameters(), theta, device)
                                for cur_optim in optim: 
                                    cur_optim.step()
                                for scheduler in lr_scheduler:
                                    scheduler.step()
                            else:
                                rnn2_input = torch.zeros(rnn2_output.shape, dtype = torch.float, device = device)
                                rnn2_input[0][0] = rnn2_output[0][0].item()
                                if rnn2_state != None:
                                    rnn2_state = (rnn2_state[0].detach(), rnn2_state[1].detach())
                                rnn2_output, rnn2_state = rnn2(rnn2_input, rnn2_state)
                                cur_y = torch.tensor([y[count][i].item()], dtype = torch.float, device = device)
                                l = loss(rnn2_output.view(-1), cur_y)
                                sum_loss += l.item()
                                optim[-1].zero_grad()
                                l.backward()
                                grad_clipping(RNN2.parameters(), theta, device)
                                optim[-1].step()
                                lr_scheduler[-1].step()
                        rnn2_output, rnn2_state = rnn2(rnn2_output, rnn2_state)
            if NUM_PIECES == 8:
                loss_plt_8.append(sum_loss / n)
                acc, TP, FP, TN, FN = evaluate_model(dataset_pieces(path, 8, False), cnn, rnn1, rnn2, NUM_PIECES, device, 
                                                 extra, path)
                acc_plt_8.append(acc)
            else:
                loss_plt_16.append(sum_loss / n)
                acc, TP, FP, TN, FN = evaluate_model(dataset_pieces(path, 16, False), cnn, rnn1, rnn2, NUM_PIECES, device, 
                                                 extra, path)
                acc_plt_16.append(acc)

            print('epoch: %d, NUM_PIECES: %d, loss: %.4f, time: %.1f, accuracy: %4f, TP: %d, FP: %d, TN: %d, FN: %d,' %
                  (epoch, NUM_PIECES, sum_loss / n, time.time() - start, acc, TP, FP, TN, FN))
            print('learning_rate1 : %.4f, learning_rate2 : %.6f, learning_rate3 : %.4f' % 
                  (optim[0].param_groups[0]['lr'], optim[1].param_groups[0]['lr'], optim[2].param_groups[0]['lr']))

In [28]:
#resnet = models.resnet()
#RNN1 = Graph_RNN1(extra) 
#RNN2 = Graph_RNN2()
#path = r'E:\data\data'
#dataset8_train = dataset_pieces(path, 8, True)
#dataset8_test = dataset_pieces(path, 8, False)
#dataset16_train = dataset_pieces(path, 16, True)
#dataset16_test = dataset_pieces(path, 16, False)
#dataset_train = {'8':dataset8_train, '16':dataset16_train}
#dataset_test = {'8':dataset8_test, '16':dataset16_test}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
#extra = 128
#optim = [resnetoptim, RNN1optim, RNN2optim]
#resnetoptim = torch.optim.Adam(resnet.parameters(), 0.0001)  
#RNN1optim = torch.optim.Adam(RNN1.parameters(), 0.00001)
#RNN2optim = torch.optim.Adam(RNN2.parameters(), 0.0001)
theta = 1e-3
NUM_PIECES = 8
num_epochs = 500
loss_plt_8 = []
loss_plt_16 = []
acc_plt_8 = []
acc_plt_16 = []
device

device(type='cuda')

In [29]:
with torch.autograd.set_detect_anomaly(True):
    train_model(dataset_pieces, resnet, RNN1, RNN2, num_epochs, device, optim, lr_scheduler, extra, theta, path, evaluate_model)

epoch: 1, NUM_PIECES: 8, loss: 56.3699, time: 44.0, accuracy: 0.678571, TP: 66, FP: 24, TN: 124, FN: 66,
learning_rate1 : 0.0004, learning_rate2 : 0.000036, learning_rate3 : 0.0008


KeyboardInterrupt: 

In [None]:
x = list(range(1, num_epochs + 1))
y1 = loss_plt_8
y2 = loss_plt_16
plt.plot(x, y1)
plt.plot(x, y2)
plt.legend(['loss_plt_8', 'loss_plt_16'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.savefig(r'/home/heu/wmh/1.png')

In [None]:
y1 = acc_plt_8
y2 = acc_plt_16
plt.plot(x, y1)
plt.plot(x, y2)
plt.legend(['acc_plt_8', 'acc_plt_16'])
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.savefig(r'/home/heu/wmh/2.png')