In [1]:
import torch 
from dataset import Pic_to_Pic_dataset
from models import UNET, U2NET
from torch.utils.data import DataLoader
from loss import SSIM_DICE_BCE, DiceScore
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
from PIL import Image 
from torchvision.transforms import ToTensor
import os 
from tqdm import tqdm 
import cv2

In [2]:
model = UNET().cuda()
ckpt = torch.load('./ckpts/quantum_noise/56/best_unet.pth') 
model.load_state_dict(ckpt['model_state'])
dice_score = DiceScore()
print(model)


UNET(
  (ch): DoubleConv(
    (net): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): PReLU(num_parameters=64)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): PReLU(num_parameters=64)
    )
  )
  (down1): DownBlock(
    (net): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (net): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): PReLU(num_parameters=128)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum

In [3]:
def imgs_to_vid(path): 
    imgs = os.listdir(path)
    imgs = [os.path.join(path, img) for img in imgs if img.endswith('.png')]
    imgs = sorted(imgs, key=lambda x: int(x.split('/')[-1].split('.')[0]))
    img = cv2.imread(imgs[6])
    os.makedirs('{}/'.format(path), exist_ok=True)
    video = cv2.VideoWriter('{}/vid.mp4'.format(path), cv2.VideoWriter_fourcc(*'mp4v'), 10, (img.shape[1], img.shape[0]))

    for img_path in tqdm(imgs): 
        if not img_path.endswith('.png'): 
            continue
        img = cv2.imread(img_path)
        video.write(img)       
    video.release()
    os.system('rm ./plots/{}/*.png'.format(patient_id))



In [4]:
whole_df = pd.read_csv('/home/shivac/qml-data/csv_files/val_10_org.csv') 
patient_ids = np.unique(whole_df.patient_id)
model.eval()
for patient_id in patient_ids:
    os.makedirs('plots/{}'.format(patient_id), exist_ok=True)
    df = whole_df[whole_df.patient_id == patient_id].sort_values('idx')
    df.reset_index(inplace=True)
    for i in tqdm(range(len(df))): 
        img_path = '/home/shivac/qml-data/' + df.loc[i].img_path
        mask_path = '/home/shivac/qml-data/' + df.loc[i].mask_path
        img = Image.open(img_path).convert('L')
        mask = Image.open(mask_path) 
        mask = ToTensor()(mask).unsqueeze(0)
        img = ToTensor()(img).unsqueeze(0)
        logits = model(img.cuda())
        dice = round(dice_score(mask.cuda(), logits).item(), 2)

        plt.figure(figsize=(10, 6), facecolor='gray')
        plt.axis('off')
        plt.title('Depth: ' + str(i) + ' dice_score: ' + str(dice))
        plt.subplot(1,3,1)
        plt.title('img')
        plt.axis('off')
        plt.imshow(img[0].permute(1,2,0), cmap='gray')
        plt.subplot(1,3,2)
        plt.title('mask')
        plt.axis('off')
        plt.imshow(mask[0].permute(1,2,0), cmap='gray')
        plt.subplot(1,3,3)
        plt.title('logits')
        plt.axis('off')
        plt.imshow(logits[0].detach().cpu().permute(1,2,0), cmap='gray')
        # plt.tight_layout() 
        plt.savefig('plots/{}/{}.png'.format(patient_id, i))
        plt.clf() 
        plt.close()
    imgs_to_vid('./plots/{}/'.format(patient_id))
        
        


  0%|          | 0/301 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 301/301 [00:56<00:00,  5.37it/s]
100%|██████████| 301/301 [00:03<00:00, 99.72it/s] 
100%|██████████| 301/301 [00:55<00:00,  5.46it/s]
100%|██████████| 301/301 [00:02<00:00, 103.30it/s]
100%|██████████| 301/301 [00:55<00:00,  5.44it/s]
100%|██████████| 301/301 [00:02<00:00, 103.33it/s]
100%|██████████| 301/301 [00:55<00:00,  5.42it/s]
100%|██████████| 301/301 [00:03<00:00, 100.11it/s]
100%|██████████| 301/301 [00:56<00:00,  5.32it/s]
100%|██████████| 301/301 [00:03<00:00, 96.92it/s] 
100%|██████████| 299/299 [00:55<00:00,  5.41it/s]
100%|██████████| 299/299 [00:02<00:00, 101.56it/s]
100%|██████████| 301/301 [00:55<00:00,  5.42it/s]
100%|██████████| 301/301 [00:03<00:00, 99.40it/s] 
100%|██████████| 301/301 [00:55<00:00,  5.40it/s]
100%|██████████| 301/301 [00:02<00:00, 104.16it/s]
100%|██████████| 301/301 [00:56<00:00,  5.32it/s]
100%|██████████| 301/301 [00:03<00:00, 95.70it/s] 
100%|██████████| 301/301 [00:55<00:00, 

In [5]:
patient_id

'MEDVID0085_M_20211202_111644_0001_IMAGES'