In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import networks
import data_handler
import trainer
from utils import check_log_dir, make_log_name, set_seed
from arguments import get_args
import time
import os 
import argparse
# args = get_args()


## Hyperparameters

In [3]:
"""
Hyperparameters
"""

args = argparse.Namespace(
    seed = 0,
    dataset = 'waterbird',
    batch_size = 128,
    epochs = 1,
    device = 0,
    n_workers = 1,
    balSampling = False,
    model = 'resnet18',
    pretrained = True,
    method = 'scratch',
    optim = 'SGD',
    lr = 0.01,
    weight_decay = 0.0001,
    cuda = True,
    term = 20,
    record = False,
    log_dir = './logs/',
    date = '20230828',
    save_dir = './trained_models/',
)

torch.backends.cudnn.enabled = True
set_seed(args.seed)

np.set_printoptions(precision=4)
torch.set_printoptions(precision=4)


## Load dataset & network model

In [4]:
########################## get dataloader ################################
tmp = data_handler.DataloaderFactory.get_dataloader(args.dataset, 
                                                    batch_size=args.batch_size,
                                                    seed=args.seed,
                                                    n_workers=args.n_workers,
                                                    balSampling=args.balSampling,
                                                    args=args
                                                    )
n_classes, n_groups, train_loader, test_loader = tmp
########################## get model ##################################
model = networks.ModelFactory.get_model(args.model, n_classes, 224,
                                        pretrained=args.pretrained, n_groups=n_groups)

model.cuda('cuda:{}'.format(args.device))
print('successfully call the model')
#     set_seed(seed)
scheduler=None
########################## get trainer ##################################
if args.optim == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optim == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optim == 'AdamW':
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

trainer_ = trainer.TrainerFactory.get_trainer(args.method, model=model, args=args,
                                            optimizer=optimizer, scheduler=scheduler)



KeyError: 256

## Start training

In [4]:
####################### start training & evaluation ####################
start_t = time.time()
trainer_.train(train_loader, test_loader, args.epochs)
end_t = time.time()
train_t = int((end_t - start_t)/60)  # to minutes

print('Training Time : {} hours {} minutes'.format(int(train_t/60), (train_t % 60)))  


[1/1,    20] Method: scratch Train Loss: 0.365 Train Acc: 0.85 [0.49 s/batch]
[1/1] Method: scratch Test Loss: 0.476 Test Acc: 0.78 Test DCAM 0.67 [19.73 s]
Training Finished!
Training Time : 0 hours 0 minutes


## Evaluation

In [5]:
####################### Evaluation ####################
criterion = torch.nn.CrossEntropyLoss(reduction='none')

loss, acc, dcaM, dcaA, group_acc, group_loss = trainer_.evaluate(trainer_.model, test_loader, criterion, train=False)
print('Test')
# print('Loss \t Accuracy \t DCA-M \t DCA-A \t Group Accuracy \t Group Loss \t')
print('Loss: {:.3f}'.format(loss.item()))
acc = (group_acc[0,0] * 0.95 + group_acc[0,1] * 0.05 + group_acc[1,0]*0.05 + group_acc[1,1]) / 2
print('Accuracy: {:.3f}'.format(acc.item()))
print('DCA-M: {:.3f}'.format(dcaM))
print('DCA-A: {:.3f}'.format(dcaA))
print('Group Accuracy:')
print(group_acc.cpu().numpy())
print('Group Loss:')
print(group_loss.cpu().numpy())

Test
Loss: 0.476
Accuracy: 0.780
DCA-M: 0.665
DCA-A: 0.455
Group Accuracy:
[[0.9978 0.1121]
 [0.7534 0.7773]]
Group Loss:
[[0.0435 1.9802]
 [0.4823 0.466 ]]
