In [None]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import ast
import json
from PIL import Image,ImageDraw,ImageDraw2
import  io
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision
from torchvision import transforms, utils
import torchvision.transforms as T
import os
import cv2
import glob
import time
import tqdm
import warnings
warnings.filterwarnings("ignore")


%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
en_dict = {}
dec_dict = {}

In [None]:
def encode_files(csv):
    """ Encode all label by name of csv_files """
    counter = 0
    for fn in csv:
        en_dict[fn[:-4].split('/')[-1].replace(' ', '_')] = counter
        counter += 1
        

def decode_labels(label):
    return dec_dict[label]

def get_label(nfile):
    return en_dict[nfile.replace(' ', '_')[:-4]]

def get_csv(path):
    csv_files = []
    for file in os.listdir(path):
        if file.endswith('csv'):
            csv_files.append(file)
    return csv_files

def strokes_to_arr_1Channel(arr):
    arr = ast.literal_eval(arr)
    x = [x_pnt for stroke in arr for x_pnt in stroke[0]]
    y = [x_pnt for stroke in arr for x_pnt in stroke[1]]

    plt.plot(x,y,color = 'black')
    plt.axis('off')

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    plt.clf()
    
    image = np.array(Image.open(buf))
    buf.close()
    image = np.transpose(image, (2, 0, 1))
    image = 0.2989*image[0] + 0.5870*image[1] + 0.1140*image[2]
    image = np.ceil(image)
    image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
    return image


def strokes_to_arr_3Channels(arr):
    arr = ast.literal_eval(arr)
    x = [x_pnt for stroke in arr for x_pnt in stroke[0]]
    y = [x_pnt for stroke in arr for x_pnt in stroke[1]]

    plt.plot(x,y,color = 'black')
    plt.axis('off')

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    plt.clf()
    
    image = np.array(Image.open(buf))
    buf.close()
    image = np.transpose(image, (2, 0, 1))
    image = np.asarray([image[0],image[1],image[2]])
    return image

def validation(lossf, scoref):
    model.eval()
    loss, score = 0, 0
    vlen = len(valid_loader)
    for x, y in valid_loader:
        x, y = x.to(device), y.to(device)
        output = model(x)
        loss += lossf(output, y).item()
        score += scoref(output, y)[0].item()
    model.train()
    return loss/vlen, score/vlen

def accuracy(output, target, topk=(3,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def mapk(output, target, k=3):
    """
    Computes the mean average precision at k.
    
    Parameters
    ----------
    output (torch.Tensor): A Tensor of predicted elements.
                           Shape: (N,C)  where C = number of classes, N = batch size
    target (torch.int): A Tensor of elements that are to be predicted. 
                        Shape: (N) where each value is  0≤targets[i]≤C−1
    k (int, optional): The maximum number of predicted elements
    
    Returns
    -------
    score (torch.float):  The mean average precision at k over the output
    """
    with torch.no_grad():
        batch_size = target.size(0)

        _, pred = output.topk(k, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        for i in range(k):
            correct[i] = correct[i]*(k-i)
            
        score = correct[:k].view(-1).float().sum(0, keepdim=True)
        score.mul_(1.0 / (k * batch_size))
        return score
    
def squeeze_weights(m):
        m.weight.data = m.weight.data.sum(dim=1)[:,None]
        m.in_channels = 1

In [None]:
path = '/kaggle/input/quickdraw-doodle-recognition/train_raw/'
csv = get_csv(path)
csv = sorted(csv)
encode_files(csv)
dec_dict = {v: k for k, v in en_dict.items()}

In [None]:
class DoodleDataset(Dataset):
    def __init__(self,csv,directory,function,mode,nrows,transform = None):
        self.csv = csv
        self.directory = directory
        self.function = function
        self.transform = transform
        self.mode = mode 
        self.image = pd.read_csv(directory + csv,usecols = ['drawing'],nrows = nrows)
        self.length = len(self.image)
        self.image = self.image.reset_index()
        self.label = get_label(csv)
        
    def __len__(self):
        return self.length
            
    def __getitem__(self, idx):
        image = self.function(self.image['drawing'][idx])
        label = self.label
    

        # Resize if needed and convert to tensor
        transform_resize = T.Resize((224, 224))
        image = transform_resize(torch.tensor(image).unsqueeze(0)).float()  # (1, 224, 224)
        
        if self.transform:
            image = self.transform(image)
        
        if self.mode == 'train':
            return image, label
        return image

In [None]:
model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_se_resnext101_32x4d')

In [None]:
for param in model.parameters():
    if isinstance(param,nn.Conv2d):
        param.requires_grad = False

In [None]:
dataset = ConcatDataset([DoodleDataset(csv[c],path,strokes_to_arr_1Channel,'train' ,10000,None) for c in range(0,100)])
loader = DataLoader(dataset, batch_size=64, shuffle=True,num_workers=0)

In [None]:
model.conv1.apply(squeeze_weights)

num_classes = 340
model.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, amsgrad=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5000,12000,18000], gamma=0.5)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)
torch.cuda.empty_cache()

In [None]:
epochs = 10
lsize = len(loader)
itr = 1
p_itr = 100 # print every N iteration
model.train()
tot = 0
acc = 0
tloss, score = 0, 0
last = 0
for epoch in range(epochs):
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        _,output = torch.max(output,1)
        loss.backward()
        optimizer.step()
        tloss += loss.item()
        last = acc
        acc += (output == y).sum().item()
        tot += 128
        scheduler.step()
        if itr%p_itr==0:
            print('Iteration {} -> Train Loss: {:.4f}, acc: {:.3f}'.format(itr, tloss/p_itr,acc/tot))
            if acc > last:
                last = acc
                acc = 0
                torch.save(model.state_dict(),"resnext_101_" +str(last)+".pth" )
            tloss, tot,acc = 0, 0,0
                
        itr +=1
        
        

In [None]:
filename_pth='resnext101.pth'
torch.save(model.state_dict(), filename_pth)