In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import LabelEncoder
import torch

In [2]:
if torch.backends.mps.is_available():
    torch.mps.manual_seed(42)

torch.backends.mps.deterministic = True
torch.backends.mps.benchmark = False

device = (
    torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
)
print("Device", device)

Device mps


In [3]:
data = pd.read_csv("StarClassificationDataset.csv")
data = data.dropna()
data.shape
data

  data = pd.read_csv("StarClassificationDataset.csv")


Unnamed: 0,object_ID,alpha,delta,UV_filter,green_filter,red_filter,near_IR_filter,IR_filter,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,red_shift,plate_ID,MJD,fiber_ID,class
0,1.240000e+18,135.6891066,32.494632,23.87882,22.27530,20.39501,19.16573,18.79371,3606,301,2,79,6.540000e+18,0.634794,5812,56354,171,GALAXY
1,1.240000e+18,144.8261006,31.274185,24.77759,22.83188,22.58444,21.16812,21.61427,4518,301,5,119,1.180000e+19,0.779136,10445,58158,427,GALAXY
2,1.240000e+18,142.1887896,35.582444,25.26307,22.66389,20.60976,19.34857,18.94827,3606,301,2,120,5.150000e+18,0.644195,4576,55592,299,GALAXY
3,1.240000e+18,338.7410378,-0.402828,22.13682,23.77656,21.61162,20.50454,19.25010,4192,301,3,214,1.030000e+19,0.932346,9149,58039,775,GALAXY
5,1.240000e+18,340.9951205,20.589476,23.48827,23.33776,21.32195,20.25615,19.54544,8102,301,3,110,5.660000e+18,1.424659,5026,55855,741,QSO
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,1.240000e+18,39.620709,-2.594074,22.16759,22.97586,21.90404,21.30548,20.73569,7778,301,2,581,1.060000e+19,0.000000,9374,57749,438,GALAXY
99996,1.240000e+18,29.493819,19.798874,22.69118,22.38628,20.45003,19.75759,19.41526,7917,301,1,289,8.590000e+18,0.404895,7626,56934,866,GALAXY
99997,1.240000e+18,224.587407,15.700707,21.16916,19.26997,18.20428,17.69034,17.35221,5314,301,4,308,3.110000e+18,0.143366,2764,54535,74,GALAXY
99998,1.240000e+18,212.268621,46.660365,25.35039,21.63757,19.91386,19.07254,18.62482,3650,301,4,131,7.600000e+18,0.455040,6751,56368,470,GALAXY


In [4]:
constant_columns = data.columns[data.nunique() == 1]
print(constant_columns)
data.drop(columns=constant_columns, inplace=True)
data.shape

Index(['object_ID', 'rerun_ID'], dtype='object')


(99991, 16)

In [5]:
df = data.copy()
# fill ' ' with nan
df[["alpha", "run_ID"]] = df[["alpha", "run_ID"]].replace(r"^\s*$", np.nan, regex=True)
# drop new nan
df.dropna(inplace=True)
df.shape
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 99989 entries, 0 to 99999
Data columns (total 16 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   alpha           99989 non-null  object 
 1   delta           99989 non-null  float64
 2   UV_filter       99989 non-null  float64
 3   green_filter    99989 non-null  float64
 4   red_filter      99989 non-null  float64
 5   near_IR_filter  99989 non-null  float64
 6   IR_filter       99989 non-null  float64
 7   run_ID          99989 non-null  object 
 8   cam_col         99989 non-null  int64  
 9   field_ID        99989 non-null  int64  
 10  spec_obj_ID     99989 non-null  float64
 11  red_shift       99989 non-null  float64
 12  plate_ID        99989 non-null  int64  
 13  MJD             99989 non-null  int64  
 14  fiber_ID        99989 non-null  int64  
 15  class           99989 non-null  object 
dtypes: float64(8), int64(5), object(3)
memory usage: 13.0+ MB


In [6]:
df['run_ID'] = df['run_ID'].astype('int')
df["alpha"] = df["alpha"].astype("float")
df.info()
X = df.drop("class", axis=1)

<class 'pandas.core.frame.DataFrame'>
Index: 99989 entries, 0 to 99999
Data columns (total 16 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   alpha           99989 non-null  float64
 1   delta           99989 non-null  float64
 2   UV_filter       99989 non-null  float64
 3   green_filter    99989 non-null  float64
 4   red_filter      99989 non-null  float64
 5   near_IR_filter  99989 non-null  float64
 6   IR_filter       99989 non-null  float64
 7   run_ID          99989 non-null  int64  
 8   cam_col         99989 non-null  int64  
 9   field_ID        99989 non-null  int64  
 10  spec_obj_ID     99989 non-null  float64
 11  red_shift       99989 non-null  float64
 12  plate_ID        99989 non-null  int64  
 13  MJD             99989 non-null  int64  
 14  fiber_ID        99989 non-null  int64  
 15  class           99989 non-null  object 
dtypes: float64(9), int64(6), object(1)
memory usage: 13.0+ MB


In [7]:
pd.set_option("display.float_format", "{:.6f}".format)
print("Mean", X.mean(axis=0))
print()
print("Std", X.std(axis=0))

Mean alpha                            177.622734
delta                             24.135449
UV_filter                         21.980473
green_filter                      20.531376
red_filter                        19.645774
near_IR_filter                    19.084869
IR_filter                         18.668818
run_ID                          4481.413946
cam_col                            3.511626
field_ID                         186.129254
spec_obj_ID      5784028913180449792.000000
red_shift                          0.576663
plate_ID                        5137.102581
MJD                            55588.708128
fiber_ID                         449.303253
dtype: float64

Std alpha                             96.501108
delta                             19.644584
UV_filter                         31.771029
green_filter                      31.752030
red_filter                         1.854751
near_IR_filter                     1.757889
IR_filter                         31.729891
run_ID 

In [8]:
scaler = MinMaxScaler()
X = scaler.fit_transform(X)

In [9]:
pd.set_option("display.float_format", "{:.6f}".format)
print("Mean", X.mean(axis=0))
print()
print("Std", X.std(axis=0))

Mean [0.49338897 0.42167726 0.99892333 0.99889629 0.49740803 0.42409801
 0.99893154 0.54295467 0.50232526 0.17906877 0.3973934  0.08355161
 0.39663729 0.54351558 0.44875201]

Std [0.26806155 0.19299821 0.00316702 0.0031655  0.09391195 0.07753678
 0.00316399 0.24398131 0.31739026 0.15236723 0.24086552 0.10407314
 0.24039151 0.24691938 0.27275998]


In [10]:
X = torch.tensor(X, dtype=torch.float32)

In [30]:
encoder = LabelEncoder()
y = encoder.fit_transform(df["class"])
classes = list(df["class"].unique())

y = torch.tensor(y, dtype=torch.long)

In [12]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=42
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.20, random_state=42
)

In [13]:
from torch.utils.data import Dataset, DataLoader

class ClassData(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_data = ClassData(X_train, y_train)
val_data = ClassData(X_val, y_val)
test_data = ClassData(X_test, y_test)

In [14]:
train_loader = DataLoader(train_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)

In [15]:
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self):
        super().__init__() 
        self.model = nn.Sequential(
            nn.Linear(15, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 3)
        )

    def forward(self, x):
        return self.model(x)

In [16]:
model = Classifier()

In [17]:
loss_function = torch.nn.CrossEntropyLoss()

In [18]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [19]:
model.to(device)  # Push the model to the device

Classifier(
  (model): Sequential(
    (0): Linear(in_features=15, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.3, inplace=False)
    (8): Linear(in_features=64, out_features=3, bias=True)
  )
)

In [20]:
def train_model(
    model, train_loader, val_loader, loss_function, optimizer, num_epochs, patience=5
):
    best_val_loss = float("inf")  
    patience_counter = 0  

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = loss_function(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * val_correct / val_total

        print(
            f"Epoch [{epoch+1}/{num_epochs}], "
            f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, "
            f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%"
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0  # Reset the patience counter
            # Save the best model
            torch.save(model.state_dict(), "./models/best_model.tar")
        else:
            patience_counter += 1
            print(f"Early stopping patience: {patience_counter}/{patience}")

        if patience_counter >= patience:
            print("Early stopping triggered. Stopping training.")
            break


train_model(model=model, train_loader=train_loader, val_loader=val_loader, loss_function=loss_function, optimizer=optimizer, num_epochs=15)


Epoch [1/15], Train Loss: 0.4026, Train Accuracy: 84.30%, Val Loss: 0.2037, Val Accuracy: 93.52%
Epoch [2/15], Train Loss: 0.2621, Train Accuracy: 90.76%, Val Loss: 0.1713, Val Accuracy: 94.49%
Epoch [3/15], Train Loss: 0.2316, Train Accuracy: 91.86%, Val Loss: 0.1612, Val Accuracy: 94.79%
Epoch [4/15], Train Loss: 0.2191, Train Accuracy: 92.55%, Val Loss: 0.1594, Val Accuracy: 95.56%
Epoch [5/15], Train Loss: 0.2063, Train Accuracy: 93.15%, Val Loss: 0.1630, Val Accuracy: 95.46%
Early stopping patience: 1/5
Epoch [6/15], Train Loss: 0.1965, Train Accuracy: 93.47%, Val Loss: 0.1594, Val Accuracy: 96.05%
Epoch [7/15], Train Loss: 0.1893, Train Accuracy: 93.78%, Val Loss: 0.1607, Val Accuracy: 95.10%
Early stopping patience: 1/5
Epoch [8/15], Train Loss: 0.1846, Train Accuracy: 93.95%, Val Loss: 0.1480, Val Accuracy: 95.81%
Epoch [9/15], Train Loss: 0.1814, Train Accuracy: 94.07%, Val Loss: 0.1898, Val Accuracy: 93.05%
Early stopping patience: 1/5
Epoch [10/15], Train Loss: 0.1795, Train

In [21]:
def test(dataloader, model, loss_function, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_function(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    return correct, test_loss

In [26]:
model.load_state_dict(torch.load("./models/best_model.tar", weights_only=True))
model.to(device)

Classifier(
  (model): Sequential(
    (0): Linear(in_features=15, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.3, inplace=False)
    (8): Linear(in_features=64, out_features=3, bias=True)
  )
)

In [27]:
correct, test_loss = test(test_loader, model, loss_function, device)
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

Test Error: 
 Accuracy: 95.4%, Avg loss: 0.157214 



In [24]:
correct, test_loss = test(test_loader, model, loss_function, device)
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

Test Error: 
 Accuracy: 95.4%, Avg loss: 0.157214 



In [None]:
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label.item()]] += 1
            total_pred[classes[label.item()]] += 1

for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")

Accuracy for class: GALAXY is 98.3 %
Accuracy for class: QSO   is 87.5 %
Accuracy for class: STAR  is 94.4 %
