In [None]:
!pip install ../input/easydict/easydict-1.9-py2.py3-none-any.whl


In [None]:
from sklearn.model_selection import GroupKFold, StratifiedKFold
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import random

import os
import sys


sys.path.insert(1, '../input/snapmix/')
import glob
from utils import load_checkpoint,get_train_setting
from trainer.comm_test import validate
from trainer.comm_train import train
import networks.resnet_ft as resnet_ft
from easydict import EasyDict as edict 
import torch
import torch.nn as nn
from datasets.cassava import ImageLoader
from datasets.tfs import get_cassava_transform
from torch.utils import data
import time
from tqdm import tqdm
import copy
from collections import Counter

def set_env(seed=0):
    # set seeding
    random.seed(seed)
    np.random.seed(seed) # cpu vars
    torch.manual_seed(seed) # cpu  vars
    torch.cuda.manual_seed(seed) # cpu  vars
    torch.cuda.manual_seed_all(seed) # gpu vars
    
def predict(model,testloader,midlevel=False):   
    model.eval()
    time_start = time.time()
    pbar = tqdm(testloader, dynamic_ncols=True, total=len(testloader))
    pres = []
    for idx, (input, _) in enumerate(pbar):

        input = input.cuda()

        if conf.tta is None:
            output,_,moutput = model(input)
        else:
            bs, ncrops, c, h, w = input.size()
            output,_,moutput = model(input.view(-1,c,h,w))
            output = output.view(bs, ncrops, -1).mean(1)
            moutput = moutput.view(bs, ncrops, -1).mean(1)
        if midlevel:
            foutput = output + moutput
        else:
            foutput = output
        pre = torch.argmax(foutput,dim=1)
        pres.append(pre)
    pres = torch.cat(pres)
    
    return pres

def get_dataset(conf):

    datadir = 'data/cassava'

    if conf and 'datadir' in conf:
        datadir = conf.datadir


    trainpd,valpd = None,None

    testimgdir = datadir + '/test_images'
    testfile = glob.glob(testimgdir+'/*.jpg')
    testfile = [os.path.basename(fn) for fn in testfile]
    testpd = pd.DataFrame(testfile, columns =['image_id']) 
    testpd['label'] = 0

  
    if 'foldid' in conf:
        traindata = pd.read_csv(datadir+'/train.csv')
        folds = StratifiedKFold(n_splits=5).split(np.arange(traindata.shape[0]), traindata.label.values)
        trainidx,validx = list(folds)[conf.foldid]
        trainpd = traindata.loc[trainidx,:].reset_index(drop=True)
        valpd = traindata.loc[validx,:].reset_index(drop=True)
    imgdir = datadir + '/train_images'
    transform_train,transform_test = get_cassava_transform(conf)
    ds_train = ImageLoader(imgdir, train=True, transform=transform_train,pdata=trainpd)
    
    

    if conf.tta is None or conf.tta == 2:
        
        ds_val = ImageLoader(imgdir, train=False, transform=transform_test,pdata=valpd,tta=conf.tta)
        ds_test = ImageLoader(testimgdir, train=False, transform=transform_test,pdata=testpd,tta=conf.tta)
    else:
        ds_test = ImageLoader(testimgdir, train=False, transform=transform_train,pdata=testpd,tta=conf.tta)
        ds_val = ImageLoader(imgdir, train=False, transform=transform_train,pdata=valpd,tta=conf.tta)

    return ds_train,ds_val,ds_test,testpd

def get_params(model,conf=None):

    if conf is not None and 'prams_group' in  conf:
        prams_group = conf.prams_group
        lr_group = conf.lr_group
        params = []
        for pram,lr in zip(prams_group,lr_group):
            params.append({'params':model.module.get_params(pram),'lr': lr})

        return params

    return model.parameters()

 

In [None]:
set_env(seed=0)
foldid = 2

conf = edict({
    'depth':50,
    'pretrained':True,
    'num_class':5,
    'midlevel':True,
    'datadir':'../input/cassava-leaf-disease-classification',
    'dataset':'cassava',
    'testing':False,
    'tta': 2,
    'foldid':foldid,
    'cropsize':448,
    'netname':'resnet50',
    'net_type':'resnet_ft',
    'prams_group':['ftlayer','freshlayer'],
    'lr_group':[0.001,0.01],
    'lrstep':[20],
    'lr': 0.001,
    'epochs':30,
    'lrgamma':0.1,
    'criterion':'CrossEntropyLoss',
    'reduction':'none',
    'momentum':0.9,
    'weight_decay':1e-4,
    'pretrained':True,
    'mixmethod': 'snapmix',
    'prob': 1,
    'beta': 5}
   )



ds_train,ds_val,ds_test,testpd = get_dataset(conf)
train_loader =data.DataLoader(ds_train, batch_size=16, shuffle= True, num_workers=8, pin_memory=True)
val_loader =data.DataLoader(ds_val, batch_size=32, shuffle= False, num_workers=8, pin_memory=True)
test_loader =data.DataLoader(ds_test, batch_size=32, shuffle= False, num_workers=16, pin_memory=True)


model = eval(conf.net_type).get_net(conf)
model = nn.DataParallel(model).cuda()

optimizer = torch.optim.SGD(get_params(model,conf),conf.lr,momentum=conf.momentum,weight_decay=conf.weight_decay,nesterov=True)
criterion = nn.CrossEntropyLoss(reduction='none').cuda()

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=conf.lrstep, gamma=conf.lrgamma, last_epoch=-1)
if conf.midlevel:
    outfile = './'+conf.netname+'_mf'+str(foldid)+'.pt'
else:
    outfile = './'+conf.netname+'_f'+str(foldid)+'.pt'

if not os.path.isfile('../input/resfold2/'+outfile):
    best_score = 0.
    ## ------main loop-----
    for epoch in range(0, conf.epochs):  
        lr = optimizer.param_groups[0]['lr']
        print("Epoch: [{} | {} LR: {}".format(epoch+1,conf.epochs,lr))
        tmp_loss = train(train_loader, model, criterion, optimizer, conf)
        scheduler.step()
        infostr = {'Epoch:  {}   train_loss: {}'.format(epoch+1,tmp_loss)}
        print(infostr)
        with torch.no_grad():
            val_score,val_loss,mscore,ascore = validate(val_loader, model,criterion, conf)
            comscore = val_score
            if conf.midlevel:
                comscore = ascore
            is_best = comscore > best_score
            best_score = max(comscore,best_score)
            infostr = {'Epoch:  {:.4f}   loss: {:.4f},gs: {:.4f},bs:{:.4f},ms: {:.4f},as:{:.4f}'.format(epoch+1,val_loss,val_score,best_score,mscore,ascore)}
            print(infostr)
            if is_best:
                mdict = {'state_dict': model.module.state_dict(),'epoch':epoch}
                torch.save(mdict,outfile)
else:
    outfile = '../input/resfold2/'+outfile

load_checkpoint(model,outfile)
with torch.no_grad():
    pres = predict(model,test_loader,conf.midlevel).cpu().numpy()
testpd['label'] = pres
testpd.to_csv('submission.csv', index=False)
print(testpd.head())
                    
                    