In [1]:
import os

os.chdir("../")

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import random_split, DataLoader
import timm
import numpy as np

import argparse
import torch_optimizer as optimizer
import wandb
from pathlib import Path
from config import settings

import models.spinalnet_resnet as spinalnet_resnet
import models.effnet as effnet
import models.densenet as densenet
import models.spinalnet_vgg as spinalnet_vgg
import models.vitL16 as vitL16
import models.alexnet_vgg as alexnet_vgg
import models.resnet18 as resnet18

import  data
# import data.segmentation as segmentation
# import metrics.metrics as metrics
from data import DataPart
from train import Trainer
import metrics


all_models = [
    ('ResNet18', resnet18),
    ('EfficientNet', effnet),
    # ('DenseNet', densenet),
    # ('SpinalNet_ResNet', spinalnet_resnet),
    # ('SpinalNet_VGG', spinalnet_vgg),
    # ('ViTL16', vitL16),
    # ('AlexNet_VGG', alexnet_vgg)
]

all_optimizers = [
    ('SGD', optim.SGD),
    ('Rprop', optim.Rprop),
    ('Adam', optim.Adam),
    ('NAdam', optim.NAdam),
    ('RAdam', optim.RAdam),
    ('AdamW', optim.AdamW),
    #('Adagrad', optim.Adagrad),
    ('RMSprop', optim.RMSprop),
    #('Adadelta', optim.Adadelta),
    ('DiffGrad', optimizer.DiffGrad),
    # ('LBFGS', optim.LBFGS)
]

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
datasets, dataloaders = data.create_dataloaders()

train_loader = dataloaders[DataPart.TRAIN]
val_loader = dataloaders[DataPart.VALIDATE]
test_loader = dataloaders[DataPart.TEST_DR5]


  return bound(*args, **kwds)


INFO: Query finished. [astroquery.utils.tap.core]
252
84
84
244
85


In [3]:

# parser = argparse.ArgumentParser(description='Model training')
# parser.add_argument('--models', nargs='+', default=['ResNet18', 'EfficientNet', 'DenseNet', 'SpinalNet_ResNet', 'SpinalNet_VGG', 'ViTL16', 'AlexNet_VGG'],
#                     help='List of models to train (default: all)')
# parser.add_argument('--epochs', type=int, default=5, help='Number of epochs to train (default: 5)')
# parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate for optimizer (default: 0.0001)')
# parser.add_argument('--mm', type=float, default=0.9, help='Momentum for optimizer (default: 0.9)')
# parser.add_argument('--optimizer', choices=[name for name, _ in all_optimizers], default='Adam', help='Optimizer to use (default: Adam)')

# args = parser.parse_args()

# selected_models = [(model_name, model) for model_name, model in models if model_name in args.models]

# num_epochs = args.epochs
# lr = args.lr
# momentum = args.mm
# optimizer_name = args.optimizer


In [4]:
selected_models = all_models[:2]

num_epochs = 1
lr = 0.0001
momentum = 0.9
optimizer_name = "Adam"




In [5]:
if settings.wandb_api_token:
    wandb.login(key=settings.wandb_api_token)
    wandb.init(project='cluster-search', config={}, reinit=True)
else:
    wandb.init(project='cluster-search', config={}, reinit=True)


wandb.config.models = [name for name, _ in selected_models]
wandb.config.num_epochs = num_epochs
wandb.config.lr = lr
wandb.config.momentum = momentum
wandb.config.optimizer = optimizer_name

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mzehov1[0m ([33mmzekhov[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/mszekhov/.netrc


In [6]:

# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()

results = {}
val_results = {}

classes = ('random', 'clusters')


In [7]:
for model_name, model in selected_models:

     model = model.load_model()

     optimizer_class = dict(all_optimizers)[optimizer_name]

     if optimizer_name in ['SGD', 'RMSprop']:
          optimizer = optimizer_class(model.parameters(), lr=lr, momentum=momentum) 
     else:
          optimizer = optimizer_class(model.parameters(), lr=lr)
         
     trainer = Trainer(
          model=model,
          criterion=criterion,
          optimizer=optimizer,
          train_dataloader=train_loader,
          val_dataloader=val_loader,

     )

     trainer.train(num_epochs)

     for step in range(trainer.global_step):
          wandb.log(
               {
                    f'{model_name}_{optimizer_name}_train_loss': trainer.history['train_loss'][step], 
                    f'{model_name}_{optimizer_name}_train_accuracy':trainer.history['train_acc'][step], 
                    'global_step': step + 1})
          
     for epoch in range(num_epochs):
          wandb.log(
               {
                    f'{model_name}_{optimizer_name}_val_loss': trainer.history['val_loss'][epoch], 
                    f'{model_name}_{optimizer_name}_val_accuracy': trainer.history['val_acc'][epoch], 
                    'epoch': epoch})

     
     train_table = wandb.Table(
          data=[
               [
                    step, 
                    trainer.history['train_loss'][step], 
                    trainer.history['train_acc'][step]
               ] for step in range(trainer.global_step)],
          columns=["Epoch", "Loss", "Accuracy"])

     val_table = wandb.Table(
          data=[
               [
                    epoch, 
                    trainer.history['val_loss'][epoch], 
                    trainer.history['val_acc'][epoch]
               ] for epoch in range(num_epochs)],
          columns=["Epoch", "Loss", "Accuracy"])

     wandb.log({"Train Metrics": train_table, "Validation Metrics": val_table})

     y_pred, y_probs, y_true, *_ = trainer.test(test_loader)

     metrics.modelPerformance(model_name, optimizer_name, y_true, y_pred, y_probs, classes)

metrics.combine_metrics(selected_models, optimizer_name)

100%|██████████| 1/1 [00:09<00:00,  9.68s/batch]      | 0/1 [00:00<?, ?epoch/s]
100%|██████████| 1/1 [00:04<00:00,  4.19s/it]
100%|██████████| 1/1 [00:03<00:00,  3.39s/it]                                          
100%|██████████| 1/1 [00:32<00:00, 32.63s/batch]          | 0/1 [00:00<?, ?epoch/s]
100%|██████████| 1/1 [00:12<00:00, 12.13s/it]
100%|██████████| 1/1 [00:10<00:00, 10.49s/it]                                              


In [None]:
wandb.finish()

wandb_run = wandb.run
if wandb_run:
    logged_metrics = wandb_run.history()
    print("Logged Metrics:")
    for key, value in logged_metrics.items():
        print(key, ":", value)
else:
    print("No wandb run found.")

VBox(children=(Label(value='0.004 MB of 0.009 MB uploaded\r'), FloatProgress(value=0.4286352967475131, max=1.0…

0,1
ResNet18_Adam_train_accuracy,▆▆▅▆▆▆▇▇▇█▇▇▇▇█▁
ResNet18_Adam_train_loss,█▇█▇▇▇▅▅▅▄▆▅▃▂▁▇
ResNet18_Adam_val_accuracy,▁
ResNet18_Adam_val_loss,▁
epoch,▁
global_step,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██

0,1
ResNet18_Adam_train_accuracy,0.0
ResNet18_Adam_train_loss,0.6936
ResNet18_Adam_val_accuracy,0.72852
ResNet18_Adam_val_loss,0.62251
epoch,0.0
global_step,16.0


No wandb run found.


In [13]:
import segmentation

model_name, model = selected_models[0]
segmentation.create_segmentation_plots(
    model,
    model_name,
    optimizer_name=optimizer_name
)

  loaded_model = torch.load(weights_path, map_location=device)
100%|██████████| 1/1 [00:00<00:00,  3.11it/s]


23
23
23
23
23
23
23
23
23
23


100%|██████████| 7/7 [00:19<00:00,  2.79s/it]
100%|██████████| 7/7 [00:20<00:00,  2.98s/it]
100%|██████████| 7/7 [00:21<00:00,  3.12s/it]
100%|██████████| 7/7 [00:20<00:00,  2.91s/it]
100%|██████████| 7/7 [00:19<00:00,  2.85s/it]
100%|██████████| 7/7 [00:20<00:00,  2.98s/it]
100%|██████████| 7/7 [00:19<00:00,  2.81s/it]
100%|██████████| 7/7 [00:20<00:00,  2.90s/it]
100%|██████████| 7/7 [00:22<00:00,  3.25s/it]
100%|██████████| 7/7 [00:22<00:00,  3.23s/it]
100%|██████████| 1/1 [00:00<00:00,  7.92it/s]


23
23
23
23
23


100%|██████████| 7/7 [00:18<00:00,  2.66s/it]
100%|██████████| 7/7 [00:19<00:00,  2.84s/it]
100%|██████████| 7/7 [00:19<00:00,  2.78s/it]
100%|██████████| 7/7 [00:28<00:00,  4.08s/it]
100%|██████████| 7/7 [00:18<00:00,  2.62s/it]
100%|██████████| 1/1 [00:00<00:00,  8.08it/s]


23
23
23
23
23


100%|██████████| 7/7 [00:18<00:00,  2.66s/it]
100%|██████████| 7/7 [00:22<00:00,  3.25s/it]
100%|██████████| 7/7 [00:24<00:00,  3.56s/it]
100%|██████████| 7/7 [00:25<00:00,  3.66s/it]
100%|██████████| 7/7 [00:26<00:00,  3.74s/it]
100%|██████████| 1/1 [00:00<00:00, 14.72it/s]


49


100%|██████████| 16/16 [00:50<00:00,  3.13s/it]
100%|██████████| 1/1 [00:00<00:00, 34.24it/s]


49
Failed attempt 0 to download /Users/mszekhov/Desktop/current_projects/galaxyHackers/storage/segmentation/samples/dr5_big/758/679.fits with an HTTPError
Failed attempt 0 to download /Users/mszekhov/Desktop/current_projects/galaxyHackers/storage/segmentation/samples/dr5_big/758/728.fits with an HTTPError


100%|██████████| 16/16 [00:43<00:00,  2.74s/it]
