- https://www.kaggle.com/abhishek/very-simple-pytorch-training-0-59?scriptVersionId=16436961
- https://www.kaggle.com/abhishek/pytorch-inference-kernel-lazy-tta

In [1]:
dbg = False
if dbg:
    dbgtrnsz=100
    dbgvalsz=100

In [2]:
PRFX = 'devCv0630'
SEED = 111
SZ = (256, 256)
BSZ = 64
BSZ_INFER = BSZ*2
N_EPOCHS = 15


# setup

In [3]:
import random 
import numpy as np
import torch
import os
import datetime

def set_torch_seed(seed=SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
        torch.backends.cudnn.deterministic = True 
        torch.backends.cudnn.benchmark = False

set_torch_seed()

In [4]:
import pandas as pd
from collections import Counter
import time
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import mean_squared_error
from sklearn.metrics import cohen_kappa_score

from torch.utils.data import Dataset
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.optim import lr_scheduler

from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True


from pathlib import Path
p_o = f'../output/{PRFX}'
Path(p_o).mkdir(exist_ok=True, parents=True)

# preprocess

In [5]:
img2grd = []

In [6]:
p = '../input/aptos2019-blindness-detection'
pp = Path(p)
train = pd.read_csv(pp/'train.csv')

len_blnd = len(train)

img2grd_blnd = [(f'{p}/train_images/{o[0]}.png',o[1])  for o in train.values]

img2grd += img2grd_blnd
display(len(img2grd))
display(Counter(o[1] for o in img2grd).most_common())

3662

[(0, 1805), (2, 999), (1, 370), (4, 295), (3, 193)]

In [7]:
p = '../input/diabetic-retinopathy-detection'
pp = Path(p)
train=pd.read_csv(pp/'trainLabels.csv')
test=pd.read_csv(pp/'retinopathy_solution.csv')

img2grd_diab_train=[(f'{p}/train_images/{o[0]}.jpeg',o[1])  for o in train.values]
img2grd_diab_test=[(f'{p}/test_images/{o[0]}.jpeg',o[1])  for o in test.values]
img2grd += img2grd_diab_train
display(len(img2grd))
display(Counter(o[1] for o in img2grd).most_common())
img2grd += img2grd_diab_test
len(img2grd)
display(Counter(o[1] for o in img2grd).most_common())

38788

[(0, 27615), (2, 6291), (1, 2813), (3, 1066), (4, 1003)]

[(0, 67148), (2, 14152), (1, 6575), (3, 2280), (4, 2209)]

In [8]:
p = '../input/IDRID/B. Disease Grading'
pp = Path(p)
train=pd.read_csv(pp/'2. Groundtruths/a. IDRiD_Disease Grading_Training Labels.csv')
test=pd.read_csv(pp/'2. Groundtruths/b. IDRiD_Disease Grading_Testing Labels.csv')

img2grd_idrid_train=[(f'{p}/1. Original Images/a. Training Set/{o[0]}.jpg',o[1])  for o in train.values]
img2grd_idrid_test=[(f'{p}/1. Original Images/b. Testing Set/{o[0]}.jpg',o[1])  for o in test.values]
img2grd += img2grd_idrid_train
display(len(img2grd))
display(Counter(o[1] for o in img2grd).most_common())
img2grd += img2grd_idrid_test
len(img2grd)
display(Counter(o[1] for o in img2grd).most_common())

92777

[(0, 67282), (2, 14288), (1, 6595), (3, 2354), (4, 2258)]

[(0, 67316), (2, 14320), (1, 6600), (3, 2373), (4, 2271)]

In [9]:
img2grd = np.array(img2grd)

In [10]:
if np.all([Path(o[0]).exists() for o in img2grd]): print('All files are here!')

All files are here!


# dataset

In [11]:
set_torch_seed()
idx_val = range(len_blnd)
idx_trn = range(len_blnd, len(img2grd))

img2grd_trn = img2grd[idx_trn]
img2grd_val = img2grd[idx_val]

display(len(img2grd_trn), len(img2grd_val))

img2grd_trn[:3], img2grd_val[:3]

89218

3662

(array([['../input/diabetic-retinopathy-detection/train_images/10_left.jpeg',
         '0'],
        ['../input/diabetic-retinopathy-detection/train_images/10_right.jpeg',
         '0'],
        ['../input/diabetic-retinopathy-detection/train_images/13_left.jpeg',
         '0']], dtype='<U82'),
 array([['../input/aptos2019-blindness-detection/train_images/000c1434d8d7.png',
         '2'],
        ['../input/aptos2019-blindness-detection/train_images/001639a390f0.png',
         '4'],
        ['../input/aptos2019-blindness-detection/train_images/0024cdab0c1e.png',
         '1']], dtype='<U82'))

In [12]:
if dbg:
    img2grd_trn = img2grd_trn[:dbgtrnsz]
    img2grd_val = img2grd_val[:dbgvalsz]

In [13]:
class BlndDataset(Dataset):
    def __init__(self, img2grd, transform):
        self.img2grd = img2grd
        self.transform = transform

    def __len__(self):
        return len(self.img2grd)

    def __getitem__(self, idx):
        img,grd = img2grd[idx]
        image = self.transform(Image.open(img))
        label = torch.tensor(int(grd))
        return image, label

transform_train = transforms.Compose([
    transforms.Resize(SZ),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

ds_trn = BlndDataset(img2grd_trn, transform=transform_train)
ds_val = BlndDataset(img2grd_val, transform=transform_train)

data_loader = torch.utils.data.DataLoader(ds_trn, batch_size=BSZ, shuffle=True, num_workers=0)
data_loader_val = torch.utils.data.DataLoader(ds_val, batch_size=BSZ_INFER, shuffle=False, num_workers=0)

# model

In [14]:
model = torchvision.models.resnet50(pretrained=False)
model.load_state_dict(torch.load("../input/pytorch_models/resnet50-19c8e357.pth"));

In [15]:
# model.fc = nn.Linear(2048, 1)
model.fc = nn.Sequential(
                          nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                          nn.Dropout(p=0.25),
                          nn.Linear(in_features=2048, out_features=2048, bias=True),
                          nn.ReLU(),
                          nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                          nn.Dropout(p=0.5),
                          nn.Linear(in_features=2048, out_features=1, bias=True),
                         )

device = torch.device("cuda")

model = model.to(device)

In [16]:
plist = [
         {'params': model.layer4.parameters(), 'lr': 1e-4, 'weight': 0.001},
         {'params': model.fc.parameters(), 'lr': 1e-3}
         ]

optimizer = optim.Adam(plist, lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10)

# Training Loop

In [17]:
len_dl = len(data_loader)
len_ds = len(ds_trn)
len_dl_val = len(data_loader_val)
y_val = np.array([int(o[1]) for o in ds_val.img2grd])[:,None]

since = time.time()
criterion = nn.MSELoss()

set_torch_seed()
for epoch in range(N_EPOCHS):
    print(f'Epoch {epoch}/{N_EPOCHS-1}')
    scheduler.step()
    model.train()
    running_loss = 0.0
    running_n = 0
    for step, d in enumerate(data_loader):
        inputs = d[0]
        labels = d[1].view(-1, 1)
        inputs = inputs.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        running_n += inputs.size(0)
        if (step+1) % (2 if dbg else 1000) == 0: 
            print(f'[{datetime.datetime.now()}] epoch-{epoch} step-{step+1}/{len_dl} loss: {running_loss/running_n:.5f}')
    epoch_loss = running_loss / len_ds
    
    ###### val #######
    model.eval()
    preds_val = np.zeros((len(ds_val), 1))
    for step, d in enumerate(data_loader_val):
        if (step+1) % (2 if dbg else 1000) == 0: 
            print(f'[{datetime.datetime.now()}] val step-{step+1}/{len_dl_val}')
        inputs = d[0]
        inputs = inputs.to(device, dtype=torch.float)
        with torch.no_grad(): outputs = model(inputs)
        preds_val[step*BSZ_INFER:(step+1)*BSZ_INFER] = outputs.detach().cpu().squeeze().numpy()[:,None]#.ravel().reshape(-1, 1)
    
    mse_val = mean_squared_error(preds_val, y_val)        
    print(f'Training Loss: {epoch_loss:.4f}; Val Loss: {mse_val:.4f}')

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
torch.save(model.state_dict(), f"{p_o}/model.bin")

Epoch 0/14


KeyboardInterrupt: 