In [None]:
from torchvision.models import resnext50_32x4d as resnext
from torchvision.models import vgg16
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd 
import numpy as np
import cv2
from PIL import Image
import os
from tqdm import tqdm

from matplotlib import cm
import matplotlib.pyplot as plt
from torchvision import transforms
from sklearn.model_selection import train_test_split
print('gpu,',torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class LeafNet(nn.Module): 
    def __init__(self, hidden_size=1280,num_cls=5):
        super(LeafNet, self).__init__()
        self.model_base = resnext().double()
        #self.model_base = vgg16().double()
        self.linear1 = nn.Linear(1000,hidden_size)
        self.linear2 = nn.Linear(hidden_size,hidden_size)
        self.linear3 = nn.Linear(hidden_size,num_cls)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        self.loss = nn.CrossEntropyLoss() 
        self.bn0 = nn.BatchNorm2d(1000)
        self.bn1 = nn.BatchNorm2d(hidden_size)
        self.bn2 = nn.BatchNorm2d(hidden_size)
        ####self.loss = nn.BCEWithLogitsLoss() 不应该用bce loss，你上面的ce loss是对的
        
    def forward(self,inputs):
        '''
        inputs: B x [img_size x img_size x num_channel]
        labels: B x num_cls
        '''
        img,labels = inputs
        x = self.model_base(img)
        x = self.dropout(self.relu(self.linear1(x)))
        x = self.dropout(self.relu(self.linear2(x)))
        x = self.linear3(x)
        ls = self.loss(x,labels)
        return ls
    
    def inference(self,inputs):
        '''
        inputs: B x [img_size x img_size x num_channel]
        labels: B x num_cls
        '''
        img = inputs
        x = self.model_base(img)
        x = self.dropout(self.relu(self.linear1(x)))
        x = self.dropout(self.relu(self.linear2(x)))
        x = self.linear3(x)
        x = torch.argmax(x,dim=1)
        return x

In [None]:
'''
def one_hot(lst,num_cls):
    tables = np.zeros((len(lst),num_cls))
    for i,line in enumerate(lst):
        tables[i,line] = 1
    return tables
'''
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2
import albumentations as A
class Leafdataset(Dataset):
    def __init__(self,path,mode_train=False,num_cls=5):
        self.mode_train = mode_train
        '''
        if self.mode_train:
            image_base = os.path.join(path,'train_images/')
            csv_path = os.path.join(path,'train.csv')
        else:
            image_base = os.path.join(path,'test_images/')
            csv_path = os.path.join(path,'test.csv')
        '''
        
        image_base = os.path.join(path,'train_images/')
        csv_path = os.path.join(path,'train.csv')
        info = pd.read_csv(csv_path)
        
        labels = []
        for i in range(len(info['label'])):
            #print(info['label'][i])
            labels.append(info['label'][i])
        
        img_names = info['image_id']
        
        num_total = len(labels)
        imgs = list()
        
        for img_name in img_names:
            img_name = os.path.join(image_base,img_name)
            imgs.append(img_name)
            
        if self.mode_train:
            self.imgs = imgs#[:int(0.98*len(imgs))]
            self.labels = labels#[:int(0.*len(labels))]
        else:
            self.imgs = imgs[int(0.8*len(imgs)):]
            self.labels = labels[int(0.8*len(labels)):]
            
        self.preprocess = A.Compose([
                                              CenterCrop(256,256, p=1.),
                                              Resize(256,256),
                                              Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
                                              ToTensorV2(p=1.0),
                                              ])
        self.augmentation = A.Compose([
            #transforms.Resize(256),
            RandomResizedCrop(256, 256),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ])
        print(len(self.labels),len(self.imgs))
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self,idx):
        #print(idx,len(self.labels),self.labels)
        img_name = self.imgs[idx]
        img = Image.open(img_name)
        img = np.array(img)
        if self.mode_train:
            img = self.augmentation(image=img)['image'].float()
        else:
            img = self.preprocess(image=img)['image'].float()
        label = torch.tensor(self.labels[idx])
        return img,label

class Leafdataset_val(Dataset):
    def __init__(self,path,num_cls=5):
        import glob
        image_base = os.path.join(path,'test_images/')
        self.imgs = glob.glob(image_base+'*.jpg')
        
        
        self.preprocess = A.Compose([
                                      CenterCrop(256,256, p=1.),
                                      Resize(256,256),
                                      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
                                      ToTensorV2(p=1.0),
                                      ])
        print(len(self.imgs))
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self,idx):
        #print(idx,len(self.labels),self.labels)
        img_name = self.imgs[idx]
        img = Image.open(img_name)
        img = np.array(img)
        
        img = self.preprocess(image=img)['image'].float()
        return img_name.split('/')[-1],img

In [None]:

def validate_test(model,dataset_test):
    predictions = []
    ids = []
    
    for i,(img_name,img) in enumerate(dataset_test):
        img = img.unsqueeze(0).to(device)
        pred = model.inference(img).cpu().detach().item()
        predictions.append(pred)
        ids.append(img_name)
    sub = pd.DataFrame({'image_id': ids, 'label': predictions})
    sub.to_csv('./submission.csv', index = False)
        
def validate(model,dataset_test):
    model.eval()
    
    num_corr = 0
    num_total = len(dataset_test)
    for i,(img,label) in enumerate(dataset_test):
        img = img.unsqueeze(0).to(device)
        pred = model.inference(img).cpu().detach().item()
        
        label = label.cpu().detach().item()
        
        if pred==label:
            num_corr += 1
    
    print("accuracy is", num_corr*1.0/num_total, num_corr,'/',num_total)

In [None]:
def inference(model, test_loader, device):
    model.to(device)
    #tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    
    predictions = []
    ids = []
    for i,(img_name,img) in enumerate(test_loader):
        img = img.unsqueeze(0).to(device)
        
        #for state in states:
        #model.load_state_dict(states)
        #model.eval()
        with torch.no_grad():
                
            pred = model.inference(img).cpu().detach().item()
            predictions.append(pred)
            ids.append(img_name)
 
    sub = pd.DataFrame({'image_id': ids, 'label': predictions})
    sub.to_csv('./submission.csv', index = False)
    sub.head()

In [None]:
seed=42
torch.manual_seed(seed)

lr = 1e-4
batch_size = 50
num_epochs = 10
path = "../input/cassava-leaf-disease-classification/"
dataset_train = Leafdataset(path, mode_train=True)
dataset_test = Leafdataset(path, mode_train=False)
dataset_val = Leafdataset_val(path)
train_loader = DataLoader(dataset_train,batch_size=batch_size,shuffle=True)

In [None]:
#for Resnext
model = LeafNet().float()
#model = enet_v2(enet_type[i], out_dim=5)
model.load_state_dict(torch.load('../input/first-2/ckpt_best.pt'), strict=True)
#state_dict = torch.load('../input/first-2/ckpt_best.pt')
#states = state_dict
#test_dataset = Leafdataset_val("../input/cassava-leaf-disease-classification/")
#test_loader = DataLoader(dataset_val, batch_size=50, shuffle=False)
inference(model, dataset_val, device)