In [None]:
import os
from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from sklearn.model_selection import train_test_split
from tqdm import tqdm

class MyDataset(Dataset):
    def __init__(self, csv_path, data_dir = './', transform=None):
        super().__init__()
        self.df = pd.read_csv(csv_path).values
        self.data_dir = data_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_name, label = self.df[index]
        img_path = os.path.join(self.data_dir, img_name)
        with Image.open(img_path) as img:
            image = img.convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, label-1

In [None]:
use_gpu = True
num_classes = 5
num_epochs = 200
early_stopping = 10
model = models.densenet201(pretrained=True)

In [None]:
for para in list(model.parameters()):
    para.requires_grad=False
for para in list(model.features.denseblock3.parameters()):
    para.requires_grad=True
for para in list(model.features.transition3.parameters()):
    para.requires_grad=True
for para in list(model.features.denseblock4.parameters()):
    para.requires_grad=True
for para in list(model.features.norm5.parameters()):
    para.requires_grad=True

In [None]:
model.classifier = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(1920, num_classes),
)

In [None]:
if use_gpu:
    model = model.cuda()

In [None]:
df = pd.read_csv('../input/train/train.csv')
df.sort_values('filename',inplace=True)
df_train, df_valid = train_test_split(df, test_size=0.2, stratify=df[' type'].values, shuffle=True, 
                                     random_state=1234)
df_train.to_csv('df_train.csv',index=False)
df_valid.to_csv('df_valid.csv',index=False)

trans_train = transforms.Compose([transforms.RandomResizedCrop(size=224),
                            transforms.RandomHorizontalFlip(),
#                             transforms.ColorJitter(0.5,0, 0.5,0),
                            transforms.RandomGrayscale(),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])

trans_valid = transforms.Compose([transforms.Resize(size=256),
                            transforms.CenterCrop(size=224),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])

dataset_train = MyDataset(csv_path='df_train.csv', 
    data_dir='../input/train/data/', transform=trans_train)
dataset_valid = MyDataset(csv_path='df_valid.csv', 
    data_dir='../input/train/data/', transform=trans_valid)

loader_train = DataLoader(dataset = dataset_train, batch_size=32, shuffle=True, num_workers=0)
loader_valid = DataLoader(dataset = dataset_valid, batch_size=32, shuffle=False, num_workers=0)

In [None]:
params_to_update = []
for name,param in model.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        print("\t",name)

In [None]:
# classifer_params_id = list(map(id, model.classifier.parameters()))
# conv_params = filter(lambda p: id(p) not in classifer_params_id, params_to_update)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(params_to_update)
# optimizer = optim.SGD([
#             {'params': conv_params},
#             {'params': model.classifier.parameters(), 'lr': 1e-3}
#             ], lr=1e-4, momentum=0.9)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

best_val_acc = 0.0
best_epoch = 0
epoch_since_best = 0

for epoch in range(num_epochs):
    scheduler.step()
    model.train()
    train_total_samples = 0
    train_acc = 0
    train_loss = 0
    for i, data in enumerate(loader_train):
        print('.',end='')
        inputs, labels = data
        train_total_samples += labels.size()[0]
        if use_gpu:
            inputs, labels = inputs.cuda(), labels.cuda()

        optimizer.zero_grad()
        outputs = model(inputs)        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_pred = torch.argmax(outputs.data, dim=1)        
        train_acc += torch.sum(train_pred == labels.data)
        train_loss += loss.item() * labels.size()[0]
            
    model.eval()
    valid_total_samples = 0
    valid_acc = 0
    val_loss = 0
    for _, data in enumerate(loader_valid):
        inputs, labels = data
        valid_total_samples += labels.size()[0]
        if use_gpu:
            inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        valid_pred = torch.argmax(outputs.data, dim=1)        
        valid_acc += torch.sum(valid_pred == labels.data)
        val_loss += loss.item() * labels.size()[0]

    train_acc = train_acc.cpu().numpy() / train_total_samples
    valid_acc = valid_acc.cpu().numpy() / valid_total_samples
    train_loss = train_loss / train_total_samples
    val_loss = val_loss / valid_total_samples
    
    print()
    print('[Epoch %d] train loss %.6f train acc %.6f  valid loss %.6f valid acc %.6f' % (
        epoch, train_loss, train_acc, val_loss, valid_acc))

    if valid_acc > best_val_acc:
        best_val_acc = valid_acc
        best_epoch = epoch
        epoch_since_best = 0
        print('save model...')
        torch.save(model.state_dict(), 'tuned-resnet101.pt')
        print('saved.')
    else:
        epoch_since_best += 1
        
    if epoch_since_best > early_stopping:
        break
            
print('Finished Training')
print('best_epoch: %d, best_val_acc %.6f' % (best_epoch, best_val_acc))