In [None]:
!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



In [None]:
!pip install wandb

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.utils.data as data
from tqdm import tqdm


In [None]:
import wandb
wandb.login()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [5]:
%cd 'drive/MyDrive/transfer_learning'

/content/drive/MyDrive/transfer_learning


In [6]:
sweep_config = {
    'method': 'random'
    }
metric = {
    'name': 'loss',
    'goal': 'minimize'   
    }
sweep_config['metric'] = metric

parameters_dict = {
    'optimizer': {
        'values': ['adam', 'sgd']
        }
    }
sweep_config['parameters'] = parameters_dict

parameters_dict.update({
    'epochs': {
        'value': 30}
    })

import math

parameters_dict.update({
    'learning_rate': {
        # a flat distribution between 0 and 0.1
        'distribution': 'uniform',
        'min': 0.001,
        'max': 0.1
      },
    'batch_size': {
        # integers between 32 and 256
        # with evenly-distributed logarithms 
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(2),
        'max': math.log(4),
      }
    })

In [7]:
sweep_config

{'method': 'random',
 'metric': {'goal': 'minimize', 'name': 'loss'},
 'parameters': {'batch_size': {'distribution': 'q_log_uniform',
   'max': 1.3862943611198906,
   'min': 0.6931471805599453,
   'q': 1},
  'epochs': {'value': 30},
  'learning_rate': {'distribution': 'uniform', 'max': 0.1, 'min': 0.001},
  'optimizer': {'values': ['adam', 'sgd']}}}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="pytorch-sweeps-demo")

In [9]:
loss_function = nn.CrossEntropyLoss()

In [11]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms

def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        train_data_loader, val_data_loader = build_dataset(config.batch_size)
        network = build_network()
        optimizer = build_optimizer(network, config.optimizer, config.learning_rate)
        for epoch in range(config.epochs):
          print('epoch:',epoch)
          avg_loss, avg_acc = train_epoch(network, train_data_loader, optimizer)
          wandb.log({"loss_train": avg_loss, "accuracy_train": avg_acc, "epoch": epoch})

          avg_loss_eval, avg_acc_eval = test_epoch(network, val_data_loader)    
          wandb.log({"loss_eval": avg_loss_eval, "accuracy_eval": avg_acc_eval})      

In [12]:
def calc_acc(preds, labels):
    preds_max = torch.argmax(preds, 1)
    acc = torch.sum(preds_max == labels.data, dtype=torch.float16) / len(preds)
    return acc

In [13]:
def build_dataset(batch_size):
    
  transform =transforms.Compose([
                                  transforms.Resize((64, 64)),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
  ])

  dataset = torchvision.datasets.ImageFolder(root='/content/drive/MyDrive/MNIST_persian/MNIST_persian', transform=transform)
  train_size = int(len(dataset)*0.8)
  val_size = len(dataset)-train_size
  train_data ,val_data = data.random_split(dataset,[train_size,val_size])
  train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,num_workers=1)
  val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False,num_workers=1)
  return train_data_loader, val_data_loader


def build_network():
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = torchvision.models.resnet50(pretrained=True)
  return model.to(device)
        

def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = torch.optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9)
    elif optimizer == "adam":
        optimizer = torch.optim.Adam(network.parameters(),
                               lr=learning_rate)
    return optimizer


def train_epoch(network, train_data_loader, optimizer):
  network.train(True)
  train_loss=0.0
  train_acc=0.0
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  for images,labels in tqdm(train_data_loader):
    images=images.to(device)
    labels=labels.to(device)
    optimizer.zero_grad()
    
    preds_train = network(images)

    loss_train=loss_function(preds_train,labels) # loss_train
    loss_train.backward()

    optimizer.step()

    train_loss += loss_train
    train_acc += calc_acc(preds_train,labels)
    
  total_loss = train_loss/len(train_data_loader)
  total_acc = train_acc/len(train_data_loader)
  print(f"loss_train:{total_loss},accuracy_train:{total_acc}")
  
  return total_loss, total_acc

def test_epoch(network, val_data_loader):
  network.eval()
  test_loss=0.0
  test_acc=0.0
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  for images,labels in tqdm(val_data_loader):    
    images=images.to(device)
    labels=labels.to(device)
    
    preds_test = network(images)

    loss_test=loss_function(preds_test,labels) 

    test_loss += loss_test
    test_acc += calc_acc(preds_test,labels)

  total_loss = test_loss/len(val_data_loader)
  total_acc = test_acc/len(val_data_loader)

  print(f"loss_eval:{total_loss},accuracy_eval:{total_acc}")

  return total_loss, total_acc


In [14]:
wandb.agent(sweep_id, train, count=5)

[34m[1mwandb[0m: Agent Starting Run: qmhx0o3h with config:
[34m[1mwandb[0m: 	batch_size: 3
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	learning_rate: 0.043075675165062874
[34m[1mwandb[0m: 	optimizer: adam


epoch: 0


100%|██████████| 320/320 [00:23<00:00, 13.40it/s]


loss_train:3.754723310470581,accuracy_train:0.09942626953125


100%|██████████| 80/80 [00:02<00:00, 35.76it/s]


loss_eval:2.4350335597991943,accuracy_eval:0.09576416015625
epoch: 1


100%|██████████| 320/320 [00:23<00:00, 13.50it/s]


loss_train:2.0306954383850098,accuracy_train:0.2496337890625


100%|██████████| 80/80 [00:02<00:00, 33.83it/s]


loss_eval:1.726707100868225,accuracy_eval:0.300048828125
epoch: 2


100%|██████████| 320/320 [00:23<00:00, 13.34it/s]


loss_train:1.6459190845489502,accuracy_train:0.41259765625


100%|██████████| 80/80 [00:02<00:00, 36.48it/s]


loss_eval:1.5682371854782104,accuracy_eval:0.43359375
epoch: 3


100%|██████████| 320/320 [00:24<00:00, 13.32it/s]


loss_train:1.0716047286987305,accuracy_train:0.59130859375


100%|██████████| 80/80 [00:02<00:00, 37.82it/s]


loss_eval:1.3242113590240479,accuracy_eval:0.41748046875
epoch: 4


100%|██████████| 320/320 [00:24<00:00, 13.28it/s]


loss_train:1.0480657815933228,accuracy_train:0.6103515625


100%|██████████| 80/80 [00:02<00:00, 34.94it/s]


loss_eval:0.655882716178894,accuracy_eval:0.7744140625
epoch: 5


100%|██████████| 320/320 [00:24<00:00, 13.30it/s]


loss_train:0.9674946665763855,accuracy_train:0.64697265625


100%|██████████| 80/80 [00:02<00:00, 35.85it/s]


loss_eval:0.7923521399497986,accuracy_eval:0.69873046875
epoch: 6


100%|██████████| 320/320 [00:24<00:00, 13.24it/s]


loss_train:0.7548975944519043,accuracy_train:0.720703125


100%|██████████| 80/80 [00:02<00:00, 36.25it/s]


loss_eval:0.6660678386688232,accuracy_eval:0.6962890625
epoch: 7


100%|██████████| 320/320 [00:24<00:00, 13.28it/s]


loss_train:0.7692030668258667,accuracy_train:0.7119140625


100%|██████████| 80/80 [00:02<00:00, 34.13it/s]


loss_eval:0.8345792889595032,accuracy_eval:0.63720703125
epoch: 8


100%|██████████| 320/320 [00:24<00:00, 13.20it/s]


loss_train:0.7077062129974365,accuracy_train:0.75


100%|██████████| 80/80 [00:02<00:00, 36.24it/s]


loss_eval:0.6255318522453308,accuracy_eval:0.7451171875
epoch: 9


100%|██████████| 320/320 [00:24<00:00, 13.28it/s]


loss_train:0.7175092101097107,accuracy_train:0.75341796875


100%|██████████| 80/80 [00:02<00:00, 35.45it/s]


loss_eval:0.8503103256225586,accuracy_eval:0.67724609375
epoch: 10


100%|██████████| 320/320 [00:24<00:00, 13.27it/s]


loss_train:0.6844592690467834,accuracy_train:0.7685546875


100%|██████████| 80/80 [00:02<00:00, 34.78it/s]


loss_eval:1.7881149053573608,accuracy_eval:0.412109375
epoch: 11


100%|██████████| 320/320 [00:24<00:00, 13.29it/s]


loss_train:0.6817416548728943,accuracy_train:0.7607421875


100%|██████████| 80/80 [00:02<00:00, 35.28it/s]


loss_eval:0.5867863893508911,accuracy_eval:0.794921875
epoch: 12


100%|██████████| 320/320 [00:24<00:00, 13.30it/s]


loss_train:0.5133391618728638,accuracy_train:0.8115234375


100%|██████████| 80/80 [00:02<00:00, 36.47it/s]


loss_eval:0.6190091371536255,accuracy_eval:0.7744140625
epoch: 13


100%|██████████| 320/320 [00:24<00:00, 13.13it/s]


loss_train:0.4692014753818512,accuracy_train:0.83056640625


100%|██████████| 80/80 [00:02<00:00, 36.79it/s]


loss_eval:0.46594539284706116,accuracy_eval:0.83349609375
epoch: 14


100%|██████████| 320/320 [00:24<00:00, 13.29it/s]


loss_train:0.4477194845676422,accuracy_train:0.859375


100%|██████████| 80/80 [00:02<00:00, 35.89it/s]


loss_eval:0.6308197975158691,accuracy_eval:0.7939453125
epoch: 15


100%|██████████| 320/320 [00:24<00:00, 13.33it/s]


loss_train:0.6784534454345703,accuracy_train:0.783203125


100%|██████████| 80/80 [00:02<00:00, 37.54it/s]


loss_eval:0.7766324877738953,accuracy_eval:0.7412109375
epoch: 16


100%|██████████| 320/320 [00:23<00:00, 13.61it/s]


loss_train:0.5275207161903381,accuracy_train:0.82177734375


100%|██████████| 80/80 [00:02<00:00, 38.59it/s]


loss_eval:0.5795013308525085,accuracy_eval:0.74951171875
epoch: 17


100%|██████████| 320/320 [00:23<00:00, 13.63it/s]


loss_train:0.47764459252357483,accuracy_train:0.8349609375


100%|██████████| 80/80 [00:02<00:00, 38.02it/s]


loss_eval:0.6686856746673584,accuracy_eval:0.79541015625
epoch: 18


100%|██████████| 320/320 [00:23<00:00, 13.61it/s]


loss_train:0.4891701638698578,accuracy_train:0.83203125


100%|██████████| 80/80 [00:02<00:00, 36.35it/s]


loss_eval:0.7495593428611755,accuracy_eval:0.7998046875
epoch: 19


100%|██████████| 320/320 [00:23<00:00, 13.50it/s]


loss_train:0.39310428500175476,accuracy_train:0.8759765625


100%|██████████| 80/80 [00:02<00:00, 37.80it/s]


loss_eval:0.4696139395236969,accuracy_eval:0.83203125
epoch: 20


100%|██████████| 320/320 [00:23<00:00, 13.58it/s]


loss_train:0.39463645219802856,accuracy_train:0.87353515625


100%|██████████| 80/80 [00:02<00:00, 37.67it/s]


loss_eval:0.6224752068519592,accuracy_eval:0.81005859375
epoch: 21


100%|██████████| 320/320 [00:23<00:00, 13.48it/s]


loss_train:0.38877037167549133,accuracy_train:0.86572265625


100%|██████████| 80/80 [00:02<00:00, 38.91it/s]


loss_eval:0.36988329887390137,accuracy_eval:0.8740234375
epoch: 22


100%|██████████| 320/320 [00:23<00:00, 13.57it/s]


loss_train:0.3645612299442291,accuracy_train:0.88525390625


100%|██████████| 80/80 [00:02<00:00, 38.59it/s]


loss_eval:0.5621528029441833,accuracy_eval:0.85400390625
epoch: 23


100%|██████████| 320/320 [00:23<00:00, 13.43it/s]


loss_train:0.2757287919521332,accuracy_train:0.921875


100%|██████████| 80/80 [00:02<00:00, 36.73it/s]


loss_eval:0.4926529824733734,accuracy_eval:0.85791015625
epoch: 24


100%|██████████| 320/320 [00:23<00:00, 13.44it/s]


loss_train:0.37893861532211304,accuracy_train:0.88427734375


100%|██████████| 80/80 [00:02<00:00, 36.60it/s]


loss_eval:0.5594025254249573,accuracy_eval:0.798828125
epoch: 25


100%|██████████| 320/320 [00:23<00:00, 13.49it/s]


loss_train:0.43082496523857117,accuracy_train:0.8662109375


100%|██████████| 80/80 [00:02<00:00, 38.13it/s]


loss_eval:0.3242792785167694,accuracy_eval:0.8662109375
epoch: 26


100%|██████████| 320/320 [00:23<00:00, 13.59it/s]


loss_train:0.2850453555583954,accuracy_train:0.9140625


100%|██████████| 80/80 [00:02<00:00, 37.95it/s]


loss_eval:0.5972259044647217,accuracy_eval:0.798828125
epoch: 27


100%|██████████| 320/320 [00:23<00:00, 13.53it/s]


loss_train:0.3199602961540222,accuracy_train:0.91552734375


100%|██████████| 80/80 [00:02<00:00, 36.69it/s]


loss_eval:0.2833866477012634,accuracy_eval:0.91259765625
epoch: 28


100%|██████████| 320/320 [00:23<00:00, 13.48it/s]


loss_train:0.20474648475646973,accuracy_train:0.93115234375


100%|██████████| 80/80 [00:02<00:00, 37.82it/s]


loss_eval:0.6189097166061401,accuracy_eval:0.8037109375
epoch: 29


100%|██████████| 320/320 [00:23<00:00, 13.46it/s]


loss_train:0.30695486068725586,accuracy_train:0.91162109375


100%|██████████| 80/80 [00:02<00:00, 35.22it/s]


loss_eval:0.5509801506996155,accuracy_eval:0.8193359375


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy_eval,▁▃▄▄▇▆▆▆▇▆▄▇▇▇▇▇▇▇▇▇▇█▇█▇█▇█▇▇
accuracy_train,▁▂▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇██▇███▇████
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
loss_eval,█▆▅▄▂▃▂▃▂▃▆▂▂▂▂▃▂▂▃▂▂▁▂▂▂▁▂▁▂▂
loss_train,█▅▄▃▃▃▂▂▂▂▂▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy_eval,0.81934
accuracy_train,0.91162
epoch,29.0
loss_eval,0.55098
loss_train,0.30695


[34m[1mwandb[0m: Agent Starting Run: 28smz49e with config:
[34m[1mwandb[0m: 	batch_size: 3
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	learning_rate: 0.03723199039835957
[34m[1mwandb[0m: 	optimizer: adam


epoch: 0


100%|██████████| 320/320 [00:23<00:00, 13.66it/s]


loss_train:3.5174472332000732,accuracy_train:0.149658203125


100%|██████████| 80/80 [00:02<00:00, 36.48it/s]


loss_eval:2.164724826812744,accuracy_eval:0.220947265625
epoch: 1


100%|██████████| 320/320 [00:23<00:00, 13.61it/s]


loss_train:2.1308116912841797,accuracy_train:0.223388671875


100%|██████████| 80/80 [00:02<00:00, 36.71it/s]


loss_eval:1.7139486074447632,accuracy_eval:0.312255859375
epoch: 2


100%|██████████| 320/320 [00:23<00:00, 13.55it/s]


loss_train:1.584629774093628,accuracy_train:0.40625


100%|██████████| 80/80 [00:02<00:00, 38.38it/s]


loss_eval:1.4114201068878174,accuracy_eval:0.442138671875
epoch: 3


100%|██████████| 320/320 [00:23<00:00, 13.57it/s]


loss_train:1.2846568822860718,accuracy_train:0.54296875


100%|██████████| 80/80 [00:02<00:00, 38.41it/s]


loss_eval:0.7797461748123169,accuracy_eval:0.7490234375
epoch: 4


100%|██████████| 320/320 [00:23<00:00, 13.40it/s]


loss_train:0.9561538100242615,accuracy_train:0.671875


100%|██████████| 80/80 [00:02<00:00, 38.14it/s]


loss_eval:1.6510772705078125,accuracy_eval:0.49658203125
epoch: 5


100%|██████████| 320/320 [00:23<00:00, 13.51it/s]


loss_train:0.922902524471283,accuracy_train:0.6767578125


100%|██████████| 80/80 [00:02<00:00, 36.44it/s]


loss_eval:1.1930683851242065,accuracy_eval:0.5830078125
epoch: 6


100%|██████████| 320/320 [00:23<00:00, 13.47it/s]


loss_train:0.7519305944442749,accuracy_train:0.75537109375


100%|██████████| 80/80 [00:02<00:00, 37.34it/s]


loss_eval:0.49471554160118103,accuracy_eval:0.8359375
epoch: 7


100%|██████████| 320/320 [00:23<00:00, 13.40it/s]


loss_train:0.650959312915802,accuracy_train:0.7724609375


100%|██████████| 80/80 [00:02<00:00, 36.73it/s]


loss_eval:0.8233828544616699,accuracy_eval:0.6904296875
epoch: 8


100%|██████████| 320/320 [00:23<00:00, 13.46it/s]


loss_train:0.5553634166717529,accuracy_train:0.79345703125


100%|██████████| 80/80 [00:02<00:00, 36.82it/s]


loss_eval:1.0505262613296509,accuracy_eval:0.6298828125
epoch: 9


100%|██████████| 320/320 [00:24<00:00, 13.32it/s]


loss_train:0.5272261500358582,accuracy_train:0.8173828125


100%|██████████| 80/80 [00:02<00:00, 33.26it/s]


loss_eval:0.9188949465751648,accuracy_eval:0.7041015625
epoch: 10


100%|██████████| 320/320 [00:23<00:00, 13.39it/s]


loss_train:0.4618544280529022,accuracy_train:0.83837890625


100%|██████████| 80/80 [00:02<00:00, 37.56it/s]


loss_eval:5.075261116027832,accuracy_eval:0.183349609375
epoch: 11


100%|██████████| 320/320 [00:23<00:00, 13.43it/s]


loss_train:0.40712738037109375,accuracy_train:0.8662109375


100%|██████████| 80/80 [00:02<00:00, 35.10it/s]


loss_eval:0.35001489520072937,accuracy_eval:0.88671875
epoch: 12


100%|██████████| 320/320 [00:23<00:00, 13.44it/s]


loss_train:0.4261256158351898,accuracy_train:0.8603515625


100%|██████████| 80/80 [00:02<00:00, 35.83it/s]


loss_eval:0.42865100502967834,accuracy_eval:0.8623046875
epoch: 13


100%|██████████| 320/320 [00:23<00:00, 13.42it/s]


loss_train:0.4264601767063141,accuracy_train:0.86572265625


100%|██████████| 80/80 [00:02<00:00, 36.39it/s]


loss_eval:0.6490603089332581,accuracy_eval:0.8408203125
epoch: 14


100%|██████████| 320/320 [00:23<00:00, 13.44it/s]


loss_train:0.36812183260917664,accuracy_train:0.87744140625


100%|██████████| 80/80 [00:02<00:00, 36.92it/s]


loss_eval:0.45835208892822266,accuracy_eval:0.85400390625
epoch: 15


100%|██████████| 320/320 [00:23<00:00, 13.45it/s]


loss_train:0.2868257761001587,accuracy_train:0.8974609375


100%|██████████| 80/80 [00:02<00:00, 36.81it/s]


loss_eval:0.9976545572280884,accuracy_eval:0.69580078125
epoch: 16


100%|██████████| 320/320 [00:23<00:00, 13.45it/s]


loss_train:0.2854788899421692,accuracy_train:0.89990234375


100%|██████████| 80/80 [00:02<00:00, 37.94it/s]


loss_eval:0.36925187706947327,accuracy_eval:0.86083984375
epoch: 17


100%|██████████| 320/320 [00:23<00:00, 13.49it/s]


loss_train:0.3504241406917572,accuracy_train:0.8818359375


100%|██████████| 80/80 [00:02<00:00, 37.15it/s]


loss_eval:0.33888188004493713,accuracy_eval:0.9072265625
epoch: 18


100%|██████████| 320/320 [00:23<00:00, 13.43it/s]


loss_train:0.30290859937667847,accuracy_train:0.8994140625


100%|██████████| 80/80 [00:02<00:00, 37.17it/s]


loss_eval:0.5927570462226868,accuracy_eval:0.80322265625
epoch: 19


100%|██████████| 320/320 [00:23<00:00, 13.37it/s]


loss_train:0.22051873803138733,accuracy_train:0.93994140625


100%|██████████| 80/80 [00:02<00:00, 34.53it/s]


loss_eval:2.205188751220703,accuracy_eval:0.4921875
epoch: 20


100%|██████████| 320/320 [00:24<00:00, 13.31it/s]


loss_train:0.2819678485393524,accuracy_train:0.90869140625


100%|██████████| 80/80 [00:02<00:00, 37.76it/s]


loss_eval:2.39050555229187,accuracy_eval:0.42041015625
epoch: 21


100%|██████████| 320/320 [00:23<00:00, 13.41it/s]


loss_train:0.25988903641700745,accuracy_train:0.91552734375


100%|██████████| 80/80 [00:02<00:00, 36.92it/s]


loss_eval:0.3589889407157898,accuracy_eval:0.90771484375
epoch: 22


100%|██████████| 320/320 [00:23<00:00, 13.33it/s]


loss_train:0.24333634972572327,accuracy_train:0.9287109375


100%|██████████| 80/80 [00:02<00:00, 35.48it/s]


loss_eval:0.6908765435218811,accuracy_eval:0.7607421875
epoch: 23


100%|██████████| 320/320 [00:24<00:00, 13.22it/s]


loss_train:0.24252836406230927,accuracy_train:0.92041015625


100%|██████████| 80/80 [00:02<00:00, 35.57it/s]


loss_eval:0.35427045822143555,accuracy_eval:0.90771484375
epoch: 24


100%|██████████| 320/320 [00:24<00:00, 13.24it/s]


loss_train:0.21395035088062286,accuracy_train:0.9345703125


100%|██████████| 80/80 [00:02<00:00, 36.02it/s]


loss_eval:0.2855319082736969,accuracy_eval:0.8955078125
epoch: 25


100%|██████████| 320/320 [00:24<00:00, 13.29it/s]


loss_train:0.23865000903606415,accuracy_train:0.9326171875


100%|██████████| 80/80 [00:02<00:00, 36.15it/s]


loss_eval:0.37498021125793457,accuracy_eval:0.8916015625
epoch: 26


100%|██████████| 320/320 [00:24<00:00, 13.30it/s]


loss_train:0.1647997349500656,accuracy_train:0.94384765625


100%|██████████| 80/80 [00:02<00:00, 36.76it/s]


loss_eval:0.48536762595176697,accuracy_eval:0.85791015625
epoch: 27


100%|██████████| 320/320 [00:24<00:00, 13.33it/s]


loss_train:0.19214119017124176,accuracy_train:0.9345703125


100%|██████████| 80/80 [00:02<00:00, 35.31it/s]


loss_eval:0.6512923240661621,accuracy_eval:0.7646484375
epoch: 28


100%|██████████| 320/320 [00:24<00:00, 13.30it/s]


loss_train:0.1841019093990326,accuracy_train:0.94677734375


100%|██████████| 80/80 [00:02<00:00, 38.08it/s]


loss_eval:0.6490859389305115,accuracy_eval:0.85791015625
epoch: 29


100%|██████████| 320/320 [00:24<00:00, 13.33it/s]


loss_train:0.25596874952316284,accuracy_train:0.9306640625


100%|██████████| 80/80 [00:02<00:00, 34.51it/s]


loss_eval:0.5552111864089966,accuracy_eval:0.8701171875


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy_eval,▁▂▄▆▄▅▇▆▅▆▁██▇▇▆██▇▄▃█▇████▇██
accuracy_train,▁▂▃▄▆▆▆▆▇▇▇▇▇▇▇██▇████████████
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
loss_eval,▄▃▃▂▃▂▁▂▂▂█▁▁▂▁▂▁▁▁▄▄▁▂▁▁▁▁▂▂▁
loss_train,█▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy_eval,0.87012
accuracy_train,0.93066
epoch,29.0
loss_eval,0.55521
loss_train,0.25597


[34m[1mwandb[0m: Agent Starting Run: 0tv0fh3d with config:
[34m[1mwandb[0m: 	batch_size: 2
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	learning_rate: 0.05191451695198839
[34m[1mwandb[0m: 	optimizer: sgd


epoch: 0


100%|██████████| 480/480 [00:27<00:00, 17.45it/s]


loss_train:4.240832328796387,accuracy_train:0.1156005859375


100%|██████████| 120/120 [00:02<00:00, 41.66it/s]


loss_eval:2.8471810817718506,accuracy_eval:0.2208251953125
epoch: 1


100%|██████████| 480/480 [00:27<00:00, 17.57it/s]


loss_train:2.5172250270843506,accuracy_train:0.12396240234375


100%|██████████| 120/120 [00:02<00:00, 41.40it/s]


loss_eval:3.835145950317383,accuracy_eval:0.13330078125
epoch: 2


100%|██████████| 480/480 [00:27<00:00, 17.50it/s]


loss_train:2.1799983978271484,accuracy_train:0.231201171875


100%|██████████| 120/120 [00:02<00:00, 41.62it/s]


loss_eval:1.6369389295578003,accuracy_eval:0.395751953125
epoch: 3


100%|██████████| 480/480 [00:27<00:00, 17.55it/s]


loss_train:1.632797122001648,accuracy_train:0.41552734375


100%|██████████| 120/120 [00:03<00:00, 39.64it/s]


loss_eval:2.37272310256958,accuracy_eval:0.591796875
epoch: 4


100%|██████████| 480/480 [00:27<00:00, 17.58it/s]


loss_train:1.2604024410247803,accuracy_train:0.5625


100%|██████████| 120/120 [00:02<00:00, 40.63it/s]


loss_eval:1.0633535385131836,accuracy_eval:0.66259765625
epoch: 5


100%|██████████| 480/480 [00:27<00:00, 17.59it/s]


loss_train:1.077789545059204,accuracy_train:0.6240234375


100%|██████████| 120/120 [00:02<00:00, 41.82it/s]


loss_eval:1.1560578346252441,accuracy_eval:0.64990234375
epoch: 6


100%|██████████| 480/480 [00:26<00:00, 17.81it/s]


loss_train:0.9223328232765198,accuracy_train:0.6884765625


100%|██████████| 120/120 [00:03<00:00, 38.73it/s]


loss_eval:1.2537667751312256,accuracy_eval:0.495849609375
epoch: 7


100%|██████████| 480/480 [00:27<00:00, 17.73it/s]


loss_train:0.8329365849494934,accuracy_train:0.71240234375


100%|██████████| 120/120 [00:02<00:00, 40.79it/s]


loss_eval:0.6168798804283142,accuracy_eval:0.779296875
epoch: 8


100%|██████████| 480/480 [00:27<00:00, 17.72it/s]


loss_train:0.7374338507652283,accuracy_train:0.73974609375


100%|██████████| 120/120 [00:02<00:00, 41.41it/s]


loss_eval:0.5495001673698425,accuracy_eval:0.7998046875
epoch: 9


100%|██████████| 480/480 [00:27<00:00, 17.55it/s]


loss_train:0.6165119409561157,accuracy_train:0.78955078125


100%|██████████| 120/120 [00:02<00:00, 41.13it/s]


loss_eval:0.6785656809806824,accuracy_eval:0.77490234375
epoch: 10


100%|██████████| 480/480 [00:27<00:00, 17.52it/s]


loss_train:0.6073135733604431,accuracy_train:0.8115234375


100%|██████████| 120/120 [00:02<00:00, 40.76it/s]


loss_eval:0.7651383280754089,accuracy_eval:0.6875
epoch: 11


100%|██████████| 480/480 [00:27<00:00, 17.49it/s]


loss_train:0.5039456486701965,accuracy_train:0.82177734375


100%|██████████| 120/120 [00:02<00:00, 41.06it/s]


loss_eval:0.964609682559967,accuracy_eval:0.69189453125
epoch: 12


100%|██████████| 480/480 [00:27<00:00, 17.61it/s]


loss_train:0.5369629263877869,accuracy_train:0.83251953125


100%|██████████| 120/120 [00:02<00:00, 42.11it/s]


loss_eval:0.5683205723762512,accuracy_eval:0.845703125
epoch: 13


100%|██████████| 480/480 [00:27<00:00, 17.60it/s]


loss_train:0.4249316453933716,accuracy_train:0.8623046875


100%|██████████| 120/120 [00:03<00:00, 37.93it/s]


loss_eval:0.47665825486183167,accuracy_eval:0.85400390625
epoch: 14


100%|██████████| 480/480 [00:27<00:00, 17.54it/s]


loss_train:0.4648404121398926,accuracy_train:0.85009765625


100%|██████████| 120/120 [00:02<00:00, 40.61it/s]


loss_eval:0.46586722135543823,accuracy_eval:0.875
epoch: 15


100%|██████████| 480/480 [00:27<00:00, 17.56it/s]


loss_train:0.43653371930122375,accuracy_train:0.861328125


100%|██████████| 120/120 [00:02<00:00, 41.55it/s]


loss_eval:0.7218329906463623,accuracy_eval:0.81689453125
epoch: 16


100%|██████████| 480/480 [00:27<00:00, 17.55it/s]


loss_train:0.39540258049964905,accuracy_train:0.86865234375


100%|██████████| 120/120 [00:02<00:00, 41.22it/s]


loss_eval:0.34111180901527405,accuracy_eval:0.89599609375
epoch: 17


100%|██████████| 480/480 [00:27<00:00, 17.62it/s]


loss_train:0.30938687920570374,accuracy_train:0.89892578125


100%|██████████| 120/120 [00:02<00:00, 41.83it/s]


loss_eval:0.5410152077674866,accuracy_eval:0.87939453125
epoch: 18


100%|██████████| 480/480 [00:27<00:00, 17.60it/s]


loss_train:0.3387436866760254,accuracy_train:0.8935546875


100%|██████████| 120/120 [00:02<00:00, 41.01it/s]


loss_eval:0.43912336230278015,accuracy_eval:0.8916015625
epoch: 19


100%|██████████| 480/480 [00:27<00:00, 17.58it/s]


loss_train:0.25262534618377686,accuracy_train:0.9228515625


100%|██████████| 120/120 [00:02<00:00, 41.52it/s]


loss_eval:0.43996962904930115,accuracy_eval:0.86669921875
epoch: 20


100%|██████████| 480/480 [00:27<00:00, 17.56it/s]


loss_train:0.20275774598121643,accuracy_train:0.92724609375


100%|██████████| 120/120 [00:03<00:00, 39.00it/s]


loss_eval:0.5950037837028503,accuracy_eval:0.87060546875
epoch: 21


100%|██████████| 480/480 [00:27<00:00, 17.57it/s]


loss_train:0.2564195990562439,accuracy_train:0.935546875


100%|██████████| 120/120 [00:02<00:00, 41.07it/s]


loss_eval:0.20133104920387268,accuracy_eval:0.9375
epoch: 22


100%|██████████| 480/480 [00:27<00:00, 17.48it/s]


loss_train:0.20139625668525696,accuracy_train:0.9375


100%|██████████| 120/120 [00:02<00:00, 41.13it/s]


loss_eval:0.19634060561656952,accuracy_eval:0.94189453125
epoch: 23


100%|██████████| 480/480 [00:27<00:00, 17.24it/s]


loss_train:0.1964942216873169,accuracy_train:0.9404296875


100%|██████████| 120/120 [00:03<00:00, 39.90it/s]


loss_eval:0.23005305230617523,accuracy_eval:0.91259765625
epoch: 24


100%|██████████| 480/480 [00:27<00:00, 17.28it/s]


loss_train:0.1997431218624115,accuracy_train:0.93310546875


100%|██████████| 120/120 [00:02<00:00, 41.10it/s]


loss_eval:0.3855433762073517,accuracy_eval:0.8916015625
epoch: 25


100%|██████████| 480/480 [00:27<00:00, 17.57it/s]


loss_train:0.09952887147665024,accuracy_train:0.9716796875


100%|██████████| 120/120 [00:02<00:00, 40.89it/s]


loss_eval:0.15287774801254272,accuracy_eval:0.9501953125
epoch: 26


100%|██████████| 480/480 [00:27<00:00, 17.70it/s]


loss_train:0.1519538164138794,accuracy_train:0.96044921875


100%|██████████| 120/120 [00:02<00:00, 40.78it/s]


loss_eval:0.2691046893596649,accuracy_eval:0.9248046875
epoch: 27


100%|██████████| 480/480 [00:27<00:00, 17.65it/s]


loss_train:0.20074260234832764,accuracy_train:0.94189453125


100%|██████████| 120/120 [00:02<00:00, 40.63it/s]


loss_eval:0.20179659128189087,accuracy_eval:0.93310546875
epoch: 28


100%|██████████| 480/480 [00:27<00:00, 17.62it/s]


loss_train:0.18482284247875214,accuracy_train:0.94677734375


100%|██████████| 120/120 [00:03<00:00, 37.90it/s]


loss_eval:0.19283844530582428,accuracy_eval:0.94189453125
epoch: 29


100%|██████████| 480/480 [00:27<00:00, 17.64it/s]


loss_train:0.07847405225038528,accuracy_train:0.97607421875


100%|██████████| 120/120 [00:03<00:00, 38.84it/s]


loss_eval:0.1609918475151062,accuracy_eval:0.9375


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy_eval,▂▁▃▅▆▅▄▇▇▆▆▆▇▇▇▇█▇▇▇▇███▇█████
accuracy_train,▁▁▂▃▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇███████████
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
loss_eval,▆█▄▅▃▃▃▂▂▂▂▃▂▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁
loss_train,█▅▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy_eval,0.9375
accuracy_train,0.97607
epoch,29.0
loss_eval,0.16099
loss_train,0.07847


[34m[1mwandb[0m: Agent Starting Run: e8jpd9m4 with config:
[34m[1mwandb[0m: 	batch_size: 3
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	learning_rate: 0.06413977433952679
[34m[1mwandb[0m: 	optimizer: adam


epoch: 0


100%|██████████| 320/320 [00:23<00:00, 13.35it/s]


loss_train:3.9414615631103516,accuracy_train:0.10272216796875


100%|██████████| 80/80 [00:02<00:00, 37.09it/s]


loss_eval:2.362734794616699,accuracy_eval:0.1124267578125
epoch: 1


100%|██████████| 320/320 [00:23<00:00, 13.51it/s]


loss_train:2.3622632026672363,accuracy_train:0.09759521484375


100%|██████████| 80/80 [00:02<00:00, 37.15it/s]


loss_eval:2.380697011947632,accuracy_eval:0.05413818359375
epoch: 2


100%|██████████| 320/320 [00:23<00:00, 13.36it/s]


loss_train:2.469727039337158,accuracy_train:0.08709716796875


100%|██████████| 80/80 [00:02<00:00, 36.34it/s]


loss_eval:3.0101284980773926,accuracy_eval:0.13330078125
epoch: 3


100%|██████████| 320/320 [00:23<00:00, 13.40it/s]


loss_train:2.400555372238159,accuracy_train:0.09136962890625


100%|██████████| 80/80 [00:02<00:00, 37.67it/s]


loss_eval:2.3235344886779785,accuracy_eval:0.08331298828125
epoch: 4


100%|██████████| 320/320 [00:23<00:00, 13.38it/s]


loss_train:2.329110860824585,accuracy_train:0.09747314453125


100%|██████████| 80/80 [00:02<00:00, 35.62it/s]


loss_eval:2.3870961666107178,accuracy_eval:0.07489013671875
epoch: 5


100%|██████████| 320/320 [00:23<00:00, 13.39it/s]


loss_train:2.3237569332122803,accuracy_train:0.12158203125


100%|██████████| 80/80 [00:02<00:00, 36.38it/s]


loss_eval:2.3625195026397705,accuracy_eval:0.1124267578125
epoch: 6


100%|██████████| 320/320 [00:24<00:00, 13.29it/s]


loss_train:2.3343868255615234,accuracy_train:0.08819580078125


100%|██████████| 80/80 [00:02<00:00, 35.16it/s]


loss_eval:2.3483598232269287,accuracy_eval:0.09979248046875
epoch: 7


100%|██████████| 320/320 [00:23<00:00, 13.34it/s]


loss_train:2.3297672271728516,accuracy_train:0.08819580078125


100%|██████████| 80/80 [00:02<00:00, 35.92it/s]


loss_eval:2.3836419582366943,accuracy_eval:0.09161376953125
epoch: 8


100%|██████████| 320/320 [00:24<00:00, 13.26it/s]


loss_train:2.333132028579712,accuracy_train:0.0975341796875


100%|██████████| 80/80 [00:02<00:00, 36.69it/s]


loss_eval:2.3436992168426514,accuracy_eval:0.05413818359375
epoch: 9


100%|██████████| 320/320 [00:24<00:00, 13.30it/s]


loss_train:2.335918426513672,accuracy_train:0.09234619140625


100%|██████████| 80/80 [00:02<00:00, 34.20it/s]


loss_eval:2.328477621078491,accuracy_eval:0.09149169921875
epoch: 10


100%|██████████| 320/320 [00:24<00:00, 13.24it/s]


loss_train:2.335827350616455,accuracy_train:0.09539794921875


100%|██████████| 80/80 [00:02<00:00, 36.00it/s]


loss_eval:2.329035520553589,accuracy_eval:0.05413818359375
epoch: 11


100%|██████████| 320/320 [00:24<00:00, 13.27it/s]


loss_train:2.332159996032715,accuracy_train:0.0985107421875


100%|██████████| 80/80 [00:02<00:00, 35.69it/s]


loss_eval:2.3463053703308105,accuracy_eval:0.0958251953125
epoch: 12


100%|██████████| 320/320 [00:24<00:00, 13.32it/s]


loss_train:2.3273072242736816,accuracy_train:0.10906982421875


100%|██████████| 80/80 [00:02<00:00, 34.49it/s]


loss_eval:2.339569091796875,accuracy_eval:0.09979248046875
epoch: 13


100%|██████████| 320/320 [00:24<00:00, 13.24it/s]


loss_train:2.328871726989746,accuracy_train:0.08087158203125


100%|██████████| 80/80 [00:02<00:00, 35.89it/s]


loss_eval:2.36572265625,accuracy_eval:0.09564208984375
epoch: 14


100%|██████████| 320/320 [00:24<00:00, 13.28it/s]


loss_train:2.3351986408233643,accuracy_train:0.095458984375


100%|██████████| 80/80 [00:02<00:00, 36.88it/s]


loss_eval:2.305875778198242,accuracy_eval:0.09161376953125
epoch: 15


100%|██████████| 320/320 [00:24<00:00, 13.30it/s]


loss_train:2.327707529067993,accuracy_train:0.1004638671875


100%|██████████| 80/80 [00:02<00:00, 36.03it/s]


loss_eval:2.3182756900787354,accuracy_eval:0.13330078125
epoch: 16


100%|██████████| 320/320 [00:24<00:00, 13.28it/s]


loss_train:2.332519769668579,accuracy_train:0.09417724609375


100%|██████████| 80/80 [00:02<00:00, 36.33it/s]


loss_eval:2.2993273735046387,accuracy_eval:0.11248779296875
epoch: 17


100%|██████████| 320/320 [00:24<00:00, 13.33it/s]


loss_train:2.3303775787353516,accuracy_train:0.0811767578125


100%|██████████| 80/80 [00:02<00:00, 33.62it/s]


loss_eval:2.34714937210083,accuracy_eval:0.08331298828125
epoch: 18


100%|██████████| 320/320 [00:24<00:00, 13.29it/s]


loss_train:2.340857982635498,accuracy_train:0.08111572265625


100%|██████████| 80/80 [00:02<00:00, 35.63it/s]


loss_eval:2.374396800994873,accuracy_eval:0.05413818359375
epoch: 19


100%|██████████| 320/320 [00:23<00:00, 13.35it/s]


loss_train:2.329714298248291,accuracy_train:0.1004638671875


100%|██████████| 80/80 [00:02<00:00, 36.61it/s]


loss_eval:2.3490161895751953,accuracy_eval:0.05413818359375
epoch: 20


100%|██████████| 320/320 [00:23<00:00, 13.40it/s]


loss_train:2.333177089691162,accuracy_train:0.0882568359375


100%|██████████| 80/80 [00:02<00:00, 36.05it/s]


loss_eval:2.3419253826141357,accuracy_eval:0.05413818359375
epoch: 21


100%|██████████| 320/320 [00:24<00:00, 13.31it/s]


loss_train:2.3289573192596436,accuracy_train:0.0911865234375


100%|██████████| 80/80 [00:02<00:00, 36.72it/s]


loss_eval:2.358935832977295,accuracy_eval:0.07489013671875
epoch: 22


100%|██████████| 320/320 [00:23<00:00, 13.36it/s]


loss_train:2.3361124992370605,accuracy_train:0.1058349609375


100%|██████████| 80/80 [00:02<00:00, 35.66it/s]


loss_eval:2.3728697299957275,accuracy_eval:0.05413818359375
epoch: 23


100%|██████████| 320/320 [00:24<00:00, 13.29it/s]


loss_train:2.3301937580108643,accuracy_train:0.1058349609375


100%|██████████| 80/80 [00:02<00:00, 35.86it/s]


loss_eval:2.324800729751587,accuracy_eval:0.07489013671875
epoch: 24


100%|██████████| 320/320 [00:24<00:00, 13.31it/s]


loss_train:2.3310446739196777,accuracy_train:0.09234619140625


100%|██████████| 80/80 [00:02<00:00, 36.38it/s]


loss_eval:2.363896131515503,accuracy_eval:0.05413818359375
epoch: 25


100%|██████████| 320/320 [00:23<00:00, 13.36it/s]


loss_train:2.635436773300171,accuracy_train:0.1004638671875


100%|██████████| 80/80 [00:02<00:00, 35.95it/s]


loss_eval:2.3421883583068848,accuracy_eval:0.10400390625
epoch: 26


100%|██████████| 320/320 [00:23<00:00, 13.34it/s]


loss_train:2.327915906906128,accuracy_train:0.09442138671875


100%|██████████| 80/80 [00:02<00:00, 37.44it/s]


loss_eval:2.3268959522247314,accuracy_eval:0.05413818359375
epoch: 27


100%|██████████| 320/320 [00:24<00:00, 13.30it/s]


loss_train:2.333015203475952,accuracy_train:0.083984375


100%|██████████| 80/80 [00:02<00:00, 36.55it/s]


loss_eval:2.33577299118042,accuracy_eval:0.08331298828125
epoch: 28


100%|██████████| 320/320 [00:24<00:00, 13.32it/s]


loss_train:2.3301544189453125,accuracy_train:0.1025390625


100%|██████████| 80/80 [00:02<00:00, 34.14it/s]


loss_eval:2.295684576034546,accuracy_eval:0.13330078125
epoch: 29


100%|██████████| 320/320 [00:23<00:00, 13.37it/s]


loss_train:2.331974506378174,accuracy_train:0.08001708984375


100%|██████████| 80/80 [00:02<00:00, 35.89it/s]


loss_eval:2.3297064304351807,accuracy_eval:0.08331298828125


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy_eval,▆▁█▄▃▆▅▄▁▄▁▅▅▅▄█▆▄▁▁▁▃▁▃▁▅▁▄█▄
accuracy_train,▅▄▂▃▄█▂▂▄▃▄▄▆▁▄▄▃▁▁▄▂▃▅▅▃▄▃▂▅▁
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
loss_eval,▂▂█▁▂▂▂▂▁▁▁▁▁▂▁▁▁▂▂▂▁▂▂▁▂▁▁▁▁▁
loss_train,█▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁

0,1
accuracy_eval,0.08331
accuracy_train,0.08002
epoch,29.0
loss_eval,2.32971
loss_train,2.33197


[34m[1mwandb[0m: Agent Starting Run: mh1rh15d with config:
[34m[1mwandb[0m: 	batch_size: 4
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	learning_rate: 0.025795316500007347
[34m[1mwandb[0m: 	optimizer: sgd


epoch: 0


100%|██████████| 240/240 [00:15<00:00, 15.47it/s]


loss_train:4.903987407684326,accuracy_train:0.1312255859375


100%|██████████| 60/60 [00:01<00:00, 34.79it/s]


loss_eval:8.943643569946289,accuracy_eval:0.120849609375
epoch: 1


100%|██████████| 240/240 [00:15<00:00, 15.68it/s]


loss_train:2.1134819984436035,accuracy_train:0.2447509765625


100%|██████████| 60/60 [00:01<00:00, 34.12it/s]


loss_eval:2.899571180343628,accuracy_eval:0.333251953125
epoch: 2


100%|██████████| 240/240 [00:15<00:00, 15.62it/s]


loss_train:1.5585490465164185,accuracy_train:0.435302734375


100%|██████████| 60/60 [00:01<00:00, 34.30it/s]


loss_eval:1.3549531698226929,accuracy_eval:0.52099609375
epoch: 3


100%|██████████| 240/240 [00:15<00:00, 15.48it/s]


loss_train:1.0648053884506226,accuracy_train:0.60498046875


100%|██████████| 60/60 [00:01<00:00, 33.21it/s]


loss_eval:0.7676481604576111,accuracy_eval:0.70849609375
epoch: 4


100%|██████████| 240/240 [00:15<00:00, 15.55it/s]


loss_train:0.7352511286735535,accuracy_train:0.7333984375


100%|██████████| 60/60 [00:01<00:00, 33.79it/s]


loss_eval:0.872505247592926,accuracy_eval:0.64990234375
epoch: 5


100%|██████████| 240/240 [00:15<00:00, 15.56it/s]


loss_train:0.6429131627082825,accuracy_train:0.76123046875


100%|██████████| 60/60 [00:01<00:00, 33.05it/s]


loss_eval:0.537294328212738,accuracy_eval:0.7998046875
epoch: 6


100%|██████████| 240/240 [00:15<00:00, 15.54it/s]


loss_train:0.49702081084251404,accuracy_train:0.828125


100%|██████████| 60/60 [00:01<00:00, 34.41it/s]


loss_eval:0.5769896507263184,accuracy_eval:0.79150390625
epoch: 7


100%|██████████| 240/240 [00:15<00:00, 15.66it/s]


loss_train:0.5107338428497314,accuracy_train:0.81689453125


100%|██████████| 60/60 [00:01<00:00, 35.18it/s]


loss_eval:0.6538604497909546,accuracy_eval:0.7958984375
epoch: 8


100%|██████████| 240/240 [00:15<00:00, 15.65it/s]


loss_train:0.3982994258403778,accuracy_train:0.8564453125


100%|██████████| 60/60 [00:01<00:00, 35.60it/s]


loss_eval:1.350955605506897,accuracy_eval:0.7626953125
epoch: 9


100%|██████████| 240/240 [00:15<00:00, 15.67it/s]


loss_train:0.3091282248497009,accuracy_train:0.892578125


100%|██████████| 60/60 [00:01<00:00, 34.85it/s]


loss_eval:0.363490492105484,accuracy_eval:0.8623046875
epoch: 10


100%|██████████| 240/240 [00:15<00:00, 15.65it/s]


loss_train:0.3666568398475647,accuracy_train:0.86376953125


100%|██████████| 60/60 [00:01<00:00, 33.11it/s]


loss_eval:0.402763694524765,accuracy_eval:0.88330078125
epoch: 11


100%|██████████| 240/240 [00:15<00:00, 15.69it/s]


loss_train:0.2687295079231262,accuracy_train:0.89990234375


100%|██████████| 60/60 [00:01<00:00, 34.53it/s]


loss_eval:0.31061723828315735,accuracy_eval:0.904296875
epoch: 12


100%|██████████| 240/240 [00:15<00:00, 15.59it/s]


loss_train:0.2312351018190384,accuracy_train:0.921875


100%|██████████| 60/60 [00:01<00:00, 34.28it/s]


loss_eval:0.32491207122802734,accuracy_eval:0.87060546875
epoch: 13


100%|██████████| 240/240 [00:15<00:00, 15.62it/s]


loss_train:0.1847241371870041,accuracy_train:0.93017578125


100%|██████████| 60/60 [00:01<00:00, 34.14it/s]


loss_eval:0.23290683329105377,accuracy_eval:0.92919921875
epoch: 14


100%|██████████| 240/240 [00:15<00:00, 15.58it/s]


loss_train:0.1484023779630661,accuracy_train:0.94873046875


100%|██████████| 60/60 [00:01<00:00, 33.93it/s]


loss_eval:0.34672632813453674,accuracy_eval:0.904296875
epoch: 15


100%|██████████| 240/240 [00:15<00:00, 15.59it/s]


loss_train:0.10421867668628693,accuracy_train:0.96435546875


100%|██████████| 60/60 [00:01<00:00, 34.80it/s]


loss_eval:0.3516830801963806,accuracy_eval:0.91650390625
epoch: 16


100%|██████████| 240/240 [00:15<00:00, 15.45it/s]


loss_train:0.1396486610174179,accuracy_train:0.94873046875


100%|██████████| 60/60 [00:01<00:00, 34.22it/s]


loss_eval:0.39872053265571594,accuracy_eval:0.8876953125
epoch: 17


100%|██████████| 240/240 [00:15<00:00, 15.49it/s]


loss_train:0.09314872324466705,accuracy_train:0.966796875


100%|██████████| 60/60 [00:01<00:00, 35.45it/s]


loss_eval:0.37407588958740234,accuracy_eval:0.904296875
epoch: 18


100%|██████████| 240/240 [00:15<00:00, 15.65it/s]


loss_train:0.0892687737941742,accuracy_train:0.96875


100%|██████████| 60/60 [00:01<00:00, 34.30it/s]


loss_eval:0.3481219708919525,accuracy_eval:0.8916015625
epoch: 19


100%|██████████| 240/240 [00:15<00:00, 15.66it/s]


loss_train:0.12400639057159424,accuracy_train:0.96044921875


100%|██████████| 60/60 [00:01<00:00, 34.39it/s]


loss_eval:0.22733739018440247,accuracy_eval:0.93310546875
epoch: 20


100%|██████████| 240/240 [00:15<00:00, 15.53it/s]


loss_train:0.0909605622291565,accuracy_train:0.970703125


100%|██████████| 60/60 [00:01<00:00, 34.50it/s]


loss_eval:0.22938726842403412,accuracy_eval:0.9375
epoch: 21


100%|██████████| 240/240 [00:15<00:00, 15.63it/s]


loss_train:0.08344448357820511,accuracy_train:0.97802734375


100%|██████████| 60/60 [00:01<00:00, 35.32it/s]


loss_eval:0.26757335662841797,accuracy_eval:0.91650390625
epoch: 22


100%|██████████| 240/240 [00:15<00:00, 15.50it/s]


loss_train:0.12717628479003906,accuracy_train:0.9541015625


100%|██████████| 60/60 [00:01<00:00, 34.45it/s]


loss_eval:0.37021762132644653,accuracy_eval:0.88330078125
epoch: 23


100%|██████████| 240/240 [00:15<00:00, 15.58it/s]


loss_train:0.07842901349067688,accuracy_train:0.9716796875


100%|██████████| 60/60 [00:01<00:00, 35.61it/s]


loss_eval:0.2094448357820511,accuracy_eval:0.9501953125
epoch: 24


100%|██████████| 240/240 [00:15<00:00, 15.73it/s]


loss_train:0.04064582660794258,accuracy_train:0.9853515625


100%|██████████| 60/60 [00:01<00:00, 32.42it/s]


loss_eval:0.22671885788440704,accuracy_eval:0.9375
epoch: 25


100%|██████████| 240/240 [00:15<00:00, 15.65it/s]


loss_train:0.03802723065018654,accuracy_train:0.986328125


100%|██████████| 60/60 [00:01<00:00, 35.32it/s]


loss_eval:0.2112153172492981,accuracy_eval:0.9375
epoch: 26


100%|██████████| 240/240 [00:15<00:00, 15.65it/s]


loss_train:0.02854051999747753,accuracy_train:0.99169921875


100%|██████████| 60/60 [00:01<00:00, 34.85it/s]


loss_eval:0.31469693779945374,accuracy_eval:0.92919921875
epoch: 27


100%|██████████| 240/240 [00:15<00:00, 15.58it/s]


loss_train:0.013218376785516739,accuracy_train:0.998046875


100%|██████████| 60/60 [00:01<00:00, 33.55it/s]


loss_eval:0.2004440724849701,accuracy_eval:0.94580078125
epoch: 28


100%|██████████| 240/240 [00:15<00:00, 15.26it/s]


loss_train:0.012881346978247166,accuracy_train:0.9970703125


100%|██████████| 60/60 [00:01<00:00, 32.99it/s]


loss_eval:0.2715218663215637,accuracy_eval:0.9375
epoch: 29


100%|██████████| 240/240 [00:15<00:00, 15.45it/s]


loss_train:0.008373342454433441,accuracy_train:0.9970703125


100%|██████████| 60/60 [00:01<00:00, 33.56it/s]


loss_eval:0.21060937643051147,accuracy_eval:0.94580078125


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy_eval,▁▃▄▆▅▇▇▇▆▇▇█▇███▇█████▇███████
accuracy_train,▁▂▃▅▆▆▇▇▇▇▇▇▇▇████████████████
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
loss_eval,█▃▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_train,█▄▃▃▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy_eval,0.9458
accuracy_train,0.99707
epoch,29.0
loss_eval,0.21061
loss_train,0.00837


In [20]:
model = build_network()
torch.save(model.state_dict(), "/content/drive/MyDrive/transfer_learning/model_mnist.pth")

In [None]:
import pprint

pprint.pprint(sweep_config)

{'method': 'random',
 'metric': {'goal': 'minimize', 'name': 'loss'},
 'parameters': {'batch_size': {'distribution': 'q_log_uniform',
                               'max': 5.545177444479562,
                               'min': 3.4657359027997265,
                               'q': 1},
                'epochs': {'values': 3},
                'learning_rate': {'distribution': 'uniform',
                                  'max': 0.1,
                                  'min': 0},
                'optimizer': {'values': ['adam', 'sgd']}}}


In [None]:
wandb.init(project="transferlearning_persion_mnist")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torchvision.models.resnet50(pretrained=True)

In [None]:
# Transfer Learning
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 10)

In [None]:
model=model.to(device)

In [None]:
model

In [None]:
# freeze
# ct = 0
# for child in model.children():
#   ct += 1
#   if ct < 6:
#     for param in child.parameters():
#         param.requires_grad = False

In [None]:
config = wandb.config
config.learning_rate = 0.01
config.batch_size = 1
config.epochs = 15

In [None]:
transform =transforms.Compose([
                                transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])

dataset = torchvision.datasets.ImageFolder(root='/content/drive/MyDrive/MNIST_persian/MNIST_persian', transform=transform)
torch.manual_seed(0)
train_size = int(len(dataset)*0.8)
val_size = len(dataset)-train_size
train_data ,val_data = data.random_split(dataset,[train_size,val_size])

train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True,num_workers=32)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=config.batch_size, shuffle=False,num_workers=32)

  cpuset_checked))


In [None]:
# compile
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
loss_function = nn.CrossEntropyLoss()

In [None]:
def calc_acc(preds, labels):
    preds_max = torch.argmax(preds, 1)
    acc = torch.sum(preds_max == labels.data, dtype=torch.float64) / len(preds)
    return acc

In [None]:
def train(model, train_data_loader,epoch):
  wandb.watch(model)
  model.train(True)
  train_loss=0.0
  train_acc=0.0
  for images,labels in tqdm(train_data_loader):
    images=images.to(device)
    labels=labels.to(device)
    optimizer.zero_grad()
    
    preds_train = model(images)

    loss_train=loss_function(preds_train,labels) # loss_train
    loss_train.backward()

    optimizer.step()

    train_loss += loss_train
    train_acc += calc_acc(preds_train,labels)
    
  total_loss = train_loss/len(train_data_loader)
  total_acc = train_acc/len(train_data_loader)

  if epoch % 2 == 0:
    wandb.log({"epoch": epoch})
    wandb.log({"loss_train": total_loss})
    wandb.log({"acc_train": total_acc})

  print(f"loss_train:{total_loss},accuracy_train:{total_acc}")

In [None]:
def test(model, val_data_loader,epoch):
  model.eval()
  test_loss=0.0
  test_acc=0.0
  for images,labels in tqdm(val_data_loader):    
    images=images.to(device)
    labels=labels.to(device)
    
    preds_test = model(images)

    loss_test=loss_function(preds_test,labels) 

    test_loss += loss_test
    test_acc += calc_acc(preds_test,labels)

  total_loss = test_loss/len(val_data_loader)
  total_acc = test_acc/len(val_data_loader)

  if epoch % 2 == 0:
        wandb.log({"loss_eval": total_loss})
        wandb.log({"acc_eval": total_acc})

  print(f"loss_eval:{total_loss},accuracy_eval:{total_acc}")

In [None]:
for epoch in range(config.epochs):
  print(f'Epoch:{epoch}')
  train(model, train_data_loader, epoch)
  test(model, val_data_loader, epoch)

In [23]:
transform =transforms.Compose([
                                  transforms.Resize((64, 64)),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
  ])
import cv2
import numpy as np
model = build_network()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/transfer_learning/model_mnist.pth' ,map_location=torch.device(device)))

model.eval()
img = cv2.imread('/content/drive/MyDrive/MNIST_persian/MNIST_persian/2/2_13.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = cv2.resize(img, (64, 64))

from PIL import Image
PIL_image = Image.fromarray(img)

tensor = transform(PIL_image).unsqueeze(0).to(device)

# process
pred = model(tensor)

# postprocess
pred = pred.cpu().detach().numpy()
pred = np.argmax(pred)

print('prediction:', pred)

prediction: 549
