In [1]:
# Imports
import torch
import os
import numpy as np
import pandas as pd
import copy
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, ConcatDataset
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score
from tqdm import tqdm 
from sklearn import metrics

from utils.dataGen import Patches
from utils.datasets import *
from ResNet.ResNet_3D_3 import ResNet_3D as CNN, ResidualBlock

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.device_count(), torch.cuda.get_device_name()

(1, 'NVIDIA RTX A4000')

In [3]:
# Hyperparameters

in_channels = 1
classes = ["background", "leaf", "diseased"]
num_classes = len(classes)
weights = [1.0, 1.0, 1.0]
rgb_bands = [7, 15, 32]
# rgb_bands = [0]
img_size = 20
learning_rate = 1e-06
lr_step = 30
batch_size = 128
num_epochs = 50
frac = .7
save_dataset = True
augment = True
count = False
kfold = 5

In [4]:
# Init dataset

p_14 = Patches("D:\\gyeongsang_22_10_14\\PATCHES\\STACK_75_BANDS_14")
p_14_h = Patches("D:\\gyeongsang_22_10_14\\PATCHES\\STACK_75_BANDS_14_2")
p_21 = Patches("D:\\gyeongsang_22_10_21_1\\PATCHES\\STACK_75_BANDS_21_1")
p_25_h = Patches("D:\\gyeongsang_22_10_25\\PATCHES\\STACK_75_BANDS_25_healthy")
p_25 = Patches("D:\\gyeongsang_22_10_25\\PATCHES\\STACK_75_BANDS_25")
p_28 = Patches("D:\\gyeongsang_22_10_28\\PATCHES\\STACK_75_BANDS_28")
p_28_h = Patches("D:\\gyeongsang_22_10_28\\PATCHES\\STACK_75_BANDS_28_healthy")

In [5]:
print(p_14.describe(),
      p_14_h.describe(),
      p_21.describe(),
      p_25_h.describe(),
      p_25.describe(),
      p_28.describe(),
      p_28_h.describe())

For hdr image, there are: 
 background    551
leaf          489
diseased       48
Name: class, dtype: int64 

For hdr image, there are: 
 leaf        668
diseased     87
Name: class, dtype: int64 

For hdr image, there are: 
 background    617
leaf          576
diseased       61
Name: class, dtype: int64 

For hdr image, there are: 
 leaf    337
Name: class, dtype: int64 

For hdr image, there are: 
 leaf          663
background    650
diseased      117
Name: class, dtype: int64 

For hdr image, there are: 
 background    1148
leaf           710
diseased        96
Name: class, dtype: int64 

For hdr image, there are: 
 leaf    271
Name: class, dtype: int64 

None None None None None None None


In [6]:
# Get dataset

patch_14_hd = p_14_h.generateDataset([0, 300, 87])
patch_14_b = p_14.generateDataset([300, 0, 0])
patch_14 = pd.concat([patch_14_hd, patch_14_b])
           
# Shuffle dataset

patch_14 = patch_14.sample(frac=1)

Dataset has 387 samples [1 2] class indices included
   Example: D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_14_2\hdr\leaf\89_275.HDR
   Unique: True


Dataset has 300 samples [0] class indices included
   Example: D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_14\hdr\background\103_404.HDR
   Unique: True




In [7]:
def splitDatasetKfold(patch, kfold, dtype='hdr', train_frac=0.7):   
    patch_kfold = []
    kfold_recs  = patch.copy()
    
    for k in range(kfold):
        tr = pd.DataFrame(columns=["path", "type", "class"])
        vl = tr.copy()
        for i in range(num_classes):
            tmp = patch.loc[(patch['class'] == i) & 
                             (patch['type'] == dtype)]
            tmp_tr = tmp.sample(frac=train_frac, replace=False, random_state=k)
            tmp_vl = tmp.drop(tmp_tr.index)

            tr = pd.concat([tr, tmp_tr])
            vl = pd.concat([vl, tmp_vl])
            
        kfold_recs.loc[tr.index, f'k{k}'] = 0
        kfold_recs.loc[vl.index, f'k{k}'] = 1
            
        patch_kfold.append([tr, vl])

    # print((kfold_recs.loc[:, [f'k{i}' for i in range(kfold)]].sum(axis=1) == 0).sum())
    
    return patch_kfold, kfold_recs

In [8]:
def evaluate(loader, model):
    actuals = np.array([])
    predictions = np.array([])
    evaluations = np.array([])
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            _, model_pred = scores.max(1)
            actuals = np.append(actuals, y.cpu().detach().numpy())
            predictions = np.append(predictions, model_pred.cpu().detach().numpy())
    
    for i in range(num_classes):
        ac = (actuals == i)
        pr = (predictions == i)

        ca = (ac & pr).sum() / ac.sum()
        evaluations = np.append(evaluations, ca)
        
    oa = (actuals == predictions).sum() / len(actuals)
    evaluations = np.append(evaluations, oa)
    
    model.train()

    return evaluations



In [9]:
# Predict 

def predictDataset(loader, model, savefig=False):
    actual = np.array([])
    prediction = np.array([])
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            _, predictions = scores.max(1)
            actual = np.append(actual, y.cpu().detach().numpy())
            prediction = np.append(prediction, predictions.cpu().detach().numpy())
    
    return actual, prediction

In [10]:
patch_14_kfold, kfold_recs = splitDatasetKfold(patch_14, kfold)

In [11]:
for kfold_count, p in enumerate(patch_14_kfold):

    # Patch wrapper
    img_14 = wrapPatch(p)

    # Define training images
    train_img = img_14[0]
    val_img = img_14[1]

    if count:
        countImg(train_img, classes)
        countImg(val_img, classes)
        
    # Augment images 
    if augment:
        img_14_aug1 = augmentPatch2(img_14, [0, 0, 5])

        train_img = ConcatDataset([train_img] + img_14_aug1[0])
        val_img = ConcatDataset([val_img] + img_14_aug1[1])   

        if count:
            countImg(train_img, classes)
            countImg(val_img, classes)
            
    # Load images
    loaders = imagesLoader([[train_img, True],
                            [val_img  , False],
                            [img_14[1], False],
                            [img_14   , True]], batch_size)

    train_loader = loaders[0]
    test_loader = loaders[1]

    print(next(iter(train_loader))[0].shape)
    
    # Initialize network
    # model = CNN(in_channels=in_channels, num_classes=num_classes).to(device)
    # model = CNN(ResidualBlock, [1, 2, 2, 2]).to(device)
    model = CNN(ResidualBlock).to(device)
    print("Num of trainable param:", sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Loss and optimizer
    class_weights = torch.FloatTensor(weights).cuda()
    criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1)
    
    # Check the model
    with torch.no_grad():
        for x, y in train_loader:
            x = x.to(device=device)
            y = y.to(device=device)
            x_ = model(x)
            print(x_[0])
            break
            
    # Evaluate the model
    print(evaluate(train_loader, model), evaluate(test_loader, model))
    
    # Make a list to record training progress
    eval_ent = [test_loader, train_loader, loaders[2]]
    eval_hist = [[[] for i in range(2 + num_classes)]for i in range(len(eval_ent))]
    highest_acc = 0
    
    # Train the network
    for epoch in range(num_epochs):
        batch_loss = np.array([])
        
        print("Epoch: ", epoch)
        
        for batch_idx, (data, targets) in loop:
            data = data.to(device=device)
            targets = targets.to(device=device)

            scores = model(data)
            loss = criterion(scores, targets)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            predictions = torch.argmax(scores, dim=1)

            # loop.set_description(f"Epoch[{epoch}/{num_epochs - 1}]")

            batch_loss = np.append(batch_loss, loss.item())

        scheduler.step()
        eval_hist[0][0].append(np.average(batch_loss))

        for ee, eh in zip(eval_ent, eval_hist):
            ev = evaluate(ee, model)
            for i in range(len(ev)):
                eh[i + 1].append(ev[i])

        if eval_hist[2][-1][-1] > highest_acc:
            torch.save(model.state_dict(), os.getcwd() + f"\\ResNet3D_노균병_k{kfold_count}_best.pt")
            highest_acc = eval_hist[2][-1][-1]        
            
    # Load weights
    model.load_state_dict(torch.load(os.getcwd() + f"\\ResNet3D_노균병_k{kfold_count}_best.pt"))
    
    # Predict
    for dataset in loaders:
        print(evaluate(dataset, model), "\n")

Training set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
Validation set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
torch.Size([128, 1, 55, 20, 20])
Num of trainable param: 18946371
tensor([ 0.2367,  0.7296, -0.4634], device='cuda:0')
[0.         1.         0.         0.26717557] [0.         1.         0.         0.26785714]


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:17<00:00,  2.54s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.72s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.70s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.74s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.73s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.69s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.68s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|███████████████████████████████████

[0.99047619 0.89047619 0.87431694 0.90966921] 

[0.97777778 0.91111111 0.83974359 0.89583333] 

[0.97777778 0.91111111 0.15384615 0.84466019] 

[0.98666667 0.89666667 0.2183908  0.85007278] 

Training set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
Validation set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
torch.Size([128, 1, 55, 20, 20])
Num of trainable param: 18946371
tensor([-0.0020,  0.3714, -0.0098], device='cuda:0')
[0.         0.87142857 0.20765027 0.32951654] [0.         0.81111111 0.18589744 0.30357143]


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.69s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.69s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.69s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.68s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.68s/it]
100%|███████████████████████████████████

[1.         0.86666667 0.80874317 0.87531807] 

[0.98888889 0.9        0.79487179 0.875     ] 

[0.98888889 0.9        0.19230769 0.84951456] 

[0.99666667 0.87666667 0.16091954 0.83842795] 

Training set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
Validation set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
torch.Size([128, 1, 55, 20, 20])
Num of trainable param: 18946371
tensor([ 0.2791, -0.2581,  0.4196], device='cuda:0')
[0.         1.         0.         0.26717557] [0.         1.         0.         0.26785714]


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|███████████████████████████████████

[0.99047619 0.88571429 0.86885246 0.90585242] 

[0.97777778 0.9        0.91666667 0.92857143] 

[0.97777778 0.9        0.5        0.88349515] 

[0.98666667 0.89       0.31034483 0.8588064 ] 

Training set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
Validation set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
torch.Size([128, 1, 55, 20, 20])
Num of trainable param: 18946371
tensor([-0.0501,  0.3423, -0.8187], device='cuda:0')
[0.13333333 0.45714286 0.25409836 0.27608142] [0.18888889 0.55555556 0.21794872 0.30059524]


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.74s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.73s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.81s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.77s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.69s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.68s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.71s/it]
100%|███████████████████████████████████

[0.98095238 0.87142857 0.8989071  0.91348601] 

[1.         0.78888889 0.88461538 0.88988095] 

[1.         0.78888889 0.30769231 0.82038835] 

[0.98666667 0.84666667 0.36781609 0.84716157] 

Training set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
Validation set
class index: 2
	 Rotated 60 deg
	 Rotated 120 deg
	 Rotated 180 deg
	 Rotated 240 deg
	 Rotated 300 deg
torch.Size([128, 1, 55, 20, 20])
Num of trainable param: 18946371
tensor([-0.0821, -0.1007, -0.0284], device='cuda:0')
[0.03333333 0.         1.         0.47455471] [0.03333333 0.         1.         0.47321429]


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.66s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.68s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.66s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.66s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:11<00:00,  1.68s/it]
100%|███████████████████████████████████

[0.98095238 0.9        0.8715847  0.90839695] 

[1.         0.92222222 0.83974359 0.9047619 ] 

[1.         0.92222222 0.11538462 0.85436893] 

[0.98666667 0.90666667 0.1954023  0.85152838] 



In [20]:
kfold_preds = patch_14.copy()

# to img
img = wrapPatch([patch_14, patch_14])

# load
loader = imagesLoader([[img[0], False]], batch_size)[0]

for k in range(kfold):
    # init model
    model = CNN(ResidualBlock).to(device)
    model.load_state_dict(torch.load(os.getcwd() + f"\\ResNet3D_노균병_k{k}_best.pt"))
    
    # predict
    actual, prediction = predictDataset(loader, model)

    kfold_preds.loc[:, f'k{k}'] = prediction    

In [25]:
kfold_preds.to_csv("kfold_preds.csv")
kfold_recs.to_csv("kfold_recs.csv")

In [26]:
kfold_recs

Unnamed: 0,path,type,class,k0,k1,k2,k3,k4
7,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,0.0,1.0,0.0,1.0,1.0
747,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,1,0.0,0.0,0.0,1.0,1.0
12,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,1.0,1.0,1.0,1.0,1.0
82,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,0.0,0.0,0.0,1.0,1.0
56,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,1.0,0.0,1.0,1.0,1.0
...,...,...,...,...,...,...,...,...
474,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,0,0.0,0.0,0.0,0.0,0.0
301,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,0,1.0,0.0,0.0,0.0,1.0
21,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,1.0,0.0,0.0,1.0,1.0
386,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,0,0.0,0.0,1.0,1.0,0.0


In [27]:
kfold_preds

Unnamed: 0,path,type,class,k0,k1,k2,k3,k4
7,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,1.0,1.0,1.0,1.0,1.0
747,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,1,1.0,1.0,1.0,1.0,1.0
12,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,1.0,1.0,1.0,1.0,1.0
82,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,1.0,1.0,1.0,1.0,1.0
56,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,2.0,2.0,2.0,2.0,2.0
...,...,...,...,...,...,...,...,...
474,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,0,0.0,0.0,0.0,0.0,0.0
301,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,0,0.0,0.0,0.0,0.0,0.0
21,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,2,1.0,1.0,2.0,2.0,1.0
386,D:\gyeongsang_22_10_14\PATCHES\STACK_75_BANDS_...,hdr,0,0.0,0.0,0.0,0.0,0.0
