In [18]:
import torch
from torch import nn
from torch.utils import data
from torch.nn.utils import clip_grad_norm_
from torchvision import transforms, models
import os
from astropy.visualization import MinMaxInterval
from astropy.io import fits
from matplotlib import pyplot as plt
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm import tqdm
import json
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [3]:
config_file = Path("config/paths.json").open('r')
paths = json.load(config_file)

In [30]:
survey_table = pd.read_csv(paths["survey_path"])

## Random Forest Classifier

In [None]:
features = ["u_iso", 
            "J0378_iso", 
            "J0395_iso", 
            "J0410_iso", 
            "J0430_iso", 
            "g_iso", 
            "J0515_iso", 
            "r_iso", 
            "J0660_iso", 
            "i_iso", 
            "J0861_iso", 
            "z_iso",
            "FWHM_n",
            "A",
            "B",
            "KRON_RADIUS"]

test_idx = survey_table["sampling_1"] == "test"
train_idx = np.logical_not(test_idx)
X_train = survey_table[train_idx][features]
y_train = survey_table[train_idx]["target"]
X_test = survey_table[test_idx][features]
y_test = survey_table[test_idx]["target"]

In [64]:
rf = RandomForestClassifier(bootstrap=False, random_state=42)
rf.fit(X_train, y_train)

In [74]:
y_pred = rf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average='macro')
recall = recall_score(y_test, y_pred, average='macro')
f1 = f1_score(y_test, y_pred, average='macro')

print(f"Accuracy: {acc*100:.4f}%")
print(f"Precision: {precision*100:.4f}%")
print(f"Recall: {recall*100:.4f}%")
print(f"F1: {f1*100:.4f}%")

Accuracy: 94.6373%
Precision: 92.6535%
Recall: 93.3424%
F1: 92.9722%


## CNN Model

In [15]:
torch.manual_seed(42)

dict_encode = {'quasar': 0, 'star': 1, 'galaxy': 2}

class FITSDataset(data.Dataset):
    def __init__(self, data_path, transforms):
        self.img_files = list(Path(data_path).glob("*/*.fits"))
        self.transforms = transforms

    def __getitem__(self, index):
        _img = fits.getdata(self.img_files[index]).astype(np.float32)
        _label = dict_encode[self.img_files[index].parent.name]

        if self.transforms is not None:
            return self.transforms(_img), _label

        else:
            return _img, _label

    def __len__(self):
        return len(self.img_files)

def ToImage(arr):
    return Image.fromarray(arr)

def Norm(img):
    norm_tr = MinMaxInterval()
    return norm_tr(img)

fits_transform = transforms.Compose([
    transforms.Lambda(ToImage),
    transforms.Resize(size=(32, 32)),
    transforms.Lambda(Norm),
    transforms.ToTensor()
])

BATCH_SIZE = 4

train_dataset = FITSDataset(paths["training_path"], fits_transform)
train_dataloader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)

test_dataset = FITSDataset(paths["testing_path"], fits_transform)
test_dataloader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

print(f"{len(train_dataloader)} train batches of {BATCH_SIZE} and {len(test_dataloader)} test batches of {BATCH_SIZE}")

# _data_tensor, _label_tensor = next(iter(train_dataloader))

# plt.imshow(_data_tensor[0].squeeze().numpy(), cmap='gray')
# plt.title(f"Class: {_label_tensor.numpy()[0]}");

3 train batches of 4 and 3 test batches of 4


In [5]:
VGG16 = models.vgg16(weights='IMAGENET1K_V1')

In [6]:
VGG16.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
VGG16.classifier[6] = nn.Linear(4096, len(dict_encode))

In [16]:
EPOCHS = 9
LR = 0.001
# WD = 0.0
MAX_NORM = 2.0

#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

model = VGG16.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RAdam(params=model.parameters(), lr=LR)

train_loss = []
test_loss = []
train_acc = []
test_acc = []

for epoch in tqdm(range(EPOCHS)):
    model.train()

    train_loss.append(0.)
    train_acc.append(0.)

    for (imgs, labels) in train_dataloader:
        imgs, labels = imgs.to(device=device, non_blocking=True), labels.to(device=device, non_blocking=True)
        
        logits = model(imgs)
        loss = criterion(logits, labels)
        train_loss[-1] += loss.item()

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        pred_classes = torch.argmax(torch.softmax(logits, dim=1), dim=1)
        train_acc[-1] += (pred_classes==labels).sum().item() / len(logits)
    
    train_loss[-1] /= len(train_dataloader)
    train_acc[-1] /= len(train_dataloader)
    
    test_loss.append(0.)
    test_acc.append(0.)

    model.eval()
    with torch.inference_mode():
        for (imgs, labels) in test_dataloader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device=device, non_blocking=True)

            logits = model(imgs)
            loss = criterion(logits, labels)
            test_loss[-1] += loss.item()

            pred_classes = torch.argmax(torch.softmax(logits, dim=1), dim=1)
            test_acc[-1] += (pred_classes==labels).sum().item() / len(logits)
        
        test_loss[-1] /= len(test_dataloader)
        test_acc[-1] /= len(test_dataloader)

    clip_grad_norm_(parameters=model.parameters(), 
                    max_norm=MAX_NORM, 
                    norm_type=2)


100%|██████████| 9/9 [01:01<00:00,  6.81s/it]


In [None]:
plt.figure(figsize=(15, 7))

epoch_range = range(EPOCHS)

plt.subplot(1, 2, 1)
plt.plot(epoch_range, train_loss, label="Train loss")
plt.plot(epoch_range, test_loss, label="Test loss")
plt.title("Loss")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epoch_range, train_acc, label="Train accuracy")
plt.plot(epoch_range, test_acc, label="Test accuracy")
plt.title("Accuracy")
plt.xlabel("Epochs")
plt.legend();

In [73]:
MODEL_NAME = f"vgg16-{LR:.4f}-{int(MAX_NORM)}-{EPOCHS}.pth".replace('.', '_', 1)
MODEL_FILEPATH = os.path.join("models", MODEL_NAME)
torch.save(obj=model.state_dict(), f=MODEL_FILEPATH)