In [65]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import os
from glob import glob
from tqdm import tqdm
import albumentations as A
from sklearn.model_selection import train_test_split, StratifiedKFold


# PlatVillage

In [62]:
village_data = os.listdir('./data/public/PlantVillage')
label_encoder = {}
for idx, data_name in enumerate(village_data) :
    label_encoder[idx] = data_name

label_decoder = {val:key for key, val in label_encoder.items()}
display(label_decoder)
display(label_encoder)

{'Pepper__bell___Bacterial_spot': 0,
 'Pepper__bell___healthy': 1,
 'Potato___Early_blight': 2,
 'Potato___healthy': 3,
 'Potato___Late_blight': 4,
 'Tomato_Bacterial_spot': 5,
 'Tomato_Early_blight': 6,
 'Tomato_healthy': 7,
 'Tomato_Late_blight': 8,
 'Tomato_Leaf_Mold': 9,
 'Tomato_Septoria_leaf_spot': 10,
 'Tomato_Spider_mites_Two_spotted_spider_mite': 11,
 'Tomato__Target_Spot': 12,
 'Tomato__Tomato_mosaic_virus': 13,
 'Tomato__Tomato_YellowLeaf__Curl_Virus': 14}

{0: 'Pepper__bell___Bacterial_spot',
 1: 'Pepper__bell___healthy',
 2: 'Potato___Early_blight',
 3: 'Potato___healthy',
 4: 'Potato___Late_blight',
 5: 'Tomato_Bacterial_spot',
 6: 'Tomato_Early_blight',
 7: 'Tomato_healthy',
 8: 'Tomato_Late_blight',
 9: 'Tomato_Leaf_Mold',
 10: 'Tomato_Septoria_leaf_spot',
 11: 'Tomato_Spider_mites_Two_spotted_spider_mite',
 12: 'Tomato__Target_Spot',
 13: 'Tomato__Tomato_mosaic_virus',
 14: 'Tomato__Tomato_YellowLeaf__Curl_Virus'}

# custom dataset 정의

In [63]:
class VillageDataset(Dataset) :
    def __init__(self, files, transform, mode='train') :
        super(VillageDataset, self).__init__()
        self.files = files
        self.transform = transform
        self.mode = mode
        
    def __len__(self) :
        return len(self.files)
    
    def __getitem__(self, idx) :
        file_path = self.files[idx]
        
        label = label_decoder[file_path.split('\\')[-2]]
        
        img = cv2.imread(file_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(image=img)['image']
        img = img.transpose(2, 0, 1)
        
        return torch.tensor(img, dtype=torch.float32) / 255.0, torch.tensor(label, dtype=torch.long)

# Dataloader w/o kfold

In [85]:
train_transforms = A.Compose([
                A.Resize(224 ,224),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])

val_transforms = A.Compose([
    A.Resize(224,224)
])

train = glob('./data/public/PlantVillage/*/*.JPG')
print("total : ", len(train))
label_list = [label_decoder[img_path.split('\\')[-2]] for img_path in img_list]

train, val = train_test_split(train, test_size=0.2, shuffle=True, stratify=label_list)
print("train : ", len(train))
print("val : ", len(val))
display(train)
train_dataset = VillageDataset(train, train_transforms)
val_dataset = VillageDataset(val, val_transforms)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

total :  20636
train :  16508
val :  4128


['./data/public/PlantVillage\\Tomato_Late_blight\\7d8f07cb-cb5b-4c4a-97f9-224258af8343___GHLB2 Leaf 8791.JPG',
 './data/public/PlantVillage\\Tomato_Bacterial_spot\\01a46cb5-d354-4f59-868e-e56186701541___GCREC_Bact.Sp 5638.JPG',
 './data/public/PlantVillage\\Tomato_Early_blight\\c63ff9c7-0422-4a71-8727-7c2581fdee6e___RS_Erly.B 9406.JPG',
 './data/public/PlantVillage\\Tomato__Tomato_YellowLeaf__Curl_Virus\\28d69631-a796-49b8-8213-2e17e48a09c4___YLCV_GCREC 5221.JPG',
 './data/public/PlantVillage\\Tomato_Spider_mites_Two_spotted_spider_mite\\8cfb09a9-81f5-4d7f-b5d2-459ae2cead76___Com.G_SpM_FL 8536.JPG',
 './data/public/PlantVillage\\Tomato__Target_Spot\\e2c3ecfd-9637-4158-9400-315f5949b7f4___Com.G_TgS_FL 7962.JPG',
 './data/public/PlantVillage\\Tomato_Septoria_leaf_spot\\619ccfba-56c4-4750-b920-7c7314f78b7c___Matt.S_CG 6908.JPG',
 './data/public/PlantVillage\\Tomato__Tomato_YellowLeaf__Curl_Virus\\100e0500-25e7-4ecf-bbff-47b59472f911___UF.GRC_YLCV_Lab 01282.JPG',
 './data/public/PlantVilla

# DataLoader w/ kfold

In [83]:
img_list = glob('./data/public/PlantVillage/*/*.JPG')
label_list = [label_decoder[img_path.split('\\')[-2]] for img_path in img_list]

kfold = StratifiedKFold(n_splits=4, random_state=13, shuffle=True)
for idx, (kfold_train, kfold_val) in enumerate(kfold.split(img_list, label_list), 1) :
    print(idx, len(kfold_train), len(kfold_val))
    print(kfold_train)
    print(kfold_val)

1 15477 5159
[    1     2     3 ... 20632 20634 20635]
[    0    16    25 ... 20616 20630 20633]
2 15477 5159
[    0     1     3 ... 20633 20634 20635]
[    2     5    10 ... 20615 20617 20625]
3 15477 5159
[    0     1     2 ... 20633 20634 20635]
[    6    12    14 ... 20610 20622 20632]
4 15477 5159
[    0     2     5 ... 20630 20632 20633]
[    1     3     4 ... 20631 20634 20635]


# Plant Doc

### image 이름에 , 있는거 제거

In [117]:
with open('./data/public/PlantDoc/train/_classes.txt', 'r') as f :
    classese = f.readlines()

before_name = []
modi_name = []
for i in range(1, len(classese)) :
    if ',' in classese[i].split(', ')[0] :
        split_name = classese[i].split(', ')
        before_name.append(split_name[0])
        modi_name.append(split_name[0].replace(',',''))
        
        classese[i] = classese[i].replace(before_name[-1], modi_name[-1])
        print(before_name[-1])
        

flat,1000x1000,075,f_u2_jpg.rf.66e349ed247b0dabe796bf7bc4821505.jpg
8_--Virus-Damaged-foliage-at-top-versus-healthy,-non-infected-bottom-foliage_jpg.rf.b88f190e1dc2af80677face44a88c31f.jpg
autumn,+blueberry+leaves_jpg.rf.e4e581e9845ac2581d4395de63e617e0.jpg


### 새로운 label txt 파일 생성

In [118]:
with open('./data/public/PlantDoc/train/_modi_classes.txt', 'w') as f :
    for data ina classese:
        f.write(data)

### image 이름도 변경해주기

In [120]:
root_path = './data/public/PlantDoc/train'
img_list = glob(root_path+'/*.JPG')

for i in range(0, len(before_name)) :
    before_index = img_list.index(root_path+'\\'+before_name[i])
    img_list[before_index] = root_path+'\\'+modi_name[i]
    print(root_path+'\\'+before_name[i])
    print(root_path+'\\'+modi_name[i])
    print()
    os.rename(root_path+'\\'+before_name[i], root_path+'\\'+modi_name[i])

./data/public/PlantDoc/train\flat,1000x1000,075,f_u2_jpg.rf.66e349ed247b0dabe796bf7bc4821505.jpg
./data/public/PlantDoc/train\flat1000x1000075f_u2_jpg.rf.66e349ed247b0dabe796bf7bc4821505.jpg

./data/public/PlantDoc/train\8_--Virus-Damaged-foliage-at-top-versus-healthy,-non-infected-bottom-foliage_jpg.rf.b88f190e1dc2af80677face44a88c31f.jpg
./data/public/PlantDoc/train\8_--Virus-Damaged-foliage-at-top-versus-healthy-non-infected-bottom-foliage_jpg.rf.b88f190e1dc2af80677face44a88c31f.jpg

./data/public/PlantDoc/train\autumn,+blueberry+leaves_jpg.rf.e4e581e9845ac2581d4395de63e617e0.jpg
./data/public/PlantDoc/train\autumn+blueberry+leaves_jpg.rf.e4e581e9845ac2581d4395de63e617e0.jpg



# data list 만들기

In [162]:
import pandas as pd

label_csv = pd.read_csv('./data/public/PlantDoc/train/_modi_classes.txt')
columns_name = list(label_csv.columns)
columns_name.remove('filename')
columns_name.insert(0, 'None')

label_encoder = {idx : val for idx, val in enumerate(columns_name)}
label_decoder = {val : key for key, val in label_encoder.items()}
display(label_encoder)
display(label_decoder)

{0: 'None',
 1: ' Apple Scab Leaf',
 2: ' Apple leaf',
 3: ' Apple rust leaf',
 4: ' Bell_pepper leaf',
 5: ' Bell_pepper leaf spot',
 6: ' Blueberry leaf',
 7: ' Cherry leaf',
 8: ' Corn Gray leaf spot',
 9: ' Corn leaf blight',
 10: ' Corn rust leaf',
 11: ' Peach leaf',
 12: ' Potato leaf',
 13: ' Potato leaf early blight',
 14: ' Potato leaf late blight',
 15: ' Raspberry leaf',
 16: ' Soyabean leaf',
 17: ' Soybean leaf',
 18: ' Squash Powdery mildew leaf',
 19: ' Strawberry leaf',
 20: ' Tomato Early blight leaf',
 21: ' Tomato Septoria leaf spot',
 22: ' Tomato leaf',
 23: ' Tomato leaf bacterial spot',
 24: ' Tomato leaf late blight',
 25: ' Tomato leaf mosaic virus',
 26: ' Tomato leaf yellow virus',
 27: ' Tomato mold leaf',
 28: ' Tomato two spotted spider mites leaf',
 29: ' grape leaf',
 30: ' grape leaf black rot'}

{'None': 0,
 ' Apple Scab Leaf': 1,
 ' Apple leaf': 2,
 ' Apple rust leaf': 3,
 ' Bell_pepper leaf': 4,
 ' Bell_pepper leaf spot': 5,
 ' Blueberry leaf': 6,
 ' Cherry leaf': 7,
 ' Corn Gray leaf spot': 8,
 ' Corn leaf blight': 9,
 ' Corn rust leaf': 10,
 ' Peach leaf': 11,
 ' Potato leaf': 12,
 ' Potato leaf early blight': 13,
 ' Potato leaf late blight': 14,
 ' Raspberry leaf': 15,
 ' Soyabean leaf': 16,
 ' Soybean leaf': 17,
 ' Squash Powdery mildew leaf': 18,
 ' Strawberry leaf': 19,
 ' Tomato Early blight leaf': 20,
 ' Tomato Septoria leaf spot': 21,
 ' Tomato leaf': 22,
 ' Tomato leaf bacterial spot': 23,
 ' Tomato leaf late blight': 24,
 ' Tomato leaf mosaic virus': 25,
 ' Tomato leaf yellow virus': 26,
 ' Tomato mold leaf': 27,
 ' Tomato two spotted spider mites leaf': 28,
 ' grape leaf': 29,
 ' grape leaf black rot': 30}

In [169]:
root_path = './data/public/PlantDoc/train'

with open('./data/public/PlantDoc/train/_modi_classes.txt', 'r') as f :
    classese = f.readlines()

label_list = []
img_list = []
for i in range(1, len(classese)) :
    img_list.append(root_path + '/' + classese[i].split(', ')[0])
    if '1' in classese[i].split(', ') :
        label_list.append(classese[i].split(', ').index('1'))
    elif '1\n' in classese[i].split(', ') :
        label_list.append(classese[i].split(', ').index('1\n'))
    else : 
        label_list.append(0)

print(img_list)
print(label_list)

['./data/public/PlantDoc/train/2017-08-27%2B-%2BLate%2BBlight%2Bon%2BGypsy%2B1_jpg.rf.00584c9d14b6fda3086f9a482675d610.jpg', './data/public/PlantDoc/train/Slide1_preview_jpg.rf.00f57cfc188631c10e10e2b93a38d296.jpg', './data/public/PlantDoc/train/d-to-tylv-fo005-14496A31BE70553130A_jpg.rf.017e4558f1b4baeadc32a13beb178a6e.jpg', './data/public/PlantDoc/train/2-3%20Gray%20leaf%20spot%20BRUCE_jpg.rf.0040c2e382cf11216938e6f49b7d65ed.jpg', './data/public/PlantDoc/train/blueberry-leaf-isolated-white-background-blueberry-leaf-isolated-99227523_jpg.rf.00a825930e9396a34d6c3fc5cf86686e.jpg', './data/public/PlantDoc/train/Apple-Scab-image-02_jpg.rf.00cbc9a108dbdaadf4232b5392e3d3c8.jpg', './data/public/PlantDoc/train/Faske%20Southern%20rust%20of%20corn_jpg.rf.016116db58617610bb259b2965f41589.jpg', './data/public/PlantDoc/train/14456_img_jpg.rf.0052b88d849e614c6037a2834800b4f8.jpg', './data/public/PlantDoc/train/24154194959_fd4b42edde_b_jpg.rf.0012f9978a300c5635840700fa784280.jpg', './data/public/Pla