In [1]:
import torch
from trainers.map_based_trainer import MapBasedTrainer
from dataset.map_based.map_based_jacquard import MapBasedJacquardDataset
from loss_functions.map_loss import MapLoss
from models.grconvnet4 import GRConvNet4

In [4]:
dataset = MapBasedJacquardDataset(
    image_size = 224, 
    precision = torch.float32,
    cache_path = "/Users/gursi/Desktop/jacquard/cache",
    random_augment = True
)

model = GRConvNet4(clip=True)
loss_fn = MapLoss()
lr = 1e-4
optimizer = torch.optim.Adam

# Scheduler that halves learning rate every 25 iterations
def scheduler(lr, step):
    if (step+1) % 25 == 0:
        return lr/2
    return lr

trainer = MapBasedTrainer(
    training_mode = "cls",
    model = model,
    device = "mps",
    loss_fn = loss_fn,
    dataset = dataset,
    optimizer = optimizer,
    lr = lr,
    train_batch_size = 8,
    test_split_ratio = 0.2,
    checkpoint_dir = "/Users/gursi/Desktop/new_trials",
    log_dir = "logs",
    scheduler = scheduler,
    num_accumulate_batches = 8,
)

In [5]:
trainer.run(100)

--------------------------------------------------


Cls training step 1: 100%|██████████| 111/111 [00:23<00:00,  4.63it/s, loss=-.114] 
Cls test step 1: 100%|██████████| 14/14 [00:02<00:00,  4.75it/s]


Average Loss: -0.2642645931669644 | Accuracy: 0.3963963963963964
--------------------------------------------------


Cls training step 2: 100%|██████████| 111/111 [00:23<00:00,  4.63it/s, loss=-.987] 
Cls test step 2: 100%|██████████| 14/14 [00:02<00:00,  4.75it/s]


Average Loss: -0.6476673611572811 | Accuracy: 0.1891891891891892
--------------------------------------------------


Cls training step 3: 100%|██████████| 111/111 [00:23<00:00,  4.65it/s, loss=-.522]
Cls test step 3: 100%|██████████| 14/14 [00:02<00:00,  4.74it/s]


Average Loss: -0.7617331104619163 | Accuracy: 0.1891891891891892
--------------------------------------------------


Cls training step 4: 100%|██████████| 111/111 [00:23<00:00,  4.71it/s, loss=-1.87]
Cls test step 4: 100%|██████████| 14/14 [00:02<00:00,  4.73it/s]


Average Loss: -0.7893122072730746 | Accuracy: 0.23423423423423423
--------------------------------------------------


Cls training step 5: 100%|██████████| 111/111 [00:23<00:00,  4.69it/s, loss=-1.89]
Cls test step 5: 100%|██████████| 14/14 [00:03<00:00,  4.66it/s]


Average Loss: -0.7981768122741154 | Accuracy: 0.23873873873873874
--------------------------------------------------


Cls training step 6: 100%|██████████| 111/111 [00:23<00:00,  4.69it/s, loss=-.863]
Cls test step 6: 100%|██████████| 14/14 [00:03<00:00,  4.52it/s]


Average Loss: -0.8032948332173484 | Accuracy: 0.24774774774774774
--------------------------------------------------


Cls training step 7:  65%|██████▍   | 72/111 [00:15<00:08,  4.64it/s, loss=-1.07]


KeyboardInterrupt: 