# Import necessary modules

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torchvision.transforms import Compose, Resize, ToTensor
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid

import os
import time
import pandas as pd
import numpy as np
from PIL import Image, ImageEnhance
import matplotlib.pyplot as plt

# Install timm to access ViT PyTorch Models

In [None]:
pip install '../input/timm034/timm-0.3.4-py3-none-any.whl'

# Import timm

In [None]:
import timm

# Display Available Vision Transformer Models

In [None]:
print("Available ViT Models: ")
timm.list_models("vit*")

In [None]:
data_path  = '../input/cassava-leaf-disease-classification/'
train_path = '../input/cassava-leaf-disease-classification/train_images/'
test_path  = '../input/cassava-leaf-disease-classification/test_images/'
model_path = '../input/vitbase16224/jx_vit_base_p16_224-80ecf9dd.pth'
Cassava_model = '../input/cassavanewaugtp95epochs3/CassavaViT_newaug_TP95_Epochs3_LR1-75e05.pt'

# Define ViTBase 16 class

In [None]:
class ViTBase16(nn.Module):
    def __init__(self, n_classes, pretrained=False):
        
        super(ViTBase16, self).__init__()
        
        self.model = timm.create_model("vit_base_patch16_224", pretrained=False)
        
        if pretrained:
            self.model.load_state_dict(torch.load(model_path))
            
        self.model.head = nn.Linear(self.model.head.in_features, n_classes)
    
    def forward(self, x):
        x = self.model(x)
        return x

# Create cassava pre-trained ViTBase16 model instance (5 classes)

In [None]:
cassava_model = ViTBase16(n_classes=5, pretrained=True)
cassava_model.load_state_dict(torch.load(Cassava_model))
cassava_gpu_model = cassava_model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cassava_gpu_model.parameters(), lr=1.5e-05)
cassava_gpu_model

# Examine test image folder

In [None]:
test_img_names = []
for folder, subfolders, filenames in os.walk(test_path):
    print('************************************************FOLDER************************************************')
    print(folder)
    print('************************************************IMAGES************************************************')
    for img in filenames:
        print(img)
        if img[-3:] == 'jpg':
            test_img_names.append(img)        
print('Testing Images: ',len(test_img_names))

# Create class to load test data (adds image names)

In [None]:
class TestSet2(Dataset):
    """Cassava Disease Dataset"""

    def __init__(self, root_dir, test_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with test image names.
            root_dir (string): Directory with all the test images.
            transform (callable, optional): Optional transform to be applied
            on a sample.
        """
        super().__init__()
        self.root_dir = root_dir
        self.test_dir = test_dir
        self.transform = transform
        print(root_dir)
        print(test_dir)
        print("Cassava Disease Test Dataset Length = ", len(os.listdir(self.test_dir)))


    def __len__(self):
        return len(os.listdir(self.test_dir))


    def __getitem__(self, idx):
        img_path = self.test_dir + os.listdir(self.test_dir)[idx] 
        img = Image.open(img_path).convert("RGB")
        image_name = os.listdir(self.test_dir)[idx]

        
        if self.transform:
            image = self.transform(img)
    
        return (image, image_name)

# Create test transforms (needed for image resizing)

In [None]:
test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

# Create test dataset

In [None]:
testset = TestSet2(root_dir = '' , test_dir=test_path, transform = test_transform)
print(testset)

# Load the test dataset

In [None]:
test_batch_size = 1
test_loader = DataLoader(dataset=testset, batch_size=test_batch_size, shuffle=True, pin_memory=True)

# Check to see if test_loader is working

In [None]:
print(test_loader)

# Grab the first batch of 16 images
for images,names in test_loader: 
    break

im = make_grid(images, nrow=4)  # the default nrow is 8

# Inverse normalize the images
#inv_normalize = transforms.Normalize(
#    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
#    std=[1/0.229, 1/0.224, 1/0.225]
#)
#im_inv = inv_normalize(im)
print(names)
# Print the images
plt.figure(figsize=(12,4))
plt.imshow(np.transpose(im.numpy(), (1, 2, 0)));

# Used trained model to generate test data predictions
# Generate submission file

In [None]:
test_losses = []
test_correct = []
col_names =  ['image_id', 'label']
submission_df = pd.DataFrame(columns = col_names)
tic = time.time()
with torch.no_grad():
    for b, (X_test, name) in enumerate(test_loader):
        X_test = X_test.cuda()
        # Apply the model
        y_test_pred = cassava_gpu_model(X_test)
        predicted = torch.max(y_test_pred.data, 1)[1]
        label = int(predicted)
        image_name = name[0]
        new_row = {'image_id': image_name, 'label': label}
        submission_df = submission_df.append(new_row, ignore_index=True)
        
toc = time.time() - tic
print('Time for model val is ', toc)
#display(submission_df)
submission_df.to_csv('submission.csv', index=False)
df_sub_test = pd.read_csv('submission.csv')
display(df_sub_test)
!ls