In [1]:
import torch
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1. Data

In [2]:
from tools.Data import muffin_chihuahua
from torchvision import transforms
from torch.utils.data import DataLoader

In [3]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
])
BATCH_SIZE = 128

train_data = muffin_chihuahua(root='data/train', transform=transform)
train_loader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle=True)


test_data = muffin_chihuahua(root='data/test')
test_loader = DataLoader(test_data, batch_size = BATCH_SIZE, shuffle=True)


muffin: 2174
chihuahua: 2559
muffin: 544
chihuahua: 640


# 2. Resnet18

In [10]:
from tools.Model import resnet18
import torch.nn as nn

In [12]:
model = resnet18().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
EPOCH = 10

for epoch in range(EPOCH):
    print(f'Epoch {epoch+1}')
    model.train()
    train_total = 0
    for i, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_total += loss.item()
    print(f'Train_Loss {train_total/(i+1)}')
    
    model.eval()
    max_acc = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Accuracy {correct/total}')
    if correct/total > max_acc:
        max_acc = correct/total
        torch.save(model.state_dict(), 'resnet.pth')
    

Epoch 1


37it [00:43,  1.17s/it]


Train_Loss 0.6621213603664089
Accuracy 0.5388513513513513
Epoch 2


37it [00:44,  1.21s/it]


Train_Loss 0.5113930347803477
Accuracy 0.5447635135135135
Epoch 3


37it [00:44,  1.19s/it]


Train_Loss 0.40652468397810654
Accuracy 0.5413851351351351
Epoch 4


37it [00:43,  1.19s/it]


Train_Loss 0.3467430222678829
Accuracy 0.5405405405405406
Epoch 5


37it [00:45,  1.23s/it]


Train_Loss 0.2954799315413913
Accuracy 0.5413851351351351
Epoch 6


37it [00:44,  1.21s/it]


Train_Loss 0.26432858246403773
Accuracy 0.5405405405405406
Epoch 7


37it [00:44,  1.20s/it]


Train_Loss 0.2383278479447236
Accuracy 0.5413851351351351
Epoch 8


37it [00:44,  1.20s/it]


Train_Loss 0.21664690810280876
Accuracy 0.5405405405405406
Epoch 9


37it [00:44,  1.21s/it]


Train_Loss 0.20302298342859423
Accuracy 0.5413851351351351
Epoch 10


37it [00:45,  1.22s/it]


Train_Loss 0.1910324966585314
Accuracy 0.5405405405405406


# 3. transfomer

In [8]:
from tools.Model import vit_base

In [9]:
model = vit_base().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
EPOCH = 10

for epoch in range(EPOCH):
    print(f'Epoch {epoch+1}')
    model.train()
    train_total = 0
    for i, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_total += loss.item()
    print(f'Train_Loss {train_total/(i+1)}')
    
    model.eval()
    max_acc = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Accuracy {correct/total}')
    if correct/total > max_acc:
        max_acc = correct/total
        torch.save(model.state_dict(), 'vit.pth')

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Epoch 1


37it [00:59,  1.61s/it]


Train_Loss 0.1105087737275942
Accuracy 0.9864864864864865
Epoch 2


37it [00:59,  1.61s/it]


Train_Loss 0.022877235418638668
Accuracy 0.9907094594594594
Epoch 3


37it [00:59,  1.61s/it]


Train_Loss 0.020039074185832933
Accuracy 0.9923986486486487
Epoch 4


37it [00:59,  1.61s/it]


Train_Loss 0.015588700444110343
Accuracy 0.9932432432432432
Epoch 5


37it [00:59,  1.61s/it]


Train_Loss 0.01386269909405225
Accuracy 0.9949324324324325
Epoch 6


37it [00:59,  1.61s/it]


Train_Loss 0.011699183422417657
Accuracy 0.995777027027027
Epoch 7


37it [00:59,  1.61s/it]


Train_Loss 0.011282906351560677
Accuracy 0.995777027027027
Epoch 8


37it [00:59,  1.61s/it]


Train_Loss 0.009559130768065114
Accuracy 0.995777027027027
Epoch 9


37it [00:59,  1.61s/it]


Train_Loss 0.008933989198979092
Accuracy 0.9949324324324325
Epoch 10


37it [00:59,  1.61s/it]


Train_Loss 0.009181340530237838
Accuracy 0.9949324324324325


# CLIP classification

In [13]:
import clip
import os
from PIL import Image

In [14]:
model, preprocess = clip.load("ViT-B/32", device=DEVICE)
classes = ["muffin", "chihuahua"]

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(DEVICE)
muffin_img = [os.path.join('data/test/muffin', i) for i in os.listdir('data/test/muffin')]
chihuahua_img = [os.path.join('data/test/chihuahua', i) for i in os.listdir('data/test/chihuahua')]
image_inputs = torch.stack([preprocess(Image.open(i)).to(DEVICE) for i in muffin_img + chihuahua_img])
label = [0]*len(muffin_img) + [1]*len(chihuahua_img)

with torch.no_grad():
    image_features = model.encode_image(image_inputs)
    text_features = model.encode_text(text_inputs)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    print(similarity)

    correct = 0
    for i in range(len(label)):
        if similarity[i][label[i]] > 0.5:
            correct += 1
    print(f'Accuracy {correct/len(label)}')



100%|███████████████████████████████████████| 338M/338M [00:03<00:00, 97.3MiB/s]


tensor([[1.0000e+00, 0.0000e+00],
        [7.0264e-01, 2.9736e-01],
        [5.7764e-01, 4.2261e-01],
        ...,
        [2.2697e-04, 1.0000e+00],
        [4.3983e-03, 9.9561e-01],
        [1.7679e-04, 1.0000e+00]], device='cuda:0', dtype=torch.float16)
Accuracy 0.9831081081081081
