In [1]:
import pandas as pd
from cassavadata import CassavaDataset
from pathlib import Path
from augmentations import get_augmentations,get_tta
from torch.utils.data import DataLoader
from lightning import CassavaModel
import torch
from models import Resnext,get_efficientnet
from sklearn.model_selection import StratifiedKFold
import numpy as np

In [2]:
test_df = pd.read_csv('../data/sample_submission.csv')
path = Path('../data/')
batch_size,num_workers = 32,8
ssl_models = [
    "resnet18_ssl",
    "resnet50_ssl",
    "resnext50_32x4d_ssl",
    "resnext101_32x4d_ssl",
    "resnext101_32x8d_ssl",
    "resnext101_32x16d_ssl",
]

In [3]:
df = pd.read_csv(path/'train.csv')

In [4]:
df.head()

Unnamed: 0,image_id,label
0,1000015157.jpg,0
1,1000201771.jpg,3
2,100042118.jpg,1
3,1000723321.jpg,1
4,1000812911.jpg,3


In [5]:
tta_tfms = get_tta(image_size=512)
test_ds = CassavaDataset(path=path/'test_images',df=test_df,transform=tta_tfms)
test_df = pd.read_pickle('../data/valid_df.pkl')
test_ds = CassavaDataset(path=path/'train_images',df=test_df,transform=tta_tfms)


In [6]:
test_dl = DataLoader( dataset=test_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=False,
            pin_memory=True,)

In [7]:
device = torch.device('cuda')
model = Resnext(model_name=ssl_models[2],num_classes=5,kaggle=True)
model = model.to(device)



In [8]:
model=get_efficientnet(model_name='tf_efficientnet_b4_ns',pretrained=False,num_classes=5)

In [9]:
s = "tf_efficientnet_b4_ns"

In [10]:
s.find('effi')

3

In [22]:
ls -al weights/resnext50_32x4d_ssl_0.pth/Cassava/28s783hr/checkpoints/

total 269900
drwxr-xr-x 2 root root        35 Dec 16 11:29  [0m[01;34m.[0m/
drwxr-xr-x 3 root root        25 Dec 16 11:13  [01;34m..[0m/
-rw-r--r-- 1 root root 276374401 Dec 16 11:29 'epoch=4-step=449.ckpt'


In [30]:
fold_id = 0
path = Path(f'weights/resnext50_32x4d_ssl_{fold_id}.pth/Cassava/28s783hr/checkpoints/')
chk = torch.load(list(path.iterdir())[0])

In [29]:
list(path.iterdir())[0]

PosixPath('weights/resnext50_32x4d_ssl_0.pth/Cassava/28s783hr/checkpoints/epoch=4-step=449.ckpt')

In [33]:
import glob

In [36]:
fold_id = 0
glob.glob(f'weights/resnext50_32x4d_ssl_{fold_id}.pth/Cassava/*/checkpoints/*')[0]

'weights/resnext50_32x4d_ssl_0.pth/Cassava/28s783hr/checkpoints/epoch=4-step=449.ckpt'

In [38]:
for fold_id in range(5):
    file = glob.glob(f'weights/resnext50_32x4d_ssl_{fold_id}.pth/Cassava/*/checkpoints/*')[0]
    chk = torch.load(file)
    model_weights = {k.replace('model.',''):v for k,v in chk['state_dict'].items()}
    torch.save(model_weights,f'model_weights_res_{fold_id}.pth')
    
    

In [39]:
ls -al


total 721676
drwxr-xr-x 1 root root     4096 Dec 16 13:42  [0m[01;34m.[0m/
drwxr-xr-x 1 root root       98 Dec 16 10:17  [01;34m..[0m/
drwxr-xr-x 1 root root       80 Dec 16 11:13  [01;34m.git[0m/
-rw-r--r-- 1 root root     1919 Dec 16 06:04  .gitignore
drwxr-xr-x 2 root root       45 Dec 15 09:08  [01;34m.ipynb_checkpoints[0m/
drwxr-xr-x 2 root root       27 Dec 14 05:40  [01;34m.vscode[0m/
drwxr-xr-x 1 root root      262 Dec 16 08:46  [01;34mCassava[0m/
drwxr-xr-x 5 root root       64 Dec 14 13:22 [01;34m'Cassava Leaf Disease'[0m/
-rw-r--r-- 1 root root     1075 Dec 14 05:40  LICENSE
-rw-r--r-- 1 root root       64 Dec 14 05:40  README.md
drwxr-xr-x 1 root root       72 Dec 16 13:30  [01;34m__pycache__[0m/
-rw-r--r-- 1 root root     2504 Dec 14 06:40  augmentations.py
-rw-r--r-- 1 root root     3723 Dec 16 05:48  cassavadata.py
-rw-r--r-- 1 root root    10072 Dec 14 14:07  hubconf_pretrained_false.py
-rw-r--r-- 1 root root     1412 Dec 14 13:36  infere

In [11]:
chk_path = '/notebooks/Cassava/Cassava/apelzc8i/checkpoints/epoch=4-step=449.ckpt'
chk = torch.load(chk_path)
model_weights = {k.replace('model.',''):v for k,v in chk['state_dict'].items()}
torch.save(model_weights,'model_weights_res.pth')

model.load_state_dict(model_weights)

FileNotFoundError: [Errno 2] No such file or directory: '/notebooks/Cassava/Cassava/apelzc8i/checkpoints/epoch=4-step=449.ckpt'

In [14]:
def get_preds():
    preds = []
    with torch.no_grad():
        for xb,_ in test_dl:
            xb = xb.to(device)
            pred = model(xb)
            preds.append(pred.to('cpu'))
    return torch.cat(preds)




In [15]:
preds = get_preds()

In [16]:
preds.shape

torch.Size([4280, 5])

In [18]:
preds = torch.zeros(len(test_ds),5)
for o in range(2):
    preds += get_preds()

In [20]:
preds /= 2

In [24]:
preds.argmax(1).tolist()

[3,
 3,
 3,
 2,
 3,
 4,
 3,
 3,
 3,
 3,
 1,
 3,
 3,
 4,
 2,
 3,
 1,
 2,
 3,
 3,
 3,
 3,
 3,
 4,
 3,
 1,
 3,
 3,
 0,
 3,
 4,
 3,
 3,
 3,
 4,
 1,
 3,
 0,
 3,
 4,
 3,
 3,
 3,
 3,
 3,
 4,
 3,
 3,
 3,
 3,
 0,
 2,
 1,
 2,
 3,
 3,
 1,
 3,
 3,
 3,
 1,
 3,
 1,
 3,
 3,
 3,
 3,
 4,
 3,
 3,
 3,
 3,
 4,
 3,
 3,
 4,
 2,
 3,
 3,
 3,
 3,
 4,
 3,
 3,
 4,
 3,
 3,
 3,
 3,
 2,
 1,
 3,
 3,
 4,
 3,
 4,
 0,
 3,
 2,
 3,
 4,
 3,
 4,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 2,
 3,
 4,
 1,
 4,
 3,
 4,
 0,
 0,
 1,
 3,
 3,
 2,
 3,
 1,
 3,
 3,
 3,
 2,
 3,
 3,
 3,
 3,
 3,
 4,
 4,
 2,
 2,
 4,
 3,
 3,
 0,
 2,
 2,
 4,
 3,
 1,
 4,
 2,
 3,
 3,
 3,
 3,
 3,
 0,
 3,
 3,
 3,
 2,
 3,
 1,
 3,
 3,
 4,
 2,
 3,
 3,
 2,
 4,
 3,
 1,
 0,
 3,
 1,
 3,
 3,
 3,
 3,
 3,
 3,
 2,
 4,
 3,
 3,
 3,
 4,
 3,
 1,
 3,
 2,
 3,
 0,
 3,
 0,
 2,
 3,
 3,
 3,
 2,
 4,
 4,
 3,
 3,
 4,
 3,
 1,
 3,
 3,
 0,
 3,
 3,
 3,
 2,
 3,
 2,
 3,
 3,
 3,
 3,
 4,
 3,
 4,
 2,
 3,
 1,
 3,
 1,
 4,
 3,
 0,
 3,
 3,
 3,
 3,
 3,
 3,
 4,
 3,
 1,
 4,
 2,
 4,
 3,
 3,
 3,
 3,
 4,
 4,
 0,
 3,


In [16]:
test_df.head()

Unnamed: 0,image_id,label
10544,288080098.jpg,3
11634,3080364100.jpg,3
578,110051175.jpg,3
9118,2612067247.jpg,1
15857,3852927202.jpg,3


In [18]:
import geffnet
import torch.nn as nn

In [15]:
eff = geffnet.create_model('tf_efficientnet_b4_ns',pretrained=True)

In [19]:
eff.classifier = nn.Linear(eff.classifier.in_features,5)

In [20]:
eff

GenEfficientNet(
  (conv_stem): Conv2dSame(3, 48, kernel_size=(3, 3), stride=(2, 2), bias=False)
  (bn1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (act1): SwishMe()
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
        (bn1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SwishMe()
        (se): SqueezeExcite(
          (avg_pool): AdaptiveAvgPool2d(output_size=1)
          (conv_reduce): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
          (act1): SwishMe()
          (conv_expand): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
        )
        (conv_pw): Conv2d(48, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Identity()
      )
  

In [None]:
python lightning.py --batch_size=32 --num_workers=42 --img_sz=512 --max_epochs=5 --model_name='tf_efficientnet_b4_ns'

In [7]:
ls -al Cassava/sc1xokcd/checkpoints/epoch=4-step=449.ckpt

total 206756
drwxr-xr-x 2 root root        35 Dec 15 10:02  [0m[01;34m.[0m/
drwxr-xr-x 3 root root        25 Dec 15 09:44  [01;34m..[0m/
-rw-r--r-- 1 root root 211716581 Dec 15 10:02 'epoch=4-step=449.ckpt'
