In [1]:
%load_ext autoreload
%autoreload 2
import argparse
from typing import Dict

import torch
import os
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from model.metrics import Metrics
from model.rvgan import RVGAN
from model.dice_loss import DiceLoss, DiceBCELoss
from datasets.dataset import RetinaSegmentationDataset
from utils.resultPrinter import ResultPrinter
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def eval_model(model, dataloader, criterion, device):
    metrics_tracker = Metrics(device)
    model.eval()
    eval_running_loss = 0.0
    with torch.no_grad():
        for ind, (img, lbl) in enumerate(tqdm(dataloader, desc="Validation")):
            # Copy to device
            img = img.to(device)
            lbl = lbl.to(device)
            # Make the prediction
            lbl_pred = model(img)
            # Compute loss
            loss = criterion(lbl_pred, lbl)
            # compute metrics
            metrics_tracker.calculate(lbl_pred, lbl)
            # Running tally
            eval_running_loss += loss.item() * img.shape[0]

    # Compute the loss for this epoch
    eval_loss = eval_running_loss / (ind + 1)
    # Compute the metrics for this epoch
    metrics = metrics_tracker.get_mean_metrics(ind + 1)
    metrics['loss'] = eval_loss
    return metrics

In [3]:
rootdir: str = "C:/Users/shawn/Desktop/Development/CS7643/data/DATA_4D_Patches/DATA_4D_Patches"
workers: int = 8
load_encoder_weights: str = None
load_bt_checkpoint: str = None
anneal_tmax: int = 10
anneal_eta: int = 0
run_name: str = "test-drive"
checkpoint_dir: str = "C:/Users/shawn/Desktop/Development/CS7643/checkpoint/"

args = {
    "learning_rate": 0.01,
    "unet_layers": "64-128-256",
    "epochs": 10,
    "batch_size": 64,
    "scheduler": "CosineAnnealing",
    "loss_function": "DiceLoss",
    "dropout": 0.2
}

# Get the device
device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"

# Initialize the model on the GPU
model = RVGAN(num_channels_in=4).to(device)

In [4]:
prev_weights = torch.load("..//checkpoint//rvgan--batch_size=16--epochs=15epoch-8.pth")
prev_weights

{'epoch': 9,
 'model': OrderedDict([('generators.fine_generator.conv_1.weight',
               tensor([[[[-5.5948e-02,  5.2397e-02,  1.9697e-02,  ..., -2.5557e-02,
                          -2.8537e-02,  3.6268e-02],
                         [ 3.1104e-02,  7.0880e-02,  3.0611e-02,  ..., -5.1871e-02,
                           2.9963e-02,  1.5707e-04],
                         [ 4.2496e-02,  4.4020e-02, -9.7736e-05,  ..., -5.2078e-02,
                           4.9402e-02, -5.3281e-02],
                         ...,
                         [-5.4072e-02, -2.7449e-02, -3.2675e-02,  ...,  5.8486e-02,
                           4.0180e-02, -1.3633e-02],
                         [-6.4187e-02,  7.1670e-02, -3.7059e-02,  ...,  4.9444e-02,
                          -2.5259e-02,  1.7161e-02],
                         [-2.7537e-03,  5.0264e-02, -3.6045e-02,  ...,  7.0404e-02,
                          -6.1138e-03,  3.5448e-02]],
               
                        [[ 6.2885e-02, -1.0912e-02,

In [5]:
model.load_state_dict(prev_weights["model"])

<All keys matched successfully>

In [6]:
from datasets.dataset import RetinaSegmentationDataset
import torchvision.transforms as transforms
from img_transform.transforms import EyeMaskCustomTransform, EyeDatasetCustomTransform


DRIVE_TRANSFORMS = transforms.Compose([
    transforms.ToTensor(),
    torch.nn.ConstantPad2d((0, 75, 0, 56), 0),
    EyeDatasetCustomTransform(mask_threshold=0.25),
])


# Load the validation datasets
test_path = os.path.join(rootdir, "Testing")
test_file_basenames = os.listdir(os.path.join(test_path, "images"))
test_dataset = RetinaSegmentationDataset(test_path, test_file_basenames, has_labels=False, img_transforms=DRIVE_TRANSFORMS)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, num_workers=1,
    pin_memory=True, shuffle=False)

In [7]:
import numpy as np
def predict_model(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for ind, (img, lbl) in enumerate(tqdm(dataloader, desc="Testing")):
            # Copy to device
            img = img.to(device)
            # Make the prediction
            coarse_generator_out, fine_generator_out = model.generators(img)
            lbl_pred = fine_generator_out
            im = np.round(lbl_pred.squeeze(0).squeeze(0).cpu().detach().numpy()[:584, :565])
            im = Image.fromarray(im.astype(np.uint8) * 255)
            im.save(f"C:/Users/shawn/Desktop/Development/CS7643/drive_predicted_rvgan/{ind}.png")

In [8]:
predict_model(model, test_dataloader, device)

Testing: 100%|█████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.02it/s]
