In [1]:
!pip3 install -U -q gdown
!gdown 1TF1muxFC0sEYMHNW89tr_skk3oUiK4y_

[0mDownloading...
From: https://drive.google.com/uc?id=1TF1muxFC0sEYMHNW89tr_skk3oUiK4y_
To: /kaggle/working/mobile_net_256_large.pth
100%|██████████████████████████████████████| 10.9M/10.9M [00:00<00:00, 19.4MB/s]


## Setup

In [2]:
import json
import os
import random
from typing import Dict
from warnings import filterwarnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torchvision
from PIL import Image, ImageDraw
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, ConcatDataset, Dataset
from torchvision import transforms
from tqdm.auto import tqdm

filterwarnings("ignore", category=DeprecationWarning)
filterwarnings("ignore", category=FutureWarning)
filterwarnings("ignore", category=UserWarning)
sns.set_theme()

In [3]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


seed_everything(20212022)

In [4]:
test_files = ['../input/shuffle-csvs/train_k99.csv.gz']
print("test_files : ", test_files)

def encode_labels():
    sample_file = "../input/shuffle-csvs/train_k0.csv.gz"
    sample_df = pd.read_csv(sample_file, usecols=['word', 'y']).drop_duplicates('word', keep='first')
    label2idx = {x: y for x, y in zip(sample_df['word'], sample_df['y'])}
    idx2label = {v: k for k, v in label2idx.items()}
    with open('classes.json', 'w') as f:
        dump_data = {"en_dict": label2idx, "dec_dict": idx2label}
        json.dump(dump_data, f, ensure_ascii=False)

    return label2idx, idx2label


en_dict, dec_dict = encode_labels()

test_files :  ['../input/shuffle-csvs/train_k99.csv.gz']


## Custom Dataset

In [5]:
class DoodleDataset(Dataset):
    def __init__(self,
                 filepath: str,
                 nrows=None,
                 skiprows=None,
                 size: int = 256,
                 transforms=None):

        self.size = size
        self.transforms = transforms
        self.data = pd.read_csv(filepath, usecols=['drawing', 'word', 'y'], nrows=nrows, skiprows=skiprows)

    @staticmethod
    def _draw(strokes, size, lw=6):
        BASE_SIZE = 256
        pil_img = Image.new('P', (BASE_SIZE, BASE_SIZE), color=255)
        img_draw = ImageDraw.Draw(pil_img)

        for stroke in strokes:
            for i in range(len(stroke[0]) - 1):
                img_draw.line((stroke[0][i], stroke[1][i], stroke[0][i + 1], stroke[1][i + 1]),
                              fill=0,
                              width=lw)
        return pil_img

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        stroke = json.loads(self.data['drawing'][idx])
        img = self._draw(stroke, self.size)

        if self.transforms:
            img = self.transforms(img)

        return img, self.data['y'][idx]

## Model

In [6]:
class DrawClassifier(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        self.net = torchvision.models.mobilenet_v2(pretrained=True)
        self.net.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(1280, num_classes)
        )

    def forward(self, x):
        return self.net(x)

## Data preparation

In [7]:
data_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

IN_SIZE = 224

test_ds = ConcatDataset([
    DoodleDataset(fn, size=IN_SIZE, transforms=data_transforms)
    for fn in test_files
])

BATCH_SIZE = 256

test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
print("Test set: ", len(test_ds))

Test set:  102042


In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = "./mobile_net_256_large.pth"
n_classes = len(en_dict.keys())
model = DrawClassifier(n_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


  0%|          | 0.00/13.6M [00:00<?, ?B/s]

In [9]:
from sklearn.metrics import classification_report

def test_model(model, test_loader):
    model.eval()

    test_correct = 0
    total = 0
    stream = tqdm(test_loader)
    rs = {
        "y_true":[],
        "y_pred":[]
    }

    with torch.no_grad():
        for inputs, labels in stream:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            test_correct += torch.sum(preds == labels.data)

            rs['y_true'].extend(labels.data.cpu().detach().numpy())
            rs['y_pred'].extend(preds.cpu().detach().numpy())
            total += labels.size(0)

        print("Test accuracy: {:.4f}".format(test_correct.double() / total))
        return rs

In [10]:
test_summary = test_model(model, test_loader)
y_true = test_summary['y_true']
y_pred = test_summary['y_pred']
sorted_en_dict = {k: en_dict[k] for k in sorted(en_dict, key=en_dict.get)}
target_names = list(sorted_en_dict.keys())

clf_report = classification_report(y_true, y_pred, output_dict=True)

  0%|          | 0/399 [00:00<?, ?it/s]

Test accuracy: 0.7739


In [11]:
df = pd.DataFrame.from_dict(clf_report)
final_report = df.filter(['macro avg', 'accuracy', 'weighted avg'])
final_report

Unnamed: 0,macro avg,accuracy,weighted avg
precision,0.779544,0.773868,0.779706
recall,0.774303,0.773868,0.773868
f1-score,0.77471,0.773868,0.774565
support,102042.0,0.773868,102042.0
