In [1]:
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.trainer as trainer
import torch.utils.trainer.plugins
from torch.autograd import Variable
import numpy as np
import os

from imagedataset import ImageDataset

In [2]:
# show images inline
%matplotlib inline
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
# display all dataframe columns
pd.options.display.max_columns = 50

In [3]:
train_dir = "train"
# train_dir = "sample"
use_cuda = True
batch_size = 64
print('Using CUDA:', use_cuda)

Using CUDA: True


In [4]:
data_path = "data/galaxies/"

def dataframe_from_csv():
    df=pd.read_csv(data_path + "classes.csv", sep=',')
    df.set_index("GalaxyID", inplace=True)
    return df

In [7]:
import importlib
import imagedataset
importlib.reload(imagedataset)

# Data loading code
traindir = os.path.join(data_path, train_dir )
valdir = os.path.join(data_path, 'valid') 
testdir = os.path.join(data_path, 'test')
sampledir = os.path.join(data_path, 'sample')

targets = dataframe_from_csv();
num_classes = len(targets.columns)

# pytorch way of implementing fastai's get_batches, (utils.py)
def get_data_loader(dirname, shuffle=True, batch_size = 64, test_mode = False):
    image_dataset = ImageDataset(dirname, targets, test_mode)
    return torch.utils.data.DataLoader(image_dataset, batch_size=batch_size, 
                                       shuffle=shuffle, pin_memory=use_cuda), image_dataset

sample_loader, sample_dataset = get_data_loader(sampledir, batch_size=batch_size)
print('Images in sample folder:', len(sample_dataset))

Images in sample folder: 10


In [6]:
train_loader, train_dataset = get_data_loader(traindir, batch_size=batch_size)
print('Images in train folder:', len(train_dataset))

Images in train folder: 51578


In [8]:
targets.head()

Unnamed: 0_level_0,Class1.1,Class1.2,Class1.3,Class2.1,Class2.2,Class3.1,Class3.2,Class4.1,Class4.2,Class5.1,Class5.2,Class5.3,Class5.4,Class6.1,Class6.2,Class7.1,Class7.2,Class7.3,Class8.1,Class8.2,Class8.3,Class8.4,Class8.5,Class8.6,Class8.7,Class9.1,Class9.2,Class9.3,Class10.1,Class10.2,Class10.3,Class11.1,Class11.2,Class11.3,Class11.4,Class11.5,Class11.6
GalaxyID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1
100008,0.383147,0.616853,0.0,0.0,0.616853,0.038452,0.578401,0.418398,0.198455,0.0,0.104752,0.512101,0.0,0.054453,0.945547,0.201463,0.181684,0.0,0.0,0.027226,0.0,0.027226,0.0,0.0,0.0,0.0,0.0,0.0,0.279952,0.138445,0.0,0.0,0.092886,0.0,0.0,0.0,0.325512
100023,0.327001,0.663777,0.009222,0.031178,0.632599,0.46737,0.165229,0.591328,0.041271,0.0,0.236781,0.160941,0.234877,0.189149,0.810851,0.0,0.135082,0.191919,0.0,0.0,0.140353,0.0,0.048796,0.0,0.0,0.012414,0.0,0.018764,0.0,0.131378,0.45995,0.0,0.591328,0.0,0.0,0.0,0.0
100053,0.765717,0.177352,0.056931,0.0,0.177352,0.0,0.177352,0.0,0.177352,0.0,0.11779,0.059562,0.0,0.0,1.0,0.0,0.741864,0.023853,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
100078,0.693377,0.238564,0.068059,0.0,0.238564,0.109493,0.129071,0.189098,0.049466,0.0,0.0,0.113284,0.12528,0.320398,0.679602,0.408599,0.284778,0.0,0.0,0.0,0.096119,0.096119,0.0,0.128159,0.0,0.0,0.0,0.0,0.094549,0.0,0.094549,0.189098,0.0,0.0,0.0,0.0,0.0
100090,0.933839,0.0,0.066161,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.029383,0.970617,0.494587,0.439252,0.0,0.0,0.0,0.0,0.0,0.0,0.029383,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [9]:
# Load the model
model = models.resnet50(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /home/ubuntu/.torch/models/resnet50-19c8e357.pth
100.0%


In [10]:
expansion = 4 # TODO use Bottleneck.expansion instead
def get_decision_tree_fc():
    branches = nn.ModuleList([
        nn.Sequential(nn.Linear(512 * expansion, 3), nn.Softmax()), # Classes: 1.1, 1.2, 1.3
        nn.Sequential(nn.Linear(512 * expansion, 2), nn.Softmax()), # Classes: 2.1, 2.2
        nn.Sequential(nn.Linear(512 * expansion, 2), nn.Softmax()), # Classes: 3.1, 3.2, 3.3
        nn.Sequential(nn.Linear(512 * expansion, 2), nn.Softmax()), # Classes: 4.1, 4.2
        nn.Sequential(nn.Linear(512 * expansion, 4), nn.Softmax()), # Classes: 5.1, 5.2, 5.3, 5.4
        nn.Sequential(nn.Linear(512 * expansion, 2), nn.Softmax()), # Classes: 6.1, 6.2
        nn.Sequential(nn.Linear(512 * expansion, 3), nn.Softmax()), # Classes: 7.1, 7.2, 7.3
        nn.Sequential(nn.Linear(512 * expansion, 7), nn.Softmax()), # Classes: 8.1, ..., 8.7
        nn.Sequential(nn.Linear(512 * expansion, 3), nn.Softmax()), # Classes: 9.1, 9.2, 9.3
        nn.Sequential(nn.Linear(512 * expansion, 3), nn.Softmax()), # Classes: 10.1, 10.2, 10.3
        nn.Sequential(nn.Linear(512 * expansion, 6), nn.Softmax()) # Classes: 11.1, ..., 11.6
    ])
    return branches

def normalize(x):
    # x is list of 11 elements, one tensor per class
    offset = -1
    # Class 1  NOP
    x[2 + offset] = torch.mm(torch.diag(x[1 + offset][:,2 + offset]), x[2 + offset]) # Class 1.2 * Class 2
    x[3 + offset] = torch.mm(torch.diag(x[2 + offset][:,2 + offset]), x[3 + offset]) # Class 2.2 * Class 3
    x[4 + offset] = torch.mm(torch.diag(x[2 + offset][:,2 + offset]), x[4 + offset]) # Class 2.2 * Class 4
    x[5 + offset] = torch.mm(torch.diag(x[2 + offset][:,2 + offset]), x[5 + offset]) # Class 2.2 * Class 5
    # Class 6  NOP
    x[7 + offset] = torch.mm(torch.diag(x[1 + offset][:,1 + offset]), x[7 + offset]) # Class 1.1 * Class 7
    x[8 + offset] = torch.mm(torch.diag(x[6 + offset][:,1 + offset]), x[8 + offset]) # Class 6.1 * Class 8
    x[9 + offset] = torch.mm(torch.diag(x[2 + offset][:,1 + offset]), x[9 + offset]) # Class 2.1 * Class 9
    x[10 + offset] = torch.mm(torch.diag(x[4 + offset][:,1 + offset]), x[10 + offset]) # Class 4.1 * Class 10
    x[11 + offset] = torch.mm(torch.diag(x[4 + offset][:,1 + offset]), x[11 + offset]) # Class 4.1 * Class 11
    return x

In [11]:
import types

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    
    # replace the last fc layer by a concatenation (columns = 1) of all branches defined in the resnet galaxy model
    bx = [b(x) for b in self.branches]
    return torch.cat(normalize(bx), 1)

# Finetune by replacing the last fully connected layer and freezing all network parameters
for param in model.parameters():
    param.requires_grad = False

# The fc layer is no longer needed
if hasattr(model, 'fc'):
    del model.fc
model.branches = get_decision_tree_fc()

# Monkey patch Resnet's forward layer to use our own instead
model.forward = types.MethodType(forward, model) 

In [12]:
# Replace the last fully-connected layer matching the new class count
print('Using {:d} classes: {}'.format(num_classes, targets.columns))

Using 37 classes: Index(['Class1.1', 'Class1.2', 'Class1.3', 'Class2.1', 'Class2.2', 'Class3.1',
       'Class3.2', 'Class4.1', 'Class4.2', 'Class5.1', 'Class5.2', 'Class5.3',
       'Class5.4', 'Class6.1', 'Class6.2', 'Class7.1', 'Class7.2', 'Class7.3',
       'Class8.1', 'Class8.2', 'Class8.3', 'Class8.4', 'Class8.5', 'Class8.6',
       'Class8.7', 'Class9.1', 'Class9.2', 'Class9.3', 'Class10.1',
       'Class10.2', 'Class10.3', 'Class11.1', 'Class11.2', 'Class11.3',
       'Class11.4', 'Class11.5', 'Class11.6'],
      dtype='object')


In [12]:
# Or load the model if already trained
# model = torch.load("data/galaxies/resnet-50.pth")

In [13]:
# define loss function (criterion) and optimizer
criterion = nn.MSELoss()
# enable cuda if available
if(use_cuda):
    model.cuda()
    criterion.cuda()
optimizer = optim.SGD(model.branches.parameters(), lr=1e-2, momentum=0.9)

In [14]:
def getTrainer():
    # fine-tune with new classes
    t = trainer.Trainer(model, criterion, optimizer, train_loader)
    t.register_plugin(trainer.plugins.ProgressMonitor())
    t.register_plugin(trainer.plugins.LossMonitor())
    t.register_plugin(trainer.plugins.TimeMonitor())
    t.register_plugin(trainer.plugins.Logger(['progress', 'loss', 'time']))
    
    # Requires a monkey patched version of trainery.py that calls
    # input_var = Variable(batch_input.cuda())    # Line 57
    # target_var = Variable(batch_target.cuda())  # Line 58
    return t

In [27]:
t = getTrainer()
epochs = 1
model.train()
t.run(epochs)

progress: 1/806 (0.12%)	loss: 0.0234  (0.0070)	time: 0ms  (0ms)
progress: 2/806 (0.25%)	loss: 0.0176  (0.0102)	time: 558ms  (167ms)
progress: 3/806 (0.37%)	loss: 0.0198  (0.0131)	time: 549ms  (282ms)
progress: 4/806 (0.50%)	loss: 0.0201  (0.0152)	time: 537ms  (358ms)
progress: 5/806 (0.62%)	loss: 0.0219  (0.0172)	time: 543ms  (414ms)
progress: 6/806 (0.74%)	loss: 0.0196  (0.0179)	time: 540ms  (452ms)
progress: 7/806 (0.87%)	loss: 0.0212  (0.0189)	time: 536ms  (477ms)
progress: 8/806 (0.99%)	loss: 0.0187  (0.0189)	time: 539ms  (495ms)
progress: 9/806 (1.12%)	loss: 0.0203  (0.0193)	time: 541ms  (509ms)
progress: 10/806 (1.24%)	loss: 0.0226  (0.0203)	time: 540ms  (518ms)
progress: 11/806 (1.36%)	loss: 0.0201  (0.0202)	time: 538ms  (524ms)
progress: 12/806 (1.49%)	loss: 0.0228  (0.0210)	time: 537ms  (528ms)
progress: 13/806 (1.61%)	loss: 0.0172  (0.0198)	time: 535ms  (530ms)
progress: 14/806 (1.74%)	loss: 0.0151  (0.0184)	time: 545ms  (535ms)
progress: 15/806 (1.86%)	loss: 0.0217  (0.0194)

progress: 120/806 (14.89%)	loss: 0.0196  (0.0187)	time: 543ms  (542ms)
progress: 121/806 (15.01%)	loss: 0.0211  (0.0194)	time: 539ms  (541ms)
progress: 122/806 (15.14%)	loss: 0.0190  (0.0193)	time: 553ms  (545ms)
progress: 123/806 (15.26%)	loss: 0.0193  (0.0193)	time: 540ms  (543ms)
progress: 124/806 (15.38%)	loss: 0.0214  (0.0199)	time: 544ms  (544ms)
progress: 125/806 (15.51%)	loss: 0.0189  (0.0196)	time: 541ms  (543ms)
progress: 126/806 (15.63%)	loss: 0.0221  (0.0203)	time: 548ms  (544ms)
progress: 127/806 (15.76%)	loss: 0.0191  (0.0200)	time: 546ms  (545ms)
progress: 128/806 (15.88%)	loss: 0.0149  (0.0185)	time: 541ms  (544ms)
progress: 129/806 (16.00%)	loss: 0.0196  (0.0188)	time: 541ms  (543ms)
progress: 130/806 (16.13%)	loss: 0.0189  (0.0188)	time: 543ms  (543ms)
progress: 131/806 (16.25%)	loss: 0.0182  (0.0186)	time: 543ms  (543ms)
progress: 132/806 (16.38%)	loss: 0.0195  (0.0189)	time: 541ms  (542ms)
progress: 133/806 (16.50%)	loss: 0.0222  (0.0199)	time: 539ms  (541ms)
progre

progress: 236/806 (29.28%)	loss: 0.0197  (0.0188)	time: 549ms  (548ms)
progress: 237/806 (29.40%)	loss: 0.0184  (0.0187)	time: 545ms  (547ms)
progress: 238/806 (29.53%)	loss: 0.0196  (0.0190)	time: 541ms  (545ms)
progress: 239/806 (29.65%)	loss: 0.0190  (0.0190)	time: 560ms  (550ms)
progress: 240/806 (29.78%)	loss: 0.0171  (0.0184)	time: 543ms  (548ms)
progress: 241/806 (29.90%)	loss: 0.0171  (0.0180)	time: 546ms  (547ms)
progress: 242/806 (30.02%)	loss: 0.0175  (0.0179)	time: 539ms  (545ms)
progress: 243/806 (30.15%)	loss: 0.0169  (0.0176)	time: 554ms  (548ms)
progress: 244/806 (30.27%)	loss: 0.0193  (0.0181)	time: 540ms  (545ms)
progress: 245/806 (30.40%)	loss: 0.0204  (0.0188)	time: 545ms  (545ms)
progress: 246/806 (30.52%)	loss: 0.0190  (0.0189)	time: 549ms  (546ms)
progress: 247/806 (30.65%)	loss: 0.0207  (0.0194)	time: 549ms  (547ms)
progress: 248/806 (30.77%)	loss: 0.0214  (0.0200)	time: 548ms  (547ms)
progress: 249/806 (30.89%)	loss: 0.0196  (0.0199)	time: 544ms  (546ms)
progre

progress: 352/806 (43.67%)	loss: 0.0193  (0.0195)	time: 549ms  (548ms)
progress: 353/806 (43.80%)	loss: 0.0212  (0.0200)	time: 551ms  (549ms)
progress: 354/806 (43.92%)	loss: 0.0194  (0.0198)	time: 546ms  (548ms)
progress: 355/806 (44.04%)	loss: 0.0163  (0.0188)	time: 546ms  (547ms)
progress: 356/806 (44.17%)	loss: 0.0177  (0.0185)	time: 551ms  (549ms)
progress: 357/806 (44.29%)	loss: 0.0215  (0.0194)	time: 545ms  (547ms)
progress: 358/806 (44.42%)	loss: 0.0172  (0.0187)	time: 546ms  (547ms)
progress: 359/806 (44.54%)	loss: 0.0194  (0.0189)	time: 542ms  (546ms)
progress: 360/806 (44.67%)	loss: 0.0230  (0.0202)	time: 552ms  (548ms)
progress: 361/806 (44.79%)	loss: 0.0175  (0.0194)	time: 546ms  (547ms)
progress: 362/806 (44.91%)	loss: 0.0210  (0.0198)	time: 549ms  (548ms)
progress: 363/806 (45.04%)	loss: 0.0184  (0.0194)	time: 561ms  (552ms)
progress: 364/806 (45.16%)	loss: 0.0165  (0.0185)	time: 554ms  (552ms)
progress: 365/806 (45.29%)	loss: 0.0200  (0.0190)	time: 545ms  (550ms)
progre

progress: 468/806 (58.06%)	loss: 0.0214  (0.0195)	time: 545ms  (546ms)
progress: 469/806 (58.19%)	loss: 0.0179  (0.0190)	time: 549ms  (547ms)
progress: 470/806 (58.31%)	loss: 0.0156  (0.0180)	time: 551ms  (548ms)
progress: 471/806 (58.44%)	loss: 0.0234  (0.0196)	time: 542ms  (546ms)
progress: 472/806 (58.56%)	loss: 0.0212  (0.0201)	time: 551ms  (548ms)
progress: 473/806 (58.68%)	loss: 0.0186  (0.0197)	time: 546ms  (547ms)
progress: 474/806 (58.81%)	loss: 0.0183  (0.0192)	time: 530ms  (542ms)
progress: 475/806 (58.93%)	loss: 0.0184  (0.0190)	time: 548ms  (544ms)
progress: 476/806 (59.06%)	loss: 0.0177  (0.0186)	time: 549ms  (545ms)
progress: 477/806 (59.18%)	loss: 0.0197  (0.0189)	time: 550ms  (547ms)
progress: 478/806 (59.31%)	loss: 0.0201  (0.0193)	time: 544ms  (546ms)
progress: 479/806 (59.43%)	loss: 0.0185  (0.0191)	time: 547ms  (546ms)
progress: 480/806 (59.55%)	loss: 0.0184  (0.0189)	time: 549ms  (547ms)
progress: 481/806 (59.68%)	loss: 0.0167  (0.0182)	time: 547ms  (547ms)
progre

progress: 584/806 (72.46%)	loss: 0.0216  (0.0207)	time: 545ms  (547ms)
progress: 585/806 (72.58%)	loss: 0.0206  (0.0207)	time: 546ms  (547ms)
progress: 586/806 (72.70%)	loss: 0.0206  (0.0206)	time: 551ms  (548ms)
progress: 587/806 (72.83%)	loss: 0.0198  (0.0204)	time: 543ms  (546ms)
progress: 588/806 (72.95%)	loss: 0.0196  (0.0201)	time: 544ms  (546ms)
progress: 589/806 (73.08%)	loss: 0.0193  (0.0199)	time: 540ms  (544ms)
progress: 590/806 (73.20%)	loss: 0.0213  (0.0203)	time: 556ms  (548ms)
progress: 591/806 (73.33%)	loss: 0.0210  (0.0205)	time: 542ms  (546ms)
progress: 592/806 (73.45%)	loss: 0.0188  (0.0200)	time: 547ms  (546ms)
progress: 593/806 (73.57%)	loss: 0.0182  (0.0195)	time: 544ms  (546ms)
progress: 594/806 (73.70%)	loss: 0.0194  (0.0194)	time: 548ms  (546ms)
progress: 595/806 (73.82%)	loss: 0.0184  (0.0191)	time: 549ms  (547ms)
progress: 596/806 (73.95%)	loss: 0.0187  (0.0190)	time: 542ms  (546ms)
progress: 597/806 (74.07%)	loss: 0.0192  (0.0190)	time: 548ms  (546ms)
progre

progress: 700/806 (86.85%)	loss: 0.0196  (0.0192)	time: 548ms  (549ms)
progress: 701/806 (86.97%)	loss: 0.0158  (0.0182)	time: 560ms  (552ms)
progress: 702/806 (87.10%)	loss: 0.0184  (0.0182)	time: 544ms  (550ms)
progress: 703/806 (87.22%)	loss: 0.0182  (0.0182)	time: 549ms  (549ms)
progress: 704/806 (87.34%)	loss: 0.0199  (0.0187)	time: 544ms  (548ms)
progress: 705/806 (87.47%)	loss: 0.0221  (0.0197)	time: 556ms  (550ms)
progress: 706/806 (87.59%)	loss: 0.0204  (0.0199)	time: 545ms  (548ms)
progress: 707/806 (87.72%)	loss: 0.0234  (0.0210)	time: 554ms  (550ms)
progress: 708/806 (87.84%)	loss: 0.0164  (0.0196)	time: 547ms  (549ms)
progress: 709/806 (87.97%)	loss: 0.0189  (0.0194)	time: 552ms  (550ms)
progress: 710/806 (88.09%)	loss: 0.0204  (0.0197)	time: 549ms  (550ms)
progress: 711/806 (88.21%)	loss: 0.0195  (0.0196)	time: 541ms  (547ms)
progress: 712/806 (88.34%)	loss: 0.0170  (0.0189)	time: 553ms  (549ms)
progress: 713/806 (88.46%)	loss: 0.0216  (0.0197)	time: 543ms  (547ms)
progre

In [16]:
# Load validation data
val_loader, val_dataset = get_data_loader(valdir, shuffle=False, batch_size=batch_size)

In [20]:
import sys

def get_error(val_loader):
    # Process each mini-batch and accumulate all correct classifications
    num_batches = sum(1 for b in enumerate(val_loader))
    batches = enumerate(val_loader)
    error2 = 0
    for i, (images, labels) in batches:
        sys.stdout.write('\rBatch: {:d}/{:d}'.format(i + 1, num_batches))
        sys.stdout.flush()
        if use_cuda:
            images = images.cuda()
        predictions = model(Variable(images, volatile=True))
        error2 += labels.sub(predictions.data.cpu()).pow(2).sum()
    # Avoid carriage return
    print('')
    return np.sqrt(error2 / len(val_loader.dataset.images_targets) / num_classes )

In [76]:
model.eval()
print('RMSE for validation set: {}'.format(get_error(train_loader)))

Batch: 5/5
RMSE for validation set: 0.2202402150748411


In [15]:
# Load test data
test_loader, test_dataset = get_data_loader(testdir, shuffle=False, batch_size=batch_size, test_mode = True)

Using test mode


In [17]:
import sys

def predict(loader):
    # Process each mini-batch and accumulate all correct classifications
    all_predictions_df = pd.DataFrame(data=None, columns=targets.columns,index=targets.index)
    # Drop all rows
    all_predictions_df.drop(targets.index, inplace=True)
    num_batches = sum(1 for b in enumerate(loader))
    batches = enumerate(loader)
    current_image = 0
    for i, (images, labels) in batches:
        sys.stdout.write('\rBatch: {:d}/{:d}'.format(i + 1, num_batches))
        sys.stdout.flush()
        if use_cuda:
            images = images.cuda()
        predictions = model(Variable(images, volatile=True))
        batch_predictions = predictions.data.cpu().numpy()
        for row_prediction in batch_predictions:
            image_id = loader.dataset.images_idx_to_id[current_image]
            all_predictions_df.loc[image_id] = row_prediction
            current_image += 1
    # Avoid carriage return
    print('')
    return all_predictions_df

def save_kaggle_predictions(loader):
    all_predictions_df = predict(loader)
    all_predictions_df.to_csv(data_path + "predicted-resnet.csv")
    return all_predictions_df

In [28]:
model.eval()
print('RMSE for sample set: {}'.format(get_error(sample_loader)))
predictions_sample_df = predict(sample_loader)
sample_ids = [121190, 172857, 411011, 447671, 504228]

Batch: 1/1
RMSE for sample set: 0.156487128981932
Batch: 1/1


In [29]:
# Predicted
predictions_sample_df.loc[sample_ids]

Unnamed: 0_level_0,Class1.1,Class1.2,Class1.3,Class2.1,Class2.2,Class3.1,Class3.2,Class4.1,Class4.2,Class5.1,Class5.2,Class5.3,Class5.4,Class6.1,Class6.2,Class7.1,Class7.2,Class7.3,Class8.1,Class8.2,Class8.3,Class8.4,Class8.5,Class8.6,Class8.7,Class9.1,Class9.2,Class9.3,Class10.1,Class10.2,Class10.3,Class11.1,Class11.2,Class11.3,Class11.4,Class11.5,Class11.6
GalaxyID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1
121190,0.440881,0.546587,0.012532,0.420132,0.126455,0.036843,0.089612,0.040102,0.086354,0.017623,0.046544,0.051462,0.010827,0.114819,0.885181,0.085061,0.247446,0.108374,0.013931,0.009523,0.012867,0.021893,0.029439,0.0227,0.004465,0.287488,0.032236,0.100408,0.015526,0.017589,0.006987,0.002491,0.020635,0.003678,0.001274,0.001919,0.010105
172857,0.770655,0.220244,0.009101,0.007519,0.212725,0.037775,0.17495,0.055151,0.157574,0.017355,0.069008,0.10832,0.018041,0.119357,0.880643,0.417291,0.296096,0.057268,0.015549,0.007938,0.01769,0.022559,0.032075,0.01925,0.004296,0.004283,0.000812,0.002424,0.026173,0.022148,0.00683,0.003754,0.024376,0.004976,0.002021,0.002373,0.01765
411011,0.37009,0.604242,0.025668,0.241971,0.362271,0.087366,0.274905,0.200208,0.162063,0.054428,0.161878,0.113375,0.032589,0.164762,0.835238,0.14483,0.162157,0.063103,0.026551,0.013653,0.02468,0.022608,0.041286,0.025374,0.01061,0.105148,0.035916,0.100908,0.081719,0.068488,0.05,0.017936,0.078187,0.019597,0.015143,0.01195,0.057394
447671,0.675893,0.312709,0.011398,0.04132,0.271389,0.054724,0.216665,0.070191,0.201197,0.027823,0.10461,0.115795,0.023161,0.184446,0.815554,0.244213,0.327778,0.103902,0.023049,0.011344,0.028567,0.026924,0.05097,0.034752,0.008839,0.02706,0.003503,0.010758,0.029321,0.03071,0.01016,0.004336,0.028242,0.007532,0.003072,0.004699,0.022311
504228,0.19948,0.798023,0.002496,0.004312,0.793711,0.142292,0.651418,0.557349,0.236362,0.057202,0.443629,0.238628,0.054252,0.31874,0.68126,0.100811,0.085996,0.012674,0.037738,0.013943,0.037975,0.059693,0.105615,0.054375,0.009401,0.002165,0.0004,0.001748,0.289389,0.205334,0.062625,0.020139,0.236694,0.040868,0.021114,0.014815,0.223718


In [23]:
# vs. Expected
targets.loc[sample_ids]

Unnamed: 0_level_0,Class1.1,Class1.2,Class1.3,Class2.1,Class2.2,Class3.1,Class3.2,Class4.1,Class4.2,Class5.1,Class5.2,Class5.3,Class5.4,Class6.1,Class6.2,Class7.1,Class7.2,Class7.3,Class8.1,Class8.2,Class8.3,Class8.4,Class8.5,Class8.6,Class8.7,Class9.1,Class9.2,Class9.3,Class10.1,Class10.2,Class10.3,Class11.1,Class11.2,Class11.3,Class11.4,Class11.5,Class11.6
GalaxyID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1
121190,0.917746,0.065081,0.017173,0.0,0.065081,0.0,0.065081,0.0,0.065081,0.0,0.050492,0.014589,0.0,0.0,1.0,0.819644,0.098102,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
172857,0.236824,0.672569,0.090607,0.01272,0.659849,0.382924,0.276925,0.0,0.659849,0.052058,0.251263,0.26923,0.087297,0.929137,0.070863,0.039323,0.190984,0.006517,0.0,0.0,0.022624,0.0,0.405228,0.501284,0.0,0.007332,0.0,0.005388,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
411011,0.0404,0.945151,0.014449,0.0,0.945151,0.0,0.945151,0.628501,0.31665,0.056692,0.734381,0.107272,0.046805,0.399899,0.600101,0.0404,0.0,0.0,0.061523,0.030761,0.061523,0.123046,0.092284,0.030761,0.0,0.0,0.0,0.0,0.288318,0.244064,0.096118,0.0,0.025562,0.121258,0.287405,0.032875,0.1614
447671,0.016244,0.97674,0.007016,0.0,0.97674,0.268893,0.707847,0.946116,0.030624,0.034044,0.744836,0.052886,0.144975,0.066376,0.933624,0.008179,0.008065,0.0,0.033188,0.0,0.0,0.0,0.033188,0.0,0.0,0.0,0.0,0.0,0.498293,0.352617,0.095207,0.028281,0.132814,0.691398,0.035732,0.0,0.057892
504228,0.334812,0.654743,0.010444,0.0,0.654743,0.0,0.654743,0.111746,0.542997,0.166858,0.487885,0.0,0.0,0.11716,0.88284,0.016743,0.318069,0.0,0.0,0.0,0.0,0.0,0.11716,0.0,0.0,0.0,0.0,0.0,0.111746,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.111746


In [84]:
from IPython.display import FileLink
model.eval()
predictions_df = save_kaggle_predictions(train_loader)
FileLink(data_path + "predicted-resnet.csv")

Batch: 5/5


In [85]:
predictions_df.loc[[121190]]

Unnamed: 0_level_0,Class1.1,Class1.2,Class1.3,Class2.1,Class2.2,Class3.1,Class3.2,Class4.1,Class4.2,Class5.1,Class5.2,Class5.3,Class5.4,Class6.1,Class6.2,Class7.1,Class7.2,Class7.3,Class8.1,Class8.2,Class8.3,Class8.4,Class8.5,Class8.6,Class8.7,Class9.1,Class9.2,Class9.3,Class10.1,Class10.2,Class10.3,Class11.1,Class11.2,Class11.3,Class11.4,Class11.5,Class11.6
GalaxyID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1
121190,0.370342,0.303781,0.325876,0.159851,0.14393,0.096677,0.047253,0.030769,0.113161,0.035853,0.036615,0.042174,0.029288,0.448526,0.551474,0.116921,0.140765,0.112656,0.061869,0.084487,0.06544,0.052355,0.038448,0.07991,0.066018,0.051639,0.035258,0.072954,0.014407,0.010246,0.006115,0.004381,0.005554,0.005709,0.004226,0.005747,0.005151


In [30]:
# Save model
torch.save(model, "data/galaxies/resnet-50.pth")