In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import networks
import data_handler
import trainer
from utils import check_log_dir, make_log_name, set_seed
from arguments import get_args
import time
import os
import argparse
# args = get_args()


ModuleNotFoundError: ignored

## Hyperparameters

In [None]:
"""
Hyperparameters
"""

args = argparse.Namespace(
    seed = 0,
    dataset = 'waterbird',
    batch_size = 128,
    epochs = 1,
    device = 0,
    n_workers = 1,
    balSampling = False,
    model = 'resnet18',
    pretrained = True,
    method = 'gdro',
    optim = 'SGD',
    lr = 0.01,
    gamma = 0.1,
    weight_decay = 0.0001,
    cuda = True,
    term = 20,
    record = False,
    log_dir = './logs/',
    date = '20230828',
    save_dir = './trained_models/',
)

torch.backends.cudnn.enabled = True
set_seed(args.seed)

np.set_printoptions(precision=4)
torch.set_printoptions(precision=4)

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= '1'

## Load dataset & network model

In [None]:
########################## get dataloader ################################
tmp = data_handler.DataloaderFactory.get_dataloader(args.dataset,
                                                    batch_size=args.batch_size,
                                                    seed=args.seed,
                                                    n_workers=args.n_workers,
                                                    balSampling=args.balSampling,
                                                    args=args
                                                    )
n_classes, n_groups, train_loader, test_loader = tmp
########################## get model ##################################
model = networks.ModelFactory.get_model(args.model, n_classes, 224,
                                        pretrained=args.pretrained, n_groups=n_groups)

model.cuda('cuda:{}'.format(args.device))
print('successfully call the model')
#     set_seed(seed)
scheduler=None
########################## get trainer ##################################
if args.optim == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optim == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optim == 'AdamW':
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

trainer_ = trainer.TrainerFactory.get_trainer(args.method, model=model, args=args,
                                            optimizer=optimizer, scheduler=scheduler)

  "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "


mode : test
# of 0 group data :  [2255  642]
# of 1 group data :  [2255  642]
mode : train
# of 0 group data :  [3498   56]
# of 1 group data :  [ 184 1057]
# of test data : 5794
# of train data : 4795
Dataset loaded.
# of classes, # of groups : 2, 2
successfully call the model


## Start training

In [None]:
####################### start training & evaluation ####################
start_t = time.time()
trainer_.train(train_loader, test_loader, args.epochs)
end_t = time.time()
train_t = int((end_t - start_t)/60)  # to minutes

print('Training Time : {} hours {} minutes'.format(int(train_t/60), (train_t % 60)))


[1/1,    20] Method: gdro Train Loss: 0.530 Train Acc: 0.73 [0.49 s/batch]
[1/1] Method: gdro Test Loss: 0.336 Test Acc: 0.86 Test DEOM 0.04 [19.68 s]
Training Finished!
Training Time : 0 hours 0 minutes


## Evaluation

In [None]:
####################### Evaluation ####################
criterion = torch.nn.CrossEntropyLoss(reduction='none')

loss, acc, dcaM, dcaA, group_acc, group_loss = trainer_.evaluate(trainer_.model, test_loader, criterion, train=False)
print('Test')
# print('Loss \t Accuracy \t DCA-M \t DCA-A \t Group Accuracy \t Group Loss \t')
print('Loss: {:.3f}'.format(loss.item()))
acc = (group_acc[0,0] * 0.95 + group_acc[0,1] * 0.05 + group_acc[1,0]*0.05 + group_acc[1,1]) / 2
print('Accuracy: {:.3f}'.format(acc.item()))
print('DCA-M: {:.3f}'.format(dcaM))
print('DCA-A: {:.3f}'.format(dcaA))
print('Group Accuracy:')
print(group_acc.cpu().numpy())
print('Group Loss:')
print(group_loss.cpu().numpy())

Test
Loss: 0.336
Accuracy: 0.863
DCA-M: 0.039
DCA-A: 0.023
Group Accuracy:
[[0.8812 0.8645]
 [0.8421 0.8723]]
Group Loss:
[[0.3011 0.3578]
 [0.3671 0.3247]]
