In [None]:
import os
import pickle
import numpy as np
from torch.utils.data import Dataset
import torch
import torchvision
from torchvision import datasets, transforms,models
from torch import nn, optim
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
from torch.optim.lr_scheduler import LambdaLR
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#root = ./input/cassava-leaf-disease-models
batch_size = 64
epoch = 1


***transform***

In [None]:
transform =  {
    'train':transforms.Compose([
        transforms.RandomRotation(30),
        transforms.Resize([240,320]),
        transforms.CenterCrop(size=(224,224)),
        transforms.ColorJitter(brightness = 0.2),
        transforms.ToTensor(),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0, inplace=False),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    'valid':transforms.Compose([
        transforms.Resize([240,320]),
        transforms.CenterCrop(size=(224,224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    'test':transforms.Compose([
        transforms.Resize([240,320]),
        transforms.CenterCrop(size=(224,224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

***dataset and loader***

In [None]:
class CASSAVA(Dataset):

    def __init__(self, train='train', k = 0, transform = None, target_transform=None):
        super(CASSAVA, self).__init__()
        
        self.train = train
        self.k = k
        self.transform = transform
        self.target_transform = target_transform
        
        data_ex = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
        image_id_list = data_ex.iloc[:,0]
        label = data_ex.iloc[:,1]
        full_id_list = []
        for i,id in enumerate(image_id_list):
            full_id = os.path.join('../input/cassava-leaf-disease-classification/train_images',id)
            full_id_list.append(full_id)
           
       
        split_table = {0:0, 1:4279, 2:8559, 3:12838, 4:17118, 5:21397}    
        self.data_id_list = []
        self.labels = []

        if k != 5:
            if self.train == 'train':
                for j in range(split_table[0],split_table[self.k-1]):
                    self.data_id_list.append(full_id_list[j])
                    self.labels.append(label[j])
                for j in range(split_table[self.k],split_table[5]):
                    self.data_id_list.append(full_id_list[j])
                    self.labels.append(label[j])   
            elif self.train == 'valid':
                for j in range(split_table[self.k-1],split_table[self.k]):
                    self.data_id_list.append(full_id_list[j])
                    self.labels.append(label[j])
            elif self.train == 'test':
                for j in range(split_table[4],split_table[5]):
                    self.data_id_list.append(full_id_list[j])
                    self.labels.append(label[j])
        elif k == 5:
            if self.train == 'train':
                for j in range(split_table[0],split_table[4]):
                    self.data_id_list.append(full_id_list[j])
                    self.labels.append(label[j])
        

        
    def __getitem__(self, index):
        img_id, target = self.data_id_list[index], self.labels[index]
        img = Image.open(img_id)

        if self.transform is not None:
            img = self.transform(img)
            img = img.float()

        if self.target_transform is not None:
            target = self.target_transform(target)
        
        target = np.asarray(target)
        target = torch.from_numpy(target)
        target = target.long()
        
        return img, target


    
    def __len__(self):
        return len(self.labels)

In [None]:

train_dataset_1 = CASSAVA(train='train', k = 1, transform = transform['train'], target_transform=None)
train_dataset_2 = CASSAVA(train='train', k = 2, transform = transform['train'], target_transform=None)
train_dataset_3 = CASSAVA(train='train', k = 3, transform = transform['train'], target_transform=None)
train_dataset_4 = CASSAVA(train='train', k = 4, transform = transform['train'], target_transform=None)
train_dataset_w = CASSAVA(train='train', k = 5, transform = transform['train'], target_transform=None)

valid_dataset_1 = CASSAVA(train='valid', k = 1, transform = transform['valid'], target_transform=None)
valid_dataset_2 = CASSAVA(train='valid', k = 2, transform = transform['valid'], target_transform=None)
valid_dataset_3 = CASSAVA(train='valid', k = 3, transform = transform['valid'], target_transform=None)
valid_dataset_4 = CASSAVA(train='valid', k = 4, transform = transform['valid'], target_transform=None)

test_dataset    = CASSAVA(train='test',  k = 0, transform = transform['test'],  target_transform=None)




data = {
    'train_loader' : [
                        torch.utils.data.DataLoader(train_dataset_1, batch_size,shuffle=True),
                        torch.utils.data.DataLoader(train_dataset_2, batch_size,shuffle=True),
                        torch.utils.data.DataLoader(train_dataset_3, batch_size,shuffle=True),
                        torch.utils.data.DataLoader(train_dataset_4, batch_size,shuffle=True),
                        torch.utils.data.DataLoader(train_dataset_w, batch_size,shuffle=True)],
    'valid_loader' : [
                        torch.utils.data.DataLoader(valid_dataset_1, batch_size,shuffle=False),
                        torch.utils.data.DataLoader(valid_dataset_2, batch_size,shuffle=False),
                        torch.utils.data.DataLoader(valid_dataset_3, batch_size,shuffle=False),
                        torch.utils.data.DataLoader(valid_dataset_4, batch_size,shuffle=False)],
    'test_loader' : torch.utils.data.DataLoader(test_dataset, batch_size,shuffle=False)}

***model***

In [None]:
RES50 = torch.load('../input/resnet50-for-pytorch/Resnet-50')
#print(res50)
#c = res50.fc.out_features
#print(c)

        # main_layers = [conv1,bn1,relu,maxpool,
        #                layer1,layer2,layer3,layer4,
        #                avgpool,fc]

#freeze
for param in RES50.conv1.parameters():
    param.requires_grad = False
    
for param in RES50.bn1.parameters():
    param.requires_grad = False
    
for param in RES50.relu.parameters():
    param.requires_grad = False
    
for param in RES50.maxpool.parameters():
    param.requires_grad = False
    
for param in RES50.layer1.parameters():
    param.requires_grad = False
    
for param in RES50.layer2.parameters():
    param.requires_grad = False
    
for param in RES50.layer3.parameters():
    param.requires_grad = False


RES50.fc = nn.Sequential(
    nn.Linear(2048, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, 32),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(32, 5),
    
    nn.LogSoftmax(dim=1)
)

***lambda***

In [None]:
lambdax = lambda e:0.9**e

***train policy***

In [None]:
# k: 0,1,2,3
def policy_0(lr_base = 1e0,num_epochs = 1,lambdax = lambdax,k = 0,model = None):
    
    if model is not None:
        res50 = model
    else:
        res50 = copy.deepcopy(RES50)
        
    res50 = res50.to(device)
    loss_fun = nn.NLLLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, res50.parameters()), lr = lr_base)
    if lambdax is not None:
        scheduler = LambdaLR(optimizer, lr_lambda=lambdax, last_epoch=-1)
    
    y = []

    for epoch in range(num_epochs):
        sum_loss = 0.0
        for i,(images, labels) in enumerate(data['train_loader'][k]):
            images = images.to(device)
            labels = labels.to(device)
            outputs =res50(images)

            loss = loss_fun(outputs, labels)
 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if lambdax is not None:
                scheduler.step()
                #print(scheduler.get_lr())
        
 
            y.append(loss.item())
        sum_loss += loss.item()
        print('[%d,%d] loss: %.06f' % (epoch+1, i+1, sum_loss))
    
    return res50,y,sum_loss

***valid or test policy***

In [None]:
def policy_1(new_model,data):
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in data:
            
            new_model = new_model.to(device)
            images = images.to(device)
            labels = labels.to(device)
            
            output = new_model(images)
            _,predicted = torch.max(output, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum()
            acc = correct.item() / (len(data)*batch_size)
        print("The accuracy of that lr is: {0}".format(correct.item() / (len(data)*batch_size)))
    
    return acc

***plot functions***

In [None]:
def plot_iteration(list_y):
    
    list_x = list(range(1,len(list_y)+1))
    
    plt.figure()
    #plt.subplot(train_num//3+1,3,i+1)
    plt.plot(list_x,list_y,"r")
    plt.xlabel("iteration")
    plt.ylabel("loss")

    
def plot_acc(list_y,lr):
    
    list_x = list(range(1,len(list_y)+1))
    
    plt.figure()
    plt.plot(list_x,list_y,color = "m",marker = 'o')
    for xx,yy in zip(list_x,list_y):
        plt.text(xx,yy,str(yy),fontsize = 10)
    plt.title(f"the Lr = {lr}")
    plt.xlabel("kfold")
    plt.ylabel("accuracy")
    
    
def plot_av_acc(list_x,list_y):
    plt.figure()
    plt.bar(x = list_x, height = list_y)
    for xx,yy in zip(list(range(1,len(list_x)+1)),list_y):
        plt.text(xx,yy,str(yy),fontsize = 10)
    plt.xlabel("learning rate")
    plt.ylabel("accuracy")
    
    
def plot_acc_loss(list_y1,list_y2):
    
    list_x = list(range(1,len(list_y1)+1))
    
    plt.figure()
    plt.plot(list_x,list_y1,label = 'loss',color = "m",marker = 'o')
    for xx,yy in zip(list_x,list_y1):
        plt.text(xx,yy,str(yy),fontsize = 10)
    plt.plot(list_x,list_y2,label = 'accuracy',color = "y",marker = 'o')
    plt.title("loss_acc")
    plt.xlabel("iteration")
    plt.ylabel('')

***experiment to find best base_lr***

In [None]:
def experiment_and_plot(lr = 1e0, train_num = 1):
    lr_list = []
    av_acc_list = []
    av_acco = 0.0

# k: 0,1,2,3 / 1,2,3,4

    for i in range(train_num):
        acc_list = []
        print(f"Now the base lr is: {lr}")
        
        for k in range(4):
            new_model,y,_ = policy_0(lr_base = lr,num_epochs = 3,lambdax = lambdax,k = k,model = None)
            acc = policy_1(new_model = new_model,data = data['valid_loader'][k])
            acc_list.append(acc)
            
        plot_acc(list_y = acc_list,lr = lr)
        av_acc = sum(acc_list)/4
        av_acc_list.append(av_acc)
        
        if av_acc >= av_acco:
            av_acc_max_lr = lr
            av_acco = av_acc

        #print(f"accs:{acc}")        
        
        
        lr_list.append(lr)
        lr = lr/10

        
    plot_av_acc(lr_list,av_acc_list)
    print(f"The best lr is: {av_acc_max_lr}")

    
    return av_acc_max_lr

***use the best lr to train and test***

In [None]:
def best_lr_test(lr = 1e0):
    n_num_epochs = 10
    loss_list = []
    acc_list = []
    y_list = []
    
    print(f"\nfor the best lr above is: {lr}")
    for n in range(n_num_epochs):
        if n == 0:
            fi_model,y,loss = policy_0(lr_base = lr,num_epochs = 1,lambdax = lambdax,k = 4,model = None)
        else:
            fi_model,y,loss = policy_0(lr_base = lr,num_epochs = 1,lambdax = lambdax,k = 4,model = fi_model)
        acc = policy_1(new_model = fi_model,data = data['test_loader'])
        
        loss_list.append(loss)
        acc_list.append(acc)
        y_list.extend(y)

   
    plot_acc_loss(loss_list,acc_list)
    plot_iteration(y_list)
    print(f"The final accuracy of this lr(= {lr}) is: {acc}")
    
    return fi_model

***run***

In [None]:
acc_max_lr = experiment_and_plot(lr = 1e0, train_num = 4)
final_model = best_lr_test(lr = acc_max_lr)

In [None]:
img_id = '2216849948.jpg'

img_test = Image.open('../input/cassava-leaf-disease-classification/test_images/2216849948.jpg')
transform_test = transform['test']
img_test = transform_test(img_test)
img_test = img_test.float()

img_test = torch.unsqueeze(img_test, dim = 0)
img_test = img_test.to(device)


output_test = final_model(img_test)
_,y_pred = torch.max(output_test, 1)
#print(y_pred)
y_pred = y_pred.cpu().numpy()
cc = pd.DataFrame({'image_id': img_id, 'label': y_pred})
cc.to_csv('submission.csv',index=False)

In [None]:
#ans = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
#print(ans)

In [None]:
#####################################################  调 试 分 割 线  ###########################################################