In [None]:
import numpy as np
import torch

from tqdm import tqdm
from time import time
from matplotlib import pyplot as plt

from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from torchvision.models import resnet18
import torch.nn as nn

# Custom imports
from data import GalaxiesDataset
from metrics import compute_accuracy, compute_confusion_matrix, plot_confusion_matrix

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(DEVICE)

In [None]:
# *** Charger les données ***
# Comme la classe GalaxiesDataset met toutes les données sur la mémoire vive, cette étape prend un peu de temps (1 minute sur HDD)
# Mais ça permet ensuite de chercher des batchs très rapidement.
dataset = GalaxiesDataset('Galaxy10_DECals.h5')

In [None]:
# *** Diviser en données d'entraînement et de test ***
train_test_ratios = [0.8, 0.2]
generator = torch.Generator().manual_seed(42)
train_set, test_set = random_split(dataset=dataset, lengths=train_test_ratios, generator=generator)

In [None]:
# *** ResNet18 avec 10 neurones sur la dernière couche ***
class GalaxiesResNet(nn.Module):

    def __init__(self, pretrained=False):
        super().__init__()
        self.model = resnet18(pretrained=pretrained, progress=False)
        # Récupère le nombre de neurones avant la couche de classement
        dim_before_fc = self.model.fc.in_features
        # Change la dernière couche pleinement connecté pour avoir le bon
        # nombre de neurones de sortie
        self.model.fc = nn.Linear(dim_before_fc, 10)

        if pretrained:
            # Geler les paramètres qui ne font pas partie de la dernière couche fc
            for name, param in self.model.named_parameters():
                # Les seuls paramètres à ne pas geler sont fc.weight et fc.bias
                if name not in ["fc.weight", "fc.bias"]:
                    param.requires_grad = False

    def forward(self, x):
        return self.model(x)

In [None]:
nb_epoch = 1
learning_rate = 0.01
momentum = 0.9
batch_size = 64

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

model = GalaxiesResNet(pretrained=True)
model.to(DEVICE)

criterion = torch.nn.CrossEntropyLoss()
unfrozen_params = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = torch.optim.SGD(params=unfrozen_params, lr=learning_rate, momentum=momentum)

In [None]:
# *** Boucle d'entraînement ***
model.train()

total_batch = len(train_loader)

for i_epoch in range(nb_epoch):

    train_losses, start_time = [], time()
    for i_batch, batch in enumerate(train_loader):
        #print("Batch %i out of %i"%(i_batch, total_batch))
        images, targets = batch

        images = images.to(DEVICE)
        targets = targets.to(DEVICE)
        
        optimizer.zero_grad()

        predictions = model(images)
        loss = criterion(predictions, targets)

        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

    print(' [-] epoch {:4}/{:}, train loss {:.6f} in {:.2f}s'.format(
        i_epoch+1, nb_epoch, np.mean(train_losses), time()-start_time))

In [None]:
# *** Sauvegarde modèle et poids ***
torch.save(model.state_dict(), 'galaxies_resnet.pth')

In [None]:
# *** Charge modèle ***
model = GalaxiesResNet()
model.load_state_dict(torch.load('galaxies_resnet.pth'))

In [None]:
# *** Boucle d'inférence ***
model.eval()

batch_outputs = []
batch_targets = []
for i_batch, batch in enumerate(test_loader):
    images, targets = batch
    images = images.to(DEVICE)
    targets = targets.to(DEVICE)
    with torch.no_grad():
        outputs = model(images)
    batch_outputs.append(outputs.cpu().numpy())
    batch_targets.append(targets.cpu().numpy())
outputs = np.concatenate(batch_outputs, axis=0)
predictions = outputs.argmax(axis=1)
targets = np.concatenate(batch_targets, axis=0)

In [None]:
# *** Calcule l'accuracy et plot une matrice de confusion ***

test_acc = compute_accuracy(targets, predictions)
print(' [-] test acc. {:.6f}%'.format(test_acc * 100))

label_dict = {
    0: "Disturbed",
    1: "Merging",
    2: "Round Smooth",
    3: "In-between Round Smooth",
    4: "Cigar Shaped Smooth",
    5: "Barred Spiral",
    6: "Unbarred Tight Spiral",
    7: "Unbarred Loose Spiral",
    8: "Edge-on without Bulge",
    9: "Edge-on with Bulge"
}

labels = [label_dict[i] for i in range(len(label_dict))]

confusion_matrix = compute_confusion_matrix(targets, predictions, 10)
plot_confusion_matrix(confusion_matrix, labels, "Confusion matrix")

plt.show()