In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from IPython import display

In [None]:
!pip  uninstall tridentx -y 
!pip install ../input/trident/tridentx-0.7.5-py3-none-any.whl --upgrade


In [None]:
import os
os.environ['TRIDENT_BACKEND']='pytorch'
import trident as T
from trident import *

In [None]:
import pandas as pd
df=pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv').dropna()
df

In [None]:
classnames=df.cultivar.unique().tolist()
classnames=list(sorted(classnames))
print(classnames)

In [None]:
import glob
all_images=glob.glob('../input/sorghum-id-fgvc-9/train_images/*.*g')
print(len(all_images))

images=[]
labels=[]

for index, row in df.iterrows():
    impath='../input/sorghum-id-fgvc-9/train_images/'+row['image']
    if impath in all_images:
        images.append(impath)
        labels.append(classnames.index(row['cultivar']))
        
print(len(images))
print(len(labels))

In [None]:
ds1=ImageDataset(images,symbol='images')
ds2=LabelDataset(labels,symbol='labels')
data_provider=DataProvider(traindata=Iterator(data=ds1,label=ds2))

data_provider.image_transform_funcs = [
    RandomTransform(rotation_range=45, zoom_range=(0.9,1.2), shift_range=0.05, shear_range=0.1, random_flip=0.2,keep_prob=0.3,border_mode='zero'), 
    RandomAdjustGamma(gamma_range=(0.6,1.1)),
    RandomAdjustSaturation(value_range=(0.8, 1.6)),
    RandomAdjustContrast(value_range=(0.8, 1.4)),
    AutoLevel(),
    SaltPepperNoise(prob=0.002),  # 椒鹽噪音
    Normalize(127.5, 127.5)]

In [None]:
data_provider.preview_images()

In [None]:
from trident.models import efficientnet

effb0=efficientnet.EfficientNetB0(pretrained=True,classes=100)
effb0.summary()


In [None]:


effb0.with_optimizer(optimizer=DiffGrad, lr=1e-3)\
    .with_loss(CrossEntropyLoss) \
    .with_metric(accuracy, name='accuracy')\
    .with_metric(accuracy,topk=3, name='top3 accuracy')\
    .with_regularizer('l2', reg_weight=1e-5)\
    .with_model_save_path('Models/effb0_1.pth')\
    .with_learning_rate_scheduler(CosineLR(min_lr=1e-5,period=3000))\
    .unfreeze_model_scheduling(1,'epoch',module_name='top_conv')\
    .unfreeze_model_scheduling(2,'epoch',module_name='block7a')\
    .with_automatic_mixed_precision_training()



In [None]:
plan=TrainingPlan() \
    .add_training_item(effb0) \
    .with_data_loader(data_provider)\
    .repeat_epochs(10)\
    .with_batch_size(32)\
    .out_sample_evaluation_scheduling(frequency=100,unit='batch')\
    .print_gradients_scheduling(frequency=100,unit='batch')\
    .print_progress_scheduling(10,unit='batch') \
    .display_loss_metric_curve_scheduling(frequency=200, unit='batch', imshow=True) \
    .save_model_scheduling(50,unit='batch')

In [None]:
plan.start_now()