In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import os
import numpy as np
from torchinfo import summary
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from Convs_Unet import UNet

In [2]:
from torch.utils.data import random_split

In [3]:
from tqdm.notebook import trange, tqdm

In [2]:
##Final ensemble model 

In [4]:
class Im_Seg(torch.nn.Module):
    def __init__(self ,device, loc = "../models/" , trainable = set()):
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(()) , requires_grad=True)
        self.b = torch.nn.Parameter(torch.randn(()) , requires_grad=True)
        self.c = torch.nn.Parameter(torch.randn(()) , requires_grad=True)
        self.d = torch.nn.Parameter(torch.randn(()) , requires_grad=True)
        self.mod1 = UNet(n_channels = 3 , n_classes = 1) 
        self.mod2 = UNet(n_channels = 3 , n_classes = 1) 
        self.mod3 = UNet(n_channels = 3 , n_classes = 1) 
        self.mod1.to(device = device)
        self.mod2.to(device = device)
        self.mod3.to(device = device)
        self.mod1.load_state_dict(torch.load(loc + "model1.pth", map_location=device))
        self.mod2.load_state_dict(torch.load(loc + "model2.pth", map_location=device))
        self.mod3.load_state_dict(torch.load(loc + "model3.pth", map_location=device))
        for k in self.mod1.named_parameters():
            k[1].requires_grad = True if k[0].split(".")[0] in trainable else False
        for k in self.mod2.named_parameters():
            k[1].requires_grad = True if k[0].split(".")[0] in trainable else False  
        for k in self.mod3.named_parameters():
            k[1].requires_grad = True if k[0].split(".")[0] in trainable else False  
            
    def check_state(self , mod):
        dic = {1 : self.mod1 , 2 : self.mod2 , 3 : self.mod3 }
        return dic[mod].named_parameters()
    
    def forward(self , inp1 , inp2 ,inp3):
        out1 = torch.sigmoid(self.mod1(inp1))
        out2 = torch.sigmoid(self.mod2(inp2))
        out3 = torch.sigmoid(self.mod3(inp3))
        return torch.sigmoid(self.a*out1 + self.b*out2 + self.c*out3 + self.d)
    
        

In [5]:
class Data_from_disk_full_tag(torch.utils.data.Dataset) :
    def __init__(self, dir_ = "../Data/Full_Aug/"): #tag indicates the type of input image to be trained on 
        super().__init__() 
        self.dir_ = dir_
        self.inp_dir_0 = sorted([x for x in os.listdir(dir_ + "Im_inp/") if x.split("_")[2] == str(0)])
        self.inp_dir_1 = sorted([x for x in os.listdir(dir_ + "Im_inp/") if x.split("_")[2] == str(1)])
        self.inp_dir_2 = sorted([x for x in os.listdir(dir_ + "Im_inp/") if x.split("_")[2] == str(2)])
        self.out_dir = sorted(os.listdir(dir_ + "Im_out/"))
        print(len(self.inp_dir_0) ,len(self.inp_dir_1) ,len(self.inp_dir_2) ,len(self.out_dir))
        
    def __getitem__(self , idx):
        #print(self.inp_dir[idx] , self.out_dir[idx])
        #assert self.inp_dir[idx] == self.out_dir[idx] , f"names dont match . given names are {self.inp_dir[idx]} and {self.out_dir[idx]}"
        #assert self.inp_dir[idx].split("_")[2] == str(self.tag) , f"tags are different , required : {self.tag} , found : {self.inp_dir[idx].split('_')[2]}"
        #print(self.inp_dir_0[idx] , self.inp_dir_1[idx] , self.inp_dir_2[idx] , self.out_dir[idx])
        ch0 = tuple([ss for i , ss in enumerate(self.inp_dir_0[idx].split("_")) if i != 2])
        ch1 = tuple([ss for i , ss in enumerate(self.inp_dir_1[idx].split("_")) if i != 2])
        ch2 = tuple([ss for i , ss in enumerate(self.inp_dir_2[idx].split("_")) if i != 2])
        ch3 = tuple([ss for i , ss in enumerate(self.out_dir[idx].split("_")) if i != 2])
        #print(ch0 , ch1 ,ch2 ,ch3)
        assert ch0 == ch1 and ch1 == ch2 and ch2 == ch3 , f"Names dont match {self.inp_dir_0[idx] , self.inp_dir_1[idx] , self.inp_dir_2[idx] , self.out_dir[idx]}"
        x1_0 = np.array(Image.open(self.dir_ + "Im_inp/" + self.inp_dir_0[idx]))
        x1_1 = np.array(Image.open(self.dir_ + "Im_inp/" + self.inp_dir_1[idx]))
        x1_2 = np.array(Image.open(self.dir_ + "Im_inp/" + self.inp_dir_2[idx]))
        x2 = np.array(Image.open(self.dir_ + "Im_out/" + self.out_dir[idx]))[:,:,0:1]
        x1_0 = np.transpose(x1_0 ,(2, 0, 1))
        x1_1 = np.transpose(x1_1 ,(2, 0, 1))
        x1_2 = np.transpose(x1_2 ,(2, 0, 1))
        x2 = np.transpose(x2 ,(2, 0, 1))
        x1_0 = x1_0/255
        x1_1 = x1_1/255
        x1_2 = x1_2/255
        x2 = x2/255
        #assert self.inp_dir_0[idx].split("_")[2] == str(0) , f"tags are different , required : {self.tag} , found : {self.inp_dir[idx].split('_')[2]}"
        #if return_names :
        #return (torch.from_numpy(x1) , torch.from_numpy(x2)) , (self.inp_dir[idx] ,self.out_dir[idx] )
        return ((torch.from_numpy(x1_0) , torch.from_numpy(x1_1) , torch.from_numpy(x1_2)) , torch.from_numpy(x2))
    
    def __len__(self):
        return (len(self.inp_dir_0))  

In [18]:
# test the iterator first

In [19]:
#dd = Data_from_disk_full_tag()

512 512 512 512


In [24]:
#for i in range(len(dd) - 2,len(dd)):
#    dd[i]
#    break

MC16_S2_0_750_700.png MC16_S2_1_750_700.png MC16_S2_2_750_700.png MC16_S2_0_750_700.png
('MC16', 'S2', '750', '700.png') ('MC16', 'S2', '750', '700.png') ('MC16', 'S2', '750', '700.png') ('MC16', 'S2', '750', '700.png')


In [26]:
#after thorough testing

In [6]:
g1 = torch.Generator().manual_seed(42)
train_ds , val_ds , test_ds = random_split(Data_from_disk_full_tag() , [0.7 , 0.2 ,0.1] , generator = g1)

512 512 512 512


In [7]:
len(train_ds) , len(val_ds) , len(test_ds)

(359, 102, 51)

In [16]:
batch = 3
EPOCHS = 30
lr = 0.01

In [17]:
train_loader = DataLoader(train_ds , batch , shuffle = True)
val_loader = DataLoader(val_ds , batch)
test_loader = DataLoader(test_ds , batch)

In [13]:
def evaluate(model , val_loader ,loss_func ):
    curr_val_loss = 0
    with torch.no_grad() :
        for i ,mm in enumerate(val_loader):
            x = mm[0]
            y = mm[1]
            for j in range(3):
                x[j] = x[j].to(device , dtype = torch.float32)
            y = y.to(device , dtype = torch.float32)
            loss = loss_func(model(*x), y)
            curr_val_loss = (curr_val_loss*(i) + loss)/(i+1)
    return curr_val_loss

In [34]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss*(self.min_delta + 1)):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False  

In [36]:
early_stopper = EarlyStopper(patience = 3 , min_delta = 0.05)
def fit(model , epochs , lr , train_loader , val_loader , opt_func , loss_func ):
    #Q = deque()
    history = []
    optimizer = opt_func(model.parameters() , lr)
    cur_val_loss = 0
    scheduler =  torch.optim.lr_scheduler.ExponentialLR(optimizer , gamma = 0.95 , verbose = True)
    for epoch in trange(epochs , total = epochs , desc="num_epochs") :
        iterator = tqdm(enumerate(train_loader) , total = len(train_loader) ,desc="num_batches") 
        for i ,mm in iterator:
            x = mm[0]
            y = mm[1]
            for j in range(3):
                x[j] = x[j].to(device , dtype = torch.float32)
            y = y.to(device , non_blocking = True , dtype = torch.float32)
            #scheduler.print_lr()
            #print(x.device , y.device)
            optimizer.zero_grad()
            #print("zero_grad")
            loss = loss_func(model(*x), y)
            #print("calc_loss")
            iterator.set_postfix(train_loss = loss.item())
            loss.backward()
            optimizer.step()
            
            #torch.cuda.empty_cache()
        scheduler.step()    
        val_err = evaluate(model , val_loader , loss_func)
        curr_train_loss = evaluate(model , train_loader , loss_func)
        print(f"train_loss : {curr_train_loss}  , val_loss : {val_err}")
        history.append((curr_train_loss , val_err ))
        #if len(Q) == 3 :
        #    Q.popleft()
        #Q.append(model.parameters())
        if early_stopper.early_stop(val_err) :
            #model.parameters() = Q[0]
            break
            
    return history        
            

In [10]:
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Im_Seg(device)

In [47]:
for ss in model.parameters() :
    print(ss , ss.is_cuda)
    if not ss.is_cuda:
        print(ss , ss.is_cuda)

Parameter containing:
tensor(-1.0764, device='cuda:0', requires_grad=True) True
Parameter containing:
tensor(1.0632, device='cuda:0', requires_grad=True) True
Parameter containing:
tensor(0.0732, device='cuda:0', requires_grad=True) True
Parameter containing:
tensor(0.5796, device='cuda:0', requires_grad=True) True
Parameter containing:
tensor([[[[ 1.8891e-38,  1.7821e-38,  2.3253e-38],
          [ 1.5670e-38,  1.4314e-38,  1.0727e-38],
          [ 1.3786e-38,  1.4174e-38,  1.2988e-38]],

         [[ 6.2313e-38,  4.3693e-38,  3.2927e-38],
          [ 6.3985e-38,  5.7717e-38,  6.7191e-38],
          [ 6.8400e-38,  6.7746e-38,  4.9936e-38]],

         [[ 6.8120e-38,  6.2146e-38,  5.7528e-38],
          [ 4.6568e-38,  6.5558e-38,  5.5100e-38],
          [ 5.0219e-38,  4.6784e-38,  3.9735e-38]]],


        [[[ 4.0789e-39,  5.2598e-40,  3.2969e-39],
          [-4.7167e-40,  3.6139e-39, -3.0623e-39],
          [ 4.3252e-39,  4.1237e-39,  5.3400e-40]],

         [[ 1.1483e-38,  7.5231e-39,  9

Parameter containing:
tensor([[[[-1.4516e-38,  1.0948e-38, -3.8050e-39],
          [ 1.6402e-38,  1.3690e-38,  7.1965e-40],
          [-4.3289e-38,  2.3272e-39,  1.1207e-38]],

         [[-6.3549e-38, -6.3535e-38,  6.3187e-38],
          [ 6.3173e-38, -6.3377e-38,  6.3448e-38],
          [ 6.3095e-38,  6.3556e-38,  6.3559e-38]],

         [[-6.3168e-38, -6.3530e-38, -6.3585e-38],
          [-6.3593e-38, -6.3655e-38, -6.3510e-38],
          [-6.3672e-38, -6.3683e-38, -6.3398e-38]],

         ...,

         [[ 6.3436e-38,  6.3265e-38,  6.3286e-38],
          [ 6.3473e-38,  6.3705e-38,  6.3427e-38],
          [ 6.3209e-38, -6.3433e-38,  6.3166e-38]],

         [[-1.8182e-38, -1.3315e-39, -9.3259e-39],
          [ 3.2665e-40, -5.7741e-39,  1.6793e-38],
          [-8.0375e-39,  1.4070e-38,  3.7116e-39]],

         [[ 6.3505e-38, -6.3752e-38, -6.3081e-38],
          [ 6.3251e-38, -6.3731e-38, -6.3595e-38],
          [ 6.3666e-38,  6.3064e-38, -6.3608e-38]]],


        [[[-2.6976e-39, -7.2802

Parameter containing:
tensor([[[[ 6.3316e-38,  6.3613e-38,  6.3372e-38],
          [-6.3297e-38, -6.3664e-38,  6.3642e-38],
          [-6.3382e-38, -6.3143e-38,  6.3235e-38]],

         [[ 6.3307e-38,  6.3097e-38,  6.3060e-38],
          [-6.3529e-38, -5.2004e-38,  6.3729e-38],
          [-6.3701e-38, -6.3353e-38, -6.3085e-38]],

         [[ 7.4490e-39, -1.0620e-38, -4.1111e-38],
          [ 9.8705e-39,  3.9389e-38, -9.3685e-39],
          [-3.9857e-39,  1.5550e-38,  1.7332e-38]],

         ...,

         [[ 6.3305e-38,  6.3691e-38, -6.3635e-38],
          [ 6.3401e-38,  6.3065e-38, -6.3563e-38],
          [ 6.3451e-38,  6.3217e-38,  6.3303e-38]],

         [[ 3.7900e-39,  2.3114e-38, -2.6708e-38],
          [ 1.3096e-39,  5.3774e-39,  2.8371e-38],
          [ 1.1578e-38,  2.3317e-38,  6.9785e-38]],

         [[ 2.3670e-38,  6.5534e-38, -6.6356e-38],
          [-1.5570e-38,  4.6242e-40, -3.6695e-40],
          [-1.1074e-38, -2.7700e-38, -4.0745e-39]]],


        [[[ 1.1464e-38, -6.4480

Parameter containing:
tensor([[[[-5.0792e-38, -5.0580e-38, -5.1962e-38],
          [-5.3317e-38, -4.7550e-38, -5.3286e-38],
          [-5.4062e-38, -4.9121e-38, -5.0170e-38]],

         [[-1.0287e-38,  5.2217e-39, -9.5759e-39],
          [-4.1545e-39, -6.6601e-39, -2.1093e-38],
          [-1.3885e-39, -1.0717e-38, -2.5224e-38]],

         [[-2.6616e-39, -9.7807e-40, -1.6534e-38],
          [ 1.1300e-38,  1.3559e-38,  4.0219e-39],
          [ 2.0701e-38,  2.5078e-38,  1.7249e-38]],

         ...,

         [[-5.9163e-39, -7.5285e-40, -4.1934e-39],
          [-8.6309e-39, -4.6225e-39, -6.1340e-39],
          [-7.5461e-39, -1.4828e-39, -9.9241e-39]],

         [[-1.5317e-38, -2.9550e-38, -3.5621e-38],
          [ 5.3839e-39, -5.4116e-39, -1.7626e-38],
          [ 4.4623e-38,  7.6208e-39, -4.6933e-39]],

         [[-2.9332e-38, -3.9954e-38, -6.8098e-39],
          [-3.1106e-38,  6.3205e-38, -1.6688e-38],
          [-1.1650e-38, -2.1125e-38, -4.5473e-39]]],


        [[[-5.7356e-38, -1.1788

Parameter containing:
tensor([[[[-6.3299e-38,  6.3271e-38, -6.3309e-38],
          [-6.3147e-38,  6.3719e-38, -6.3200e-38],
          [-6.3465e-38, -6.3280e-38, -6.3349e-38]],

         [[-6.3717e-38,  6.3649e-38, -6.3341e-38],
          [-6.3514e-38,  6.3132e-38, -6.3497e-38],
          [-6.3326e-38, -6.3658e-38, -6.3467e-38]],

         [[-6.3304e-38,  6.3742e-38, -6.3205e-38],
          [-6.3292e-38,  6.3758e-38, -6.3314e-38],
          [-6.3415e-38, -6.3358e-38, -6.3395e-38]],

         ...,

         [[-6.3342e-38,  6.3383e-38, -6.3629e-38],
          [-6.3660e-38,  6.3719e-38, -6.3168e-38],
          [-6.3201e-38, -6.3666e-38, -6.3340e-38]],

         [[-6.3413e-38, -6.3368e-38,  6.3505e-38],
          [ 6.3067e-38,  6.3330e-38,  6.3452e-38],
          [ 6.3220e-38,  6.3372e-38,  6.3291e-38]],

         [[ 1.2657e+00,  7.5042e-01, -2.5820e-01],
          [ 5.2796e-01,  1.2439e+00,  7.3013e-01],
          [-4.1433e-02,  1.3654e+00,  8.6467e-01]]],


        [[[ 6.3532e-38,  6.3695

Parameter containing:
tensor([[[[ 6.3157e-38,  6.3162e-38, -6.3755e-38],
          [ 6.3689e-38,  6.3097e-38, -6.3122e-38],
          [-6.3742e-38,  6.3682e-38,  6.3656e-38]],

         [[-6.3509e-38, -6.3497e-38, -6.3542e-38],
          [-6.3534e-38, -1.3642e-38, -6.3455e-38],
          [-6.3341e-38, -6.3237e-38, -6.3515e-38]],

         [[ 6.3475e-38, -5.2612e-38, -6.3354e-38],
          [ 6.3733e-38, -1.5797e-38, -6.3721e-38],
          [-6.3088e-38, -6.3730e-38, -6.3267e-38]],

         ...,

         [[-6.3459e-38, -6.3665e-38, -6.3476e-38],
          [-6.3286e-38,  6.3416e-38, -6.3468e-38],
          [ 6.3619e-38,  6.3505e-38,  6.3121e-38]],

         [[-2.5640e-38, -6.2266e-38, -1.6061e-38],
          [-4.7324e-38,  1.6007e-38, -3.3356e-38],
          [-4.8462e-38, -2.2828e-38, -3.5208e-38]],

         [[ 3.1097e-39, -1.3615e-38, -5.2557e-39],
          [ 4.1691e-39, -4.1083e-38, -4.5767e-38],
          [-1.9139e-38, -3.8472e-38, -1.0756e-38]]],


        [[[-6.3060e-38, -6.3193

Parameter containing:
tensor([[[[ 1.1407e-38,  1.2790e-38,  5.6310e-38],
          [ 1.6465e-38,  3.0949e-38, -1.8919e-38],
          [ 2.5598e-38,  3.1524e-38,  5.3304e-38]],

         [[ 2.2308e-38,  2.3919e-38,  1.9306e-38],
          [ 2.1684e-38,  1.8072e-38,  1.4185e-38],
          [ 1.7013e-38,  1.9921e-38,  1.2611e-38]],

         [[ 2.4453e-38,  1.8858e-38,  1.1070e-38],
          [ 2.3907e-38,  2.4355e-38,  1.8346e-38],
          [ 4.3989e-38,  4.6325e-38,  4.6094e-38]],

         ...,

         [[ 1.2785e-38,  1.6540e-38,  1.6145e-38],
          [ 1.2890e-38,  1.4029e-38,  1.4209e-38],
          [ 1.8140e-38,  1.7727e-38,  2.3972e-38]],

         [[-4.3654e-38, -5.2249e-38, -5.4619e-38],
          [-4.2699e-38, -5.6219e-38, -6.1396e-38],
          [-2.4710e-38, -3.9045e-38, -5.3829e-38]],

         [[-1.5308e-38, -2.0165e-38, -5.4711e-38],
          [-4.8309e-39, -1.4135e-38, -3.4945e-38],
          [-7.3812e-39, -7.6743e-39, -2.2554e-38]]],


        [[[-6.3152e-38, -6.3585

Parameter containing:
tensor([[[[-1.8231e-38, -3.2129e-38, -4.3172e-38],
          [-1.7954e-38, -2.5465e-38, -4.0172e-38],
          [-1.7367e-38, -1.9813e-38, -3.4410e-38]],

         [[-8.1873e-40, -2.8287e-39,  1.2207e-40],
          [-2.7790e-39, -3.0085e-39,  2.5706e-39],
          [ 3.4056e-40, -3.5925e-39, -1.9297e-39]],

         [[-2.4546e-38, -2.1016e-38, -1.7064e-38],
          [-2.5243e-38, -1.7897e-38, -8.5850e-39],
          [-2.9312e-38, -1.8150e-38, -1.5781e-38]],

         ...,

         [[-9.0722e-39,  4.7705e-38,  1.3598e-38],
          [-5.0906e-39,  3.5078e-38,  3.6166e-38],
          [-3.2845e-38, -1.7036e-38, -2.3174e-38]],

         [[-1.2957e-38, -1.0587e-38, -1.7632e-38],
          [-1.0709e-38, -2.2935e-38, -2.5164e-38],
          [-8.0637e-39, -1.7389e-38, -1.8944e-38]],

         [[ 2.8223e-38,  6.3609e-38,  4.8990e-38],
          [ 6.4170e-38,  4.5610e-38,  5.5385e-38],
          [ 9.5025e-39,  4.2160e-38,  2.8112e-38]]],


        [[[-2.2542e-38, -2.0866

Parameter containing:
tensor([[[[-6.3424e-38,  6.3305e-38,  6.3103e-38],
          [ 6.3251e-38,  6.3529e-38,  6.3407e-38],
          [ 6.3375e-38,  6.3138e-38,  6.3132e-38]],

         [[ 3.7174e-40, -2.7290e-39,  1.8184e-38],
          [-4.3621e-39, -9.7252e-39,  4.9334e-38],
          [-8.3186e-39, -1.7132e-38,  3.0294e-39]],

         [[ 3.6247e-39, -1.4140e-38, -4.7540e-39],
          [-7.8537e-39, -1.4816e-39,  2.7558e-38],
          [ 4.5444e-38,  4.8832e-38,  3.2877e-38]],

         ...,

         [[ 5.6975e-38,  1.3516e-38, -4.1402e-38],
          [ 3.5028e-38,  4.0149e-38, -6.9715e-38],
          [ 6.6493e-38,  2.6353e-38, -2.4352e-38]],

         [[ 6.3092e-38, -6.3244e-38, -6.3212e-38],
          [-6.3698e-38,  6.3242e-38, -6.3677e-38],
          [-6.3169e-38,  6.3729e-38,  6.3288e-38]],

         [[-3.5386e-38,  1.2616e-38, -5.8252e-38],
          [-1.7576e-38, -2.5798e-38,  6.8920e-40],
          [-7.8359e-39,  3.5415e-39,  1.1754e-38]]],


        [[[-5.3309e-02, -2.0281

Parameter containing:
tensor([ 3.7370e+00, -1.6608e-28, -4.4336e-30, -6.9426e-38, -3.4741e-31,
        -4.8131e-30, -4.5190e-34, -9.7877e-31,  4.6512e+00, -6.3717e-38,
        -5.3462e-33, -5.1962e-34, -4.3980e-30, -5.3043e-30, -3.2782e-32,
        -9.3956e-01, -1.1829e-32, -1.1446e-30, -1.1206e-30, -2.5500e-36,
        -1.3974e-37, -1.5357e-32, -1.7640e-11, -5.0600e-27, -1.6878e-29,
        -5.3648e-37, -1.6035e-32,  7.2239e+00,  5.9288e+00, -2.2825e-27,
        -1.3553e-29, -6.9376e-38, -1.6098e-31, -6.3672e-38, -2.3834e-01,
        -9.7840e-31, -6.8637e-38, -6.5131e-32, -1.0732e-35, -2.3307e-28,
        -6.3080e-38, -4.4541e-22, -3.8815e-33, -9.0505e-29, -3.3507e-28,
         1.4786e+00, -1.6226e-29, -6.3332e-38, -1.0472e-29, -4.6147e-36,
        -2.5632e-30, -1.5319e-28, -1.9040e-28, -1.9082e+00, -2.3032e-30,
        -6.3418e-38, -6.9417e-38, -6.3434e-38, -1.8033e-30, -1.4727e+01,
        -8.8957e-37, -1.8243e-34, -6.3196e-38, -2.3615e-25, -2.5533e-35,
        -1.4369e-31, -1.5952e

Parameter containing:
tensor([[[[-5.0792e-38, -5.0580e-38, -5.1962e-38],
          [-5.3317e-38, -4.7550e-38, -5.3286e-38],
          [-5.4062e-38, -4.9121e-38, -5.0170e-38]],

         [[-1.0287e-38,  5.2217e-39, -9.5759e-39],
          [-4.1545e-39, -6.6601e-39, -2.1093e-38],
          [-1.3885e-39, -1.0717e-38, -2.5224e-38]],

         [[-2.6616e-39, -9.7807e-40, -1.6534e-38],
          [ 1.1300e-38,  1.3559e-38,  4.0219e-39],
          [ 2.0701e-38,  2.5078e-38,  1.7249e-38]],

         ...,

         [[-5.9163e-39, -7.5285e-40, -4.1934e-39],
          [-8.6309e-39, -4.6225e-39, -6.1340e-39],
          [-7.5461e-39, -1.4828e-39, -9.9241e-39]],

         [[-1.5317e-38, -2.9550e-38, -3.5621e-38],
          [ 5.3839e-39, -5.4116e-39, -1.7626e-38],
          [ 4.4623e-38,  7.6208e-39, -4.6933e-39]],

         [[-2.9332e-38, -3.9954e-38, -6.8098e-39],
          [-3.1106e-38,  6.3205e-38, -1.6688e-38],
          [-1.1650e-38, -2.1125e-38, -4.5473e-39]]],


        [[[-5.7356e-38, -1.1788

In [52]:
for ss in model.named_parameters():
    #print(ss[0] ,"=>" , ss[1].device)
    print(ss)
    break

('a', Parameter containing:
tensor(-1.0764, device='cuda:0', requires_grad=True))


In [61]:
for ss in model.check_state(3) :
    print(ss[0] ," : ", ss[1].shape , " : " , ss[1].requires_grad)

inc.double_conv.0.weight  :  torch.Size([64, 3, 3, 3])  :  False
inc.double_conv.0.bias  :  torch.Size([64])  :  False
inc.double_conv.1.weight  :  torch.Size([64])  :  False
inc.double_conv.1.bias  :  torch.Size([64])  :  False
inc.double_conv.3.weight  :  torch.Size([64, 64, 3, 3])  :  False
inc.double_conv.3.bias  :  torch.Size([64])  :  False
inc.double_conv.4.weight  :  torch.Size([64])  :  False
inc.double_conv.4.bias  :  torch.Size([64])  :  False
down1.maxpool_conv.1.double_conv.0.weight  :  torch.Size([128, 64, 3, 3])  :  False
down1.maxpool_conv.1.double_conv.0.bias  :  torch.Size([128])  :  False
down1.maxpool_conv.1.double_conv.1.weight  :  torch.Size([128])  :  False
down1.maxpool_conv.1.double_conv.1.bias  :  torch.Size([128])  :  False
down1.maxpool_conv.1.double_conv.3.weight  :  torch.Size([128, 128, 3, 3])  :  False
down1.maxpool_conv.1.double_conv.3.bias  :  torch.Size([128])  :  False
down1.maxpool_conv.1.double_conv.4.weight  :  torch.Size([128])  :  False
down1.ma

In [62]:
summary(model)

Layer (type:depth-idx)                             Param #
Im_Seg                                             4
├─UNet: 1-1                                        --
│    └─DoubleConv: 2-1                             --
│    │    └─Sequential: 3-1                        (38,976)
│    └─DownConv: 2-2                               --
│    │    └─Sequential: 3-2                        (221,952)
│    └─DownConv: 2-3                               --
│    │    └─Sequential: 3-3                        (886,272)
│    └─DownConv: 2-4                               --
│    │    └─Sequential: 3-4                        (3,542,016)
│    └─DownConv: 2-5                               --
│    │    └─Sequential: 3-5                        (4,721,664)
│    └─UpConv: 2-6                                 --
│    │    └─Upsample: 3-6                          --
│    │    └─DoubleConv: 3-7                        (5,900,544)
│    └─UpConv: 2-7                                 --
│    │    └─Upsample: 3-8      

In [63]:
fit_dic1 = {"model" : model ,
           "epochs" : EPOCHS ,
           "lr" : lr ,
           "train_loader" : train_loader ,
           "val_loader" : val_loader ,
           "opt_func" : torch.optim.Adam  ,
           "loss_func" : torch.nn.functional.binary_cross_entropy}

In [64]:
history = fit(**fit_dic1)

Adjusting learning rate of group 0 to 1.0000e-02.


num_epochs:   0%|          | 0/30 [00:00<?, ?it/s]

num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 9.5000e-03.
train_loss : 0.2926328778266907  , val_loss : 0.30945128202438354


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 9.0250e-03.
train_loss : 0.2256525307893753  , val_loss : 0.2388162612915039


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 8.5737e-03.
train_loss : 0.18312960863113403  , val_loss : 0.19426308572292328


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 8.1451e-03.
train_loss : 0.15701332688331604  , val_loss : 0.16732068359851837


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 7.7378e-03.
train_loss : 0.14075054228305817  , val_loss : 0.1508336365222931


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 7.3509e-03.
train_loss : 0.13045820593833923  , val_loss : 0.14015363156795502


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 6.9834e-03.
train_loss : 0.12355322390794754  , val_loss : 0.13304048776626587


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 6.6342e-03.
train_loss : 0.11896174401044846  , val_loss : 0.12813889980316162


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 6.3025e-03.
train_loss : 0.11528358608484268  , val_loss : 0.12464191019535065


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.9874e-03.
train_loss : 0.11267159134149551  , val_loss : 0.1220875084400177


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.6880e-03.
train_loss : 0.11067508161067963  , val_loss : 0.12017067521810532


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.4036e-03.
train_loss : 0.10940771549940109  , val_loss : 0.11873199045658112


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.1334e-03.
train_loss : 0.10839668661355972  , val_loss : 0.1176377683877945


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 4.8767e-03.
train_loss : 0.10722614079713821  , val_loss : 0.11678332090377808


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 4.6329e-03.
train_loss : 0.1065339595079422  , val_loss : 0.11614169180393219


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 4.4013e-03.
train_loss : 0.10645602643489838  , val_loss : 0.11563754081726074


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 4.1812e-03.
train_loss : 0.1052846610546112  , val_loss : 0.11520486325025558


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 3.9721e-03.
train_loss : 0.10538236796855927  , val_loss : 0.11487032473087311


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 3.7735e-03.
train_loss : 0.1052083820104599  , val_loss : 0.11464006453752518


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 3.5849e-03.
train_loss : 0.1046314388513565  , val_loss : 0.1144440770149231


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 3.4056e-03.
train_loss : 0.10456680506467819  , val_loss : 0.11425723135471344


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 3.2353e-03.
train_loss : 0.10429511964321136  , val_loss : 0.11411351710557938


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 3.0736e-03.
train_loss : 0.10416863858699799  , val_loss : 0.11403296142816544


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 2.9199e-03.
train_loss : 0.10388115793466568  , val_loss : 0.11393269151449203


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 2.7739e-03.
train_loss : 0.10412395000457764  , val_loss : 0.11385174095630646


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 2.6352e-03.
train_loss : 0.10430403798818588  , val_loss : 0.11380381882190704


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 2.5034e-03.
train_loss : 0.10411995649337769  , val_loss : 0.11377391219139099


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 2.3783e-03.
train_loss : 0.1036229208111763  , val_loss : 0.11374207586050034


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 2.2594e-03.
train_loss : 0.10368048399686813  , val_loss : 0.11368817090988159


num_batches:   0%|          | 0/120 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 2.1464e-03.
train_loss : 0.10363360494375229  , val_loss : 0.1136724054813385


In [65]:
#torch.save(model.state_dict(), "final_model.pth")  
#print("Saved model to final_model.pth")
#history = np.array([(x111[0].item() , x111[1].item()) for x111 in history])
#np.savetxt('history4.csv', history, delimiter=',')

Saved model to final_model.pth


In [8]:
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Im_Seg(device)

In [9]:
#evaluate(model , test_loader ,torch.nn.functional.binary_cross_entropy )

In [10]:
model.to(device=device)
model.load_state_dict(torch.load("final_model.pth", map_location=device))

<All keys matched successfully>

In [11]:
#evaluate(model , test_loader ,torch.nn.functional.binary_cross_entropy )

In [19]:
def check_final(mm , threshold ,model):
    """
    img => torch tensor
    out => numpy image
    """
    #print(mm)
    #img = torch.unsqueeze(img , 0)
    #img = img.to(device=device, dtype=torch.float32)
    #print(img.shape)
#     with torch.no_grad():
#         x = mm[0]
#         y = mm[1]
#         PIL.show(ToPILImage()(x[0]))
#         PIL.show(ToPILImage()(x[1]))
#         PIL.show(ToPILImage()(x[2]))
#         PIL.show(ToPILImage())
#         for j in range(3):
#             x[j] = x[j].to(device , dtype = torch.float32)
#         y = y.to(device , dtype = torch.float32)
#             loss = loss_func(model(*x), y)
#         out_img = torch.sigmoid(model(img))
#         out_img = out_img.squeeze().cpu().numpy()
        
#     print(out_img.shape)    
    return mm 

In [20]:
xx = check_final(test_ds[0] , 1 , model)

In [28]:
len(xx) , len(xx[0]) , xx[1].shape , xx[0][0].shape , xx[0][1].shape , xx[0][2].shape

(2,
 3,
 torch.Size([1, 512, 512]),
 torch.Size([3, 512, 512]),
 torch.Size([3, 512, 512]),
 torch.Size([3, 512, 512]))

In [14]:
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
def get_inputs(ds):
    im1 ,im2 ,im3 = np.array(ToPILImage()(ds[0][0])) , np.array(ToPILImage()(ds[0][1])) , np.array(ToPILImage()(ds[0][2]))
    plt.figure(figsize = (20 ,10))
    plt.subplot(1 , 3 , 1)
    plt.imshow(im1)
    plt.subplot(1,3,2)
    plt.imshow(im2)
    plt.subplot(1,3,3)
    plt.imshow(im3)
    plt.show()

In [None]:
get_inputs(test_ds[0])