In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
import time
import os
import copy
import shutil# **Testing**

In [2]:
len(os.listdir('/kaggle/input/pokemon-generation-one/dataset'))

150

In [3]:
basePath = '/kaggle/'

modelSave = basePath+'working/weights.pth'
data_dir = basePath+'input/pokemon-generation-one/dataset/'
working_dir = basePath+'working/'
model_file = '/kaggle/input/alexnet/alexnet.pth'

num_classes = 150
batch_size = 8
num_epochs = 50
input_size = 224

In [4]:
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
'''import zipfile
import os
for file_name in os.listdir('/content/'):
  if file_name.endswith('.zip'):
    with zipfile.ZipFile(file_name,'r') as zip_dir:
      zip_dir.extractall(path='/content/')'''

"import zipfile\nimport os\nfor file_name in os.listdir('/content/'):\n  if file_name.endswith('.zip'):\n    with zipfile.ZipFile(file_name,'r') as zip_dir:\n      zip_dir.extractall(path='/content/')"

In [6]:
# torch.utils.model_zoo.load_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', model_dir='/kaggle/working')

In [7]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs,labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [8]:
def set_parameter_requires_grad(model):
    for param in model.parameters():
        param.requires_grad = False

In [9]:
def initialize_model(num_classes):
    alexnet = models.alexnet(pretrained=True)
    set_parameter_requires_grad(alexnet)
    num_ftrs = alexnet.classifier[6].in_features
    alexnet.classifier[6] = nn.Linear(num_ftrs,num_classes)
    return alexnet

In [10]:
def getTrainDataLoaders():
    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    print("Initializing Datasets and Dataloaders...")

    # Create training and validation datasets
    image_datasets = {x: datasets.ImageFolder(os.path.join(working_dir, x), data_transforms[x]) for x in ['train', 'val']}
    # Create training and validation dataloaders
    dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

    return dataloaders_dict

In [11]:
def getUpdatablePara(model):
    params_to_update = model.parameters()
    print("Params to learn:")
    params_to_update = []
    for name,param in model.named_parameters():
        if  param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
    return params_to_update

In [12]:
def split_train_test_data(test_size = 0.75):
    for pokemon_class in os.listdir(data_dir):
        if pokemon_class == 'dataset':
            continue
        print(pokemon_class)
        class_path = os.path.join(data_dir, pokemon_class)
        img_name_list = os.listdir(class_path)
        train_list = img_name_list[:int(test_size * len(img_name_list))]
        val_list = img_name_list[int(test_size * len(img_name_list)):]
        
        destination_folder = os.path.join(working_dir, 'train', pokemon_class)
        os.makedirs(destination_folder)            
        for img in train_list:
            source = os.path.join(class_path, img)
            destination = os.path.join(destination_folder, img)
            dest = shutil.copyfile(source, destination) 
        
        destination_folder = os.path.join(working_dir, 'val', pokemon_class)
        os.makedirs(destination_folder)
        for img in val_list:
            source = os.path.join(class_path, img)            
            destination = os.path.join(destination_folder, img)
            dest = shutil.copyfile(source, destination) 

In [13]:
def copy_model_to_cache():
    cache_dir = os.path.expanduser(os.path.join('~', '.torch'))
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    models_dir = os.path.join(cache_dir, 'models')
    if not os.path.exists(models_dir):
        os.makedirs(models_dir)
    
    dest = shutil.copyfile(model_file, os.path.join(models_dir, os.path.basename(model_file)))

# **Prepare Model**

In [14]:
# copy_model_to_cache()

# **Prepare Train and Test Data**

In [15]:
split_train_test_data(0.8)

Metapod
Weepinbell
Pidgeot
Zubat
Voltorb
Dratini
Sandslash
Wartortle
Charmander
Flareon
Exeggcute
Nidorino
Spearow
Drowzee
Electabuzz
Poliwrath
Dewgong
Mankey
Jolteon
Weedle
Dragonair
Paras
Hypno
Shellder
Starmie
MrMime
Electrode
Weezing
Arcanine
Clefable
Omanyte
Ekans
Venonat
Kingler
Butterfree
Machop
Venusaur
Bellsprout
Cloyster
Scyther
Raichu
Meowth
Wigglytuff
Kabuto
Ponyta
Lickitung
Machoke
Magnemite
Onix
Grimer
Tentacruel
Parasect
Machamp
Squirtle
Haunter
Tangela
Arbok
Dragonite
Kangaskhan
Gengar
Poliwag
Vaporeon
Hitmonlee
Graveler
Primeape
Rhyhorn
Charmeleon
Charizard
Magneton
Ninetales
Mew
Gloom
Hitmonchan
Slowbro
Rhydon
Pidgey
Cubone
Marowak
Seel
Victreebel
Golem
Horsea
Lapras
Jigglypuff
Pinsir
Goldeen
Diglett
Beedrill
Growlithe
Chansey
Tauros
Kakuna
Exeggutor
Tentacool
Nidoking
Ditto
Abra
Slowpoke
Clefairy
Poliwhirl
Zapdos
Venomoth
Aerodactyl
Fearow
Nidorina
Muk
Caterpie
Blastoise
Eevee
Geodude
Alakazam
Doduo
Snorlax
Gastly
Nidoqueen
Raticate
Articuno
Dugtrio
Vulpix
Oddish
Gol

In [16]:
# count = 0
# for pokemon_class in os.listdir(data_dir):
#     if pokemon_class == 'dataset':
#         continue
#     class_path = os.path.join(data_dir, pokemon_class)
#     class_imgs = len(os.listdir(class_path))
#     train_imgs = len(os.listdir(os.path.join(working_dir, 'train', pokemon_class)))
#     val_imgs = len(os.listdir(os.path.join(working_dir, 'val', pokemon_class)))
#     if(class_imgs != (train_imgs + val_imgs)):
#         print(pokemon_class, class_imgs, train_imgs, val_imgs)
#         count = count + 1
# print(count)

# **Training**

In [17]:
alexnet = initialize_model(num_classes)
alexnet = alexnet.to(device)
#print(alexnet)

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /root/.cache/torch/checkpoints/alexnet-owt-4df8aa71.pth


HBox(children=(FloatProgress(value=0.0, max=244418560.0), HTML(value='')))




In [18]:
dataloaders = getTrainDataLoaders()

Initializing Datasets and Dataloaders...


In [19]:
params_to_update = getUpdatablePara(alexnet)

Params to learn:
	 classifier.6.weight
	 classifier.6.bias


In [20]:
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(params_to_update, lr=0.0001, momentum=0.9)

#loss function 
criterion = nn.CrossEntropyLoss()

alexnet, hist = train_model(alexnet, dataloaders, criterion, optimizer_ft, num_epochs=num_epochs)

torch.save(alexnet.state_dict(), modelSave)

Epoch 0/49
----------


  'to RGBA images')


train Loss: 4.1139 Acc: 0.1715


  'to RGBA images')
  'to RGBA images')


val Loss: 3.2206 Acc: 0.3088

Epoch 1/49
----------


  'to RGBA images')


train Loss: 3.1035 Acc: 0.3235


  'to RGBA images')


val Loss: 2.6846 Acc: 0.4052

Epoch 2/49
----------


  'to RGBA images')


train Loss: 2.6873 Acc: 0.4048


  'to RGBA images')
  'to RGBA images')


val Loss: 2.4128 Acc: 0.4582

Epoch 3/49
----------


  'to RGBA images')


train Loss: 2.4793 Acc: 0.4413


  'to RGBA images')
  'to RGBA images')


val Loss: 2.2697 Acc: 0.4742

Epoch 4/49
----------


  'to RGBA images')


train Loss: 2.2989 Acc: 0.4759


  'to RGBA images')
  'to RGBA images')


val Loss: 2.1557 Acc: 0.4966

Epoch 5/49
----------


  'to RGBA images')


train Loss: 2.2056 Acc: 0.4863


  'to RGBA images')


val Loss: 2.0781 Acc: 0.5121

Epoch 6/49
----------


  'to RGBA images')


train Loss: 2.1034 Acc: 0.5104


  'to RGBA images')


val Loss: 2.0270 Acc: 0.5304

Epoch 7/49
----------


  'to RGBA images')


train Loss: 2.0346 Acc: 0.5187


  'to RGBA images')


val Loss: 1.9818 Acc: 0.5381

Epoch 8/49
----------


  'to RGBA images')


train Loss: 1.9679 Acc: 0.5332


  'to RGBA images')


val Loss: 1.9501 Acc: 0.5386

Epoch 9/49
----------


  'to RGBA images')


train Loss: 1.9201 Acc: 0.5443


  'to RGBA images')
  'to RGBA images')


val Loss: 1.9281 Acc: 0.5395

Epoch 10/49
----------


  'to RGBA images')


train Loss: 1.8667 Acc: 0.5590


  'to RGBA images')
  'to RGBA images')


val Loss: 1.9214 Acc: 0.5464

Epoch 11/49
----------


  'to RGBA images')


train Loss: 1.8248 Acc: 0.5631


  'to RGBA images')
  'to RGBA images')


val Loss: 1.8875 Acc: 0.5528

Epoch 12/49
----------


  'to RGBA images')


train Loss: 1.7847 Acc: 0.5705


  'to RGBA images')
  'to RGBA images')


val Loss: 1.8640 Acc: 0.5519

Epoch 13/49
----------


  'to RGBA images')


train Loss: 1.7367 Acc: 0.5861


  'to RGBA images')


val Loss: 1.8498 Acc: 0.5624

Epoch 14/49
----------


  'to RGBA images')


train Loss: 1.7133 Acc: 0.5825


  'to RGBA images')
  'to RGBA images')


val Loss: 1.8477 Acc: 0.5646

Epoch 15/49
----------


  'to RGBA images')


train Loss: 1.7128 Acc: 0.5913


  'to RGBA images')
  'to RGBA images')


val Loss: 1.8353 Acc: 0.5665

Epoch 16/49
----------


  'to RGBA images')


train Loss: 1.6529 Acc: 0.5918


  'to RGBA images')


val Loss: 1.8285 Acc: 0.5569

Epoch 17/49
----------


  'to RGBA images')


train Loss: 1.6613 Acc: 0.5991


  'to RGBA images')


val Loss: 1.8188 Acc: 0.5742

Epoch 18/49
----------


  'to RGBA images')


train Loss: 1.6122 Acc: 0.6111


  'to RGBA images')
  'to RGBA images')


val Loss: 1.8237 Acc: 0.5715

Epoch 19/49
----------


  'to RGBA images')


train Loss: 1.5857 Acc: 0.6122


  'to RGBA images')


val Loss: 1.8179 Acc: 0.5605

Epoch 20/49
----------


  'to RGBA images')


train Loss: 1.5775 Acc: 0.6163


  'to RGBA images')
  'to RGBA images')


val Loss: 1.8055 Acc: 0.5815

Epoch 21/49
----------


  'to RGBA images')


train Loss: 1.5229 Acc: 0.6301


  'to RGBA images')
  'to RGBA images')


val Loss: 1.8057 Acc: 0.5761

Epoch 22/49
----------


  'to RGBA images')


train Loss: 1.5345 Acc: 0.6266


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7961 Acc: 0.5697

Epoch 23/49
----------


  'to RGBA images')


train Loss: 1.5015 Acc: 0.6358


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7967 Acc: 0.5788

Epoch 24/49
----------


  'to RGBA images')


train Loss: 1.4804 Acc: 0.6318


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7960 Acc: 0.5724

Epoch 25/49
----------


  'to RGBA images')


train Loss: 1.4937 Acc: 0.6333


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7961 Acc: 0.5774

Epoch 26/49
----------


  'to RGBA images')


train Loss: 1.4398 Acc: 0.6443


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7870 Acc: 0.5779

Epoch 27/49
----------


  'to RGBA images')


train Loss: 1.4548 Acc: 0.6409


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7885 Acc: 0.5738

Epoch 28/49
----------


  'to RGBA images')


train Loss: 1.4447 Acc: 0.6490


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7818 Acc: 0.5720

Epoch 29/49
----------


  'to RGBA images')


train Loss: 1.4040 Acc: 0.6542


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7813 Acc: 0.5665

Epoch 30/49
----------


  'to RGBA images')


train Loss: 1.4363 Acc: 0.6424


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7687 Acc: 0.5802

Epoch 31/49
----------


  'to RGBA images')


train Loss: 1.4048 Acc: 0.6548


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7706 Acc: 0.5806

Epoch 32/49
----------


  'to RGBA images')


train Loss: 1.3768 Acc: 0.6659


  'to RGBA images')


val Loss: 1.7641 Acc: 0.5811

Epoch 33/49
----------


  'to RGBA images')


train Loss: 1.3796 Acc: 0.6645


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7547 Acc: 0.5793

Epoch 34/49
----------


  'to RGBA images')


train Loss: 1.3427 Acc: 0.6698


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7508 Acc: 0.5829

Epoch 35/49
----------


  'to RGBA images')


train Loss: 1.3576 Acc: 0.6684


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7663 Acc: 0.5774

Epoch 36/49
----------


  'to RGBA images')


train Loss: 1.3367 Acc: 0.6716


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7670 Acc: 0.5829

Epoch 37/49
----------


  'to RGBA images')


train Loss: 1.3227 Acc: 0.6711


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7647 Acc: 0.5811

Epoch 38/49
----------


  'to RGBA images')


train Loss: 1.3278 Acc: 0.6748


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7710 Acc: 0.5847

Epoch 39/49
----------


  'to RGBA images')


train Loss: 1.3330 Acc: 0.6691


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7732 Acc: 0.5875

Epoch 40/49
----------


  'to RGBA images')


train Loss: 1.3036 Acc: 0.6794


  'to RGBA images')


val Loss: 1.7612 Acc: 0.5783

Epoch 41/49
----------


  'to RGBA images')


train Loss: 1.2609 Acc: 0.6868


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7614 Acc: 0.5806

Epoch 42/49
----------


  'to RGBA images')


train Loss: 1.3089 Acc: 0.6808


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7676 Acc: 0.5742

Epoch 43/49
----------


  'to RGBA images')


train Loss: 1.2800 Acc: 0.6815


  'to RGBA images')


val Loss: 1.7689 Acc: 0.5802

Epoch 44/49
----------


  'to RGBA images')


train Loss: 1.2649 Acc: 0.6863


  'to RGBA images')


val Loss: 1.7702 Acc: 0.5802

Epoch 45/49
----------


  'to RGBA images')


train Loss: 1.2440 Acc: 0.6897


  'to RGBA images')


val Loss: 1.7720 Acc: 0.5857

Epoch 46/49
----------


  'to RGBA images')


train Loss: 1.2890 Acc: 0.6815


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7605 Acc: 0.5870

Epoch 47/49
----------


  'to RGBA images')


train Loss: 1.2403 Acc: 0.6915


  'to RGBA images')


val Loss: 1.7588 Acc: 0.5934

Epoch 48/49
----------


  'to RGBA images')


train Loss: 1.2369 Acc: 0.6938


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7568 Acc: 0.5866

Epoch 49/49
----------


  'to RGBA images')


train Loss: 1.2130 Acc: 0.6950


  'to RGBA images')
  'to RGBA images')


val Loss: 1.7660 Acc: 0.5829

Training complete in 99m 23s
Best val Acc: 0.593422


# **Testing**

In [21]:
'''alexnet = initialize_model(num_classes)

alexnet.load_state_dict(torch.load(modelSave))
alexnet.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
alexnet = alexnet.to(device)
'''

'alexnet = initialize_model(num_classes)\n\nalexnet.load_state_dict(torch.load(modelSave))\nalexnet.eval()\ndevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")\nalexnet = alexnet.to(device)\n'