In [1]:
import torchvision
from torchvision import models
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

import os, subprocess
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

In [2]:
data_dir = "./data"
dir_lst = os.listdir(data_dir)
for dir in dir_lst:
    dir_path = os.listdir(f"./data/{dir}")

    for name in dir_path:
        if name == ".ipynb_checkpoints":
            print(name)
            os.rmdir(f"./data/{dir}/{name}")

In [3]:
IMG_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512,512), antialias = True),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
    ])

In [4]:
train_imgfolder = ImageFolder(root="./data/train",
                              transform=IMG_TRANSFORM,
                              target_transform=None
                             )

test_imgfolder = ImageFolder(root="./data/test",
                              transform=IMG_TRANSFORM
                             )

In [5]:
Image.MAX_IMAGE_PIXELS = None

train_dataloader = DataLoader(dataset=train_imgfolder,
							  batch_size = 32,
                              num_workers=os.cpu_count(),
                              shuffle=True,
                              drop_last=True 
                             )

test_dataloader = DataLoader(dataset=test_imgfolder,
							  batch_size = 32,
                              num_workers=os.cpu_count(),
                              shuffle=True,
                              drop_last=True 
                            )

In [6]:
import torch
import torch.nn as nn

from tqdm import tqdm

pre_trained_model = models.vgg16(pretrained=True)




In [7]:
pre_trained_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [8]:
pre_trained_model.classifier[6].out_features=20
pre_trained_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [9]:
# model setting
model = pre_trained_model

EPOCH = 50
LEARNING_RATE = 1e-3 
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = 0.01)
criterion = nn.CrossEntropyLoss()
print("Using device: ",DEVICE)

Using device:  cuda:0


In [10]:
def loss_fn(class_output, label, criterion):
    CE_loss = criterion(class_output, label)

    return CE_loss 

In [11]:
def train(train_loader, model, optimizer, criterion, epoch):
    model.train()
    correct = 0
    total = 0
    
    epoch_loss = []
    for index, (img_data,labels) in enumerate(train_loader):
        img_data, labels = img_data.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        
        outputs = model(img_data)
        loss = loss_fn(outputs, labels, criterion)
        
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
        
        loss.backward()
        optimizer.step()

        epoch_loss.append(loss.item())
        if index%100==0:
            print(f'train loss : {np.mean(epoch_loss):>7f}, [epoch:{epoch}, iter:{index}]')
            
    accuracy = correct / total
    print(f"TRAIN ACCURACY : {(100*accuracy):>0.1f}% [{correct}/{total}]")
    return np.mean(epoch_loss)

In [12]:
def test(test_loader, model, criterion, epoch):
    model.eval()

    epoch_loss = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        for index, (img_data,labels) in enumerate(test_loader):
            img_data, labels = img_data.to(DEVICE), labels.to(DEVICE)
            outputs = model(img_data)
    
            loss = loss_fn(outputs, labels, criterion)

            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            
            epoch_loss.append(loss.item())
    
    accuracy = correct / total
    
    print(f"TEST ACCURACY : {(100*accuracy):>0.1f}% [{correct}/{total}]")
    return np.mean(epoch_loss)

In [13]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
train_loss_lst, test_loss_lst = [], []

model_version = 1


test_best_loss = 100
test_current_loss = 100
early_stop_threshold = 5
early_stop_trigger = 0

model = model.to(DEVICE)

for i in tqdm(range(EPOCH), desc = 'Train'):
    print(f"EPOCH {i+1} \n-------------------------")
    train_loss = train(train_dataloader, model, optimizer, criterion, i+1)
    test_loss = test(test_dataloader, model, criterion, i+1)

    train_loss_lst.append(train_loss)
    test_loss_lst.append(test_loss)

    if test_loss < test_best_loss :
        print("...MODEL SAVE...")
        test_best_loss = test_loss 
        
        if model_version == 0:
            torch.save(model.state_dict(),'contrastive_model_for_style_weight_v2_without_target_style.pth')
        else:
            torch.save(model.state_dict(),'contrastive_model_for_style_weight_vgg.pth')
        
    if test_current_loss < test_loss:
        early_stop_trigger += 1 
    else:
        early_stop_trigger = 0 
    test_current_loss  = test_loss 

    print(f'\nEPOCH:{i+1} | train loss : {train_loss}, test loss : {test_loss}\n')
    
    if early_stop_trigger >= early_stop_threshold:
        break
        
print("DONE!")

Train:   0%|                                       | 0/50 [00:00<?, ?it/s]

EPOCH 1 
-------------------------
train loss : 9.335195, [epoch:1, iter:0]
train loss : 3.616809, [epoch:1, iter:100]
train loss : 3.269554, [epoch:1, iter:200]
train loss : 3.123299, [epoch:1, iter:300]
train loss : 3.031203, [epoch:1, iter:400]
train loss : 2.974612, [epoch:1, iter:500]




train loss : 2.935123, [epoch:1, iter:600]
train loss : 2.904851, [epoch:1, iter:700]
train loss : 2.876264, [epoch:1, iter:800]
train loss : 2.859765, [epoch:1, iter:900]
train loss : 2.838943, [epoch:1, iter:1000]
train loss : 2.824859, [epoch:1, iter:1100]
train loss : 2.809382, [epoch:1, iter:1200]
train loss : 2.796275, [epoch:1, iter:1300]
train loss : 2.785020, [epoch:1, iter:1400]
train loss : 2.774746, [epoch:1, iter:1500]
train loss : 2.764865, [epoch:1, iter:1600]




train loss : 2.755371, [epoch:1, iter:1700]
train loss : 2.748129, [epoch:1, iter:1800]
train loss : 2.741427, [epoch:1, iter:1900]
TRAIN ACCURACY : 16.0% [9879/61664]
TEST ACCURACY : 18.7% [2179/11680]
...MODEL SAVE...


Train:   2%|▌                         | 1/50 [27:09<22:10:50, 1629.60s/it]


EPOCH:1 | train loss : 2.7387776316666024, test loss : 2.5965626644761595

EPOCH 2 
-------------------------
train loss : 2.649712, [epoch:2, iter:0]
train loss : 2.598808, [epoch:2, iter:100]
train loss : 2.593772, [epoch:2, iter:200]
train loss : 2.595323, [epoch:2, iter:300]
train loss : 2.589351, [epoch:2, iter:400]
train loss : 2.583602, [epoch:2, iter:500]




train loss : 2.575858, [epoch:2, iter:600]
train loss : 2.572456, [epoch:2, iter:700]
train loss : 2.572790, [epoch:2, iter:800]
train loss : 2.567370, [epoch:2, iter:900]
train loss : 2.563753, [epoch:2, iter:1000]
train loss : 2.557678, [epoch:2, iter:1100]
train loss : 2.555543, [epoch:2, iter:1200]
train loss : 2.556446, [epoch:2, iter:1300]
train loss : 2.552682, [epoch:2, iter:1400]
train loss : 2.549940, [epoch:2, iter:1500]




train loss : 2.547929, [epoch:2, iter:1600]
train loss : 2.547897, [epoch:2, iter:1700]
train loss : 2.545545, [epoch:2, iter:1800]
train loss : 2.542919, [epoch:2, iter:1900]
TRAIN ACCURACY : 20.2% [12478/61664]
TEST ACCURACY : 16.7% [1956/11680]
...MODEL SAVE...


Train:   4%|█                         | 2/50 [54:28<21:48:05, 1635.11s/it]


EPOCH:2 | train loss : 2.541893950199054, test loss : 2.525979761881371

EPOCH 3 
-------------------------
train loss : 2.290074, [epoch:3, iter:0]
train loss : 2.477750, [epoch:3, iter:100]
train loss : 2.486967, [epoch:3, iter:200]
train loss : 2.490834, [epoch:3, iter:300]
train loss : 2.496138, [epoch:3, iter:400]
train loss : 2.490900, [epoch:3, iter:500]
train loss : 2.489587, [epoch:3, iter:600]
train loss : 2.485671, [epoch:3, iter:700]
train loss : 2.483252, [epoch:3, iter:800]




train loss : 2.485235, [epoch:3, iter:900]
train loss : 2.484604, [epoch:3, iter:1000]
train loss : 2.480889, [epoch:3, iter:1100]
train loss : 2.481131, [epoch:3, iter:1200]
train loss : 2.479177, [epoch:3, iter:1300]
train loss : 2.476302, [epoch:3, iter:1400]
train loss : 2.473942, [epoch:3, iter:1500]
train loss : 2.475841, [epoch:3, iter:1600]
train loss : 2.473823, [epoch:3, iter:1700]




train loss : 2.473737, [epoch:3, iter:1800]
train loss : 2.472976, [epoch:3, iter:1900]
TRAIN ACCURACY : 22.7% [13983/61664]
TEST ACCURACY : 17.8% [2079/11680]
...MODEL SAVE...


Train:   6%|█▍                      | 3/50 [1:21:33<21:17:01, 1630.24s/it]


EPOCH:3 | train loss : 2.4716905505779683, test loss : 2.513133588882342

EPOCH 4 
-------------------------
train loss : 2.566462, [epoch:4, iter:0]




train loss : 2.473447, [epoch:4, iter:100]
train loss : 2.472657, [epoch:4, iter:200]
train loss : 2.460264, [epoch:4, iter:300]
train loss : 2.453531, [epoch:4, iter:400]
train loss : 2.453367, [epoch:4, iter:500]
train loss : 2.453296, [epoch:4, iter:600]
train loss : 2.452635, [epoch:4, iter:700]




train loss : 2.450531, [epoch:4, iter:800]
train loss : 2.452687, [epoch:4, iter:900]
train loss : 2.451036, [epoch:4, iter:1000]
train loss : 2.445256, [epoch:4, iter:1100]
train loss : 2.444453, [epoch:4, iter:1200]
train loss : 2.446622, [epoch:4, iter:1300]
train loss : 2.447741, [epoch:4, iter:1400]
train loss : 2.447864, [epoch:4, iter:1500]
train loss : 2.447858, [epoch:4, iter:1600]
train loss : 2.445645, [epoch:4, iter:1700]
train loss : 2.447052, [epoch:4, iter:1800]
train loss : 2.447453, [epoch:4, iter:1900]
TRAIN ACCURACY : 23.3% [14386/61664]
TEST ACCURACY : 19.4% [2267/11680]
...MODEL SAVE...


Train:   8%|█▉                      | 4/50 [1:48:34<20:47:03, 1626.61s/it]


EPOCH:4 | train loss : 2.4483495372901882, test loss : 2.504375258210587

EPOCH 5 
-------------------------
train loss : 2.602230, [epoch:5, iter:0]
train loss : 2.435166, [epoch:5, iter:100]




train loss : 2.421673, [epoch:5, iter:200]
train loss : 2.432977, [epoch:5, iter:300]
train loss : 2.438147, [epoch:5, iter:400]
train loss : 2.434146, [epoch:5, iter:500]
train loss : 2.440158, [epoch:5, iter:600]
train loss : 2.441230, [epoch:5, iter:700]
train loss : 2.439090, [epoch:5, iter:800]
train loss : 2.438056, [epoch:5, iter:900]
train loss : 2.440694, [epoch:5, iter:1000]
train loss : 2.438369, [epoch:5, iter:1100]




train loss : 2.434948, [epoch:5, iter:1200]
train loss : 2.436320, [epoch:5, iter:1300]
train loss : 2.434322, [epoch:5, iter:1400]
train loss : 2.435652, [epoch:5, iter:1500]
train loss : 2.435426, [epoch:5, iter:1600]
train loss : 2.436043, [epoch:5, iter:1700]
train loss : 2.436678, [epoch:5, iter:1800]
train loss : 2.437255, [epoch:5, iter:1900]
TRAIN ACCURACY : 23.8% [14692/61664]
TEST ACCURACY : 21.9% [2558/11680]
...MODEL SAVE...


Train:  10%|██▍                     | 5/50 [2:15:16<20:13:18, 1617.74s/it]


EPOCH:5 | train loss : 2.4370303459593097, test loss : 2.498797262531437

EPOCH 6 
-------------------------
train loss : 2.285460, [epoch:6, iter:0]
train loss : 2.426437, [epoch:6, iter:100]




train loss : 2.417673, [epoch:6, iter:200]
train loss : 2.435453, [epoch:6, iter:300]
train loss : 2.436734, [epoch:6, iter:400]
train loss : 2.434243, [epoch:6, iter:500]
train loss : 2.437068, [epoch:6, iter:600]
train loss : 2.438974, [epoch:6, iter:700]
train loss : 2.438088, [epoch:6, iter:800]
train loss : 2.435297, [epoch:6, iter:900]
train loss : 2.435025, [epoch:6, iter:1000]
train loss : 2.432050, [epoch:6, iter:1100]
train loss : 2.431620, [epoch:6, iter:1200]




train loss : 2.431630, [epoch:6, iter:1300]
train loss : 2.430251, [epoch:6, iter:1400]
train loss : 2.426921, [epoch:6, iter:1500]
train loss : 2.428697, [epoch:6, iter:1600]
train loss : 2.427318, [epoch:6, iter:1700]
train loss : 2.427418, [epoch:6, iter:1800]
train loss : 2.426499, [epoch:6, iter:1900]
TRAIN ACCURACY : 23.9% [14739/61664]


Train:  12%|██▉                     | 6/50 [2:41:59<19:42:51, 1612.98s/it]

TEST ACCURACY : 17.9% [2087/11680]

EPOCH:6 | train loss : 2.427998814362594, test loss : 2.548241824319918

EPOCH 7 
-------------------------
train loss : 2.447647, [epoch:7, iter:0]
train loss : 2.434177, [epoch:7, iter:100]
train loss : 2.429394, [epoch:7, iter:200]




train loss : 2.413185, [epoch:7, iter:300]




train loss : 2.422419, [epoch:7, iter:400]
train loss : 2.419181, [epoch:7, iter:500]
train loss : 2.417802, [epoch:7, iter:600]
train loss : 2.416406, [epoch:7, iter:700]
train loss : 2.417007, [epoch:7, iter:800]
train loss : 2.417133, [epoch:7, iter:900]
train loss : 2.414654, [epoch:7, iter:1000]
train loss : 2.415371, [epoch:7, iter:1100]
train loss : 2.417462, [epoch:7, iter:1200]
train loss : 2.417953, [epoch:7, iter:1300]
train loss : 2.420403, [epoch:7, iter:1400]
train loss : 2.425523, [epoch:7, iter:1500]
train loss : 2.426542, [epoch:7, iter:1600]
train loss : 2.426111, [epoch:7, iter:1700]
train loss : 2.425564, [epoch:7, iter:1800]
train loss : 2.425829, [epoch:7, iter:1900]
TRAIN ACCURACY : 24.1% [14857/61664]


Train:  14%|███▎                    | 7/50 [3:08:39<19:12:54, 1608.70s/it]

TEST ACCURACY : 18.3% [2137/11680]

EPOCH:7 | train loss : 2.4253127123942075, test loss : 2.5023584133958163

EPOCH 8 
-------------------------
train loss : 2.316508, [epoch:8, iter:0]
train loss : 2.384916, [epoch:8, iter:100]
train loss : 2.417537, [epoch:8, iter:200]
train loss : 2.419526, [epoch:8, iter:300]
train loss : 2.419771, [epoch:8, iter:400]
train loss : 2.417613, [epoch:8, iter:500]
train loss : 2.418323, [epoch:8, iter:600]
train loss : 2.415319, [epoch:8, iter:700]
train loss : 2.419961, [epoch:8, iter:800]
train loss : 2.424488, [epoch:8, iter:900]




train loss : 2.419894, [epoch:8, iter:1000]
train loss : 2.420206, [epoch:8, iter:1100]
train loss : 2.421542, [epoch:8, iter:1200]
train loss : 2.421953, [epoch:8, iter:1300]
train loss : 2.419054, [epoch:8, iter:1400]
train loss : 2.419771, [epoch:8, iter:1500]
train loss : 2.420249, [epoch:8, iter:1600]
train loss : 2.419430, [epoch:8, iter:1700]




train loss : 2.418988, [epoch:8, iter:1800]
train loss : 2.418718, [epoch:8, iter:1900]
TRAIN ACCURACY : 24.2% [14928/61664]
TEST ACCURACY : 16.8% [1957/11680]
...MODEL SAVE...


Train:  16%|███▊                    | 8/50 [3:35:44<18:49:35, 1613.70s/it]


EPOCH:8 | train loss : 2.4184123341382695, test loss : 2.472295115092029

EPOCH 9 
-------------------------
train loss : 2.796915, [epoch:9, iter:0]
train loss : 2.396260, [epoch:9, iter:100]
train loss : 2.397210, [epoch:9, iter:200]
train loss : 2.399718, [epoch:9, iter:300]
train loss : 2.399959, [epoch:9, iter:400]
train loss : 2.397930, [epoch:9, iter:500]
train loss : 2.404484, [epoch:9, iter:600]
train loss : 2.407096, [epoch:9, iter:700]
train loss : 2.406426, [epoch:9, iter:800]
train loss : 2.406417, [epoch:9, iter:900]
train loss : 2.409799, [epoch:9, iter:1000]




train loss : 2.412263, [epoch:9, iter:1100]
train loss : 2.415342, [epoch:9, iter:1200]




train loss : 2.417424, [epoch:9, iter:1300]
train loss : 2.415938, [epoch:9, iter:1400]
train loss : 2.415784, [epoch:9, iter:1500]
