In [1]:
import torch
import time
import torch.optim as optim
import numpy as np
import os
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from pytorch_model_summary import summary
import torch.nn as nn
import torchvision
from efficientnet_pytorch import EfficientNet
import random
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import albumentations as A 
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from albumentations.pytorch import ToTensorV2
import warnings
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from albumentations.pytorch import ToTensorV2

ModuleNotFoundError: No module named 'torch'

In [None]:
# From path_constants.py

notebook_dir = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
dir_path = notebook_dir + '/DummyDatabase'
marmot_v1 = dir_path + '/marmot_v1'
marmot_extended = dir_path + '/marmot_extended'
ORIG_DATA_PATH = f'{marmot_v1}/marmot_dataset_v1.0/data/English'
DATA_PATH = 'Marmot_data'
PROCESSED_DATA = f'{dir_path}/marmot_processed'
PREDICTIONS = f"{dir_path}/predictions"
TEST_IMAGES = f"{dir_path}/test_images"
MODELS = f"{dir_path}/models"
IMAGE_PATH = os.path.join(PROCESSED_DATA, 'image')
TABLE_MASK_PATH = os.path.join(PROCESSED_DATA, 'table_mask')
COL_MASK_PATH = os.path.join(PROCESSED_DATA, 'col_mask')
Marmot_data = f'{dir_path}/{DATA_PATH}'
POSITIVE_DATA_LBL = os.path.join(ORIG_DATA_PATH, 'Positive','Labeled')



print(dir_path)

In [None]:
# From configurations.py

SEED = 0
LEARNING_RATE = 0.0001
EPOCHS = 100
BATCH_SIZE = 2
WEIGHT_DECAY = 3e-4
DATAPATH = f'{PROCESSED_DATA}/processed_data.csv'
MODEL_NAME = "densenet_configuration_4_model_checkpoint.pth.tar"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# From dataset.py

class ImageFolder(nn.Module):
    def __init__(self, df, transform = None):
        super(ImageFolder, self).__init__()
        self.df = df
        if transform is None:
            self.transform = A.Compose([
                A.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225], max_pixel_value = 255,),
                ToTensorV2()
            ])
    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        image_path, table_mask_path, column_mask_path = self.df.iloc[index, 0], self.df.iloc[index, 1], self.df.iloc[index, 2]
        image = np.array(Image.open(image_path))
        table_image = torch.FloatTensor(np.array(Image.open(table_mask_path)) / 255.0).reshape(1, 1024, 1024)
        column_image = torch.FloatTensor(np.array(Image.open(column_mask_path)) / 255.0).reshape(1, 1024, 1024)
        image = self.transform(image = image)['image']
        return {"image": image, "table_image": table_image, "column_image": column_image}

def get_mean_std(train_data, transform):
    dataset = ImageFolder(train_data , transform)
    train_loader = DataLoader(dataset, batch_size = 128)
    mean = 0.
    std = 0.
    for img_dict in tqdm.tqdm(train_loader):
        batch_samples = img_dict["image"].size(0)
        images = img_dict["image"].view(batch_samples, img_dict["image"].size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
    mean /= len(train_loader.dataset)
    std /= len(train_loader.dataset)
    print(mean)
    print(std)

# Read referencing csv file
df = pd.read_csv(f'{PROCESSED_DATA}/processed_data.csv')
dataset = ImageFolder(df[df['hasTable'] == 1])
img_num = 0
for img_dict in dataset:
    save_image(img_dict["image"], f'image_{img_num}.png')
    save_image(img_dict["table_image"], f'table_image_{img_num}.png')
    save_image(img_dict["column_image"], f'column_image_{img_num}.png')
    img_num += 1
    if img_num == 6:
        break

In [None]:
# From encoder.py

class VGG19(nn.Module):
    def __init__(self, pretrained = True, requires_grad = True):
        super(VGG19, self).__init__()
        _vgg = torchvision.models.vgg19(pretrained = pretrained).features
        self.vgg_pool3 = torch.nn.Sequential()
        self.vgg_pool4 = torch.nn.Sequential()
        self.vgg_pool5 = torch.nn.Sequential()
        for x in range(19):
            self.vgg_pool3.add_module(str(x), _vgg[x])
        for x in range(19, 28):
            self.vgg_pool4.add_module(str(x), _vgg[x])
        for x in range(28, 37):
            self.vgg_pool5.add_module(str(x), _vgg[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False
    
    def forward(self, x): 
        pool_3_out = self.vgg_pool3(x)
        pool_4_out = self.vgg_pool4(pool_3_out)
        pool_5_out = self.vgg_pool5(pool_4_out)
        return (pool_3_out, pool_4_out, pool_5_out)

class ResNet(nn.Module):
    def __init__(self, pretrained = True, requires_grad = True):
        super(ResNet, self).__init__()
        resnet18 = torchvision.models.resnet34(pretrained = True)
        self.layer_1 = nn.Sequential(resnet18.conv1, resnet18.bn1, resnet18.relu, resnet18.maxpool, resnet18.layer1)
        self.layer_2 = resnet18.layer2
        self.layer_3 = resnet18.layer3
        self.layer_4 = resnet18.layer4
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        out_1 = self.layer_2(self.layer_1(x))
        out_2 = self.layer_3(out_1)
        out_3 = self.layer_4(out_2)
        return out_1, out_2, out_3

class DenseNet(nn.Module):
    def __init__(self, pretrained = True, requires_grad = True):
        super(DenseNet, self).__init__()
        denseNet = torchvision.models.densenet121(pretrained = True).features
        self.densenet_out_1 = torch.nn.Sequential()
        self.densenet_out_2 = torch.nn.Sequential()
        self.densenet_out_3 = torch.nn.Sequential()
        for x in range(8):
            self.densenet_out_1.add_module(str(x), denseNet[x])
        for x in range(8,10):
            self.densenet_out_2.add_module(str(x), denseNet[x])
        self.densenet_out_3.add_module(str(10), denseNet[10])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        out_1 = self.densenet_out_1(x)
        out_2 = self.densenet_out_2(out_1)
        out_3 = self.densenet_out_3(out_2)
        return out_1, out_2, out_3

class efficientNet_B0(nn.Module):
    def __init__(self, pretrained = True, requires_grad = True):
        super(efficientNet_B0, self).__init__()
        eNet = EfficientNet.from_pretrained('efficientnet-b0')
        self.eNet_out_1 = torch.nn.Sequential()
        self.eNet_out_2 = torch.nn.Sequential()
        self.eNet_out_3 = torch.nn.Sequential()
        blocks = eNet._blocks
        self.eNet_out_1.add_module('_conv_stem', eNet._conv_stem)
        self.eNet_out_1.add_module('_bn0', eNet._bn0)
        for x in range(14):
            self.eNet_out_1.add_module(str(x), blocks[x])
        self.eNet_out_2.add_module(str(14), blocks[14])
        self.eNet_out_3.add_module(str(15), blocks[15])

    def forward(self, x):
        out_1 = self.eNet_out_1(x)
        out_2 = self.eNet_out_2(out_1)
        out_3 = self.eNet_out_3(out_2)
        return out_1, out_2, out_3

class efficientNet(nn.Module):
    def __init__(self, model_type = 'efficientnet-b0',  pretrained = True, requires_grad = True):
        super(efficientNet, self).__init__()
        eNet = EfficientNet.from_pretrained(model_type)
        self.eNet_out_1 = torch.nn.Sequential()
        self.eNet_out_2 = torch.nn.Sequential()
        self.eNet_out_3 = torch.nn.Sequential()
        blocks = eNet._blocks
        self.eNet_out_1.add_module('_conv_stem', eNet._conv_stem)
        self.eNet_out_1.add_module('_bn0', eNet._bn0)
        for x in range(len(blocks)-3):
            self.eNet_out_1.add_module(str(x), blocks[x])
        self.eNet_out_2.add_module(str(len(blocks)-2), blocks[len(blocks)-2])
        self.eNet_out_3.add_module(str(len(blocks)-1), blocks[len(blocks)-1])

    def forward(self, x):
        out_1 = self.eNet_out_1(x)
        out_2 = self.eNet_out_2(out_1)
        out_3 = self.eNet_out_3(out_2)
        return out_1, out_2, out_3

# model = DenseNet()
# x = torch.randn(1, 3, 1024, 1024)
# model(x)

In [None]:
# From tablenet_model.py


class TableDecoder(nn.Module):
    def __init__(self, channels, kernels, strides):
        super(TableDecoder, self).__init__()
        self.conv_7_table = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = kernels[0], stride = strides[0])
        self.upsample_1_table = nn.ConvTranspose2d(in_channels = 256, out_channels=128, kernel_size = kernels[1], stride = strides[1])
        self.upsample_2_table = nn.ConvTranspose2d(in_channels = 128 + channels[0], out_channels = 256, kernel_size = kernels[2], stride = strides[2])
        self.upsample_3_table = nn.ConvTranspose2d(in_channels = 256 + channels[1], out_channels = 1, kernel_size = kernels[3], stride = strides[3])

    def forward(self, x, pool3_out, pool4_out):
        x = self.conv_7_table(x)
        out = self.upsample_1_table(x)
        out = torch.cat((out, pool4_out), dim=1)
        out = self.upsample_2_table(out)
        out = torch.cat((out, pool3_out), dim=1)
        out = self.upsample_3_table(out)
        return out

class ColumnDecoder(nn.Module):
    def __init__(self, channels, kernels, strides):
        super(ColumnDecoder, self).__init__()
        self.conv_8_column = nn.Sequential(
            nn.Conv2d(in_channels = 256,out_channels = 256,kernel_size = kernels[0], stride = strides[0]),
            nn.ReLU(inplace=True),
            nn.Dropout(0.8),
            nn.Conv2d(in_channels = 256,out_channels = 256,kernel_size = kernels[0], stride = strides[0])
        )
        self.upsample_1_column = nn.ConvTranspose2d(in_channels = 256, out_channels=128, kernel_size = kernels[1], stride = strides[1])
        self.upsample_2_column = nn.ConvTranspose2d(in_channels = 128 + channels[0], out_channels = 256, kernel_size = kernels[2], stride = strides[2])
        self.upsample_3_column = nn.ConvTranspose2d( in_channels = 256 + channels[1], out_channels = 1, kernel_size = kernels[3], stride = strides[3])

    def forward(self, x, pool3_out, pool4_out):
        x = self.conv_8_column(x)
        out = self.upsample_1_column(x)
        out = torch.cat((out, pool4_out), dim=1)
        out = self.upsample_2_column(out)
        out = torch.cat((out, pool3_out), dim=1)
        out = self.upsample_3_column(out)
        return out

class TableNet(nn.Module):
    def __init__(self,encoder = 'vgg', use_pretrained_model = True, basemodel_requires_grad = True):
        super(TableNet, self).__init__()
        self.kernels = [(1,1), (2,2), (2,2),(8,8)]
        self.strides = [(1,1), (2,2), (2,2),(8,8)]
        self.in_channels = 512
        if encoder == 'vgg':
            self.base_model = VGG19(pretrained = use_pretrained_model, requires_grad = basemodel_requires_grad)
            self.pool_channels = [512, 256]
        elif encoder == 'resnet':
            self.base_model = ResNet(pretrained = use_pretrained_model, requires_grad = basemodel_requires_grad)
            self.pool_channels = [256, 128]
        elif encoder == 'densenet':
            self.base_model = DenseNet(pretrained = use_pretrained_model, requires_grad = basemodel_requires_grad)
            self.pool_channels = [512, 256]
            self.in_channels = 1024
            self.kernels = [(1,1), (1,1), (2,2),(16,16)]
            self.strides = [(1,1), (1,1), (2,2),(16,16)]
        elif 'efficientnet' in encoder:
            self.base_model = efficientNet(model_type = encoder, pretrained = use_pretrained_model, requires_grad = basemodel_requires_grad)
            if 'b0' in encoder:
                self.pool_channels = [192, 192]
                self.in_channels = 320
            elif 'b1' in encoder:
                self.pool_channels = [320, 192]
                self.in_channels = 320
            elif 'b2' in encoder:
                self.pool_channels = [352, 208]
                self.in_channels = 352
            self.kernels = [(1,1), (1,1), (1,1),(32,32)]
            self.strides = [(1,1), (1,1), (1,1),(32,32)]
        self.conv6 = nn.Sequential(
            nn.Conv2d(in_channels = self.in_channels, out_channels = 256, kernel_size=(1,1)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.8),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size=(1,1)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.8)
        )
        self.table_decoder = TableDecoder(self.pool_channels, self.kernels, self.strides)
        self.column_decoder = ColumnDecoder(self.pool_channels, self.kernels, self.strides)

    def forward(self, x):
        pool3_out, pool4_out, pool5_out = self.base_model(x)
        conv_out = self.conv6(pool5_out)
        table_out = self.table_decoder(conv_out, pool3_out, pool4_out)
        column_out = self.column_decoder(conv_out, pool3_out, pool4_out)
        return table_out, column_out

In [None]:
# From model_loss.py

class TableNetLoss(nn.Module):
    def __init__(self):
        super(TableNetLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
    
    def forward(self, table_prediction, table_target, column_prediction = None, column_target = None,):
        table_loss = self.bce(table_prediction, table_target)
        column_loss = self.bce(column_prediction, column_target)
        return table_loss, column_loss

In [None]:
# From general_utilities.py

TRANSFORM = A.Compose([
    A.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225], max_pixel_value = 255,),
    ToTensorV2()
])
# Apply the SEED
def seed_all(SEED_VALUE = SEED):
    random.seed(SEED_VALUE)
    os.environ['PYTHONHASHSEED'] = str(SEED_VALUE)
    np.random.seed(SEED_VALUE)
    torch.manual_seed(SEED_VALUE)
    torch.cuda.manual_seed(SEED_VALUE)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def get_data_loaders(data_path = DATAPATH):
    df = pd.read_csv(data_path)
    train_data, test_data  = train_test_split(df, test_size = 0.2, random_state = SEED, stratify = df.hasTable)
    train_dataset = ImageFolder(train_data, transform = None)
    test_dataset = ImageFolder(test_data, transform = None)
    train_loader =  DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4, pin_memory = True)
    test_loader =  DataLoader(test_dataset, batch_size = 8, shuffle = False, num_workers = 4, pin_memory = True)
    return train_loader, test_loader

# Save Checkpoint
def save_checkpoint(state, filename = f"{PROCESSED_DATA}/model_checkpoint.pth.tar"):
    torch.save(state, filename)
    print("Checkpoint Saved at: ", filename)

# Load the checkpoint we saved
def load_checkpoint(checkpoint, model, optimizer = None):
    print("Loading checkpoint...")
    model.load_state_dict(checkpoint['state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])
    last_epoch = checkpoint['epoch']
    tr_metrics = checkpoint['train_metrics']
    te_metrics = checkpoint['test_metrics']
    return last_epoch, tr_metrics, te_metrics

def write_summary(writer, tr_metrics, te_metrics, epoch):
    writer.add_scalar("Table Loss/Train", tr_metrics['table_loss'], global_step = epoch)
    writer.add_scalar("Table Loss/Test", te_metrics['table_loss'], global_step = epoch)
    writer.add_scalar("Table Acc/Train", tr_metrics['table_acc'], global_step = epoch)
    writer.add_scalar("Table Acc/Test", te_metrics['table_acc'], global_step = epoch)
    writer.add_scalar("Table F1/Train", tr_metrics['table_f1'], global_step = epoch)
    writer.add_scalar("Table F1/Test", te_metrics['table_f1'], global_step = epoch)
    writer.add_scalar("Table Precision/Train", tr_metrics['table_precision'], global_step = epoch)
    writer.add_scalar("Table Precision/Test", te_metrics['table_precision'], global_step = epoch)
    writer.add_scalar("Table Recall/Train", tr_metrics['table_recall'], global_step = epoch)
    writer.add_scalar("Table Recall/Test", te_metrics['table_recall'], global_step = epoch)
    writer.add_scalar("Column Loss/Train", tr_metrics['column_loss'], global_step = epoch)
    writer.add_scalar("Column Loss/Test", te_metrics['column_loss'], global_step = epoch)
    writer.add_scalar("Column Acc/Train", tr_metrics['col_acc'], global_step = epoch)
    writer.add_scalar("Column Acc/Test", te_metrics['col_acc'], global_step = epoch)
    writer.add_scalar("Column F1/Train", tr_metrics['col_f1'], global_step = epoch)
    writer.add_scalar("Column F1/Test", te_metrics['col_f1'], global_step = epoch)    
    writer.add_scalar("Column Precision/Train", tr_metrics['col_precision'], global_step = epoch)
    writer.add_scalar("Column Precision/Test", te_metrics['col_precision'], global_step = epoch)
    writer.add_scalar("Column Recall/Train", tr_metrics['col_recall'], global_step = epoch)
    writer.add_scalar("Column Recall/Test", te_metrics['col_recall'], global_step = epoch)

def display_metrics(epoch, tr_metrics, te_metrics):
    print(f"Epoch: {epoch} \n\
        Table Loss -- Train: {tr_metrics['table_loss']:.3f} Test: {te_metrics['table_loss']:.3f}\n\
        Table Acc -- Train: {tr_metrics['table_acc']:.3f} Test: {te_metrics['table_acc']:.3f}\n\
        Table F1 -- Train: {tr_metrics['table_f1']:.3f} Test: {te_metrics['table_f1']:.3f}\n\
        Table Precision -- Train: {tr_metrics['table_precision']:.3f} Test: {te_metrics['table_precision']:.3f}\n\
        Table Recall -- Train: {tr_metrics['table_recall']:.3f} Test: {te_metrics['table_recall']:.3f}\n\
        \n\
        Col Loss -- Train: {tr_metrics['column_loss']:.3f} Test: {te_metrics['column_loss']:.3f}\n\
        Col Acc -- Train: {tr_metrics['col_acc']:.3f} Test: {te_metrics['col_acc']:.3f}\n\
        Col F1 -- Train: {tr_metrics['col_f1']:.3f} Test: {te_metrics['col_f1']:.3f}\n\
        Col Precision -- Train: {tr_metrics['col_precision']:.3f} Test: {te_metrics['col_precision']:.3f}\n\
        Col Recall -- Train: {tr_metrics['col_recall']:.3f} Test: {te_metrics['col_recall']:.3f}\n"
    )

def compute_metrics(ground_truth, prediction, threshold = 0.5):
    # Ref: https://stackoverflow.com/a/56649983
    ground_truth = ground_truth.int()
    prediction = (torch.sigmoid(prediction) > threshold).int()
    TP = torch.sum(prediction[ground_truth == 1] == 1)
    TN = torch.sum(prediction[ground_truth == 0] == 0)
    FP = torch.sum(prediction[ground_truth == 1] == 0)
    FN = torch.sum(prediction[ground_truth == 0] == 1)
    acc = (TP + TN) / (TP + TN + FP+ FN)
    precision = TP / (FP + TP + 1e-4)
    recall = TP / (FN + TP + 1e-4)
    f1 = 2 * precision * recall / (precision + recall + 1e-4)
    metrics = {
        'acc': acc.item(),
        'f1': f1.item(),
        'precision':precision.item(),
        'recall': recall.item()
    }
    return metrics

def display(image, table, column, title = 'Original'):
    f, ax  = plt.subplots(1, 3, figsize = (15, 8))
    ax[0].imshow(image)
    ax[0].set_title(f'{title} Image')
    ax[1].imshow(table)
    ax[1].set_title(f'{title} Table Mask')
    ax[2].imshow(column)
    ax[2].set_title(f'{title} Column Mask')
    plt.show()

def display_prediction(image, table = None, table_image = None, no_: bool = False):
  if no_:
    f1, ax  = plt.subplots(1, 1, figsize = (7, 5))
    ax.imshow(image)
    ax.set_title('Original Image')
    f1.suptitle('No Tables Detected')
  else:
    f2, ax  = plt.subplots(1, 3, figsize = (15, 8))
    ax[0].imshow(image)
    ax[0].set_title('Original Image')
    ax[1].imshow(table)
    ax[1].set_title('Image with Predicted Table')
    ax[2].imshow(table_image)
    ax[2].set_title('Predicted Table Example')
  plt.show()

def get_TableMasks(test_image, model, transform = TRANSFORM, device = DEVICE):
    image = transform(image = test_image)["image"]
    # Get predictions
    model.eval()
    with torch.no_grad():
        image = image.to(device).unsqueeze(0)
        # With torch.cuda.amp.autocast():
        table_out, column_out  = model(image)
        table_out = torch.sigmoid(table_out)
        column_out = torch.sigmoid(column_out)
    # Remove gradients
    table_out = (table_out.cpu().detach().numpy().squeeze(0).transpose(1, 2, 0) > 0.5).astype(int)
    column_out = (column_out.cpu().detach().numpy().squeeze(0).transpose(1, 2, 0) > 0.5).astype(int)
    # Return masks
    return table_out, column_out

def fixMasks(image, table_mask, column_mask):
    """ Fix Table Bounding Box to get better OCR predictions """
    table_mask = table_mask.reshape(1024, 1024).astype(np.uint8)
    column_mask = column_mask.reshape(1024, 1024).astype(np.uint8)
    # Get contours of the mask to get number of tables
    contours, table_heirarchy = cv2.findContours(table_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    table_contours = []
    # Ref: https://www.pyimagesearch.com/2015/02/09/removing-contours-image-using-python-opencv/
    # Remove bad contours
    for c in contours:
        # if the contour is bad, draw it on the mask
        if cv2.contourArea(c) > 2000:
            table_contours.append(c)
    if len(table_contours) == 0:
        return None
    # Ref : https://docs.opencv.org/4.5.2/da/d0c/tutorial_bounding_rects_circles.html
    # Get bounding box for the contour
    table_bound_rect = [None] * len(table_contours)
    for i, c in enumerate(table_contours):
        polygon = cv2.approxPolyDP(c, 3, True)
        table_bound_rect[i] = cv2.boundingRect(polygon)
    # Table bounding Box
    table_bound_rect.sort()
    column_bound_rects = []
    for x, y, w, h in table_bound_rect:
        column_mask_crop = column_mask[y : y + h, x : x + w]
        # Get contours of the mask to get number of tables
        contours, column_heirarchy = cv2.findContours(column_mask_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # Get bounding box for the contour
        bound_rect = [None] * len(contours)
        for i, c in enumerate(contours):
            polygon = cv2.approxPolyDP(c, 3, True)
            bound_rect[i] = cv2.boundingRect(polygon)
            # Adjusting columns as per table coordinates
            bound_rect[i] = (bound_rect[i][0] + x, bound_rect[i][1] + y, bound_rect[i][2], bound_rect[i][3])
        column_bound_rects.append(bound_rect)
    image = image[...,0].reshape(1024, 1024).astype(np.uint8)
    # Draw bounding boxes
    color = (0, 255, 0)
    thickness = 4
    for x, y, w, h in table_bound_rect:
        image = cv2.rectangle(image, (x, y),(x + w, y + h), color, thickness)
    return image, table_bound_rect, column_bound_rects

In [None]:
warnings.filterwarnings("ignore")


def train_on_epoch(data_loader, model, optimizer, loss, scaler, threshold = 0.5):
    combined_loss = []
    table_loss, table_acc, table_precision, table_recall, table_f1 = [], [], [], [], []
    column_loss, column_acc, column_precision, column_recall, column_f1 = [], [], [], [], []
    loop = tqdm(data_loader, leave = True)
    for batch_i, image_dict in enumerate(loop):
        image            = image_dict["image"].to(DEVICE)
        table_image      = image_dict["table_image"].to(DEVICE)
        column_image     = image_dict["column_image"].to(DEVICE)
        with torch.cuda.amp.autocast():
            table_out, column_out = model(image)
            i_table_loss, i_column_loss = loss(table_out, table_image, column_out, column_image)
        table_loss.append(i_table_loss.item())
        column_loss.append(i_column_loss.item())
        combined_loss.append((i_table_loss + i_column_loss).item())
        # Backward
        optimizer.zero_grad()
        scaler.scale(i_table_loss + i_column_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        mean_loss = sum(combined_loss) / len(combined_loss)
        loop.set_postfix(loss = mean_loss)
        cal_metrics_table = compute_metrics(table_image, table_out, threshold)
        cal_metrics_col = compute_metrics(column_image, column_out, threshold)
        table_f1.append(cal_metrics_table['f1'])
        table_precision.append(cal_metrics_table['precision'])
        table_acc.append(cal_metrics_table['acc'])
        table_recall.append(cal_metrics_table['recall'])
        column_f1.append(cal_metrics_col['f1'])
        column_acc.append(cal_metrics_col['acc'])
        column_precision.append(cal_metrics_col['precision'])
        column_recall.append(cal_metrics_col['recall'])
        metrics = {
          'combined_loss': np.mean(combined_loss),
          'table_loss': np.mean(table_loss),
          'column_loss': np.mean(column_loss),
          'table_acc': np.mean(table_acc),
          'col_acc': np.mean(column_acc),
          'table_f1': np.mean(table_f1),
          'col_f1': np.mean(column_f1),
          'table_precision': np.mean(table_precision),
          'col_precision': np.mean(column_precision),
          'table_recall': np.mean(table_recall),
          'col_recall': np.mean(column_recall)
        }
    return metrics

def test_on_epoch(data_loader, model, loss, threshold = 0.5, device = DEVICE):
    combined_loss = []
    table_loss, table_acc, table_precision, table_recall, table_f1 = [], [], [], [], []
    column_loss, column_acc, column_precision, column_recall, column_f1 = [], [], [], [], []
    model.eval()
    with torch.no_grad():
        loop = tqdm(data_loader, leave = True)
        for batch_i, image_dict in enumerate(loop):
            image            = image_dict["image"].to(device)
            table_image      = image_dict["table_image"].to(device)
            column_image     = image_dict["column_image"].to(device)
            with torch.cuda.amp.autocast():
                table_out, column_out  = model(image)
                i_table_loss, i_column_loss = loss(table_out, table_image, column_out, column_image)
            table_loss.append(i_table_loss.item())
            column_loss.append(i_column_loss.item())
            combined_loss.append((i_table_loss + i_column_loss).item())
            mean_loss = sum(combined_loss) / len(combined_loss)
            loop.set_postfix(loss=mean_loss)
            cal_metrics_table = compute_metrics(table_image, table_out, threshold)
            cal_metrics_col = compute_metrics(column_image, column_out, threshold)
            table_f1.append(cal_metrics_table['f1'])
            table_precision.append(cal_metrics_table['precision'])
            table_acc.append(cal_metrics_table['acc'])
            table_recall.append(cal_metrics_table['recall'])
            column_f1.append(cal_metrics_col['f1'])
            column_acc.append(cal_metrics_col['acc'])
            column_precision.append(cal_metrics_col['precision'])
            column_recall.append(cal_metrics_col['recall'])
    metrics = {
        'combined_loss': np.mean(combined_loss),
        'table_loss': np.mean(table_loss),
        'column_loss': np.mean(column_loss),
        'table_acc': np.mean(table_acc),
        'col_acc': np.mean(column_acc),
        'table_f1': np.mean(table_f1),
        'col_f1': np.mean(column_f1),
        'table_precision': np.mean(table_precision),
        'col_precision': np.mean(column_precision),
        'table_recall': np.mean(table_recall),
        'col_recall': np.mean(column_recall)
    }
    model.train()
    return metrics

seed_all(SEED_VALUE = SEED)
checkpoint_name = f'{PROCESSED_DATA}/{MODEL_NAME}'
model = TableNet(encoder = 'densenet', use_pretrained_model = True, basemodel_requires_grad = True)

print("Model Architecture and Trainable Paramerters")
print("="*50)
print(summary(model, torch.zeros((1, 3, 1024, 1024)), show_input = False, show_hierarchical = True))

model = model.to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = WEIGHT_DECAY)
loss = TableNetLoss()
scaler = torch.cuda.amp.GradScaler()
train_loader, test_loader = get_data_loaders(data_path = DATAPATH)

# Load checkpoint
if os.path.exists(checkpoint_name):
    last_epoch, train_metrics, test_metrics = load_checkpoint(torch.load(checkpoint_name), model)
    last_table_f1 = test_metrics['table_f1']
    last_column_f1 = test_metrics['col_f1']
    print("Loading Checkpoint...")
    display_metrics(last_epoch, train_metrics, test_metrics)
    print()
else:
    last_epoch = 0
    last_table_f1 = 0.
    last_column_f1 = 0.

# Train Network
print("Training Model\n")
writer = SummaryWriter(f"{PROCESSED_DATA}/runs/TableNet/densenet/configuration_4_batch_{BATCH_SIZE}_learningrate_{LEARNING_RATE}_encoder_train")
# For early stopping
i = 0

for epoch in range(last_epoch + 1, EPOCHS):
    print("="*30)
    start = time.time()
    train_metrics = train_on_epoch(train_loader, model, optimizer, loss, scaler, threshold = 0.5)
    test_metrics = test_on_epoch(test_loader, model, loss, threshold = 0.5)
    write_summary(writer, train_metrics, test_metrics, epoch)
    end = time.time()
    display_metrics(epoch, train_metrics, test_metrics)
    if last_table_f1 < test_metrics['table_f1'] or last_column_f1 < test_metrics['col_f1']:
        last_table_f1 = test_metrics['table_f1']
        last_column_f1 = test_metrics['col_f1']
        checkpoint = {
            'epoch': epoch, 
            'state_dict': model.state_dict(), 
            'optimizer': optimizer.state_dict(), 
            'train_metrics': train_metrics, 
            'test_metrics': test_metrics
        }
        save_checkpoint(checkpoint, checkpoint_name)