In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
from glob import glob

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import Callback
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2


# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

random_seed = 42
train_batch = 16
batch_size = 32
working_dir = "/kaggle/input/plant-pathology-2021-fgvc8/"
DIR_MODELS = '/kaggle/working'
k_fold_number = 0
num_classes = 6

epoch_num = 20

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# label_name = "scab frog_eye_leaf_spot complex"
# train_img_names = train_df.loc[train_df["labels"] == label_name].head().image
test_df = pd.DataFrame(columns=["image"])
# subplot_num = 511
(path, _, img_names) = next(os.walk(working_dir + "test_images/"))

test_df.loc[:,"image"] = img_names
# for img_name in img_names:
#     print(img_name)
    

#     im = cv2.imread(path + img_name) 
# #     plt.subplot(subplot_num)
# #     plt.title(name + " label: " + label_name)
#     plt.imshow(im)
#     plt.show()
#     subplot_num += 1
#     break
print(len(img_names))
    


In [None]:
test_df.head()

In [None]:
class PlantDataset(Dataset):
    def __init__(self, df, dir_path):
        
        self.dir_path = dir_path
        self.df = df
        self.img_ids = self.df.image.unique()
        
    def __getitem__(self, index):
        img_id = self.img_ids[index]
#         target = self.df[self.df["image"] == img_id].labels.iloc[0]
#         target = self.encode_target(target)
        
#         image = cv2.imread(self.dir_path + img_id, cv2.IMREAD_COLOR)
        image = cv2.imread(working_dir + "test_images/" + img_id, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        
#         image /= 255.0
        
        image = self.transform()(image=image)["image"]
        
        return {"image": image, "img_id": img_id}
    
    def get_by_id(self,img_id):
        index = np.where(self.img_ids == img_id)[0][0]
        return self.__getitem__(index)
    
    def transform(self):
        return A.Compose([
            A.Resize(512,512),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(p=1.0)
        ])
    
    def encode_target(self, target):
#         scab healthy frog_eye_leaf_spot rust complex powdery_mildew    
        encoded = list(map(int, ['scab' in target, "healthy" in target, "frog_eye_leaf_spot" in target, "rust" in target,  "complex" in target, "powdery_mildew" in target]))
        
        return encoded
    
    def __len__(self):
        return self.img_ids.shape[0]
       

In [None]:
# train_df.count()
test_path = working_dir + "test_images/"
test_dataset = PlantDataset(test_df, test_path)

print(test_dataset[1])

In [None]:
# def collate_fn(batch):
#     return tuple(zip(*batch))

test_data_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
#     collate_fn=collate_fn
)


In [None]:
model = resnet18(pretrained=False)
# dir(model)

In [None]:
model.add_module(name="fc", module=torch.nn.Linear(in_features=512, out_features=6, bias=True))
# # model.add_module(name="Sigmoid", module=torch.nn.Sigmoid())
# print(next(model.modules()))

# model.to(device)

In [None]:
### Lightning usage

class LitModel(pl.LightningModule):
    def __init__(self, model):
        super(LitModel, self).__init__()
        self.model = model
        self.metric = pl.metrics.F1(num_classes=num_classes)
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.lr = 5e-3
        
    def forward(self, x, *args, **kwargs):
        output = self.model(x)        
        return torch.nn.Sigmoid()(output.detach()).apply_(lambda x: x > 0.5)
    
    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=20, eta_min=1e-6)
        
        return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler}
    
    def training_step(self, batch, batch_idx):
        image = batch["image"]
        target = batch["target"]
        
        output = self.model(image)
        loss = self.criterion(output, target)
        metric = self.metric(output, target)
        
        logs = {"training_loss": loss, "train_f1": metric, "lr": self.optimizer.param_groups[0]["lr"]}
        
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        image = batch["image"]
        target = batch["target"]
        
        output = self.model(image)
        loss = self.criterion(output, target)
        metric = self.metric(output, target)
        
        logs = {"valid_loss": loss, "valid_f1": metric}
        
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

In [None]:
path = "../input/plant-pathology-2021/logs/Resnet/version_0/checkpoints/checkpoint/epoch=10-valid_loss=0.1101-valid_f1=0.8661.ckpt"
# path = "../input/plant-pathology-2021/logs/Resnet/version_0/checkpoints/last.ckpt"
lit_model = LitModel(model)
lit_model = lit_model.load_from_checkpoint(path, model=model)

In [None]:
# img = next(iter(test_data_loader))
# img["image"]

In [None]:
def decode_target(target):
#         scab healthy frog_eye_leaf_spot rust complex powdery_mildew 
    labels = ["scab", "healthy", "frog_eye_leaf_spot", "rust", "complex", "powdery_mildew"]
    decoded = []
    for key, x in enumerate(target):
        if x:
            decoded.append(labels[key])
#     decoded = [labels[key] if x else]
#     encoded = list(map(int, ['scab' if target[0], "healthy" in target, "frog_eye_leaf_spot" in target, "rust" in target,  "complex" in target, "powdery_mildew" in target]))

    return " ".join(decoded)

In [None]:
# target = [1, 0, 1, 0, 1, 0]
# decoded = decode_target(target)
# decoded

In [None]:
lit_model.eval()

output_df = pd.DataFrame(columns=["image", "labels"])
for batch in test_data_loader:
    with torch.no_grad():
        img = batch["image"]
        output = lit_model(img)
        preds = list(map(decode_target, output.tolist()))
        to_df = [pd.Series(x, index=output_df.columns) for x in zip(batch["img_id"],preds)]    
        output_df = output_df.append(to_df, ignore_index=True)
#         print(batch["img_id"], output, preds)

In [None]:
# output_df.loc[:, "image"] = batch["img_id"]
# output_df.loc[:, "labels"] = preds
output_df

In [None]:
# list(zip(batch["img_id"],preds))

In [None]:
output_df.to_csv("./submission.csv", index=False)