In [182]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import os
from glob import glob
from tqdm import tqdm
import albumentations as A
from sklearn.model_selection import train_test_split, StratifiedKFold


# PlatVillage

In [183]:
village_data = os.listdir('./data/public/PlantVillage')
label_encoder = {}
for idx, data_name in enumerate(village_data) :
    label_encoder[idx] = data_name

label_decoder = {val:key for key, val in label_encoder.items()}
display(label_decoder)
display(label_encoder)

{'Pepper__bell___Bacterial_spot': 0,
 'Pepper__bell___healthy': 1,
 'Potato___Early_blight': 2,
 'Potato___healthy': 3,
 'Potato___Late_blight': 4,
 'Tomato_Bacterial_spot': 5,
 'Tomato_Early_blight': 6,
 'Tomato_healthy': 7,
 'Tomato_Late_blight': 8,
 'Tomato_Leaf_Mold': 9,
 'Tomato_Septoria_leaf_spot': 10,
 'Tomato_Spider_mites_Two_spotted_spider_mite': 11,
 'Tomato__Target_Spot': 12,
 'Tomato__Tomato_mosaic_virus': 13,
 'Tomato__Tomato_YellowLeaf__Curl_Virus': 14}

{0: 'Pepper__bell___Bacterial_spot',
 1: 'Pepper__bell___healthy',
 2: 'Potato___Early_blight',
 3: 'Potato___healthy',
 4: 'Potato___Late_blight',
 5: 'Tomato_Bacterial_spot',
 6: 'Tomato_Early_blight',
 7: 'Tomato_healthy',
 8: 'Tomato_Late_blight',
 9: 'Tomato_Leaf_Mold',
 10: 'Tomato_Septoria_leaf_spot',
 11: 'Tomato_Spider_mites_Two_spotted_spider_mite',
 12: 'Tomato__Target_Spot',
 13: 'Tomato__Tomato_mosaic_virus',
 14: 'Tomato__Tomato_YellowLeaf__Curl_Virus'}

# custom dataset 정의

In [184]:
class VillageDataset(Dataset) :
    def __init__(self, files, transform, mode='train') :
        super(VillageDataset, self).__init__()
        self.files = files
        self.transform = transform
        self.mode = mode
        
    def __len__(self) :
        return len(self.files)
    
    def __getitem__(self, idx) :
        file_path = self.files[idx]
        
        label = label_decoder[file_path.split('\\')[-2]]
        
        img = cv2.imread(file_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(image=img)['image']
        img = img.transpose(2, 0, 1)
        
        return torch.tensor(img, dtype=torch.float32) / 255.0, torch.tensor(label, dtype=torch.long)

# Dataloader w/o kfold

In [188]:
train_transforms = A.Compose([
                A.Resize(224 ,224),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])

val_transforms = A.Compose([
    A.Resize(224,224)
])

train = glob('./data/public/PlantVillage/*/*.JPG')
print("total : ", len(train))
label_list = [label_decoder[img_path.split('\\')[-2]] for img_path in train]

train, val = train_test_split(train, test_size=0.2, shuffle=True, stratify=label_list)
print("train : ", len(train))
print("val : ", len(val))
display(train)
train_dataset = VillageDataset(train, train_transforms)
val_dataset = VillageDataset(val, val_transforms)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

total :  20636
train :  16508
val :  4128


['./data/public/PlantVillage\\Pepper__bell___Bacterial_spot\\94c2cf91-17c3-4ee1-b8f2-b20103e1c14f___NREC_B.Spot 9094.JPG',
 './data/public/PlantVillage\\Tomato_Early_blight\\983a04ad-2cc6-4d78-9365-a67012b4ffdc___RS_Erly.B 7661.JPG',
 './data/public/PlantVillage\\Potato___Early_blight\\17a06d03-8a7b-48a8-aaf9-3300741c65de___RS_Early.B 7625.JPG',
 './data/public/PlantVillage\\Tomato_Spider_mites_Two_spotted_spider_mite\\f88f405a-176b-479c-9028-913957456752___Com.G_SpM_FL 9638.JPG',
 './data/public/PlantVillage\\Tomato__Target_Spot\\ec7551da-66db-48cf-a1ab-2f9589f947d6___Com.G_TgS_FL 8095.JPG',
 './data/public/PlantVillage\\Tomato_healthy\\4f5dde42-a6ac-4886-8b9e-a944568ace95___RS_HL 0080.JPG',
 './data/public/PlantVillage\\Tomato__Tomato_YellowLeaf__Curl_Virus\\63476e74-7aae-4e1c-b550-89e4be34149f___UF.GRC_YLCV_Lab 08428.JPG',
 './data/public/PlantVillage\\Tomato__Target_Spot\\bb4a96a1-3f00-4c1c-aa56-c88620a64aa4___Com.G_TgS_FL 7938.JPG',
 './data/public/PlantVillage\\Tomato_healthy\\bb

# DataLoader w/ kfold

In [83]:
img_list = glob('./data/public/PlantVillage/*/*.JPG')
label_list = [label_decoder[img_path.split('\\')[-2]] for img_path in img_list]

kfold = StratifiedKFold(n_splits=4, random_state=13, shuffle=True)
for idx, (kfold_train, kfold_val) in enumerate(kfold.split(img_list, label_list), 1) :
    print(idx, len(kfold_train), len(kfold_val))
    print(kfold_train)
    print(kfold_val)

1 15477 5159
[    1     2     3 ... 20632 20634 20635]
[    0    16    25 ... 20616 20630 20633]
2 15477 5159
[    0     1     3 ... 20633 20634 20635]
[    2     5    10 ... 20615 20617 20625]
3 15477 5159
[    0     1     2 ... 20633 20634 20635]
[    6    12    14 ... 20610 20622 20632]
4 15477 5159
[    0     2     5 ... 20630 20632 20633]
[    1     3     4 ... 20631 20634 20635]


# Plant Doc

### image 이름에 , 있는거 제거

In [117]:
with open('./data/public/PlantDoc/train/_classes.txt', 'r') as f :
    classese = f.readlines()

before_name = []
modi_name = []
for i in range(1, len(classese)) :
    if ',' in classese[i].split(', ')[0] :
        split_name = classese[i].split(', ')
        before_name.append(split_name[0])
        modi_name.append(split_name[0].replace(',',''))
        
        classese[i] = classese[i].replace(before_name[-1], modi_name[-1])
        print(before_name[-1])
        

flat,1000x1000,075,f_u2_jpg.rf.66e349ed247b0dabe796bf7bc4821505.jpg
8_--Virus-Damaged-foliage-at-top-versus-healthy,-non-infected-bottom-foliage_jpg.rf.b88f190e1dc2af80677face44a88c31f.jpg
autumn,+blueberry+leaves_jpg.rf.e4e581e9845ac2581d4395de63e617e0.jpg


### 새로운 label txt 파일 생성

In [118]:
with open('./data/public/PlantDoc/train/_modi_classes.txt', 'w') as f :
    for data ina classese:
        f.write(data)

### image 이름도 변경해주기

In [120]:
root_path = './data/public/PlantDoc/train'
img_list = glob(root_path+'/*.JPG')

for i in range(0, len(before_name)) :
    before_index = img_list.index(root_path+'\\'+before_name[i])
    img_list[before_index] = root_path+'\\'+modi_name[i]
    print(root_path+'\\'+before_name[i])
    print(root_path+'\\'+modi_name[i])
    print()
    os.rename(root_path+'\\'+before_name[i], root_path+'\\'+modi_name[i])

./data/public/PlantDoc/train\flat,1000x1000,075,f_u2_jpg.rf.66e349ed247b0dabe796bf7bc4821505.jpg
./data/public/PlantDoc/train\flat1000x1000075f_u2_jpg.rf.66e349ed247b0dabe796bf7bc4821505.jpg

./data/public/PlantDoc/train\8_--Virus-Damaged-foliage-at-top-versus-healthy,-non-infected-bottom-foliage_jpg.rf.b88f190e1dc2af80677face44a88c31f.jpg
./data/public/PlantDoc/train\8_--Virus-Damaged-foliage-at-top-versus-healthy-non-infected-bottom-foliage_jpg.rf.b88f190e1dc2af80677face44a88c31f.jpg

./data/public/PlantDoc/train\autumn,+blueberry+leaves_jpg.rf.e4e581e9845ac2581d4395de63e617e0.jpg
./data/public/PlantDoc/train\autumn+blueberry+leaves_jpg.rf.e4e581e9845ac2581d4395de63e617e0.jpg



# data list 만들기
- none, 12, 17, 28번 클래스 삭제

In [283]:
import pandas as pd

label_csv = pd.read_csv('./data/public/PlantDoc/train/_modi_classes.txt')
columns_name = list(label_csv.columns)
columns_name.remove('filename')
columns_name.insert(0, 'None')

label_encoder = {idx : val.lstrip() for idx, val in enumerate(columns_name)}
label_decoder = {val : key for key, val in label_encoder.items()}
display(label_encoder)
display(label_decoder)

{0: 'None',
 1: 'Apple Scab Leaf',
 2: 'Apple leaf',
 3: 'Apple rust leaf',
 4: 'Bell_pepper leaf',
 5: 'Bell_pepper leaf spot',
 6: 'Blueberry leaf',
 7: 'Cherry leaf',
 8: 'Corn Gray leaf spot',
 9: 'Corn leaf blight',
 10: 'Corn rust leaf',
 11: 'Peach leaf',
 12: 'Potato leaf',
 13: 'Potato leaf early blight',
 14: 'Potato leaf late blight',
 15: 'Raspberry leaf',
 16: 'Soyabean leaf',
 17: 'Soybean leaf',
 18: 'Squash Powdery mildew leaf',
 19: 'Strawberry leaf',
 20: 'Tomato Early blight leaf',
 21: 'Tomato Septoria leaf spot',
 22: 'Tomato leaf',
 23: 'Tomato leaf bacterial spot',
 24: 'Tomato leaf late blight',
 25: 'Tomato leaf mosaic virus',
 26: 'Tomato leaf yellow virus',
 27: 'Tomato mold leaf',
 28: 'Tomato two spotted spider mites leaf',
 29: 'grape leaf',
 30: 'grape leaf black rot'}

{'None': 0,
 'Apple Scab Leaf': 1,
 'Apple leaf': 2,
 'Apple rust leaf': 3,
 'Bell_pepper leaf': 4,
 'Bell_pepper leaf spot': 5,
 'Blueberry leaf': 6,
 'Cherry leaf': 7,
 'Corn Gray leaf spot': 8,
 'Corn leaf blight': 9,
 'Corn rust leaf': 10,
 'Peach leaf': 11,
 'Potato leaf': 12,
 'Potato leaf early blight': 13,
 'Potato leaf late blight': 14,
 'Raspberry leaf': 15,
 'Soyabean leaf': 16,
 'Soybean leaf': 17,
 'Squash Powdery mildew leaf': 18,
 'Strawberry leaf': 19,
 'Tomato Early blight leaf': 20,
 'Tomato Septoria leaf spot': 21,
 'Tomato leaf': 22,
 'Tomato leaf bacterial spot': 23,
 'Tomato leaf late blight': 24,
 'Tomato leaf mosaic virus': 25,
 'Tomato leaf yellow virus': 26,
 'Tomato mold leaf': 27,
 'Tomato two spotted spider mites leaf': 28,
 'grape leaf': 29,
 'grape leaf black rot': 30}

In [284]:
root_path = './data/public/PlantDoc/train'

with open('./data/public/PlantDoc/train/_modi_classes.txt', 'r') as f :
    classese = f.readlines()

label_list = []
img_list = []
for i in range(1, len(classese)) :
    img_list.append(root_path + '/' + classese[i].split(', ')[0])
    if '1' in classese[i].split(', ') :
        label_list.append(classese[i].split(', ').index('1'))
    elif '1\n' in classese[i].split(', ') :
        label_list.append(classese[i].split(', ').index('1\n'))
    else : 
        label_list.append(0)

print(img_list)
print(label_list)

['./data/public/PlantDoc/train/2017-08-27%2B-%2BLate%2BBlight%2Bon%2BGypsy%2B1_jpg.rf.00584c9d14b6fda3086f9a482675d610.jpg', './data/public/PlantDoc/train/Slide1_preview_jpg.rf.00f57cfc188631c10e10e2b93a38d296.jpg', './data/public/PlantDoc/train/d-to-tylv-fo005-14496A31BE70553130A_jpg.rf.017e4558f1b4baeadc32a13beb178a6e.jpg', './data/public/PlantDoc/train/2-3%20Gray%20leaf%20spot%20BRUCE_jpg.rf.0040c2e382cf11216938e6f49b7d65ed.jpg', './data/public/PlantDoc/train/blueberry-leaf-isolated-white-background-blueberry-leaf-isolated-99227523_jpg.rf.00a825930e9396a34d6c3fc5cf86686e.jpg', './data/public/PlantDoc/train/Apple-Scab-image-02_jpg.rf.00cbc9a108dbdaadf4232b5392e3d3c8.jpg', './data/public/PlantDoc/train/Faske%20Southern%20rust%20of%20corn_jpg.rf.016116db58617610bb259b2965f41589.jpg', './data/public/PlantDoc/train/14456_img_jpg.rf.0052b88d849e614c6037a2834800b4f8.jpg', './data/public/PlantDoc/train/24154194959_fd4b42edde_b_jpg.rf.0012f9978a300c5635840700fa784280.jpg', './data/public/Pla

# class 별 갯수
- 10개 이하 class는 삭제 해줌
- 삭제해야할 label - 0, 12, 17, 28 

In [285]:
root_path = './data/public/PlantDoc/train'

with open('./data/public/PlantDoc/train/_modi_classes.txt', 'r') as f :
    classese = f.readlines()

label_list = []
img_list = []
for i in range(1, len(classese)) :
    img_list.append(root_path + '/' + classese[i].split(', ')[0])
    if '1' in classese[i].split(', ') :
        label_list.append(classese[i].split(', ').index('1'))
    elif '1\n' in classese[i].split(', ') :
        label_list.append(classese[i].split(', ').index('1\n'))
    else : 
        label_list.append(0)
label_cnt = {k : 0 for k, v in label_decoder.items()}

for i in label_list :
    label_cnt[label_encoder[i]] += 1

display(label_cnt)

{'None': 10,
 'Apple Scab Leaf': 83,
 'Apple leaf': 82,
 'Apple rust leaf': 78,
 'Bell_pepper leaf': 53,
 'Bell_pepper leaf spot': 62,
 'Blueberry leaf': 103,
 'Cherry leaf': 47,
 'Corn Gray leaf spot': 61,
 'Corn leaf blight': 178,
 'Corn rust leaf': 106,
 'Peach leaf': 101,
 'Potato leaf': 3,
 'Potato leaf early blight': 103,
 'Potato leaf late blight': 94,
 'Raspberry leaf': 112,
 'Soyabean leaf': 57,
 'Soybean leaf': 1,
 'Squash Powdery mildew leaf': 124,
 'Strawberry leaf': 87,
 'Tomato Early blight leaf': 79,
 'Tomato Septoria leaf spot': 137,
 'Tomato leaf': 56,
 'Tomato leaf bacterial spot': 101,
 'Tomato leaf late blight': 99,
 'Tomato leaf mosaic virus': 44,
 'Tomato leaf yellow virus': 68,
 'Tomato mold leaf': 86,
 'Tomato two spotted spider mites leaf': 0,
 'grape leaf': 57,
 'grape leaf black rot': 56}

### class 삭제

In [286]:
label_cnt_ = {label_decoder[k] : v for k, v in label_cnt.items()}
display(label_cnt_)
remove_list = []
for k, v in label_cnt_.items() :
    for i in range(0, len(label_list)) :
        if label_list[i] == k and v <= 10 :
            remove_list.append(i)
            print(label_list[i])
            print(img_list[i])
            
display(remove_list)
remove_list.sort(reverse=True)
display(remove_list)

print(len(img_list))
print(len(label_list))
for remove_num in remove_list :
    del img_list[remove_num]
    del label_list[remove_num]

print(len(img_list))
print(len(label_list))


### 0, 12, 17, 28 삭제해야할 label

{0: 10,
 1: 83,
 2: 82,
 3: 78,
 4: 53,
 5: 62,
 6: 103,
 7: 47,
 8: 61,
 9: 178,
 10: 106,
 11: 101,
 12: 3,
 13: 103,
 14: 94,
 15: 112,
 16: 57,
 17: 1,
 18: 124,
 19: 87,
 20: 79,
 21: 137,
 22: 56,
 23: 101,
 24: 99,
 25: 44,
 26: 68,
 27: 86,
 28: 0,
 29: 57,
 30: 56}

0
./data/public/PlantDoc/train/270412tglr-wild-strawberry-flowers-and-leaf-patch_jpg.rf.11e3d6fa8e0828ac716e6e40f92ebd6b.jpg
0
./data/public/PlantDoc/train/Hydrangea+%2527Claudie%2527%252C+Powdery+Mildew_JPG_jpg.rf.241a2ba81fae4dd8f4d42a639bf4c542.jpg
0
./data/public/PlantDoc/train/tomato_D4a-TobRingspotVirus-1000077_zoom_jpg.rf.2e3bf56fca79319caad79756a5a3489f.jpg
0
./data/public/PlantDoc/train/Downy%20mildew_JPG_jpg.rf.5cd0d81448f536b690c687fe43801ad8.jpg
0
./data/public/PlantDoc/train/raspberries_jpg.rf.6f265fc0cc6e20357544cda8c1e08794.jpg
0
./data/public/PlantDoc/train/powdery-mildew-erysiphe-plantani-on-young-sycamore-leaf-b372wn_jpg.rf.9f2fd157b98d7080ee62fda831e913e9.jpg
0
./data/public/PlantDoc/train/Black+Raspberry+Leaves+3_jpg.rf.a5118930e06e0d086213803cfdda04b3.jpg
0
./data/public/PlantDoc/train/raspberry-db_jpg.rf.b3eabfd7a40d314ed0bd68e867f3bf6b.jpg
0
./data/public/PlantDoc/train/aquilegia-powdery-mildew-erysiphe-aquilegiae-on-columbine-leaves-bga5xr_jpg.rf.bd049474ad6eb0e

[151, 324, 415, 827, 996, 1433, 1500, 1629, 1717, 1847, 1471, 1706, 1916, 36]

[1916, 1847, 1717, 1706, 1629, 1500, 1471, 1433, 996, 827, 415, 324, 151, 36]

2328
2328
2314
2314


### 삭제된 라벨만큼 라벨 번호 당기기

In [287]:
for i in range(0, len(label_list)) :
    print(label_list[i])
    if label_list[i] > 0 and label_list[i] < 12 :
        label_list[i] -= 1
    elif label_list[i] > 12 and label_list[i] < 17 :
        label_list[i] -= 2
    elif label_list[i] > 17 and label_list[i] < 28 :
        label_list[i] -= 3
    elif label_list[i] > 28 :
        label_list[i] -= 4
    print(label_list[i])
    print()

24
21

9
8

27
24

9
8

6
5

1
0

10
9

26
23

9
8

21
18

30
26

5
4

5
4

9
8

4
3

5
4

1
0

13
11

23
20

3
2

15
13

25
22

9
8

18
15

14
12

18
15

9
8

20
17

9
8

13
11

14
12

20
17

21
18

20
17

24
21

11
10

16
14

1
0

8
7

15
13

26
23

27
24

18
15

23
20

15
13

20
17

2
1

18
15

6
5

24
21

10
9

18
15

3
2

19
16

18
15

11
10

26
23

11
10

27
24

13
11

1
0

7
6

3
2

2
1

15
13

26
23

14
12

19
16

21
18

13
11

20
17

14
12

25
22

15
13

23
20

26
23

24
21

25
22

6
5

11
10

6
5

2
1

1
0

15
13

15
13

1
0

14
12

6
5

10
9

15
13

14
12

18
15

16
14

14
12

2
1

6
5

21
18

26
23

13
11

15
13

22
19

16
14

13
11

21
18

13
11

18
15

21
18

25
22

19
16

16
14

21
18

21
18

6
5

24
21

16
14

9
8

9
8

9
8

3
2

26
23

18
15

11
10

13
11

30
26

13
11

10
9

9
8

6
5

11
10

8
7

21
18

5
4

23
20

2
1

15
13

24
21

26
23

30
26

24
21

18
15

4
3

27
24

14
12

26
23

7
6

6
5

13
11

9
8

18
15

6
5

26
23

2
1

9
8

27
24

8
7

14
12

27
24

1
0



### label_decoder encoder 재선언

In [288]:
import pandas as pd

label_csv = pd.read_csv('./data/public/PlantDoc/train/_modi_classes.txt')
columns_name = list(label_csv.columns)
columns_name.remove('filename')
columns_name.remove(' Potato leaf') # 12
columns_name.remove(' Soybean leaf') # 17
columns_name.remove(' Tomato two spotted spider mites leaf') # 28

label_encoder = {idx : val.lstrip() for idx, val in enumerate(columns_name)}
label_decoder = {val : key for key, val in label_encoder.items()}
display(label_encoder)
display(label_decoder)

{0: 'Apple Scab Leaf',
 1: 'Apple leaf',
 2: 'Apple rust leaf',
 3: 'Bell_pepper leaf',
 4: 'Bell_pepper leaf spot',
 5: 'Blueberry leaf',
 6: 'Cherry leaf',
 7: 'Corn Gray leaf spot',
 8: 'Corn leaf blight',
 9: 'Corn rust leaf',
 10: 'Peach leaf',
 11: 'Potato leaf early blight',
 12: 'Potato leaf late blight',
 13: 'Raspberry leaf',
 14: 'Soyabean leaf',
 15: 'Squash Powdery mildew leaf',
 16: 'Strawberry leaf',
 17: 'Tomato Early blight leaf',
 18: 'Tomato Septoria leaf spot',
 19: 'Tomato leaf',
 20: 'Tomato leaf bacterial spot',
 21: 'Tomato leaf late blight',
 22: 'Tomato leaf mosaic virus',
 23: 'Tomato leaf yellow virus',
 24: 'Tomato mold leaf',
 25: 'grape leaf',
 26: 'grape leaf black rot'}

{'Apple Scab Leaf': 0,
 'Apple leaf': 1,
 'Apple rust leaf': 2,
 'Bell_pepper leaf': 3,
 'Bell_pepper leaf spot': 4,
 'Blueberry leaf': 5,
 'Cherry leaf': 6,
 'Corn Gray leaf spot': 7,
 'Corn leaf blight': 8,
 'Corn rust leaf': 9,
 'Peach leaf': 10,
 'Potato leaf early blight': 11,
 'Potato leaf late blight': 12,
 'Raspberry leaf': 13,
 'Soyabean leaf': 14,
 'Squash Powdery mildew leaf': 15,
 'Strawberry leaf': 16,
 'Tomato Early blight leaf': 17,
 'Tomato Septoria leaf spot': 18,
 'Tomato leaf': 19,
 'Tomato leaf bacterial spot': 20,
 'Tomato leaf late blight': 21,
 'Tomato leaf mosaic virus': 22,
 'Tomato leaf yellow virus': 23,
 'Tomato mold leaf': 24,
 'grape leaf': 25,
 'grape leaf black rot': 26}

### 라벨 카운트

In [289]:
label_cnt = {k : 0 for k, v in label_decoder.items()}

for i in label_list :
    label_cnt[label_encoder[i]] += 1

display(label_cnt)

{'Apple Scab Leaf': 83,
 'Apple leaf': 82,
 'Apple rust leaf': 78,
 'Bell_pepper leaf': 53,
 'Bell_pepper leaf spot': 62,
 'Blueberry leaf': 103,
 'Cherry leaf': 47,
 'Corn Gray leaf spot': 61,
 'Corn leaf blight': 178,
 'Corn rust leaf': 106,
 'Peach leaf': 101,
 'Potato leaf early blight': 103,
 'Potato leaf late blight': 94,
 'Raspberry leaf': 112,
 'Soyabean leaf': 57,
 'Squash Powdery mildew leaf': 124,
 'Strawberry leaf': 87,
 'Tomato Early blight leaf': 79,
 'Tomato Septoria leaf spot': 137,
 'Tomato leaf': 56,
 'Tomato leaf bacterial spot': 101,
 'Tomato leaf late blight': 99,
 'Tomato leaf mosaic virus': 44,
 'Tomato leaf yellow virus': 68,
 'Tomato mold leaf': 86,
 'grape leaf': 57,
 'grape leaf black rot': 56}

# Custom dataset 정의

In [290]:
class DocDataset(Dataset) :
    def __init__(self, imgs, labels, transform) :
        super(VillageDataset, self).__init__()
        self.imgs = imgs
        self.labels = labels
        self.transform = transform
        
    def __len__(self) :
        assert len(self.imgs) == len(self.labels)
        return len(self.imgs)
    
    def __getitem__(self, idx) :
        
        label = self.labels[idx]
        
        img = cv2.imread(self.imgs[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(image=img)['image']
        img = img.transpose(2, 0, 1)
        
        return torch.tensor(img, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# DataLoader w/o kfold

In [291]:
train_transforms = A.Compose([
                A.Resize(224 ,224),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])

val_transforms = A.Compose([
    A.Resize(224,224)
])


train_img, val_img, train_label, val_label = train_test_split(img_list, label_list, test_size=0.3, shuffle=True, stratify=label_list)

print("train_img : ", len(train_img))
print("train_lable : ", len(train_label))

print("val_img : ", len(val_img))
print("val_lable : ", len(val_label))

print(train_img[572])
print(train_label[572])

# train_dataset = VillageDataset(train, train_transforms)
# val_dataset = VillageDataset(val, val_transforms)

# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

train_img :  1619
train_lable :  1619
val_img :  695
val_lable :  695
./data/public/PlantDoc/train/fpls-08-01602-g001_jpg.rf.cb2535a202d5ec940624e190d81f3526.jpg
19


# Doc +Village = 라벨 합치기
- Doc == Village
- 11: 'Potato leaf early blight', ==  2: 'Potato___Early_blight',
- 12: 'Potato leaf late blight', ==  4: 'Potato___Late_blight',
- 17: 'Tomato Early blight leaf', ==  6: 'Tomato_Early_blight',
- 18: 'Tomato Septoria leaf spot', ==  10: 'Tomato_Septoria_leaf_spot',
- 19: 'Tomato leaf', ==  7: 'Tomato_healthy',
- 20: 'Tomato leaf bacterial spot', ==  5: 'Tomato_Bacterial_spot',
- 21: 'Tomato leaf late blight', ==  8: 'Tomato_Late_blight',
- 22: 'Tomato leaf mosaic virus', ==  13: 'Tomato__Tomato_mosaic_virus',
- 23: 'Tomato leaf yellow virus', ==  14: 'Tomato__Tomato_YellowLeaf__Curl_Virus'

## Village img, label data

In [328]:
village_data = os.listdir('./data/public/PlantVillage')
vill_label_encoder = {}
for idx, data_name in enumerate(village_data) :
    vill_label_encoder[idx] = data_name

vill_label_decoder = {val:key for key, val in vill_label_encoder.items()}
display(vill_label_decoder)
display(vill_label_encoder)

vill_img_list = glob('./data/public/PlantVillage/*/*.JPG')
vill_label_list = [vill_label_decoder[img_path.split('\\')[-2]] for img_path in vill_img_list]

print(len(vill_img_list))
print(len(vill_label_list))

{'Pepper__bell___Bacterial_spot': 0,
 'Pepper__bell___healthy': 1,
 'Potato___Early_blight': 2,
 'Potato___healthy': 3,
 'Potato___Late_blight': 4,
 'Tomato_Bacterial_spot': 5,
 'Tomato_Early_blight': 6,
 'Tomato_healthy': 7,
 'Tomato_Late_blight': 8,
 'Tomato_Leaf_Mold': 9,
 'Tomato_Septoria_leaf_spot': 10,
 'Tomato_Spider_mites_Two_spotted_spider_mite': 11,
 'Tomato__Target_Spot': 12,
 'Tomato__Tomato_mosaic_virus': 13,
 'Tomato__Tomato_YellowLeaf__Curl_Virus': 14}

{0: 'Pepper__bell___Bacterial_spot',
 1: 'Pepper__bell___healthy',
 2: 'Potato___Early_blight',
 3: 'Potato___healthy',
 4: 'Potato___Late_blight',
 5: 'Tomato_Bacterial_spot',
 6: 'Tomato_Early_blight',
 7: 'Tomato_healthy',
 8: 'Tomato_Late_blight',
 9: 'Tomato_Leaf_Mold',
 10: 'Tomato_Septoria_leaf_spot',
 11: 'Tomato_Spider_mites_Two_spotted_spider_mite',
 12: 'Tomato__Target_Spot',
 13: 'Tomato__Tomato_mosaic_virus',
 14: 'Tomato__Tomato_YellowLeaf__Curl_Virus'}

20636
20636


## Doc img, label data

In [329]:
import copy
doc_img_list = copy.deepcopy(img_list)
doc_label_list = copy.deepcopy(label_list)

doc_label_decoder = copy.deepcopy(label_decoder) 
doc_label_encoder = copy.deepcopy(label_encoder)

display(doc_label_decoder)
display(doc_label_encoder)
print(len(doc_img_list))
print(len(doc_label_list))

{'Apple Scab Leaf': 0,
 'Apple leaf': 1,
 'Apple rust leaf': 2,
 'Bell_pepper leaf': 3,
 'Bell_pepper leaf spot': 4,
 'Blueberry leaf': 5,
 'Cherry leaf': 6,
 'Corn Gray leaf spot': 7,
 'Corn leaf blight': 8,
 'Corn rust leaf': 9,
 'Peach leaf': 10,
 'Potato leaf early blight': 11,
 'Potato leaf late blight': 12,
 'Raspberry leaf': 13,
 'Soyabean leaf': 14,
 'Squash Powdery mildew leaf': 15,
 'Strawberry leaf': 16,
 'Tomato Early blight leaf': 17,
 'Tomato Septoria leaf spot': 18,
 'Tomato leaf': 19,
 'Tomato leaf bacterial spot': 20,
 'Tomato leaf late blight': 21,
 'Tomato leaf mosaic virus': 22,
 'Tomato leaf yellow virus': 23,
 'Tomato mold leaf': 24,
 'grape leaf': 25,
 'grape leaf black rot': 26}

{0: 'Apple Scab Leaf',
 1: 'Apple leaf',
 2: 'Apple rust leaf',
 3: 'Bell_pepper leaf',
 4: 'Bell_pepper leaf spot',
 5: 'Blueberry leaf',
 6: 'Cherry leaf',
 7: 'Corn Gray leaf spot',
 8: 'Corn leaf blight',
 9: 'Corn rust leaf',
 10: 'Peach leaf',
 11: 'Potato leaf early blight',
 12: 'Potato leaf late blight',
 13: 'Raspberry leaf',
 14: 'Soyabean leaf',
 15: 'Squash Powdery mildew leaf',
 16: 'Strawberry leaf',
 17: 'Tomato Early blight leaf',
 18: 'Tomato Septoria leaf spot',
 19: 'Tomato leaf',
 20: 'Tomato leaf bacterial spot',
 21: 'Tomato leaf late blight',
 22: 'Tomato leaf mosaic virus',
 23: 'Tomato leaf yellow virus',
 24: 'Tomato mold leaf',
 25: 'grape leaf',
 26: 'grape leaf black rot'}

2314
2314


## Village -> Doc

### Village Label 번호 변경

In [330]:
 # {vill : doc}
same_label_list_encoder = {0:27, 1:28, 2:11, 3:29, 4:12, 5:20, 6:17, 7:19, 8:21, 9:30, 10:18, 11:31, 12:32, 13:22, 14:23}
for i in range(0, len(vill_label_list)) :
    vill_label_list[i] = same_label_list_encoder[vill_label_list[i]]

## add

In [331]:
sum_img_list = doc_img_list + vill_img_list
sum_label_list = doc_label_list + vill_label_list

print(len(sum_img_list))
print(len(sum_label_list))

22950
22950


## sum_label_encoder decoder 선언

In [342]:
same_label_list_encoder = {0:27, 1:28, 2:11, 3:29, 4:12, 5:20, 6:17, 7:19, 8:21, 9:30, 10:18, 11:31, 12:32, 13:22, 14:23}
same_label_list_decoder = {v : k for k, v in same_label_list_encoder.items()}

additional_list = [0, 1, 3, 9, 11, 12]

sum_label_encoder = {}
for i in range(0, 33) :
    if i in same_label_list_decoder and i in doc_label_encoder :
        sum_label_encoder[i] = doc_label_encoder[i]
#         sum_label_encoder[vill_label_encoder[i]] = same_label_list_encoder[i]

    elif i in same_label_list_decoder and i not in doc_label_encoder :
        sum_label_encoder[i] = vill_label_encoder[same_label_list_decoder[i]]
#         sum_label_encoder[doc_label_encoder[i]] = same_label_list_encoder[i]
        
    elif i not in same_label_list_decoder and i in doc_label_encoder :
        sum_label_encoder[i] = doc_label_encoder[i]
        
sum_label_decoder = {v : k for k, v in sum_label_encoder.items()}

display(sum_label_encoder)
display(sum_label_decoder)

{0: 'Apple Scab Leaf',
 1: 'Apple leaf',
 2: 'Apple rust leaf',
 3: 'Bell_pepper leaf',
 4: 'Bell_pepper leaf spot',
 5: 'Blueberry leaf',
 6: 'Cherry leaf',
 7: 'Corn Gray leaf spot',
 8: 'Corn leaf blight',
 9: 'Corn rust leaf',
 10: 'Peach leaf',
 11: 'Potato leaf early blight',
 12: 'Potato leaf late blight',
 13: 'Raspberry leaf',
 14: 'Soyabean leaf',
 15: 'Squash Powdery mildew leaf',
 16: 'Strawberry leaf',
 17: 'Tomato Early blight leaf',
 18: 'Tomato Septoria leaf spot',
 19: 'Tomato leaf',
 20: 'Tomato leaf bacterial spot',
 21: 'Tomato leaf late blight',
 22: 'Tomato leaf mosaic virus',
 23: 'Tomato leaf yellow virus',
 24: 'Tomato mold leaf',
 25: 'grape leaf',
 26: 'grape leaf black rot',
 27: 'Pepper__bell___Bacterial_spot',
 28: 'Pepper__bell___healthy',
 29: 'Potato___healthy',
 30: 'Tomato_Leaf_Mold',
 31: 'Tomato_Spider_mites_Two_spotted_spider_mite',
 32: 'Tomato__Target_Spot'}

{'Apple Scab Leaf': 0,
 'Apple leaf': 1,
 'Apple rust leaf': 2,
 'Bell_pepper leaf': 3,
 'Bell_pepper leaf spot': 4,
 'Blueberry leaf': 5,
 'Cherry leaf': 6,
 'Corn Gray leaf spot': 7,
 'Corn leaf blight': 8,
 'Corn rust leaf': 9,
 'Peach leaf': 10,
 'Potato leaf early blight': 11,
 'Potato leaf late blight': 12,
 'Raspberry leaf': 13,
 'Soyabean leaf': 14,
 'Squash Powdery mildew leaf': 15,
 'Strawberry leaf': 16,
 'Tomato Early blight leaf': 17,
 'Tomato Septoria leaf spot': 18,
 'Tomato leaf': 19,
 'Tomato leaf bacterial spot': 20,
 'Tomato leaf late blight': 21,
 'Tomato leaf mosaic virus': 22,
 'Tomato leaf yellow virus': 23,
 'Tomato mold leaf': 24,
 'grape leaf': 25,
 'grape leaf black rot': 26,
 'Pepper__bell___Bacterial_spot': 27,
 'Pepper__bell___healthy': 28,
 'Potato___healthy': 29,
 'Tomato_Leaf_Mold': 30,
 'Tomato_Spider_mites_Two_spotted_spider_mite': 31,
 'Tomato__Target_Spot': 32}