In [1]:
import torch
from torchvision.models import resnet50
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
import wandb

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

import sys
sys.path.append('../src')

from train_eval_utils import AvgMeter

  from tqdm.autonotebook import tqdm


In [2]:
train_dataset = CIFAR10(root='../input/', download=True, transform=ToTensor())
val_dataset = CIFAR10(root='../input/', train=False, transform=ToTensor())

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

data_loaders = {'train': train_loader, 'val': val_loader}

Files already downloaded and verified


In [3]:
resnet = resnet50(weights=None, num_classes=2048)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min")
step = 'batch'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 25

entity = 'image-captioning-clip'
project_name = 'image-captioning-CLIP'
exp_name = 'resnet'

run = wandb.init(entity=entity, project=project_name)

resnet.to(device)

[34m[1mwandb[0m: Currently logged in as: [33mld2425[0m ([33mimage-captioning-clip[0m). Use [1m`wandb login --relogin`[0m to force relogin


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
best_loss = float('inf')
for epoch in tqdm(range(epochs), desc='Epochs'):

        resnet.train()
        train_meter = AvgMeter()
        tqdm_train = tqdm(data_loaders['train'], total=len(data_loaders['train']), desc='Train')
        for images, labels in tqdm_train:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = resnet(images)
            loss = criterion(outputs, labels)

            train_meter.update(loss.item(), images.size(0))

            loss.backward()
            optimizer.step()

            tqdm_train.set_postfix(train_loss=train_meter.avg)

        resnet.eval()
        val_meter = AvgMeter()
        val_loss = 0.0
        tqdm_val = tqdm(data_loaders['val'], total=len(data_loaders['val']), desc='Val')
        for images, labels in tqdm_val:
            images = images.to(device)
            labels = labels.to(device)

            outputs = resnet(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)

            val_meter.update(loss.item(), images.size(0))

            tqdm_val.set_postfix(val_loss=val_meter.avg)

        if  val_meter.avg < best_loss:
            best_loss = val_meter.avg
            torch.save(resnet.state_dict(), f'../models/{exp_name}.pt')
            artifact = wandb.Artifact(exp_name, type='model')
            artifact.add_file(f'../models/{exp_name}.pt')
            run.log_artifact(artifact)
            print("Saved Best Model")

        scheduler.step(val_loss)

run.finish()

Epochs:   0%|          | 0/25 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Saved Best Model


Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

Train:   0%|          | 0/1563 [00:00<?, ?it/s]

Val:   0%|          | 0/313 [00:00<?, ?it/s]

VBox(children=(Label(value='1059.812 MB of 1059.812 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.…