In [1]:
import torch 
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
import os
from tqdm import  tqdm
import GPUtil as GPU
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import cv2
from torchvision.transforms import *
#modelling
from torch import nn 
from torch.nn import functional as F
import torch 
from torchvision import models
from pathlib import Path

In [2]:
#define the unet architecture 
def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)


def concat(xs):
    return torch.cat(xs, 1)


class Conv3BN(nn.Module):
    def __init__(self, in_: int, out: int, bn=False):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.bn = nn.BatchNorm2d(out) if bn else None
        self.activation = nn.SELU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        x = self.activation(x)
        return x

    
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            ConvRelu(in_channels, middle_channels),
            nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)
    

class ConvRelu(nn.Module):
    def __init__(self, in_: int, out: int):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x


In [3]:

class UNet11(nn.Module):
    def __init__(self, num_classes=1, num_filters=32):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.encoder = models.vgg11(pretrained=True).features
        self.relu = self.encoder[1]
        self.conv1 = self.encoder[0]
        self.conv2 = self.encoder[3]
        self.conv3s = self.encoder[6]
        self.conv3 = self.encoder[8]
        self.conv4s = self.encoder[11]
        self.conv4 = self.encoder[13]
        self.conv5s = self.encoder[16]
        self.conv5 = self.encoder[18]

        self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)
        self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)
        self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)
        self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)
        self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)
        self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)

        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.relu(self.conv1(x))
        conv2 = self.relu(self.conv2(self.pool(conv1)))
        conv3s = self.relu(self.conv3s(self.pool(conv2)))
        conv3 = self.relu(self.conv3(conv3s))
        conv4s = self.relu(self.conv4s(self.pool(conv3)))
        conv4 = self.relu(self.conv4(conv4s))
        conv5s = self.relu(self.conv5s(self.pool(conv4)))
        conv5 = self.relu(self.conv5(conv5s))

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))
        return F.sigmoid(self.final(dec1))

    
#writing the dice based loss weights for the neural network 

class Loss:
    #init script for the loss
    def __init__(self, dice_weight=1):
        self.nll_loss = nn.BCELoss()
        self.dice_weight = dice_weight

    #call called if parameters are given to the network 
    def __call__(self, outputs, targets):
        loss = self.nll_loss(outputs, targets)
        if self.dice_weight:
            eps = 1e-15
            dice_target = (targets == 1).float()
            dice_output = outputs
            intersection = (dice_output * dice_target).sum()
            union = dice_output.sum() + dice_target.sum() + eps

            loss -= torch.log(2 * intersection / union)

        return loss


In [4]:
########Normal Image Transforms#########################################

img_transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def load_image(path):
    img=cv2.imread(path)
    img = cv2.copyMakeBorder(img, 0, 0, 1, 1, cv2.BORDER_REFLECT_101)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #img=cv2.resize(img,(img.shape[1]//2,img.shape[0]//2))
    return img.astype(np.uint8)

#dataset to load the prediction images
class PredictionDatasetPure:
    def __init__(self, paths):
        self.paths = paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx % len(self.paths)]
        #print(path)
        image = load_image(str(path))
        #path.stem returns the path 
        return img_transform(image),path.stem


In [5]:
#create the prediction directory here
local_data_path=Path('.').absolute()
prediction_path=(local_data_path/'predictions')
prediction_path.mkdir(exist_ok=True, parents=True)
#load the test images
test_images=sorted(list((local_data_path/'test_hq').glob('*.jpg')))
len_test=len(test_images)

In [12]:
#function to load images from_path, make preductions to to_path

def predict(model, from_path,to_path,batch_size=2):
    loader = DataLoader(
            dataset=PredictionDatasetPure(from_path),
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
            pin_memory=True
        )
    for i,(inputs, stems) in enumerate(tqdm(loader,desc='Prediction Progress')):
        print(i)
        #inputs=inputs.to(device)
        outputs=model(inputs)
        with torch.no_grad():
            mask=(outputs.detach().cpu().numpy()*255).astype(np.uint8)
            #print(mask)
            #print(mask.shape)
            #print("path is",str(to_path/(stems[0]+'.png')))
            cv2.imwrite(str(to_path/(stems[0]+'.png')),mask[0,0,:,:])

In [13]:
class WrappedModel(nn.Module):
	def __init__(self, module):
		super(WrappedModel, self).__init__()
		self.module = module # that I actually define.
	def forward(self, x):
		return self.module(x)



In [14]:
device = torch.device('cpu')
fold_no=0
#create test outputs corresponding to each fold in the prediction directory
fold_path=prediction_path/str(fold_no)
fold_path.mkdir(exist_ok=True, parents=True)

test_path=fold_path/'test'
test_path.mkdir(exist_ok=True,parents=True)

#predict from test_images, copy to test_path
#load the model here 

model=UNet11()
model=WrappedModel(model)
device = torch.device('cpu')
#device = torch.device('cpu')
#model=UNet11()
#model = nn.DataParallel(model, device_ids=[])
model_checkpoint='model_'+str(fold_no)+'.pt'
model_checkpoint_path=local_data_path/'bestmodelweights'/model_checkpoint

#load_the_model_weights

state=torch.load(model_checkpoint_path,map_location=device)
model.load_state_dict(state['model'])
model.eval()

predict(model,test_images,test_path,1)



Prediction Progress:   0%|          | 0/100064 [00:00<?, ?it/s][A[A

0




Prediction Progress:   0%|          | 1/100064 [00:13<372:19:29, 13.40s/it][A[A

1




Prediction Progress:   0%|          | 2/100064 [00:26<370:38:45, 13.33s/it][A[A

2




Prediction Progress:   0%|          | 3/100064 [00:40<372:58:26, 13.42s/it][A[A

3




Prediction Progress:   0%|          | 4/100064 [00:53<369:37:14, 13.30s/it][A[A

4




Prediction Progress:   0%|          | 5/100064 [01:05<364:50:18, 13.13s/it][A[A

5




Prediction Progress:   0%|          | 6/100064 [01:19<370:06:35, 13.32s/it][A[A

6




Prediction Progress:   0%|          | 7/100064 [01:33<372:51:53, 13.42s/it][A[A

7




Prediction Progress:   0%|          | 8/100064 [01:45<362:49:44, 13.05s/it][A[A

8




Prediction Progress:   0%|          | 9/100064 [01:58<361:56:05, 13.02s/it][A[A

9




Prediction Progress:   0%|          | 10/100064 [02:13<374:10:21, 13.46s/it][A[A

10




Prediction Progress:   0%|          | 11/100064 [02:25<367:58:30, 13.24s/it][A[A

11




Prediction Progress:   0%|          | 12/100064 [02:38<365:33:24, 13.15s/it][A[A

12




Prediction Progress:   0%|          | 13/100064 [02:51<359:58:33, 12.95s/it][A[A

13




Prediction Progress:   0%|          | 14/100064 [03:04<361:03:42, 12.99s/it][A[A

14




Prediction Progress:   0%|          | 15/100064 [03:15<342:37:04, 12.33s/it][A[A

15




Prediction Progress:   0%|          | 16/100064 [03:25<329:13:01, 11.85s/it][A[A

16




Prediction Progress:   0%|          | 17/100064 [03:37<324:32:30, 11.68s/it][A[A

17




Prediction Progress:   0%|          | 18/100064 [03:47<314:33:07, 11.32s/it][A[A

18




Prediction Progress:   0%|          | 19/100064 [03:58<311:36:23, 11.21s/it][A[A

19




Prediction Progress:   0%|          | 20/100064 [04:08<304:03:44, 10.94s/it][A[A

20




Prediction Progress:   0%|          | 21/100064 [04:18<294:31:06, 10.60s/it][A[A

21




Prediction Progress:   0%|          | 22/100064 [04:28<291:41:19, 10.50s/it][A[A

22




Prediction Progress:   0%|          | 23/100064 [04:38<288:17:40, 10.37s/it][A[A

23




Prediction Progress:   0%|          | 24/100064 [04:48<285:41:16, 10.28s/it][A[A

24




Prediction Progress:   0%|          | 25/100064 [04:58<278:41:46, 10.03s/it][A[A

25




Prediction Progress:   0%|          | 26/100064 [05:07<272:25:18,  9.80s/it][A[A

26




Prediction Progress:   0%|          | 27/100064 [05:17<271:22:25,  9.77s/it][A[A

27




Prediction Progress:   0%|          | 28/100064 [05:26<268:09:27,  9.65s/it][A[A

28




Prediction Progress:   0%|          | 29/100064 [05:36<266:06:22,  9.58s/it][A[A

29




Prediction Progress:   0%|          | 30/100064 [05:45<263:26:42,  9.48s/it][A[A

30




Prediction Progress:   0%|          | 31/100064 [05:54<262:52:54,  9.46s/it][A[A

31




Prediction Progress:   0%|          | 32/100064 [06:04<265:20:16,  9.55s/it][A[A

32




Prediction Progress:   0%|          | 33/100064 [06:13<260:30:00,  9.38s/it][A[A

33




Prediction Progress:   0%|          | 34/100064 [06:23<261:15:38,  9.40s/it][A[A

34




Prediction Progress:   0%|          | 35/100064 [06:32<260:57:53,  9.39s/it][A[A

35




Prediction Progress:   0%|          | 36/100064 [06:41<257:34:28,  9.27s/it][A[A

36




Prediction Progress:   0%|          | 37/100064 [06:50<257:49:53,  9.28s/it][A[A

37




Prediction Progress:   0%|          | 38/100064 [07:00<262:11:05,  9.44s/it][A[A

38




Prediction Progress:   0%|          | 39/100064 [07:09<261:14:18,  9.40s/it][A[A

39




Prediction Progress:   0%|          | 40/100064 [07:19<260:36:10,  9.38s/it][A[A

40




Prediction Progress:   0%|          | 41/100064 [07:28<260:28:52,  9.38s/it][A[A

41




Prediction Progress:   0%|          | 42/100064 [07:37<259:57:59,  9.36s/it][A[A

42




Prediction Progress:   0%|          | 43/100064 [07:47<260:30:34,  9.38s/it][A[A

43




Prediction Progress:   0%|          | 44/100064 [07:56<260:10:05,  9.36s/it][A[A

44




Prediction Progress:   0%|          | 45/100064 [08:05<257:14:14,  9.26s/it][A[A

45




Prediction Progress:   0%|          | 46/100064 [08:15<261:04:47,  9.40s/it][A[A

46




Prediction Progress:   0%|          | 47/100064 [08:24<261:02:45,  9.40s/it][A[A

47




Prediction Progress:   0%|          | 48/100064 [08:34<261:11:38,  9.40s/it][A[A

48




Prediction Progress:   0%|          | 49/100064 [08:43<261:14:08,  9.40s/it][A[A

49




Prediction Progress:   0%|          | 50/100064 [08:53<264:51:51,  9.53s/it][A[A

50




Prediction Progress:   0%|          | 51/100064 [09:02<263:52:58,  9.50s/it][A[A

51




Prediction Progress:   0%|          | 52/100064 [09:12<266:13:31,  9.58s/it][A[A

52




Prediction Progress:   0%|          | 53/100064 [09:21<262:44:43,  9.46s/it][A[A

53




Prediction Progress:   0%|          | 54/100064 [09:31<261:49:16,  9.42s/it][A[A

54




Prediction Progress:   0%|          | 55/100064 [09:40<259:14:32,  9.33s/it][A[A

55


KeyboardInterrupt: 

In [None]:
test_images

In [None]:
print(state['best_valid_loss'])

In [None]:
device

In [None]:
device