In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from IPython import display
from IPython.display import Image
plt.rcParams.update({'figure.figsize': [8,10]})


## 100種高粱品種識別

In [None]:
import time
import glob
import os
import math
import cv2
import builtins
import copy
os.environ['TRIDENT_BACKEND'] = 'pytorch'
os.environ['TRIDENT_HOME'] = './trident'

!pip uninstall tridentx -y
!pip install ../input/trident/tridentx-0.7.5-py3-none-any.whl --upgrade

In [None]:
import trident as T
from trident import *
from trident.models import efficientnet


In [None]:
import pandas as pd
df=pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv').dropna()
df

In [None]:
classnames=df.cultivar.unique().tolist()
classnames=list(sorted(classnames))
print(classnames)

In [None]:
import glob
all_images=glob.glob('../input/sorghum-id-fgvc-9/train_images/*.*g')
print(len(all_images))

images=[]
labels=[]

for index, row in df.iterrows():
    impath='../input/sorghum-id-fgvc-9/train_images/'+row['image']
    if impath in all_images:
        images.append(impath)
        labels.append(classnames.index(row['cultivar']))
        
print(len(images))
print(len(labels))
    

In [None]:
import cv2
import numpy as np
cl=CLAHE()
def multi_scale_colors(img,spec=None):
    #print(img.shape)
    
    img=cl(img)
    img1=cv2.resize(img.copy(),(240,240))
    
 
    idx=random.choice(list(range(4)))
    #print('idx:',idx,'{0}:{1},{2}:{3}'.format((idx//2)*512,(idx//2+1)*512,(idx%2)*512,(idx%2+1)*512))
    crop_image=img.copy()[(idx//2)*512:(idx//2+1)*512,(idx%2)*512:(idx%2+1)*512,:]
    #print(crop_image.shape)
    img2=cv2.resize(crop_image,(240,240), interpolation=cv2.INTER_AREA)
    
    idxes=random.choices(list(range(16)),k=4)
    #print('idxes:',idxes)
    img3=cv2.resize(img.copy()[(idxes[0]//4)*256:(idxes[0]//4+1)*256,(idxes[0]%4)*256:(idxes[0]%4+1)*256,:],(240,240), interpolation=cv2.INTER_AREA)
    img4=cv2.resize(img.copy()[(idxes[1]//4)*256:(idxes[1]//4+1)*256,(idxes[1]%4)*256:(idxes[1]%4+1)*256,:],(240,240), interpolation=cv2.INTER_AREA)
    
    img5=cv2.cvtColor(cv2.resize(img.copy()[(idxes[2]//4)*256:(idxes[2]//4+1)*256,(idxes[2]%4)*256:(idxes[2]%4+1)*256,:],(240,240), interpolation=cv2.INTER_AREA),cv2.COLOR_RGB2HSV)
    img9=cv2.cvtColor(cv2.resize(img.copy()[(idxes[3]//4)*256:(idxes[3]//4+1)*256,(idxes[3]%4)*256:(idxes[3]%4+1)*256,:],(240,240), interpolation=cv2.INTER_AREA),cv2.COLOR_BGR2YCR_CB)
    
    idxes=random.choices(list(range(64)),k=3)
    img6=cv2.resize(img.copy()[(idxes[0]//8)*128:(idxes[0]//8+1)*128,(idxes[0]%8)*128:(idxes[0]%8+1)*128,:],(240,240), interpolation=cv2.INTER_AREA)
    img7=cv2.resize(img.copy()[(idxes[1]//8)*128:(idxes[1]//8+1)*128,(idxes[1]%8)*128:(idxes[1]%8+1)*128,:],(240,240), interpolation=cv2.INTER_AREA)
    idx=random.choice(list(range(64)))
    img8=cv2.cvtColor(cv2.resize(img.copy()[(idxes[2]//8)*128:(idxes[2]//8+1)*128,(idxes[2]%8)*128:(idxes[2]%8+1)*128,:],(240,240), interpolation=cv2.INTER_AREA),cv2.COLOR_RGB2HSV)
    image_lists=[img1,img2,img3,img4,img5,img6,img7,img8,img9]
    random.shuffle(image_lists)
    new_img=np.concatenate([np.concatenate(image_lists[0:3],axis=1),np.concatenate(image_lists[3:6],axis=1),np.concatenate(image_lists[6:9],axis=1)],axis=0)
    
    return new_img



    
display.display(array2image(multi_scale_colors(cl(image2array('../input/sorghum-id-fgvc-9/train_images/2017-06-01__10-26-27-479.png').astype(np.uint8)))))
    

In [None]:
ds1=ImageDataset(images,object_type=ObjectType.rgb,symbol='images')
ds2=LabelDataset(labels,object_type=ObjectType.classification_label,symbol='labels')

ds2.binding_class_names(class_names=classnames)
print(ds2.class_names)
data_provider=DataProvider(traindata=Iterator(data=ds1,label=ds2,batch_size=4))


data_provider.image_transform_funcs = [
    RandomTransform(rotation_range=45, zoom_range=(0.9,1.2), shift_range=0.05, shear_range=0.1, random_flip=0.2,keep_prob=0.3,border_mode='zero'), 
    RandomAdjustGamma(gamma_range=(0.6,1.1)),
    RandomAdjustSaturation(value_range=(0.8, 1.6)),
    RandomAdjustContrast(value_range=(0.8, 1.4)),
    multi_scale_colors,
    AutoLevel(),
    SaltPepperNoise(prob=0.002),  # 椒鹽噪音
    Normalize(127.5, 127.5)]

In [None]:
data_provider2=DataProvider(traindata=Iterator(data=ds1,label=ds2,batch_size=3))


data_provider2.image_transform_funcs = [
    AutoLevel(),
    multi_scale_colors,
    Normalize(127.5, 127.5)]

In [None]:
def space_to_depth(x:np.ndarray, block_size=3):
    sq_size=block_size*block_size
    
    if len(x.shape)==4 and  x.shape[1]==3:
        new_tensors=[]
        for i in range(x.shape[0]):
            new_tensors.append(space_to_depth(x[i]))
        new_tensors=stack(new_tensors,axis=0)
            
    elif len(x.shape)==3:  
        new_tensors=[]
        if len(x.shape)==3 and x.shape[0]>x.shape[-1]:
            x=x.transpose([2,0,1])
        for i in range(block_size*block_size):
            new_tensors.append(x[:,(i//block_size)*240:(i//block_size+1)*240,(i%block_size)*240:(i%block_size+1)*240])
            
        new_tensors=stack(new_tensors,axis=0)
    return new_tensors
        


arr=multi_scale_colors(cl(image2array('../input/sorghum-id-fgvc-9/train_images/2017-06-01__10-26-27-479.png').astype(np.uint8)))
print(arr.shape)
arr=space_to_depth(to_tensor(image_backend_adaption(arr)), block_size=3)
print(arr.shape)
# arr1=space_to_depth(to_tensor(image_backend_adaption(arr)), block_size=3)
# print(arr0.shape)
display.display(array2image(to_numpy(arr)[0].transpose([1,2,0]).astype(np.uint8)))


In [None]:
data_provider.preview_images()

In [None]:
#_images,_labels=data_provider.next()

In [None]:
import torch
import gc
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()

In [None]:
class TanhExp(Layer):

    def __init__(self,keep_output=False, name=None):
        super(TanhExp, self).__init__(keep_output=keep_output,name=name)
        self._built = True

    def forward(self, x, **kwargs):

        return clip(x*torch.tanh(torch.exp(x)),-2,2)
    

In [None]:


    
class SorghumNetV1(Layer):
    def __init__(self, n_classes=100):
        super(SorghumNetV1, self).__init__()
        
        _effb1=efficientnet.EfficientNetB1(pretrained=True,input_shape=(3,240,240),include_top=False)
        _effb1.model.top_conv.trainable=True
        _effb1.model.block7b.trainable=True
        _effb1.model.block7b.dropout_rate=0.2
        _effb1.model.top_conv.activation=TanhExp()

        self.n_classes=n_classes
        self.backbone =_effb1.model
        self.agg=Sequential(
            Reshape((9,1280, 8, 8)),
            Aggregation(mode='mean',axis=1,keepdims=False),
            SeparableConv2d_Block((3,3),depth_multiplier=1,strides=1,auto_pad=True,use_bias=False,activation=TanhExp(),normalization='bn'),
            GlobalAvgPool2d(),
        )

        self.decoder=Dense(n_classes,activation=SoftMax())
     
    def forward(self, x):
        new_x=space_to_depth(x, block_size=3)
        B,N,C,H,W=new_x.shape
        new_x=new_x.reshape((B*N,C,H,W))
    
        return self.decoder(self.agg(self.backbone(new_x)))
    
    
    
class SorghumNetV2(Layer):
    def __init__(self, n_classes=100):
        super(SorghumNetV2, self).__init__()
        
        _effb1=efficientnet.EfficientNetB1(pretrained=True,input_shape=(3,240,240),include_top=False)
        _effb1.model.trainable=True
        _effb1.model.block7b.dropout_rate=0.2
        _effb1.model.top_conv.activation=TanhExp()

        self.n_classes=n_classes
        self.backbone =_effb1.model
        self.agg=Sequential(
            Reshape((9,1280, 8, 8)),
            Aggregation(mode='max',axis=1,keepdims=False),
            SeparableConv2d_Block((3,3),depth_multiplier=1,strides=1,auto_pad=True,use_bias=False,activation=TanhExp(),normalization='bn'),
            ShortCut(
                Identity(),
                Sequential(
                GlobalAvgPool2d(),
                Reshape((1280,1,1)),
                Conv2d((1,1),num_filters=100,use_bias=False,activation=TanhExp()),
                Conv2d((1,1),num_filters=1280,use_bias=False,activation=Sigmoid())
                ),mode='dot'
            ),
            GlobalAvgPool2d(),
        )
        
    
     
        self.decoder=Dense(n_classes,weight_norm='l2')
    
    
    def forward(self, x):
        new_x=space_to_depth(x, block_size=3)
        B,N,C,H,W=new_x.shape
        new_x=new_x.reshape((B*N,C,H,W))
    
      
        return self.decoder(self.agg(self.backbone(new_x)))
    

In [None]:
sorghumnet_v1=Model(input_shape=(3,720,720),output=SorghumNetV1(100))
#sorghumnet_v1.load_model('../input/sorghum-100-identification/Models/sorghumnet_v1_b1.pth')
#sorghumnet_v1.load_model('./Models/sorghumnet_v1_b1.pth')
sorghumnet_v1.summary()

In [None]:
sorghumnet_v2=Model(input_shape=(3,720,720),output=SorghumNetV2(100))
#sorghumnet_v2.load_model('../input/sorghum-100-identification/Models/sorghumnet_v2_b1.pth')
#sorghumnet_v2.load_model('./Models/sorghumnet_v2_b1.pth')
sorghumnet_v2.trainable=True
sorghumnet_v2.summary()

In [None]:
# cxt=get_session()
# def get_features(training_context):
#     model=training_context['current_model']
#     data=training_context['train_data']
#     data['features']=model.avg_pool.output
    
#     steps=training_context['steps']
#     if steps>0 and steps%20==0: 
#         if hasattr(cxt,'center_loss_fn'):
#             state_dict=OrderedDict()
#             state_dict['centers']=cxt.center_loss_fn.centers.data
#             with open('./Models/centers.pth', 'wb') as f:
#                 save(state_dict, f)
    
    
# cxt.center_loss_fn=CenterLoss(num_classes=100, feat_dim=1536, reduction="mean")    
# if os.path.exists('./Models/centers.pth'):
#     state_dict=load('./Models/centers.pth')
#     cxt.center_loss_fn.centers.data.copy_(state_dict['centers'].to(get_device()))
    

In [None]:
class ArcMarginProductLoss(Layer):
    def __init__(self, scale=32.0, margin=0.50, easy_margin=False, num_filters= 100,name='ArcMarginProductLoss'):
        super(ArcMarginProductLoss, self).__init__()
        self._name=name
        self.num_filters=num_filters
        self.scale = scale
        self.m = margin
        self.easy_margin = easy_margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)

        # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin
        self.base_loss=CrossEntropyLoss(reduction='mean')


    def forward(self,output, target,**kwargs):
        # cos(theta)
        try:
            cosine=l2_normalize(output)
            
            # cos(theta + m)
            sine = sqrt(1.0 - pow(cosine, 2))
            phi = cosine * self.cos_m - sine * self.sin_m

            if self.easy_margin:
                phi = where(cosine > 0, phi, cosine)
            else:
                phi = where((cosine - self.th) > 0, phi, cosine - self.mm)

            one_hot = zeros_like(cosine,requires_grad=True)
            one_hot.scatter(1, target.view(-1, 1), 1)

            output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
            output = output * self.scale
        except Exception as e:
            print(e)
            PrintException()

        loss = self.base_loss(output, target)
        return loss.mean()

In [None]:
from sklearn import manifold
from tqdm import  tqdm

 
#將arcface視覺化
def visualize(training_context):
    features=[]
    labels=[]
    preds=[]
    model=training_context['current_model']
    epoch=training_context['current_epoch']
    model.eval()
    if epoch>0:
#         if epoch==1:
  
#             model.block6e.trainable=True
#             model.block6d.trainable=True
#             model.block6c.trainable=True
#             model.block6b.trainable=True
#         elif epoch==11:
#             model.block6a.trainable=True
#         elif epoch==13:
#             model.block5d.trainable=True
#         elif epoch==15:
#             model.block5c.trainable=True


        NUM_COLORS = 100
        cm = plt.get_cmap('gist_rainbow')
        for i in tqdm(range(50)):
            _images,_labels=data_provider2.next()
            _result=to_numpy(argmax(model(to_tensor(_images)),axis=1))
            _features=to_numpy(_result['features'])

            for k in range(len(_images)):
                features.append(to_numpy(l2_normalize(_features[k])))
                labels.append(_labels[k])
                preds.append(_result[k])
        print('features',len(features),'labels',len(labels))
        labels=np.array(labels)
        features=np.array(features)
        preds=np.array(preds)

        print('accuracy:{0:.3%}'.format(np.equal(preds,labels).astype(np.float32).mean()))


        #利用TSNE降維成2維後，繪製成散布圖

        fig = plt.figure(figsize=(12,12))
        ax1= fig.add_subplot(1, 1, 1)
        tsne2 = manifold.TSNE(n_components=2, init='pca', random_state=0)  # 利用t-sne將512特徵向量降維至2
        print('tsne 訓練開始')
        features_tsne2 = tsne2.fit_transform(features) 
        #features_tsne2=l2_normalize(features_tsne2)
        print('tsne 訓練結束')
        for i in range(100):
            x_i = features_tsne2[:,0][labels==i]
            y_i = features_tsne2[:,1][labels==i]
            ax1.scatter(x_i,y_i,s=20,marker='o',c=cm(i//3*3.0/NUM_COLORS))

        model.train()
        plt.legend(classnames, loc = 'upper right')
        plt.title('epoch {0}'.format(epoch))
        plt.savefig('Results/epoch{0}.jpg'.format(epoch), bbox_inches='tight')
        plt.show()


In [None]:
#visualize(effb2.training_context)

優化方向
彙總前加上PRelu
加入DropPath

In [None]:
sorghumnet_v1.load_model('./Models/sorghumnet_v1_b1.pth')
sorghumnet_v2.load_model('./Models/sorghumnet_v2_b1.pth')
#sorghumnet_v2.model.load_state_dict(sorghumnet_v1.model.state_dict(),False)



#.with_loss(ArcMarginProductLoss(scale=32.0, margin=0.50, easy_margin=True, num_filters=100))\
# sorghumnet.with_optimizer(optimizer=DiffGrad,lr=1e-3,betas=(0.9, 0.999),gradient_centralization='all')\
#     .with_loss(CrossEntropyLoss(input_names=['classifier','labels']))\
#     .with_loss(ArcMarginProductLoss(scale=32.0, margin=0.50, easy_margin=True, num_filters=100))\
#     .with_metric(accuracy)\
#     .with_metric(accuracy,topk=3,name='top3_accuracy')\
#     .with_regularizer('l2') \
#     .with_model_save_path('./Models/sorghumnet_b1.pth')\
#     .trigger_when(when='on_epoch_start', frequency=1, unit='epoch', action=visualize)\
#     .with_learning_rate_scheduler(CosineLR(min_lr=1e-5,period=1000))\
#     .with_accumulate_grads(10)\
#     .with_automatic_mixed_precision_training()
    #.with_callbacks(CutMixCallback(alpha=1,loss_criterion=CrossEntropyLoss,save_path='Results',loss_weight=0.1))\
    
sorghumnet_v1.with_optimizer(optimizer=Adam,lr=1e-3,betas=(0.9, 0.999))\
    .with_loss(ArcMarginProductLoss(scale=32.0, margin=0.50, easy_margin=True, num_filters=100))\
    .with_metric(accuracy)\
    .with_metric(accuracy,topk=3,name='top3_accuracy')\
    .with_regularizer('l2')\
    .with_model_save_path('./Models/sorghumnet_v1_b1.pth')\
    .trigger_when(when='on_epoch_start', frequency=1, unit='epoch', action=visualize)\
    .with_learning_rate_scheduler(CosineLR(min_lr=1e-5,period=1000))\
    .with_automatic_mixed_precision_training()

sorghumnet_v2.with_optimizer(optimizer=Adam,lr=1e-3,betas=(0.9, 0.999),gradient_centralization='all')\
    .with_loss(ArcMarginProductLoss(scale=32.0, margin=0.50, easy_margin=True, num_filters=100))\
    .with_metric(accuracy)\
    .with_metric(accuracy,topk=3,name='top3_accuracy')\
    .with_regularizer('l2') \
    .with_model_save_path('./Models/sorghumnet_v2_b1.pth')\
    .trigger_when(when='on_epoch_start', frequency=1, unit='epoch', action=visualize)\
    .with_learning_rate_scheduler(CosineLR(min_lr=1e-5,period=1000))\
    .with_automatic_mixed_precision_training()

In [None]:

plan=TrainingPlan()\
    .add_training_item(sorghumnet_v1)\
    .add_training_item(sorghumnet_v2)\
    .with_data_loader(data_provider)\
    .repeat_epochs(100)\
    .with_batch_size(12)\
    .print_progress_scheduling(5,unit='batch')\
    .out_sample_evaluation_scheduling(frequency=50,unit='batch')\
    .display_loss_metric_curve_scheduling(frequency=100,unit='batch',imshow=True)\
    .save_model_scheduling(10,unit='batch')


plan.start_now()