In [None]:
package_path = '../input/timm-pytorch-image-models/pytorch-image-models-master'
import sys; sys.path.append(package_path)

In [None]:
import gc
import os
import time
import torch
import albumentations

import numpy as np
import pandas as pd

import cv2
from PIL import Image

import torch.nn as nn
from sklearn import metrics
from sklearn import model_selection
from torch.nn import functional as F
from torch.optim import Adam

import timm

import warnings
warnings.filterwarnings("ignore")

In [None]:
class TimmModels(nn.Module):
    def __init__(self, model_name, pretrained=True, num_classes=5):
        super(TimmModels, self).__init__()
        self.m = timm.create_model(model_name, pretrained=pretrained)
        model_list = list(self.m.children())
        model_list[-1] = nn.Linear(in_features=model_list[-1].in_features,
                                   out_features=num_classes,
                                   bias=True
                                  )
        self.m = nn.Sequential(*model_list)
        
    def forward(self, image):
        out = self.m(image)
        return out

In [None]:
FLAGS = {'fold':0,
         'model':'resnext50_32x4d',
         'pretrained': True,
         'batch_size':8,
         'num_workers':4,
         'lr':3e-4,
         'epochs':10,
         'device':'cuda:0'
        }

In [None]:
class ImageDataset:
    def __init__(self,
                 image_paths,
                 targets,
                 resize,
                 augmentations=None,
                 backend='pil',
                 channel_first=True
                ):
        """
        :param image_paths: list of paths to images
        :param targets: numpy array
        :param resize: tuple or None
        :param augmentations: albumentations augmentations
        """
        
        self.image_paths = image_paths
        self.targets = targets
        self.resize = resize
        self.augmentations = augmentations
        self.backend = backend
        self.channel_first = channel_first
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, item):
        targets = self.targets[item]
        
        if self.backend == 'pil':
            image = Image.open(self.image_paths[item])
            
            if self.resize is not None:
                image = image.resize((self.resize[1], self.resize[0]),
                                     resample=Image.BILINEAR
                                    )
                
            image = np.array(image)
            
            if self.augmentations is not None:
                augmented = self.augmentations(image=image)
                image = augmented['image']
                
        elif self.backend == 'cv2':
            image = cv2.imread(self.image_paths[item])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            if self.resize is not None:
                image = cv2.resize(image, 
                                   (self.resize[1], self.resize[0]),
                                   interpolation=cv2.INTER_CUBIC
                                  )
                
            if self.augmentations is not None:
                augmented = self.augmentations(image=image)
                
            image = augmented['image']
            
        else:
            raise Exception("Backend not implemented")
            
        if self.channel_first:
            image = np.transpose(image, (2,0,1)).astype(np.float32)
            
        return {"image":torch.tensor(image),
                "targets":torch.tensor(targets)
               }

In [None]:
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images/'

test_images = os.listdir(TEST_PATH)

In [None]:
test_images = [
    os.path.join(TEST_PATH, i) for i in test_images
]

In [None]:
test_images

In [None]:
test_targets = [-1 for i in range(len(test_images))]

In [None]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

In [None]:
valid_aug = albumentations.Compose(
    [
        albumentations.Normalize(
            mean,
            std,
            max_pixel_value=255.0,
            always_apply=True
        )
    ]
)

In [None]:
test_dataset = ImageDataset(
    image_paths=test_images,
    targets=test_targets,
    resize=None,
    augmentations=valid_aug
)


In [None]:
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=FLAGS['batch_size'],
    num_workers=FLAGS['num_workers'],
    drop_last=False,
    shuffle=False
)

In [None]:
model = TimmModels(
    FLAGS['model'],
    pretrained=False,
    num_classes=5)

In [None]:
model.load_state_dict(torch.load('../input/cassava-torchxla2-model/xla_trained_model_10_epochs_fold_0.pth'))

In [None]:
device = FLAGS['device']

In [None]:
model.to(device)

In [None]:
def infer(data_loader, 
          model, 
          device):
    
    fin_targets = []
    fin_outputs = []
    
    for bi, d in enumerate(data_loader):
        
        images = d['image'].to(device)
        targets = d['targets'].to(device)
        
        with torch.no_grad(): outputs = model(images)
            
        #targets_np = targets.cpu().detach().numpy().tolist()
        outputs_np = outputs.cpu().detach().numpy().tolist()
        
        #fin_targets.extend(targets_np)
        fin_outputs.extend(outputs_np)
        
        del outputs_np
        
        gc.collect()
        
    return fin_outputs

In [None]:
predictions = infer(test_loader, model, device)

In [None]:
test = pd.DataFrame()

In [None]:
test['image_id'] = list(os.listdir(TEST_PATH))

In [None]:
test['label'] = np.argmax(predictions, axis=1)

In [None]:
test.to_csv('submission.csv', index=False)

In [None]:
test.head()

In [None]:
sample = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')

In [None]:
test.values

In [None]:
sample['image_id']

In [None]:
sample['label']

In [None]:
test['label']

In [None]:
test['image_id']

In [None]:
del train_images, train_targets

In [None]:
train_images = list(os.listdir('../input/cassava-leaf-disease-classification/train_images'))

In [None]:
training_data_path = '../input/cassava-leaf-disease-classification/train_images'
train_images = [
    os.path.join(training_data_path, i) for i in train_images
]

train_targets = [-1 for i in range(len(train_images))]


In [None]:
valid_images=train_images[0:15000]
valid_targets=train_targets[0:15000]

In [None]:
valid_dataset = ImageDataset(
    image_paths=valid_images,
    targets=valid_targets,
    resize=None,
    augmentations=valid_aug
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=FLAGS['batch_size'],
    num_workers=FLAGS['num_workers'],
    drop_last=False,
    shuffle=False
)

In [None]:
#predictions = infer(valid_loader, model, device)

In [None]:
predictions

In [None]:
list(os.listdir('../input/cassava-leaf-disease-classification/train_images'))[0:15000]

In [None]:
valid_images