# Import

In [None]:
import os
import pandas as pd
import numpy as np
import cv2
import torch
import torch.nn as nn
import albumentations as A
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import PIL.Image as Image

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F

import albumentations as A
from albumentations.core.composition import Compose, OneOf
from albumentations.augmentations.transforms import CLAHE, GaussNoise, ISONoise
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import StratifiedKFold
from sklearn import preprocessing

# Config

In [None]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
class CFG:
    seed = 42
    pretrained = False
    img_size = 299
    num_classes = 6
    lr = .00001
    min_lr = 1e-6
    t_max = 20
    num_epochs = 10
    batch_size = 16
    augmentation_probability = 0.25
    accum = 1
    precision = 16
    n_fold = 5
    weight_decay = .05
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
PATH = "../input/plant-pathology-2021-fgvc8/"

TEST_DIR = PATH + 'test_images/'

In [None]:
df_all = pd.read_csv(PATH + "train.csv")
df_all.shape

In [None]:
from collections import defaultdict


dct = defaultdict(list)

for i, label in enumerate(df_all.labels):
    for category in label.split():
        dct[category].append(i)
 
dct = {key: np.array(val) for key, val in dct.items()}
dct

In [None]:
new_df = pd.DataFrame(np.zeros((df_all.shape[0], len(dct.keys())), dtype=np.int8), columns=dct.keys())

for key, val in dct.items():
    new_df.loc[val, key] = 1
    
df_all = pd.concat([df_all, new_df], axis=1)
df_all.head()

In [None]:
multi_labels = new_df.columns
multi_labels

In [None]:
sub = pd.read_csv(PATH + "sample_submission.csv")
sub.head()

In [None]:
tmp = pd.DataFrame(np.zeros([len(sub), len(new_df.columns)]), columns=multi_labels)
sub = pd.concat([sub, tmp], axis=1)
sub.head()

# Define Dataset

In [None]:
class PlantDataset(Dataset):
    def __init__(self, df, directory, transform=None):
        self.image_id = df['image'].values
        self.labels = df.iloc[:, 2:].values
        self.directory = directory
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_id = self.image_id[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        
        image_path = self.directory + image_id
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        
        augmented = self.transform(image=image)
        image = augmented['image']
        return {'image':image, 'target': label}

In [None]:
def get_transform(phase: str):
    if phase == 'train':
        '''
        return Compose([
            A.RandomResizedCrop(height=CFG.img_size, width=CFG.img_size),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.Normalize(),
            ToTensorV2(),
        ])
        '''
        return Compose([
                      A.Resize(height=CFG.img_size, width=CFG.img_size),
                      A.HorizontalFlip(p=CFG.augmentation_probability),
                      A.VerticalFlip(p=CFG.augmentation_probability),
                      A.ShiftScaleRotate(p=CFG.augmentation_probability),
                      A.Rotate(p=CFG.augmentation_probability, limit=90),
            
                      A.RGBShift(p=CFG.augmentation_probability),
            
                      A.IAAAffine(rotate=90., p=CFG.augmentation_probability),
                      A.IAAAffine(rotate=180., p=CFG.augmentation_probability),
            
                      A.RandomBrightnessContrast(p=CFG.augmentation_probability),
                      A.RandomContrast(limit = 0.5,p = CFG.augmentation_probability),
                      A.RandomSunFlare(p=CFG.augmentation_probability), 
                      A.RandomBrightness(p=CFG.augmentation_probability),
            
                      A.Normalize(),
                      ToTensorV2(),
        ])
    else:
        return Compose([
            A.Resize(height=CFG.img_size, width=CFG.img_size),
            A.Normalize(),
            ToTensorV2(),
        ])

In [None]:
train_dataset = PlantDataset(df_all, PATH + "train_images/", get_transform('train'))
dataset_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=5)
test_dataset = PlantDataset(sub, PATH + "test_images/", get_transform('valid'))
test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=5)

# Define Model

In [None]:
class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

In [None]:
class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None

        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x

In [None]:
class Xception(nn.Module):
    def __init__(self, num_classes=1000):
        super(Xception, self).__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32,64,3,bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)

        self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        self.bn3 = nn.BatchNorm2d(1536)

        self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)



    def features(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn4(x)
        return x

    def logits(self, features):
        x = self.relu(features)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def forward(self, input):
        x = self.features(input)
        x = self.logits(x)
        return x

In [None]:
from collections import OrderedDict

def fix_model_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if name.startswith('model.'):
            name = name[6:]  # remove 'model.' of dataparallel
        new_state_dict[name] = v
    return new_state_dict

In [None]:
model = Xception()

xceptionModelPath = None

for dirpath, subdirs, files in os.walk('/kaggle/input/xceptionimagenetcheckpoint'):
    if xceptionModelPath:
        break
    for file in files:
        if file.endswith('.pth'):
            modelPath = os.path.join(dirpath,file)
            break

if xceptionModelPath:
    model.load_state_dict(torch.load(xceptionModelPath), torch.device(DEVICE))

model.fc = nn.Linear(2048, CFG.num_classes)
model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)

criterion = nn.BCELoss()
criterion = criterion.to(DEVICE)

In [None]:
def saveModel(model):
    torch.save(model.state_dict(), model.__class__.__name__+'.pth')

# Inference

In [None]:
def fitModel(model):
    sigmoid = nn.Sigmoid()
    sigmoid = sigmoid.to(DEVICE)
    
    for epoch in range(CFG.num_epochs):
        model = model.train()

        for i, batch in enumerate(dataset_loader, start=1):
            image = batch['image'].cuda()
            labels = batch['target'].cuda()

            logits = model(image)
            output = sigmoid(logits)
            loss = criterion(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss = loss.detach().item()
            print("Epoch: {0}/{1}, Current Epoch Progress: {2:.2f}%, Loss: {3:.4f}".format(epoch+1, CFG.num_epochs, 100*i/(len(dataset_loader)), train_loss))

        saveModel(model)

        model.eval()

In [None]:
modelPath = None

for dirpath, subdirs, files in os.walk('/kaggle'):
    if modelPath:
        break
    for file in files:
        if file.endswith('.pth') and "xception-43020ad28.pth" not in file:
            modelPath = os.path.join(dirpath,file)
            break

if modelPath:
    print("Using pretrained model: " + modelPath)
    model.load_state_dict(torch.load(modelPath), torch.device(DEVICE))
else:
    fitModel(model)

In [None]:
model.cuda()
model.eval()

sigmoid = nn.Sigmoid()

predictions = []
for batch in test_loader:
    image = batch['image'].cuda()
    with torch.no_grad():
        outputs = model(image)
        preds = outputs.detach().cpu()
        predictions.append(sigmoid(preds).numpy() > 0.5)

In [None]:
predictions = pd.DataFrame(np.concatenate(predictions).astype(np.int), columns=new_df.columns)

In [None]:
sub.iloc[:, 2:] = predictions
sub

In [None]:
labels = []
for i, row in sub.iloc[:, 2:].iterrows():
    if (row['healthy'] == 1):
        tmp = 'healthy'
    elif (row['healthy'] == 0 and
             row['scab'] == 0 and
             row['frog_eye_leaf_spot'] == 0 and
             row['complex'] == 0 and
             row['rust'] == 0 and
             row['powdery_mildew'] == 0):
        tmp = 'healthy'
    else:
        tmp = ' '.join(multi_labels[row==row.max()])
    labels.append(tmp)

In [None]:
sub['labels'] = labels
sub[['image', 'labels']].to_csv('submission.csv', index=False)
sub.head()