In [1]:
import torch
from utils.networks import return_resnet18
import os
from utils.loops import train, valid, test
import wandb
from utils.dataloader import return_dataloaders
from utils.utils import GeneralizedCELoss
import torch.optim as optim
from utils.seed import fix

In [2]:
# parser = argparse.ArgumentParser(description='Training BCD')
# parser.add_argument('-seed', required=True, type=int, help="random seed")
# parser.add_argument('-cuda', required=True, type=int, help="gpu")
# parser.add_argument('-device', required=False, type=str, help="device")
# args = parser.parse_args()

In [3]:
###############################################
# GPU Setting
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# Random seed
seed = 0

# WandB
remote = True
project_name = 'Deboost - mar 2'
run_name = f'BCD {seed} | seed: {seed}'

# Training details
epochs = 3
learning_rate = 0.0001
batch_size = 64
device = 'cuda'

# Save models
save = True
save_path = './BCD_models/'
###############################################

In [4]:
fix(seed)


Random seed has been fixed.
Seed: 0



In [5]:
model = return_resnet18(num_classes=2)

In [6]:
dataloaders = return_dataloaders(dataset='bffhq',
                                root='../../dataset/', 
                                batch_size=batch_size,)
loss_fn = GeneralizedCELoss()
optimizer = optim.Adam(params = model.parameters(), lr=learning_rate)

In [7]:
# WandB settings
if remote:
    wandb.init(
        project=project_name,
        name=run_name,
        config={
            "random seed": seed,
            "learning_rate": learning_rate,
            "batch_size": batch_size, 
            "epochs": epochs,
            "note": ''
        }
    )
    wandb.define_metric("Train/*", step_metric="Batch step")
    wandb.define_metric("Valid/*", step_metric="Epoch step")
    wandb.define_metric("Accuracy/*", step_metric="Epoch step")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhydrated-kapri[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
for epoch in range(epochs):
    train(remote=True,
    model=model,
    epoch=epoch,
    device=device,
    dataloaders=dataloaders,
    optimizer=optimizer,
    loss_fn=loss_fn,
    loss_type='GCE',
    seed = seed)

    valid(remote=True,
    model=model,
    epoch=epoch,
    device=device,
    dataloaders=dataloaders,
    loss_fn=loss_fn,
    loss_type='GCE')

    test(remote=True,
    model=model,
    epoch=epoch,
    device=device,
    dataloaders=dataloaders,
    loss_fn=loss_fn,
    loss_type='GCE')

Epochs: 0

Random seed has been fixed.
Seed: 0

Epochs: 1

Random seed has been fixed.
Seed: 0

Epochs: 2

Random seed has been fixed.
Seed: 0

