In [None]:
# import necessary dependencies
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import os
import torch
import torch.nn as nn
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [None]:
BATCH_SIZE = 128
EPOCHS = 100
#DECAY = 5e-4
#MOMENTUM = 0.9
LR = 5e-3

In [None]:
class RotNet(nn.Module):
    def __init__(self):
        super(RotNet, self).__init__()
        self.encoder = torchvision.models.resnet18(pretrained=False)
        self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.maxpool = nn.Identity()
        self.encoder.fc = nn.Identity()
        self.linear1 = nn.Linear(512, 4, bias=False)

    def forward(self, x):
        out = self.encoder(x)
        out = self.linear1(out)

        return out

In [None]:
class CIFAR10Rot(Dataset):

    def __init__(self, base_dataset):

        self.base_dataset = base_dataset
        self.transformed = self.rot()

    def rot(self):
        roted_x = []
        roted_y = []
        for img, _ in tqdm(self.base_dataset):
            for idx, angle in enumerate([0, 90, 180, 270]):
                rot_im = torchvision.transforms.functional.rotate(img, angle)
                rot_label = idx

                roted_x.append(rot_im)
                roted_y.append(rot_label)

        return roted_x, roted_y


    def __len__(self):
        return len(self.base_dataset) * 4

    def __getitem__(self, idx):
        return self.transformed[0][idx], self.transformed[1][idx]

In [None]:
train_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainset = CIFAR10Rot(trainset)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

testset = CIFAR10Rot(torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform))
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 102311761.15it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


  0%|          | 0/50000 [00:00<?, ?it/s]

Files already downloaded and verified


  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
rot_model = RotNet().cuda()
#optimizer = torch.optim.SGD(rot_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=DECAY)
optimizer = torch.optim.Adam(rot_model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.2)



best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_corrects = 0
    rot_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        rot_model.zero_grad()
        out = rot_model(image)

        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()

    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)


    with torch.no_grad():
        rot_model.eval()

        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_model.state_dict(), f'rotnet_base_{EPOCHS}_{BATCH_SIZE}_{LR}.pth')

    scheduler.step()
    print(f'Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')

  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.9846785068511963 Acc 0.58179 ; Val Loss 0.839855432510376 Acc 0.649975


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.675491452217102 Acc 0.72758 ; Val Loss 0.6970157623291016 Acc 0.7208


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.5092211365699768 Acc 0.80211 ; Val Loss 0.4907119572162628 Acc 0.812375


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.3951680064201355 Acc 0.850355 ; Val Loss 0.4874732196331024 Acc 0.816625


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.30453747510910034 Acc 0.88571 ; Val Loss 0.3939718008041382 Acc 0.8529


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.22506281733512878 Acc 0.916635 ; Val Loss 0.3970150351524353 Acc 0.86375


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.15572518110275269 Acc 0.94244 ; Val Loss 0.43753936886787415 Acc 0.859375


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.10624672472476959 Acc 0.9612 ; Val Loss 0.47446686029434204 Acc 0.864875


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.08060764521360397 Acc 0.97089 ; Val Loss 0.5064173340797424 Acc 0.861325


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.06265072524547577 Acc 0.977425 ; Val Loss 0.5750971436500549 Acc 0.858775


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.055530350655317307 Acc 0.97986 ; Val Loss 0.5619309544563293 Acc 0.86675


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.04982493817806244 Acc 0.982115 ; Val Loss 0.535318911075592 Acc 0.86905


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.04286600649356842 Acc 0.98469 ; Val Loss 0.5756000280380249 Acc 0.871925


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.04055440425872803 Acc 0.98546 ; Val Loss 0.5963699817657471 Acc 0.8681


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.03578188642859459 Acc 0.987425 ; Val Loss 0.6557211875915527 Acc 0.8615


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.03341561183333397 Acc 0.988195 ; Val Loss 0.589906632900238 Acc 0.87235


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.031502094119787216 Acc 0.988775 ; Val Loss 0.6217256188392639 Acc 0.8702


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.02993013523519039 Acc 0.989405 ; Val Loss 0.644247829914093 Acc 0.8697


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.027453644201159477 Acc 0.990425 ; Val Loss 0.6558691263198853 Acc 0.86875


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.025786850601434708 Acc 0.990955 ; Val Loss 0.6902045607566833 Acc 0.864075


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.024755051359534264 Acc 0.991365 ; Val Loss 0.6843933463096619 Acc 0.862925


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.024008773267269135 Acc 0.99154 ; Val Loss 0.6440277099609375 Acc 0.875925


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.022888358682394028 Acc 0.991985 ; Val Loss 0.6899734735488892 Acc 0.868225


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.021415214985609055 Acc 0.99248 ; Val Loss 0.7355988025665283 Acc 0.862925


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.02095388062298298 Acc 0.992685 ; Val Loss 0.6562375426292419 Acc 0.87595


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.019641049206256866 Acc 0.993175 ; Val Loss 0.6728243231773376 Acc 0.873275


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.019428255036473274 Acc 0.99324 ; Val Loss 0.7366514801979065 Acc 0.8674


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.018472397699952126 Acc 0.993655 ; Val Loss 0.7041016221046448 Acc 0.8728


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.018337661400437355 Acc 0.993475 ; Val Loss 0.7161058783531189 Acc 0.8734


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.017644798383116722 Acc 0.99387 ; Val Loss 0.7084397673606873 Acc 0.87345


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.004254285246133804 Acc 0.998555 ; Val Loss 0.69059157371521 Acc 0.88405


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0006303308182395995 Acc 0.99989 ; Val Loss 0.7654632925987244 Acc 0.885975


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0005033563356846571 Acc 0.999915 ; Val Loss 0.834915280342102 Acc 0.884275


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.000557382998522371 Acc 0.999825 ; Val Loss 0.9045345783233643 Acc 0.885075


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.000857932202052325 Acc 0.999745 ; Val Loss 0.9237738847732544 Acc 0.88495


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0005892431945540011 Acc 0.99982 ; Val Loss 0.9559272527694702 Acc 0.883525


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0005404929397627711 Acc 0.999845 ; Val Loss 0.9982560873031616 Acc 0.882325


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0006959423772059381 Acc 0.999785 ; Val Loss 0.9915922284126282 Acc 0.88565


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0006681380327790976 Acc 0.99978 ; Val Loss 0.99798983335495 Acc 0.885525


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0006484866607934237 Acc 0.99979 ; Val Loss 0.9839022159576416 Acc 0.886625


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0006323034758679569 Acc 0.999775 ; Val Loss 1.0476943254470825 Acc 0.883425


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.000822119414806366 Acc 0.99977 ; Val Loss 1.0145412683486938 Acc 0.886825


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0005457206279970706 Acc 0.99984 ; Val Loss 1.0299352407455444 Acc 0.886075


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0005949623300693929 Acc 0.999775 ; Val Loss 1.0761955976486206 Acc 0.88325


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0006197575712576509 Acc 0.999775 ; Val Loss 1.0744221210479736 Acc 0.88455


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0005269786925055087 Acc 0.99984 ; Val Loss 1.057453989982605 Acc 0.886325


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0006668956484645605 Acc 0.9998 ; Val Loss 1.0908758640289307 Acc 0.8855


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0005355616449378431 Acc 0.9998 ; Val Loss 1.091747760772705 Acc 0.8839


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0009909016080200672 Acc 0.99967 ; Val Loss 1.0535306930541992 Acc 0.885575


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.00032913967152126133 Acc 0.99987 ; Val Loss 1.1031767129898071 Acc 0.884275


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0005095890955999494 Acc 0.999845 ; Val Loss 1.041428565979004 Acc 0.8881


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.00020769507682416588 Acc 0.999955 ; Val Loss 1.085037112236023 Acc 0.888025


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0009374250075779855 Acc 0.99969 ; Val Loss 1.1099460124969482 Acc 0.882375


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0007218930986709893 Acc 0.999765 ; Val Loss 1.0726351737976074 Acc 0.885625


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.00022643612464889884 Acc 0.999935 ; Val Loss 1.1343543529510498 Acc 0.8845


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.000709037238266319 Acc 0.999775 ; Val Loss 1.1330403089523315 Acc 0.885125


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0008676049183122814 Acc 0.999705 ; Val Loss 1.1194968223571777 Acc 0.885


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0004971072194166481 Acc 0.999825 ; Val Loss 1.0723706483840942 Acc 0.886625


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0007438999600708485 Acc 0.99977 ; Val Loss 1.1226595640182495 Acc 0.884275


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.0006454915273934603 Acc 0.99981 ; Val Loss 1.1180737018585205 Acc 0.884675


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.00019431472173891962 Acc 0.999925 ; Val Loss 1.0783722400665283 Acc 0.887275


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 6.023813330102712e-05 Acc 0.99999 ; Val Loss 1.0786712169647217 Acc 0.888475


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 3.938403824577108e-05 Acc 1.0 ; Val Loss 1.086600661277771 Acc 0.88815


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 4.340638770372607e-05 Acc 0.99999 ; Val Loss 1.084540843963623 Acc 0.889425


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.8365548385190777e-05 Acc 1.0 ; Val Loss 1.0991878509521484 Acc 0.889575


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.9192539184587076e-05 Acc 1.0 ; Val Loss 1.1160920858383179 Acc 0.88955


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.222112859977642e-05 Acc 1.0 ; Val Loss 1.1267904043197632 Acc 0.889475


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.1382333468645811e-05 Acc 1.0 ; Val Loss 1.1321969032287598 Acc 0.889625


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 3.22721571137663e-05 Acc 0.999985 ; Val Loss 1.139137864112854 Acc 0.8891


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 3.20292056130711e-05 Acc 0.99999 ; Val Loss 1.1338335275650024 Acc 0.890325


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 7.51756260797265e-06 Acc 1.0 ; Val Loss 1.1519474983215332 Acc 0.8899


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 7.3032874752243515e-06 Acc 1.0 ; Val Loss 1.1574559211730957 Acc 0.889675


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.9538520064088516e-05 Acc 0.999995 ; Val Loss 1.161406397819519 Acc 0.88945


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 5.193436209083302e-06 Acc 1.0 ; Val Loss 1.181259274482727 Acc 0.889275


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 5.945136763330083e-06 Acc 1.0 ; Val Loss 1.1823105812072754 Acc 0.889275


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 3.408218844924704e-06 Acc 1.0 ; Val Loss 1.2000552415847778 Acc 0.88995


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.1150797035952564e-05 Acc 0.999995 ; Val Loss 1.1910500526428223 Acc 0.889525


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 2.960231086035492e-06 Acc 1.0 ; Val Loss 1.2149959802627563 Acc 0.8901


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 4.161076958553167e-06 Acc 1.0 ; Val Loss 1.2278861999511719 Acc 0.8903


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 2.1400078367150854e-06 Acc 1.0 ; Val Loss 1.249379277229309 Acc 0.889825


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.8515403326091473e-06 Acc 1.0 ; Val Loss 1.2266454696655273 Acc 0.889375


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.6822012867123703e-06 Acc 1.0 ; Val Loss 1.2377396821975708 Acc 0.8902


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 3.2075563467515167e-06 Acc 1.0 ; Val Loss 1.2482436895370483 Acc 0.8901


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.245712724085024e-06 Acc 1.0 ; Val Loss 1.2398806810379028 Acc 0.890225


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 2.340023911528988e-06 Acc 1.0 ; Val Loss 1.2344602346420288 Acc 0.89025


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.8408712776363245e-06 Acc 1.0 ; Val Loss 1.2530320882797241 Acc 0.889775


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.032326508720871e-06 Acc 1.0 ; Val Loss 1.2579851150512695 Acc 0.88975


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.5107823401194764e-06 Acc 1.0 ; Val Loss 1.2377197742462158 Acc 0.889875


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.7723639302857919e-06 Acc 1.0 ; Val Loss 1.25593101978302 Acc 0.889475


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.6292590316879796e-06 Acc 1.0 ; Val Loss 1.2634341716766357 Acc 0.889425


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 8.738450105738593e-07 Acc 1.0 ; Val Loss 1.260367512702942 Acc 0.8903


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.7171338413390913e-06 Acc 1.0 ; Val Loss 1.2588255405426025 Acc 0.889625


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 8.003003131307196e-07 Acc 1.0 ; Val Loss 1.2510344982147217 Acc 0.889275


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 4.756384441861883e-07 Acc 1.0 ; Val Loss 1.2543609142303467 Acc 0.889775


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 8.915904459172452e-07 Acc 1.0 ; Val Loss 1.2715263366699219 Acc 0.889475


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 7.745495622657472e-07 Acc 1.0 ; Val Loss 1.2686781883239746 Acc 0.89


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.1117920166725526e-06 Acc 1.0 ; Val Loss 1.2699624300003052 Acc 0.890725


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 9.939565188687993e-07 Acc 1.0 ; Val Loss 1.2744115591049194 Acc 0.890225


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 5.805557066196343e-07 Acc 1.0 ; Val Loss 1.2569704055786133 Acc 0.89005


  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 5.731528176511347e-07 Acc 1.0 ; Val Loss 1.2755249738693237 Acc 0.890125


In [None]:
class RotNet_linear_eval(nn.Module):
    def __init__(self, encoder):
        super(RotNet_linear_eval, self).__init__()
        self.encoder = encoder
        self.linear = nn.Linear(512, 10, bias=False)
        self.relu = nn.ReLU()

        for param in self.encoder.parameters():
            param.requires_grad = False

    def forward(self, x):
        out = self.encoder(x)
        out = self.linear(out)
        return out

In [None]:
BATCH_SIZE = 192
EPOCHS = 50
DECAY = 5e-4
MOMENTUM = 0.9
LR = 0.01

In [None]:

train_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
rot_model = RotNet()
rot_model.load_state_dict(torch.load(pretrained_path))
rot_model.cuda()

rot_linear_eval_model = RotNet_linear_eval(rot_model.encoder)
#rot_linear_eval_model.load_state_dict(torch.load("models/rot_model_semi_sup_30_128_0.001_0.1.pth"))
rot_linear_eval_model.cuda()

#optimizer = torch.optim.SGD(rot_linear_eval_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=DECAY)
optimizer = torch.optim.Adam(rot_linear_eval_model.parameters(), lr=LR)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=0.2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)

best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_corrects = 0
    rot_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        rot_linear_eval_model.zero_grad()
        out = rot_linear_eval_model(image)

        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()

    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)


    with torch.no_grad():
        rot_linear_eval_model.eval()

        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_linear_eval_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_linear_eval_model.state_dict(), f'rot_model_resnet18_linear_eval_{EPOCHS}_{BATCH_SIZE}_{LR}.pth')

    scheduler.step()
    print(f'Epochs {epoch_idx} Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')

In [None]:

BATCH_SIZE = 128
EPOCHS = 30
#DECAY = 5e-4
#MOMENTUM = 0.9
LR = 1e-2
label_percentage = 0.01

In [None]:
class RotNet_linear_eval(nn.Module):
    def __init__(self, encoder):
        super(RotNet_linear_eval, self).__init__()
        self.encoder = encoder
        self.linear = nn.Linear(512, 10, bias=False)

    def forward(self, x):
        out = self.encoder(x)
        out = self.linear(out)
        return out


In [None]:
train_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainset, _ = torch.utils.data.random_split(trainset, [int(len(trainset)*label_percentage), int(len(trainset)*(1-label_percentage))])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:

from torch.optim import optimizer
rot_model = RotNet()
rot_model.load_state_dict(torch.load("rotnet_base_100_128_0.005.pth"))
rot_model.cuda()

rot_linear_eval_model = RotNet_linear_eval(rot_model.encoder)
rot_linear_eval_model.cuda()

#optimizer = torch.optim.SGD(rot_linear_eval_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=DECAY)
optimizer = torch.optim.Adam(rot_linear_eval_model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)

best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_corrects = 0
    rot_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        rot_linear_eval_model.zero_grad()
        out = rot_linear_eval_model(image)

        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()

    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)


    with torch.no_grad():
        rot_linear_eval_model.eval()

        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_linear_eval_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_linear_eval_model.state_dict(), f'rot_model_res18_semi_sup_{EPOCHS}_{BATCH_SIZE}_{LR}_{label_percentage}.pth')

    scheduler.step()
    print(f'Epoch {epoch_idx} Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 0 Train Loss 2.1365230083465576 Acc 0.284 ; Val Loss 1.9067717790603638 Acc 0.2948


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1 Train Loss 1.4808323383331299 Acc 0.49 ; Val Loss 1.9526225328445435 Acc 0.3883


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2 Train Loss 0.9434902667999268 Acc 0.716 ; Val Loss 3.109192371368408 Acc 0.3322


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3 Train Loss 0.5040497779846191 Acc 0.856 ; Val Loss 1.4541062116622925 Acc 0.5183


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4 Train Loss 0.2259141504764557 Acc 0.944 ; Val Loss 1.3188802003860474 Acc 0.5728


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5 Train Loss 0.07310531288385391 Acc 0.99 ; Val Loss 1.6503267288208008 Acc 0.5506


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 6 Train Loss 0.021570317447185516 Acc 0.996 ; Val Loss 2.131568431854248 Acc 0.5104


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 7 Train Loss 0.01582879014313221 Acc 0.998 ; Val Loss 2.015564441680908 Acc 0.552


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 8 Train Loss 0.0023953027557581663 Acc 1.0 ; Val Loss 2.442718267440796 Acc 0.5409


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 9 Train Loss 0.009621819481253624 Acc 0.998 ; Val Loss 2.6158292293548584 Acc 0.5374


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 10 Train Loss 0.0022090664133429527 Acc 1.0 ; Val Loss 2.348290205001831 Acc 0.5578


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 11 Train Loss 0.0034959299955517054 Acc 1.0 ; Val Loss 2.220665454864502 Acc 0.5698


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 12 Train Loss 0.003858506213873625 Acc 0.998 ; Val Loss 2.1583967208862305 Acc 0.576


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 13 Train Loss 0.003765780245885253 Acc 0.998 ; Val Loss 2.1241824626922607 Acc 0.5806


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 14 Train Loss 0.0010992534225806594 Acc 1.0 ; Val Loss 2.091736316680908 Acc 0.5848


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 15 Train Loss 0.0005529229529201984 Acc 1.0 ; Val Loss 2.078033208847046 Acc 0.5859


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 16 Train Loss 0.0007044717785902321 Acc 1.0 ; Val Loss 2.0692198276519775 Acc 0.5868


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 17 Train Loss 0.0008141279104165733 Acc 1.0 ; Val Loss 2.072218894958496 Acc 0.5873


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 18 Train Loss 0.0007379016024060547 Acc 1.0 ; Val Loss 2.0736682415008545 Acc 0.587


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 19 Train Loss 0.0004705955507233739 Acc 1.0 ; Val Loss 2.0723483562469482 Acc 0.5873


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 20 Train Loss 0.00044221317511983216 Acc 1.0 ; Val Loss 2.077167272567749 Acc 0.5864


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 21 Train Loss 0.0005539213307201862 Acc 1.0 ; Val Loss 2.080005407333374 Acc 0.5863


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 22 Train Loss 0.0004206044541206211 Acc 1.0 ; Val Loss 2.081411600112915 Acc 0.5866


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 23 Train Loss 0.0007100485963746905 Acc 1.0 ; Val Loss 2.088256359100342 Acc 0.5853


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 24 Train Loss 0.001215158263221383 Acc 1.0 ; Val Loss 2.089681625366211 Acc 0.5862


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 25 Train Loss 0.00038966041756793857 Acc 1.0 ; Val Loss 2.081785202026367 Acc 0.5866


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 26 Train Loss 0.0004500274662859738 Acc 1.0 ; Val Loss 2.080951690673828 Acc 0.5863


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 27 Train Loss 0.0005069964681752026 Acc 1.0 ; Val Loss 2.081984043121338 Acc 0.5859


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 28 Train Loss 0.00038126902654767036 Acc 1.0 ; Val Loss 2.0772244930267334 Acc 0.5865


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 29 Train Loss 0.0007211256888695061 Acc 1.0 ; Val Loss 2.0805137157440186 Acc 0.5858


In [None]:
BATCH_SIZE = 128
EPOCHS = 30
#DECAY = 5e-4
#MOMENTUM = 0.9
LR = 1e-2
label_percentage = 0.1

In [None]:
from torch.optim import optimizer
rot_model = RotNet()
rot_model.load_state_dict(torch.load("rotnet_base_100_128_0.005.pth"))
rot_model.cuda()

rot_linear_eval_model = RotNet_linear_eval(rot_model.encoder)
rot_linear_eval_model.cuda()

#optimizer = torch.optim.SGD(rot_linear_eval_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=DECAY)
optimizer = torch.optim.Adam(rot_linear_eval_model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)

best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_corrects = 0
    rot_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        rot_linear_eval_model.zero_grad()
        out = rot_linear_eval_model(image)

        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()

    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)


    with torch.no_grad():
        rot_linear_eval_model.eval()

        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_linear_eval_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_linear_eval_model.state_dict(), f'rot_model_res18_semi_sup_{EPOCHS}_{BATCH_SIZE}_{LR}_{label_percentage}.pth')

    scheduler.step()
    print(f'Epoch {epoch_idx} Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 0 Train Loss 1.2371128797531128 Acc 0.5592 ; Val Loss 1.348272442817688 Acc 0.5908


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1 Train Loss 0.7078449130058289 Acc 0.7642 ; Val Loss 1.348237156867981 Acc 0.5962


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2 Train Loss 0.4512538015842438 Acc 0.848 ; Val Loss 0.9140437841415405 Acc 0.7097


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3 Train Loss 0.23525837063789368 Acc 0.925 ; Val Loss 1.3291115760803223 Acc 0.6536


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4 Train Loss 0.277413547039032 Acc 0.9162 ; Val Loss 1.3018299341201782 Acc 0.667


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5 Train Loss 0.18677911162376404 Acc 0.9364 ; Val Loss 1.0760501623153687 Acc 0.7178


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 6 Train Loss 0.0497705303132534 Acc 0.9864 ; Val Loss 1.0958517789840698 Acc 0.7334


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 7 Train Loss 0.028035784140229225 Acc 0.9932 ; Val Loss 1.1937294006347656 Acc 0.7314


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 8 Train Loss 0.05322573333978653 Acc 0.9836 ; Val Loss 1.609192132949829 Acc 0.6872


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 9 Train Loss 0.07457903027534485 Acc 0.983 ; Val Loss 1.4248936176300049 Acc 0.7024


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 10 Train Loss 0.035490844398736954 Acc 0.9882 ; Val Loss 1.2326481342315674 Acc 0.7318


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 11 Train Loss 0.009438705630600452 Acc 0.9992 ; Val Loss 1.1924899816513062 Acc 0.7368


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 12 Train Loss 0.00504964217543602 Acc 0.9996 ; Val Loss 1.168850064277649 Acc 0.7406


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 13 Train Loss 0.007337162736803293 Acc 0.9994 ; Val Loss 1.1663818359375 Acc 0.7435


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 14 Train Loss 0.004658879246562719 Acc 0.9996 ; Val Loss 1.2158180475234985 Acc 0.7372


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 15 Train Loss 0.005783163011074066 Acc 0.9998 ; Val Loss 1.206780195236206 Acc 0.7434


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 16 Train Loss 0.010339386761188507 Acc 0.9998 ; Val Loss 1.2054681777954102 Acc 0.7398


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 17 Train Loss 0.009767058305442333 Acc 0.9984 ; Val Loss 1.264399528503418 Acc 0.7351


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 18 Train Loss 0.01149457972496748 Acc 0.9996 ; Val Loss 1.2150577306747437 Acc 0.7428


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 19 Train Loss 0.027779901400208473 Acc 0.9974 ; Val Loss 1.2063475847244263 Acc 0.7384


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 20 Train Loss 0.0054213725961744785 Acc 1.0 ; Val Loss 1.2082794904708862 Acc 0.7402


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 21 Train Loss 0.005718865431845188 Acc 1.0 ; Val Loss 1.1930962800979614 Acc 0.7407


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 22 Train Loss 0.00862752553075552 Acc 0.9998 ; Val Loss 1.231791377067566 Acc 0.7374


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 23 Train Loss 0.0034207389689981937 Acc 0.9998 ; Val Loss 1.210344910621643 Acc 0.7408


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 24 Train Loss 0.0021982931066304445 Acc 1.0 ; Val Loss 1.1900681257247925 Acc 0.7415


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 25 Train Loss 0.0017340698977932334 Acc 1.0 ; Val Loss 1.2162052392959595 Acc 0.7426


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 26 Train Loss 0.025989819318056107 Acc 0.9996 ; Val Loss 1.2408090829849243 Acc 0.7414


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 27 Train Loss 0.0018684437964111567 Acc 1.0 ; Val Loss 1.1861088275909424 Acc 0.7431


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 28 Train Loss 0.006026883609592915 Acc 0.9998 ; Val Loss 1.2003147602081299 Acc 0.7424


  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 29 Train Loss 0.0018757205689325929 Acc 1.0 ; Val Loss 1.1988903284072876 Acc 0.7448


In [None]:
color_jitter = torchvision.transforms.ColorJitter(
        0.4, 0.4, 0.4, 0.1
    )
train_transform = torchvision.transforms.Compose(
    [
        transforms.RandomResizedCrop(size=(32, 32)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([color_jitter], p=0.8),
        torchvision.transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
    ]
)

test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)

In [None]:
trainset = torchvision.datasets.CIFAR10(root='/usr/xtmp/zg78/cifar10_data/', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='/usr/xtmp/zg78/cifar10_data/', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /usr/xtmp/zg78/cifar10_data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12966949.71it/s]


Extracting /usr/xtmp/zg78/cifar10_data/cifar-10-python.tar.gz to /usr/xtmp/zg78/cifar10_data/
Files already downloaded and verified


In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
net = torchvision.models.resnet18().cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_correts = 0
    net.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        net.zero_grad()
        out = net(image)

        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        epoch_correts += torch.sum(pred == label).item()

    epoch_losses /= len(trainloader)
    epoch_correts /= len(trainset)


    with torch.no_grad():
        net.eval()

        val_epoch_losses = 0
        val_epoch_correts = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = net(image)

            loss = criterion(out, label)

            val_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            val_epoch_correts += torch.sum(pred == label).item()

        val_epoch_losses /= len(testloader)
        val_epoch_correts /= len(testset)

        if val_epoch_losses < best_loss:
            best_loss = val_epoch_losses
            torch.save(net.state_dict(), f'resent18_baseline_{EPOCHS}_{BATCH_SIZE}_{LR}.pth')

    print(f'Train Loss {epoch_losses} Acc {epoch_correts} ; Val Loss {val_epoch_losses} Acc {val_epoch_correts}')

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 2.0843257904052734 Acc 0.26268 ; Val Loss 1.6110310554504395 Acc 0.4182


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.8435710668563843 Acc 0.33306 ; Val Loss 1.8581174612045288 Acc 0.3593


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.7347240447998047 Acc 0.37446 ; Val Loss 1.5954996347427368 Acc 0.4384


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.650755524635315 Acc 0.41008 ; Val Loss 1.787009358406067 Acc 0.3818


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.5829932689666748 Acc 0.4361 ; Val Loss 1.2467713356018066 Acc 0.5559


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.5237282514572144 Acc 0.45768 ; Val Loss 1.1679574251174927 Acc 0.5958


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.4709285497665405 Acc 0.4775 ; Val Loss 1.2659306526184082 Acc 0.5655


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.4410516023635864 Acc 0.48814 ; Val Loss 1.2157931327819824 Acc 0.5765


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.3962565660476685 Acc 0.50584 ; Val Loss 1.1883560419082642 Acc 0.5824


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.3647602796554565 Acc 0.51698 ; Val Loss 1.0792388916015625 Acc 0.6327


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.332448959350586 Acc 0.53232 ; Val Loss 1.558974027633667 Acc 0.5321


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.3152536153793335 Acc 0.54068 ; Val Loss 1.0293368101119995 Acc 0.6557


  0%|          | 0/391 [00:00<?, ?it/s]

KeyboardInterrupt: ignored