# Notebook example using Kaggle GPU

In [None]:
import numpy as np
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
import torch.nn.functional as F

import matplotlib.pyplot as plt


from tqdm import tqdm
import PIL.Image as Image

if not os.path.isdir('./experiments'):
    os.makedirs('./experiments')
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Detection

skip this for training ViT

In [None]:
detector = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
detector = detector.to(device)
detector.eval()


In [None]:
!rm -r ./bird_dataset_cropped

In [None]:
!cp -r ../input/mva-recvis-2021/bird_dataset ./bird_dataset_cropped

In [None]:
train_dataset = datasets.ImageFolder('./bird_dataset_cropped/train_images')
val_dataset = datasets.ImageFolder('./bird_dataset_cropped/val_images')
test_dataset = datasets.ImageFolder('./bird_dataset_cropped/test_images')

In [None]:
def imshow(tensor):
    img = tensor.permute(1, 2, 0).numpy()
    plt.imshow(img)

def crop_bird(img_tensor):
    '''
    For each image (3xHxW), detect bounding box of bird with highest proba and crop
    '''
    data = img_tensor.to(device).unsqueeze(0)
    out = detector(data)[0]
    
    boxes, labels, scores = out['boxes'], out['labels'], out['scores']
    boxes = boxes[labels==16]
    scores = scores[labels==16]  
    
#     n = boxes.size(0)
#     visual = utils.draw_bounding_boxes(torch.tensor(img.detach().cpu()*255, dtype=torch.uint8), out['boxes'], colors=[(int(x), int(x), int(x)) for x in (np.arange(n)/n*255)])
#     imshow(visual)
    
    x1, y1, x2, y2 = boxes[scores.argmax()]
    return img_tensor[:, int(y1):int(y2), int(x1):int(x2)].detach().cpu()
    
for dataset in [train_dataset, val_dataset, test_dataset]:
    for i in tqdm(np.arange(len(dataset))):
        path = dataset.imgs[i][0]
#         print('cropping' , path)
        img = transforms.ToTensor()(dataset[i][0])
        try:
            cropped = crop_bird(img)
            plt.imsave(path, cropped.permute(1, 2, 0).numpy())
        except:
            imshow(img)
            print('Cannot crop image ', path)
            continue
        
        

In [None]:
import shutil
shutil.make_archive('./bird_dataset_cropped', 'zip', './bird_dataset_cropped')

# Model + optimizer

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224), Image.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])

train_transforms = transforms.Compose([
    transforms.Resize((224, 224), Image.BILINEAR),
#     transforms.RandomRotation(15),
#     transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize    
])

valid_transforms = transforms.Compose([
    transforms.Resize((224, 224), Image.BILINEAR),
#     transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

batch_size = 16

train_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder('../input/bird-dataset-cropped/bird_dataset_cropped/train_images', transform=train_transforms),
    batch_size=batch_size, shuffle=True, num_workers=1)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder('../input/bird-dataset-cropped/bird_dataset_cropped/val_images',transform=valid_transforms),
    batch_size=batch_size, shuffle=True, num_workers=1)

test_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder('../input/bird-dataset-cropped/bird_dataset_cropped/test_images',transform=valid_transforms),
    batch_size=1, shuffle=False, num_workers=1)

train_size, val_size = len(train_loader.dataset), len(val_loader.dataset)
print(train_size, val_size)

In [None]:
! pip install timm

import timm

In [None]:

import gc
gc.collect()
torch.cuda.empty_cache()


for f in os.listdir('experiments/'):
    os.remove(os.path.join('experiments', f))



# model_name = 'resnet152'
# model_name = 'vgg19_bn'
# model_name = 'efficientnet_b7'
# model_name = 'resnext'
model_name = 'vit'

if model_name == 'resnet152':
    model = torchvision.models.resnet152(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    for param in model.layer4.parameters():
        param.requires_grad = True
    model.fc = nn.Linear(model.fc.in_features, 20)

if model_name == 'vgg19_bn':
    model = torchvision.models.vgg19_bn(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    for l in model.classifier.parameters():
        param.requires_grad = True
    model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 20)

if model_name == 'efficientnet_b7':
    model = torchvision.models.efficientnet_b7(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    for param in model.classifier.parameters():
        param.requires_grad = True
    model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 20)

if model_name == 'resnext':
    model = torchvision.models.resnext101_32x8d(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    for param in model.layer4.parameters():
        param.requires_grad = True
    model.fc = nn.Linear(model.fc.in_features, 20)
    
if model_name == 'vit':
    model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=20)
    for param in model.parameters():
        param.requires_grad = False
    for param in model.blocks[-1].parameters():
        param.requires_grad = True
    model.head = nn.Linear(model.head.in_features, 20)

In [None]:
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5, weight_decay=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, min_lr=1e-8, patience=5)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def train(model, epoch):
    model.train()
    training_loss = 0
    correct = 0
    
    for batch_idx, (data, labels) in enumerate(train_loader):
               
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        
        #forward
        preds = model(data)
        loss = criterion(preds, labels)
        
        training_loss += loss.data.cpu().item()*len(data)
        loss.backward()
        optimizer.step()
        
        probs = F.softmax(preds, dim=1)
        preds_classes = probs.max(1)[1]
        correct += (preds_classes == labels).sum().data.cpu().detach().item()
        
        if batch_idx % 25 == 0:
            print('[{:4d}/{:4d} ({:2.0f}%)]\tLoss: {:.4f}'.format(
                batch_idx * batch_size, train_size,
                100. * batch_idx * batch_size / train_size, loss.data.cpu().detach().item()))
    
    return training_loss / train_size, correct / train_size


def validation(model):
    model.eval()
    validation_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, labels in val_loader:
            data, labels = data.to(device), labels.to(device)
            preds = model(data)
            
            # sum up batch loss
            validation_loss += criterion(preds, labels).data.cpu().detach().item()*len(data)
            probs = F.softmax(preds, dim=1)
            preds_classes = probs.max(1)[1]
            correct += (preds_classes == labels).sum().data.cpu().detach().item()
            
    return validation_loss / val_size, correct / val_size



In [None]:
epochs = 25
training_losses = []
training_accs = []
validation_losses = []
validation_accs = []

best_acc = 0.

for epoch in range(1, epochs + 1):
    print("\n################################################# EPOCH", epoch)
    training_loss, training_acc = train(model, epoch)
    validation_loss, validation_acc = validation(model)
    
    training_losses.append(training_loss)
    training_accs.append(training_acc)
    validation_losses.append(validation_loss)
    validation_accs.append(validation_acc)
    
    scheduler.step(validation_loss)
    
    print('Training set:\t Average loss: {:.4f}\t Accuracy: {:.0f}/{:.0f} ({:.0f}%)'.format(
        training_loss, training_acc*train_size, train_size, training_acc*100))
    
    print('Validation set:\t Average loss: {:.4f}\t Accuracy: {:.0f}/{:.0f} ({:.0f}%)'.format(
        validation_loss, validation_acc*val_size, val_size, validation_acc*100))
    
    if validation_acc >= best_acc or epoch==epochs:
        print('\n**********Saving model with accuracy {:0.4f} at epoch {:2d}'.format(validation_acc, epoch))
        best_acc = validation_acc
        model_file = 'experiments' + '/model_' + str(epoch) + '.pth'
        torch.save(model.state_dict(), model_file)

# Test

In [None]:
preds = np.array([])
model.eval()
with torch.no_grad():
    for i, (data, labels) in tqdm(enumerate(test_loader, 18)):
        data, labels = data.to(device), labels.to(device)
        output1 = model(data)
        sm = nn.Softmax(dim=1)(output1)
        pred = sm.max(1, keepdim=True)[1]    
        preds = np.hstack((preds, torch.squeeze(pred).cpu().numpy()))

In [None]:
f = open("submission.csv", "w")
f.write("Id,Category\n")
for (n,_),p in zip(test_loader.dataset.samples,preds):
    f.write("{},{}\n".format(n.split('/')[-1].split('.')[0], int(p)))
f.close()