In [1]:
WANDB=True

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import wandb
from fastai.vision.all import *
from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.swav import *
from fastai.callback.wandb import WandbCallback
import random

from sklearn.model_selection import StratifiedKFold
from sklearn.utils import shuffle

In [4]:
path = Path("/home/mu/.fastai/data/rice-disease-classification/")
image_path = path/"images"
files = get_image_files(image_path)

In [5]:
train_df = pd.read_csv(path/"Train.csv")
test_df = pd.read_csv(path/"Test.csv")
sample_df = pd.read_csv(path/"SampleSubmission.csv")

In [6]:
train_df.Label.unique()

array(['blast', 'brown', 'healthy'], dtype=object)

In [7]:
test_df.head()

Unnamed: 0,Image_id
0,id_00vl5wvxq3.jpg
1,id_00vl5wvxq3_rgn.jpg
2,id_01hu05mtch.jpg
3,id_01hu05mtch_rgn.jpg
4,id_030ln10ewn.jpg


In [8]:
sample_df.head()

Unnamed: 0,Image_id,blast,brown,healthy
0,id_00vl5wvxq3.jpg,0.0,0.0,0.0
1,id_01hu05mtch.jpg,0.0,0.0,0.0
2,id_030ln10ewn.jpg,0.0,0.0,0.0
3,id_03z57m8xht.jpg,0.0,0.0,0.0
4,id_04ngep1w4b.jpg,0.0,0.0,0.0


In [9]:
size=224
bs=128

In [10]:
clas_block = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=RandomSplitter(seed=42),
                       get_x=ColReader(0, pref=image_path),
                       get_y=ColReader("Label"),
                       item_tfms=Resize(size),
                       batch_tfms=[*aug_transforms(), Normalize])

In [11]:
shuf_df = shuffle(train_df, random_state=42)
dls = clas_block.dataloaders(shuf_df, bs=bs)

In [12]:
# dls.show_batch()

In [13]:
test_dl = dls.test_dl(test_df, bs=64)

In [14]:
test_dl.n

2290

In [15]:
dls.c

3

In [16]:
dls.cuda()

<fastai.data.core.DataLoaders at 0x7f8b7da8fa00>

In [17]:
# train_dl.show_batch()

In [18]:
# %pdb

In [19]:
optdict = dict(sqr_mom=0.99,mom=0.95,beta=0.,eps=1e-4)
opt_func = partial(ranger, **optdict)

In [20]:
arch = "xresnet34"

In [21]:
swav_encoder = "models/run-mwalimu-128-swav-xresnet34-pretrain-rice-disease-epc447-sz128px-bs32_encoder.pth"

In [22]:
WANDB=False
if WANDB:
    xtra_config = {"Arch": arch, "Resize": size, "Algorithm": "Pretrained SWAV", "Epochs": 100, "Size": size, "Pretrained": True, "Batch Size": bs}
    wandb.init(project="rice-disease-classification", config=xtra_config);

In [23]:
def split_func(m): return L(m[0], m[1]).map(params)

def create_learner(dls, arch='xresnet34', encoder_path="models/swav-pretrain-rice-disease-epc37_encoder.pth"):
    pretrained_encoder = torch.load(encoder_path)
    encoder = create_encoder(arch, pretrained=False, n_in=3)
    encoder.load_state_dict(pretrained_encoder)
    nf = encoder(torch.randn(2,3,size,size)).size(-1) # size=128,size=224
    classifier = create_cls_module(nf, dls.c, ps=0.5)
    cbs = [SaveModelCallback(fname=f"classifier-best-{arch}-sz{size}"),
           EarlyStoppingCallback(patience=100),
           MixUp()]
    if WANDB: cbs += [WandbCallback(log_preds=True,log_model=True)]
    model = nn.Sequential(encoder, classifier)
    learn = Learner(dls, model, opt_func=opt_func, splitter=split_func,
                    metrics=[accuracy, F.cross_entropy], loss_func=LabelSmoothingCrossEntropy(),
                    cbs=cbs)
    return learn

In [24]:
learn = create_learner(dls, encoder_path=swav_encoder)

In [25]:
# rice-disease-classification-africa/models/classifier-comic-firebrand-3-xresnet34-swav-pretrain-rice-disease-epc127-sz128px-bs256.pth
learn.load("rice-disease-classifier-best-xresnet34-epc187-sz224")

  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


<fastai.learner.Learner at 0x7f8b7da1f5e0>

In [27]:
learn.to_fp16()

<fastai.learner.Learner at 0x7f8b7da1f5e0>

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(600, 1e-4, wd=1e-2, moms=(0.95, 0.85, 0.95))

epoch,train_loss,valid_loss,accuracy,cross_entropy,time
0,0.511475,0.523205,0.876405,0.348689,00:41
1,0.507988,0.51879,0.882959,0.344424,00:40
2,0.509016,0.521682,0.881086,0.345753,00:40
3,0.50399,0.52064,0.88015,0.346734,00:41
4,0.504275,0.521165,0.878277,0.346501,00:41
5,0.506875,0.51908,0.881086,0.344741,00:40
6,0.505144,0.521177,0.879214,0.34649,00:40
7,0.504783,0.520491,0.882023,0.345954,00:41
8,0.504168,0.521771,0.879214,0.347198,00:40
9,0.503388,0.516471,0.885768,0.34145,00:40


Better model found at epoch 0 with valid_loss value: 0.5232045650482178.
Better model found at epoch 1 with valid_loss value: 0.5187902450561523.
Better model found at epoch 9 with valid_loss value: 0.5164709091186523.
Better model found at epoch 10 with valid_loss value: 0.5163695812225342.
Better model found at epoch 11 with valid_loss value: 0.5155891180038452.
Better model found at epoch 28 with valid_loss value: 0.5155863165855408.


In [28]:
# learn.fit_one_cycle(800, 1e-3, wd=1e-2, moms=(0.95, 0.85, 0.95))

In [29]:
# learn.recorder.plot_loss()

In [28]:
item_tfms = [ToTensor(), RandomResizedCrop(size, min_scale=0.75, ratio=(1.,1.))]
batch_tfms = [IntToFloatTensor(), *aug_transforms(size=int(size*0.6), max_warp=0), Normalize.from_stats(*imagenet_stats)]

In [29]:
preds, targs = learn.tta(dl=test_dl, item_tfms=item_tfms, batch_tfms=batch_tfms)

In [30]:
learn.export()

In [31]:
sample_df.head()

Unnamed: 0,Image_id,blast,brown,healthy
0,id_00vl5wvxq3.jpg,0.0,0.0,0.0
1,id_01hu05mtch.jpg,0.0,0.0,0.0
2,id_030ln10ewn.jpg,0.0,0.0,0.0
3,id_03z57m8xht.jpg,0.0,0.0,0.0
4,id_04ngep1w4b.jpg,0.0,0.0,0.0


In [36]:
sample_df.count()

Image_id    1145
blast       1145
brown       1145
healthy     1145
dtype: int64

In [37]:
test_df.count()

Image_id    2290
dtype: int64

In [32]:
dls.vocab

['blast', 'brown', 'healthy']

In [33]:
preds

TensorBase([[0.9098, 0.0234, 0.0668],
        [0.9679, 0.0125, 0.0196],
        [0.1191, 0.8313, 0.0496],
        ...,
        [0.9696, 0.0127, 0.0177],
        [0.8439, 0.0203, 0.1357],
        [0.9324, 0.0342, 0.0334]])

In [40]:
preds[0]

TensorBase([0.9098, 0.0234, 0.0668])

In [38]:
test_df['blast'] = pd.series(preds[])

KeyError: 'blast'

In [35]:
run_name = "visionary-dew-7" #wandb.run.name
family = "swav-pretrain"
epc=187
save_name = f'classifier-{run_name}-{arch}-{family}-rice-disease-epc{epc}-sz{size}px-bs{bs}'
learn.save(save_name)

Path('models/classifier-visionary-dew-7-xresnet34-swav-pretrain-rice-disease-epc187-sz224px-bs128.pth')

In [30]:
if WANDB: wandb.finish()

VBox(children=(Label(value='86.357 MB of 86.357 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
accuracy,▁▁▂▃▅▅▄▅▇▅▇▅▆▆▆▇▇▅▅▅▅▅▅▇▆▆▇▇▆▆▇▆▇▆▆▆█▇▇▇
beta_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
beta_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▁▁▂▂▃▄▅▅▆▇▇████████▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▃▂
lr_1,▁▁▂▂▃▄▅▅▆▇▇████████▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▃▂
mom_0,██▇▇▆▅▄▄▃▂▂▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆
mom_1,██▇▇▆▅▄▄▃▂▂▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆

0,1
accuracy,0.89045
beta_0,0.0
beta_1,0.0
epoch,614.0
eps_0,0.0001
eps_1,0.0001
lr_0,0.00022
lr_1,0.00022
mom_0,0.9281
mom_1,0.9281


In [31]:
learn.tta?

[0;31mSignature:[0m
[0mlearn[0m[0;34m.[0m[0mtta[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mds_idx[0m[0;34m=[0m[0;36m1[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdl[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn[0m[0;34m=[0m[0;36m4[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mitem_tfms[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_tfms[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbeta[0m[0;34m=[0m[0;36m0.25[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0muse_max[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m Return predictions on the `ds_idx` dataset or `dl` using Test Time Augmentation
[0;31mFile:[0m      ~/miniconda3/envs/rave/lib/python3.9/site-packages/fastai/learner.py
[0;31mType:[0m      method


In [None]:
def finetune(size, epochs, arch, encoder_path, lr=1e-2, wd=1e-2):
    learn = create_learner(size, arch, swav_encoder)
    learn.unfreeze()
    learn.fit_flat_cos(epochs, lr, wd=wd)
    final_acc = learn.recorder.values[-1][-2]
    return final_acc