In [1]:
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.trainer as trainer
import torch.utils.trainer.plugins
from torch.autograd import Variable
import numpy as np
import os

from imagedataset import ImageDataset

In [2]:
%matplotlib inline
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [3]:
train_dir = "train"
# train_dir = "sample"
use_cuda = True
batch_size = 64
print('Using CUDA:', use_cuda)

Using CUDA: True


In [4]:
data_path = "data/galaxies/"

def dataframe_from_csv():
    df=pd.read_csv(data_path + "classes.csv", sep=',')
    df.set_index("GalaxyID", inplace=True)
    return df

In [13]:
import importlib
import imagedataset
importlib.reload(imagedataset)

# Data loading code
traindir = os.path.join(data_path, train_dir )
valdir = os.path.join(data_path, 'valid') 
testdir = os.path.join(data_path, 'test')

targets = dataframe_from_csv();
num_classes = len(targets.columns)

# pytorch way of implementing fastai's get_batches, (utils.py)
def get_data_loader(dirname, shuffle=True, batch_size = 64, test_mode = False):
    image_dataset = ImageDataset(dirname, targets, test_mode)
    return torch.utils.data.DataLoader(image_dataset, batch_size=batch_size, 
                                       shuffle=shuffle, pin_memory=use_cuda), image_dataset

In [6]:
train_loader, train_dataset = get_data_loader(traindir, batch_size=batch_size)
print('Images in train folder:', len(train_dataset))

Images in train folder: 51578


In [23]:
# Load the model
model = models.resnet50(pretrained=True)

In [24]:
# Finetune by replacing the last fully connected layer and freezing all network parameters
for param in model.parameters():
    param.requires_grad = False

# Replace the last fully-connected layer matching the new class count
print('Using {:d} classes: {}'.format(num_classes, targets.columns))
expansion = 4 # TODO use Bottleneck.expansion instead
model.fc = nn.Sequential(nn.Linear(512 * expansion, num_classes), nn.Softmax())

Using 37 classes: Index(['Class1.1', 'Class1.2', 'Class1.3', 'Class2.1', 'Class2.2', 'Class3.1',
       'Class3.2', 'Class4.1', 'Class4.2', 'Class5.1', 'Class5.2', 'Class5.3',
       'Class5.4', 'Class6.1', 'Class6.2', 'Class7.1', 'Class7.2', 'Class7.3',
       'Class8.1', 'Class8.2', 'Class8.3', 'Class8.4', 'Class8.5', 'Class8.6',
       'Class8.7', 'Class9.1', 'Class9.2', 'Class9.3', 'Class10.1',
       'Class10.2', 'Class10.3', 'Class11.1', 'Class11.2', 'Class11.3',
       'Class11.4', 'Class11.5', 'Class11.6'],
      dtype='object')


In [7]:
# Or load the model if already trained
model = torch.load("data/galaxies/resnet-50.pth")

In [8]:
# define loss function (criterion) and optimizer
criterion = nn.MSELoss()
# enable cuda if available
if(use_cuda):
    model.cuda()
    criterion.cuda()
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

In [9]:
def getTrainer():
    # fine-tune with new classes
    t = trainer.Trainer(model, criterion, optimizer, train_loader)
    t.register_plugin(trainer.plugins.ProgressMonitor())
    t.register_plugin(trainer.plugins.LossMonitor())
    t.register_plugin(trainer.plugins.TimeMonitor())
    t.register_plugin(trainer.plugins.Logger(['progress', 'loss', 'time']))
    
    # Requires a monkey patched version of trainery.py that calls
    # input_var = Variable(batch_input.cuda())    # Line 57
    # target_var = Variable(batch_target.cuda())  # Line 58
    return t

In [33]:
t = getTrainer()
epochs = 1
model.train()
t.run(epochs)

progress: 1/806 (0.12%)	loss: 0.0449  (0.0135)	time: 0ms  (0ms)
progress: 2/806 (0.25%)	loss: 0.0532  (0.0254)	time: 545ms  (163ms)
progress: 3/806 (0.37%)	loss: 0.0458  (0.0315)	time: 526ms  (272ms)
progress: 4/806 (0.50%)	loss: 0.0488  (0.0367)	time: 528ms  (349ms)
progress: 5/806 (0.62%)	loss: 0.0490  (0.0404)	time: 549ms  (409ms)
progress: 6/806 (0.74%)	loss: 0.0446  (0.0417)	time: 540ms  (448ms)
progress: 7/806 (0.87%)	loss: 0.0384  (0.0407)	time: 522ms  (470ms)
progress: 8/806 (0.99%)	loss: 0.0462  (0.0423)	time: 522ms  (486ms)
progress: 9/806 (1.12%)	loss: 0.0471  (0.0438)	time: 523ms  (497ms)
progress: 10/806 (1.24%)	loss: 0.0533  (0.0467)	time: 524ms  (505ms)
progress: 11/806 (1.36%)	loss: 0.0499  (0.0476)	time: 530ms  (513ms)
progress: 12/806 (1.49%)	loss: 0.0505  (0.0485)	time: 531ms  (518ms)
progress: 13/806 (1.61%)	loss: 0.0425  (0.0467)	time: 529ms  (521ms)
progress: 14/806 (1.74%)	loss: 0.0490  (0.0474)	time: 525ms  (523ms)
progress: 15/806 (1.86%)	loss: 0.0452  (0.0467)

progress: 120/806 (14.89%)	loss: 0.0458  (0.0455)	time: 527ms  (526ms)
progress: 121/806 (15.01%)	loss: 0.0435  (0.0449)	time: 531ms  (528ms)
progress: 122/806 (15.14%)	loss: 0.0563  (0.0483)	time: 527ms  (527ms)
progress: 123/806 (15.26%)	loss: 0.0510  (0.0491)	time: 525ms  (527ms)
progress: 124/806 (15.38%)	loss: 0.0526  (0.0501)	time: 522ms  (525ms)
progress: 125/806 (15.51%)	loss: 0.0461  (0.0489)	time: 527ms  (526ms)
progress: 126/806 (15.63%)	loss: 0.0545  (0.0506)	time: 553ms  (534ms)
progress: 127/806 (15.76%)	loss: 0.0506  (0.0506)	time: 525ms  (531ms)
progress: 128/806 (15.88%)	loss: 0.0444  (0.0487)	time: 527ms  (530ms)
progress: 129/806 (16.00%)	loss: 0.0432  (0.0471)	time: 528ms  (529ms)
progress: 130/806 (16.13%)	loss: 0.0437  (0.0460)	time: 529ms  (529ms)
progress: 131/806 (16.25%)	loss: 0.0412  (0.0446)	time: 525ms  (528ms)
progress: 132/806 (16.38%)	loss: 0.0472  (0.0454)	time: 526ms  (527ms)
progress: 133/806 (16.50%)	loss: 0.0467  (0.0458)	time: 534ms  (529ms)
progre

progress: 236/806 (29.28%)	loss: 0.0460  (0.0460)	time: 521ms  (525ms)
progress: 237/806 (29.40%)	loss: 0.0487  (0.0468)	time: 521ms  (524ms)
progress: 238/806 (29.53%)	loss: 0.0498  (0.0477)	time: 527ms  (525ms)
progress: 239/806 (29.65%)	loss: 0.0454  (0.0470)	time: 526ms  (525ms)
progress: 240/806 (29.78%)	loss: 0.0483  (0.0474)	time: 525ms  (525ms)
progress: 241/806 (29.90%)	loss: 0.0504  (0.0483)	time: 524ms  (525ms)
progress: 242/806 (30.02%)	loss: 0.0416  (0.0463)	time: 525ms  (525ms)
progress: 243/806 (30.15%)	loss: 0.0492  (0.0472)	time: 532ms  (527ms)
progress: 244/806 (30.27%)	loss: 0.0485  (0.0476)	time: 526ms  (527ms)
progress: 245/806 (30.40%)	loss: 0.0479  (0.0477)	time: 540ms  (531ms)
progress: 246/806 (30.52%)	loss: 0.0441  (0.0466)	time: 529ms  (530ms)
progress: 247/806 (30.65%)	loss: 0.0465  (0.0466)	time: 527ms  (529ms)
progress: 248/806 (30.77%)	loss: 0.0435  (0.0457)	time: 524ms  (527ms)
progress: 249/806 (30.89%)	loss: 0.0485  (0.0465)	time: 528ms  (528ms)
progre

progress: 352/806 (43.67%)	loss: 0.0459  (0.0490)	time: 524ms  (526ms)
progress: 353/806 (43.80%)	loss: 0.0489  (0.0490)	time: 530ms  (527ms)
progress: 354/806 (43.92%)	loss: 0.0485  (0.0488)	time: 529ms  (528ms)
progress: 355/806 (44.04%)	loss: 0.0504  (0.0493)	time: 528ms  (528ms)
progress: 356/806 (44.17%)	loss: 0.0424  (0.0472)	time: 527ms  (528ms)
progress: 357/806 (44.29%)	loss: 0.0513  (0.0485)	time: 524ms  (527ms)
progress: 358/806 (44.42%)	loss: 0.0475  (0.0482)	time: 526ms  (526ms)
progress: 359/806 (44.54%)	loss: 0.0474  (0.0479)	time: 525ms  (526ms)
progress: 360/806 (44.67%)	loss: 0.0412  (0.0459)	time: 530ms  (527ms)
progress: 361/806 (44.79%)	loss: 0.0448  (0.0456)	time: 526ms  (527ms)
progress: 362/806 (44.91%)	loss: 0.0445  (0.0452)	time: 531ms  (528ms)
progress: 363/806 (45.04%)	loss: 0.0498  (0.0466)	time: 526ms  (527ms)
progress: 364/806 (45.16%)	loss: 0.0468  (0.0467)	time: 522ms  (526ms)
progress: 365/806 (45.29%)	loss: 0.0497  (0.0476)	time: 525ms  (526ms)
progre

progress: 468/806 (58.06%)	loss: 0.0474  (0.0474)	time: 526ms  (528ms)
progress: 469/806 (58.19%)	loss: 0.0480  (0.0476)	time: 523ms  (526ms)
progress: 470/806 (58.31%)	loss: 0.0445  (0.0467)	time: 523ms  (525ms)
progress: 471/806 (58.44%)	loss: 0.0508  (0.0479)	time: 524ms  (525ms)
progress: 472/806 (58.56%)	loss: 0.0477  (0.0478)	time: 526ms  (525ms)
progress: 473/806 (58.68%)	loss: 0.0466  (0.0475)	time: 522ms  (524ms)
progress: 474/806 (58.81%)	loss: 0.0467  (0.0472)	time: 523ms  (524ms)
progress: 475/806 (58.93%)	loss: 0.0478  (0.0474)	time: 530ms  (526ms)
progress: 476/806 (59.06%)	loss: 0.0477  (0.0475)	time: 524ms  (525ms)
progress: 477/806 (59.18%)	loss: 0.0468  (0.0473)	time: 523ms  (525ms)
progress: 478/806 (59.31%)	loss: 0.0503  (0.0482)	time: 523ms  (524ms)
progress: 479/806 (59.43%)	loss: 0.0471  (0.0479)	time: 524ms  (524ms)
progress: 480/806 (59.55%)	loss: 0.0491  (0.0482)	time: 530ms  (526ms)
progress: 481/806 (59.68%)	loss: 0.0443  (0.0471)	time: 526ms  (526ms)
progre

progress: 584/806 (72.46%)	loss: 0.0483  (0.0473)	time: 527ms  (526ms)
progress: 585/806 (72.58%)	loss: 0.0510  (0.0484)	time: 526ms  (526ms)
progress: 586/806 (72.70%)	loss: 0.0437  (0.0470)	time: 540ms  (530ms)
progress: 587/806 (72.83%)	loss: 0.0419  (0.0455)	time: 528ms  (529ms)
progress: 588/806 (72.95%)	loss: 0.0487  (0.0464)	time: 527ms  (529ms)
progress: 589/806 (73.08%)	loss: 0.0484  (0.0470)	time: 523ms  (527ms)
progress: 590/806 (73.20%)	loss: 0.0456  (0.0466)	time: 524ms  (526ms)
progress: 591/806 (73.33%)	loss: 0.0531  (0.0486)	time: 532ms  (528ms)
progress: 592/806 (73.45%)	loss: 0.0480  (0.0484)	time: 528ms  (528ms)
progress: 593/806 (73.57%)	loss: 0.0492  (0.0486)	time: 526ms  (527ms)
progress: 594/806 (73.70%)	loss: 0.0407  (0.0463)	time: 523ms  (526ms)
progress: 595/806 (73.82%)	loss: 0.0458  (0.0461)	time: 527ms  (526ms)
progress: 596/806 (73.95%)	loss: 0.0523  (0.0480)	time: 525ms  (526ms)
progress: 597/806 (74.07%)	loss: 0.0517  (0.0491)	time: 523ms  (525ms)
progre

progress: 700/806 (86.85%)	loss: 0.0426  (0.0454)	time: 524ms  (525ms)
progress: 701/806 (86.97%)	loss: 0.0515  (0.0472)	time: 528ms  (526ms)
progress: 702/806 (87.10%)	loss: 0.0445  (0.0464)	time: 525ms  (526ms)
progress: 703/806 (87.22%)	loss: 0.0467  (0.0465)	time: 525ms  (526ms)
progress: 704/806 (87.34%)	loss: 0.0478  (0.0469)	time: 527ms  (526ms)
progress: 705/806 (87.47%)	loss: 0.0464  (0.0468)	time: 536ms  (529ms)
progress: 706/806 (87.59%)	loss: 0.0427  (0.0456)	time: 533ms  (530ms)
progress: 707/806 (87.72%)	loss: 0.0443  (0.0452)	time: 523ms  (528ms)
progress: 708/806 (87.84%)	loss: 0.0486  (0.0462)	time: 525ms  (527ms)
progress: 709/806 (87.97%)	loss: 0.0435  (0.0454)	time: 527ms  (527ms)
progress: 710/806 (88.09%)	loss: 0.0463  (0.0457)	time: 522ms  (525ms)
progress: 711/806 (88.21%)	loss: 0.0485  (0.0465)	time: 530ms  (527ms)
progress: 712/806 (88.34%)	loss: 0.0472  (0.0467)	time: 526ms  (527ms)
progress: 713/806 (88.46%)	loss: 0.0502  (0.0478)	time: 529ms  (527ms)
progre

In [10]:
# Load validation data
val_loader, val_dataset = get_data_loader(valdir, shuffle=False, batch_size=batch_size)

In [11]:
import sys

def get_error(val_loader):
    # Process each mini-batch and accumulate all correct classifications
    num_batches = sum(1 for b in enumerate(val_loader))
    batches = enumerate(val_loader)
    error2 = 0
    for i, (images, labels) in batches:
        sys.stdout.write('\rBatch: {:d}/{:d}'.format(i + 1, num_batches))
        sys.stdout.flush()
        if use_cuda:
            images = images.cuda()
        predictions = model(Variable(images, volatile=True))
        error2 += labels.sub(predictions.data.cpu()).pow(2).sum()
    # Avoid carriage return
    print('')
    return np.sqrt(error2 / len(val_loader.dataset.images_targets) / num_classes )

In [14]:
model.eval()
print('RMSE for validation set: {}'.format(get_error(val_loader)))

Batch: 157/157
RMSE for validation set: 0.21740611259410117


In [15]:
# Load test data
test_loader, test_dataset = get_data_loader(testdir, shuffle=False, batch_size=batch_size, test_mode = True)

Using test mode


In [16]:
import sys

def predict(loader):
    # Process each mini-batch and accumulate all correct classifications
    all_predictions_df = pd.DataFrame(data=None, columns=targets.columns,index=targets.index)
    # Drop all rows
    all_predictions_df.drop(targets.index, inplace=True)
    num_batches = sum(1 for b in enumerate(loader))
    batches = enumerate(loader)
    current_image = 0
    for i, (images, labels) in batches:
        sys.stdout.write('\rBatch: {:d}/{:d}'.format(i + 1, num_batches))
        sys.stdout.flush()
        if use_cuda:
            images = images.cuda()
        predictions = model(Variable(images, volatile=True))
        batch_predictions = predictions.data.cpu().numpy()
        for row_prediction in batch_predictions:
            image_id = loader.dataset.images_idx_to_id[current_image]
            all_predictions_df.loc[image_id] = row_prediction
            current_image += 1
    # Avoid carriage return
    print('')
    return all_predictions_df

def save_kaggle_predictions(loader):
    all_predictions_df = predict(loader)
    all_predictions_df.to_csv(data_path + "predicted-resnet.csv")
    return all_predictions_df
    

In [17]:
from IPython.display import FileLink
model.eval()
save_kaggle_predictions(test_loader)
FileLink(data_path + "predicted-resnet.csv")

Batch: 1250/1250


In [38]:
# Save model
torch.save(model, "data/galaxies/resnet-50.pth")