In [1]:
from PIL import Image
import os
import glob
import random
random.seed(4)

import pandas as pd
import numpy as np
import tqdm
from sklearn.model_selection import train_test_split

import remo
remo.set_viewer('jupyter')

├── cass
    ├── train_images
        ├
        ├── image_1.jpg
        ├── image_2.jpg
        ├── ...
        ├── train.csv
        ├── train_test_split.csv (to generate)

In [10]:
mapping = { 0 : "Cassava Bacterial Blight (CBB)",
            1 : "Cassava Brown Streak Disease (CBSD)",
            2 : "Cassava Green Mottle (CGM)",
            3 : "Cassava Mosaic Disease (CMD)",
            4 : "Healthy"
          }

In [3]:
annotations_file_path = './train.csv'
images_path = './train_images/'
tags_path = './images_tags.csv'

In [8]:
annotations = pd.read_csv(annotations_file_path)

temp_train, valid = train_test_split(annotations, stratify=annotations["label"], test_size=0.2)
#train, val = train_test_split(temp_train, stratify=temp_train["label"], test_size=0.1)

# Creating a dictionary with tags
tags_dict =  {'train' : temp_train["image_id"].to_list(),
              'valid' : valid["image_id"].to_list()}

train_test_split_file_path = os.pa th.join(tags_path) 
remo.generate_image_tags(tags_dictionary  = tags_dict, 
                         output_file_path = train_test_split_file_path)

'./images_tags.csv'

In [12]:
cass =  remo.create_dataset(name = 'cassava_kaggle_dataset', 
                              local_files = [images_path, tags_path, annotations_file_path],
                              annotation_task = 'Image classification',
                              class_encoding = mapping)

Acquiring data - completed                                                                           
Processing annotation files: 2 of 2 files                                                            Processing data - completed                                                                          
Data upload completed


In [14]:
cass.export_annotations_to_file('remo_train.csv', annotation_format = 'csv', append_path = True, filter_by_tags = ['train'], export_tags=False)
cass.export_annotations_to_file('remo_valid.csv', annotation_format = 'csv', append_path = True, filter_by_tags = ['valid'], export_tags=False)

In [15]:
%matplotlib inline
# PyTorch Lightning Imports
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.loggers import TensorBoardLogger

# PyTorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import cross_entropy
from torch.utils.data import Dataset, DataLoader
import torchvision as tv
from torch.optim import Adam, AdamW
import torchvision.models as models


# Python Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import remo
remo.set_viewer('jupyter')

# Default Python
import os
from PIL import Image

In [29]:
class CustomDataset(Dataset):
    def __init__(self, data_path, transforms, mapping = cat_to_idx):
        self.data_path = data_path
        self.transforms = transforms
        self.mapping = cat_to_idx
        
    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, idx):

        im = Image.open(self.data_path.loc[idx, 'file_name'])
        label = int(self.mapping[self.data_path.loc[idx, 'classes']])
        
        if self.transforms:
            im = self.transforms(im)
        return im, label

train_data = pd.read_csv('remo_train.csv')
validation_data = pd.read_csv('remo_valid.csv')

means =  [0.485, 0.456, 0.406]
stds  =  [0.229, 0.224, 0.225]


tv_transforms      =  tv.transforms.Compose([
                        tv.transforms.RandomRotation(30),
                         tv.transforms.RandomResizedCrop(224),
                         tv.transforms.RandomHorizontalFlip(p=0.5),
                         tv.transforms.ToTensor(),
                         tv.transforms.Normalize(means, stds)])

train_dl = DataLoader(CustomDataset(data_path=train_data, transforms=tv_transforms), batch_size=128, num_workers=4, pin_memory=True)
val_dl = DataLoader(CustomDataset(data_path=validation_data, transforms=tv_transforms), batch_size = 10, num_workers=4, pin_memory=True)

In [25]:
cat_to_idx = {'Cassava bacterial blight (cbb)': 0,
              'Cassava brown streak disease (cbsd)': 1,
              'Cassava green mottle (cgm)': 2,
              'Cassava mosaic disease (cmd)': 3,
              'Healthy': 4}